Skip to content

Commit f92ed74

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add get_iam_policy and set_iam_policy methods to ReasoningEngine.
PiperOrigin-RevId: 894597178
1 parent 09794ba commit f92ed74

2 files changed

Lines changed: 89 additions & 0 deletions

File tree

tests/unit/vertex_langchain/test_reasoning_engines.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from vertexai.preview import reasoning_engines
4343
from vertexai.reasoning_engines import _reasoning_engines
4444
from vertexai.reasoning_engines import _utils
45+
from google.iam.v1 import policy_pb2
4546
from google.api import httpbody_pb2
4647
from google.protobuf import field_mask_pb2
4748
from google.protobuf import struct_pb2
@@ -794,6 +795,56 @@ def test_create_reasoning_engine(
794795
retry=_TEST_RETRY,
795796
)
796797

798+
def test_get_iam_policy(self):
799+
"""Tests that `get_iam_policy` method correctly calls the underlying API client.
800+
801+
It verifies that the `get_iam_policy` method is called with the expected
802+
resource name and returns the policy as provided by the mocked API client.
803+
"""
804+
with mock.patch.object(
805+
base.VertexAiResourceNoun, "_get_gca_resource"
806+
) as mock_get_gca_resource:
807+
mock_get_gca_resource.return_value = types.ReasoningEngine(
808+
name=_TEST_REASONING_ENGINE_RESOURCE_NAME
809+
)
810+
reasoning_engine = reasoning_engines.ReasoningEngine(
811+
_TEST_REASONING_ENGINE_RESOURCE_NAME
812+
)
813+
814+
test_policy = policy_pb2.Policy(version=1)
815+
with mock.patch.object(
816+
reasoning_engine.api_client, "get_iam_policy"
817+
) as mock_get_iam_policy:
818+
mock_get_iam_policy.return_value = test_policy
819+
policy = reasoning_engine.get_iam_policy(policy_version=1)
820+
mock_get_iam_policy.assert_called_once()
821+
assert policy == test_policy
822+
823+
def test_set_iam_policy(self):
824+
"""Tests that `set_iam_policy` method correctly calls the underlying API client.
825+
826+
It verifies that the `set_iam_policy` method is called with the expected
827+
policy and returns the policy as provided by the mocked API client.
828+
"""
829+
with mock.patch.object(
830+
base.VertexAiResourceNoun, "_get_gca_resource"
831+
) as mock_get_gca_resource:
832+
mock_get_gca_resource.return_value = types.ReasoningEngine(
833+
name=_TEST_REASONING_ENGINE_RESOURCE_NAME
834+
)
835+
reasoning_engine = reasoning_engines.ReasoningEngine(
836+
_TEST_REASONING_ENGINE_RESOURCE_NAME
837+
)
838+
839+
test_policy = policy_pb2.Policy(version=1)
840+
with mock.patch.object(
841+
reasoning_engine.api_client, "set_iam_policy"
842+
) as mock_set_iam_policy:
843+
mock_set_iam_policy.return_value = test_policy
844+
policy = reasoning_engine.set_iam_policy(test_policy)
845+
mock_set_iam_policy.assert_called_once()
846+
assert policy == test_policy
847+
797848
@pytest.mark.usefixtures("caplog")
798849
def test_create_reasoning_engine_warn_resource_name(
799850
self,

vertexai/reasoning_engines/_reasoning_engines.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@
4444
from google.cloud.aiplatform_v1beta1 import types as aip_types
4545
from google.cloud.aiplatform_v1beta1.types import reasoning_engine_service
4646
from vertexai.reasoning_engines import _utils
47+
from google.iam.v1 import iam_policy_pb2
48+
from google.iam.v1 import options_pb2
49+
from google.iam.v1 import policy_pb2
4750
from google.protobuf import field_mask_pb2
4851

4952

@@ -499,6 +502,41 @@ def operation_schemas(self) -> Sequence[_utils.JsonDict]:
499502
self._operation_schemas = spec.get("classMethods", [])
500503
return self._operation_schemas
501504

505+
def get_iam_policy(
506+
self, policy_version: Optional[int] = None
507+
) -> policy_pb2.Policy:
508+
"""Gets the access control policy for this ReasoningEngine.
509+
510+
Args:
511+
policy_version: Optional. The maximum policy version that will be used
512+
to format the policy. Valid values are 0, 1, 3.
513+
514+
Returns:
515+
The IAM policy.
516+
"""
517+
request = iam_policy_pb2.GetIamPolicyRequest(
518+
resource=self.resource_name,
519+
options=options_pb2.GetPolicyOptions(
520+
requested_policy_version=policy_version
521+
),
522+
)
523+
return self.api_client.get_iam_policy(request=request)
524+
525+
def set_iam_policy(self, policy: policy_pb2.Policy) -> policy_pb2.Policy:
526+
"""Sets the access control policy on this ReasoningEngine.
527+
528+
Args:
529+
policy: The complete policy to be applied to the resource.
530+
531+
Returns:
532+
The new IAM policy.
533+
"""
534+
request = iam_policy_pb2.SetIamPolicyRequest(
535+
resource=self.resource_name,
536+
policy=policy,
537+
)
538+
return self.api_client.set_iam_policy(request=request)
539+
502540

503541
def _validate_sys_version_or_raise(sys_version: str) -> None:
504542
"""Tries to validate the python system version."""

0 commit comments

Comments
 (0)