diff --git a/tests/unit/vertex_langchain/test_reasoning_engines.py b/tests/unit/vertex_langchain/test_reasoning_engines.py index 019dc214d9..d9e785b8d6 100644 --- a/tests/unit/vertex_langchain/test_reasoning_engines.py +++ b/tests/unit/vertex_langchain/test_reasoning_engines.py @@ -42,6 +42,7 @@ from vertexai.preview import reasoning_engines from vertexai.reasoning_engines import _reasoning_engines from vertexai.reasoning_engines import _utils +from google.iam.v1 import policy_pb2 from google.api import httpbody_pb2 from google.protobuf import field_mask_pb2 from google.protobuf import struct_pb2 @@ -794,6 +795,56 @@ def test_create_reasoning_engine( retry=_TEST_RETRY, ) + def test_get_iam_policy(self): + """Tests that `get_iam_policy` method correctly calls the underlying API client. + + It verifies that the `get_iam_policy` method is called with the expected + resource name and returns the policy as provided by the mocked API client. + """ + with mock.patch.object( + base.VertexAiResourceNoun, "_get_gca_resource" + ) as mock_get_gca_resource: + mock_get_gca_resource.return_value = types.ReasoningEngine( + name=_TEST_REASONING_ENGINE_RESOURCE_NAME + ) + reasoning_engine = reasoning_engines.ReasoningEngine( + _TEST_REASONING_ENGINE_RESOURCE_NAME + ) + + test_policy = policy_pb2.Policy(version=1) + with mock.patch.object( + reasoning_engine.api_client, "get_iam_policy" + ) as mock_get_iam_policy: + mock_get_iam_policy.return_value = test_policy + policy = reasoning_engine.get_iam_policy(policy_version=1) + mock_get_iam_policy.assert_called_once() + assert policy == test_policy + + def test_set_iam_policy(self): + """Tests that `set_iam_policy` method correctly calls the underlying API client. + + It verifies that the `set_iam_policy` method is called with the expected + policy and returns the policy as provided by the mocked API client. + """ + with mock.patch.object( + base.VertexAiResourceNoun, "_get_gca_resource" + ) as mock_get_gca_resource: + mock_get_gca_resource.return_value = types.ReasoningEngine( + name=_TEST_REASONING_ENGINE_RESOURCE_NAME + ) + reasoning_engine = reasoning_engines.ReasoningEngine( + _TEST_REASONING_ENGINE_RESOURCE_NAME + ) + + test_policy = policy_pb2.Policy(version=1) + with mock.patch.object( + reasoning_engine.api_client, "set_iam_policy" + ) as mock_set_iam_policy: + mock_set_iam_policy.return_value = test_policy + policy = reasoning_engine.set_iam_policy(test_policy) + mock_set_iam_policy.assert_called_once() + assert policy == test_policy + @pytest.mark.usefixtures("caplog") def test_create_reasoning_engine_warn_resource_name( self, diff --git a/vertexai/reasoning_engines/_reasoning_engines.py b/vertexai/reasoning_engines/_reasoning_engines.py index 322bf2a2d4..9c36be6f1a 100644 --- a/vertexai/reasoning_engines/_reasoning_engines.py +++ b/vertexai/reasoning_engines/_reasoning_engines.py @@ -44,6 +44,9 @@ from google.cloud.aiplatform_v1beta1 import types as aip_types from google.cloud.aiplatform_v1beta1.types import reasoning_engine_service from vertexai.reasoning_engines import _utils +from google.iam.v1 import iam_policy_pb2 +from google.iam.v1 import options_pb2 +from google.iam.v1 import policy_pb2 from google.protobuf import field_mask_pb2 @@ -499,6 +502,41 @@ def operation_schemas(self) -> Sequence[_utils.JsonDict]: self._operation_schemas = spec.get("classMethods", []) return self._operation_schemas + def get_iam_policy( + self, policy_version: Optional[int] = None + ) -> policy_pb2.Policy: + """Gets the access control policy for this ReasoningEngine. + + Args: + policy_version: Optional. The maximum policy version that will be used + to format the policy. Valid values are 0, 1, 3. + + Returns: + The IAM policy. + """ + request = iam_policy_pb2.GetIamPolicyRequest( + resource=self.resource_name, + options=options_pb2.GetPolicyOptions( + requested_policy_version=policy_version + ), + ) + return self.api_client.get_iam_policy(request=request) + + def set_iam_policy(self, policy: policy_pb2.Policy) -> policy_pb2.Policy: + """Sets the access control policy on this ReasoningEngine. + + Args: + policy: The complete policy to be applied to the resource. + + Returns: + The new IAM policy. + """ + request = iam_policy_pb2.SetIamPolicyRequest( + resource=self.resource_name, + policy=policy, + ) + return self.api_client.set_iam_policy(request=request) + def _validate_sys_version_or_raise(sys_version: str) -> None: """Tries to validate the python system version."""