Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions tests/unit/vertex_langchain/test_reasoning_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from vertexai.preview import reasoning_engines
from vertexai.reasoning_engines import _reasoning_engines
from vertexai.reasoning_engines import _utils
from google.iam.v1 import policy_pb2
from google.api import httpbody_pb2
from google.protobuf import field_mask_pb2
from google.protobuf import struct_pb2
Expand Down Expand Up @@ -794,6 +795,56 @@ def test_create_reasoning_engine(
retry=_TEST_RETRY,
)

def test_get_iam_policy(self):
"""Tests that `get_iam_policy` method correctly calls the underlying API client.

It verifies that the `get_iam_policy` method is called with the expected
resource name and returns the policy as provided by the mocked API client.
"""
with mock.patch.object(
base.VertexAiResourceNoun, "_get_gca_resource"
) as mock_get_gca_resource:
mock_get_gca_resource.return_value = types.ReasoningEngine(
name=_TEST_REASONING_ENGINE_RESOURCE_NAME
)
reasoning_engine = reasoning_engines.ReasoningEngine(
_TEST_REASONING_ENGINE_RESOURCE_NAME
)

test_policy = policy_pb2.Policy(version=1)
with mock.patch.object(
reasoning_engine.api_client, "get_iam_policy"
) as mock_get_iam_policy:
mock_get_iam_policy.return_value = test_policy
policy = reasoning_engine.get_iam_policy(policy_version=1)
mock_get_iam_policy.assert_called_once()
assert policy == test_policy

def test_set_iam_policy(self):
"""Tests that `set_iam_policy` method correctly calls the underlying API client.

It verifies that the `set_iam_policy` method is called with the expected
policy and returns the policy as provided by the mocked API client.
"""
with mock.patch.object(
base.VertexAiResourceNoun, "_get_gca_resource"
) as mock_get_gca_resource:
mock_get_gca_resource.return_value = types.ReasoningEngine(
name=_TEST_REASONING_ENGINE_RESOURCE_NAME
)
reasoning_engine = reasoning_engines.ReasoningEngine(
_TEST_REASONING_ENGINE_RESOURCE_NAME
)

test_policy = policy_pb2.Policy(version=1)
with mock.patch.object(
reasoning_engine.api_client, "set_iam_policy"
) as mock_set_iam_policy:
mock_set_iam_policy.return_value = test_policy
policy = reasoning_engine.set_iam_policy(test_policy)
mock_set_iam_policy.assert_called_once()
assert policy == test_policy

@pytest.mark.usefixtures("caplog")
def test_create_reasoning_engine_warn_resource_name(
self,
Expand Down
38 changes: 38 additions & 0 deletions vertexai/reasoning_engines/_reasoning_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
from google.cloud.aiplatform_v1beta1 import types as aip_types
from google.cloud.aiplatform_v1beta1.types import reasoning_engine_service
from vertexai.reasoning_engines import _utils
from google.iam.v1 import iam_policy_pb2
from google.iam.v1 import options_pb2
from google.iam.v1 import policy_pb2
from google.protobuf import field_mask_pb2


Expand Down Expand Up @@ -499,6 +502,41 @@ def operation_schemas(self) -> Sequence[_utils.JsonDict]:
self._operation_schemas = spec.get("classMethods", [])
return self._operation_schemas

def get_iam_policy(
self, policy_version: Optional[int] = None
) -> policy_pb2.Policy:
"""Gets the access control policy for this ReasoningEngine.

Args:
policy_version: Optional. The maximum policy version that will be used
to format the policy. Valid values are 0, 1, 3.

Returns:
The IAM policy.
"""
request = iam_policy_pb2.GetIamPolicyRequest(
resource=self.resource_name,
options=options_pb2.GetPolicyOptions(
requested_policy_version=policy_version
),
)
return self.api_client.get_iam_policy(request=request)

def set_iam_policy(self, policy: policy_pb2.Policy) -> policy_pb2.Policy:
"""Sets the access control policy on this ReasoningEngine.

Args:
policy: The complete policy to be applied to the resource.

Returns:
The new IAM policy.
"""
request = iam_policy_pb2.SetIamPolicyRequest(
resource=self.resource_name,
policy=policy,
)
return self.api_client.set_iam_policy(request=request)


def _validate_sys_version_or_raise(sys_version: str) -> None:
"""Tries to validate the python system version."""
Expand Down
Loading