diff --git a/tests/unit/vertexai/genai/test_agent_engines.py b/tests/unit/vertexai/genai/test_agent_engines.py index 2f954e709a..5db086e0ad 100644 --- a/tests/unit/vertexai/genai/test_agent_engines.py +++ b/tests/unit/vertexai/genai/test_agent_engines.py @@ -3162,6 +3162,28 @@ def test_query_agent_engine_async(self): None, ) + def test_cancel_query_job_agent_engine(self): + with mock.patch.object( + self.client.agent_engines._api_client, "request" + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse(body="{}") + + result = self.client.agent_engines.cancel_query_job( + name=_TEST_AGENT_ENGINE_RESOURCE_NAME, + operation_name=_TEST_AGENT_ENGINE_OPERATION_NAME, + ) + + assert isinstance(result, _genai_types.CancelQueryJobResult) + request_mock.assert_called_with( + "post", + f"{_TEST_AGENT_ENGINE_RESOURCE_NAME}:cancelAsyncQuery", + { + "_url": {"name": _TEST_AGENT_ENGINE_RESOURCE_NAME}, + "operationName": _TEST_AGENT_ENGINE_OPERATION_NAME, + }, + None, + ) + def test_check_query_job_agent_engine(self): with mock.patch.object( self.client.agent_engines._api_client, "request" diff --git a/vertexai/_genai/agent_engines.py b/vertexai/_genai/agent_engines.py index 4272e5b34e..3310bdff7f 100644 --- a/vertexai/_genai/agent_engines.py +++ b/vertexai/_genai/agent_engines.py @@ -49,6 +49,47 @@ logger.setLevel(logging.INFO) +def _CancelQueryJobAgentEngineConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + return to_object + + +def _CancelQueryJobAgentEngineRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["operation_name"]) is not None: + setv(to_object, ["operationName"], getv(from_object, ["operation_name"])) + + if getv(from_object, ["config"]) is not None: + setv( + to_object, + ["config"], + _CancelQueryJobAgentEngineConfig_to_vertex( + getv(from_object, ["config"]), to_object + ), + ) + + return to_object + + +def _CancelQueryJobResult_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + return to_object + + def _CheckQueryJobAgentEngineConfig_to_vertex( from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, @@ -396,130 +437,205 @@ def _UpdateAgentEngineRequestParameters_to_vertex( class AgentEngines(_api_module.BaseModule): - def _check_query_job( - self, - *, - name: str, - config: Optional[types.CheckQueryJobAgentEngineConfigOrDict] = None, - ) -> types.CheckQueryJobResult: - """ - Query an Agent Engine asynchronously. - """ - - parameter_model = types._CheckQueryJobAgentEngineRequestParameters( - name=name, - config=config, - ) - - request_url_dict: Optional[dict[str, str]] - if not self._api_client.vertexai: - raise ValueError("This method is only supported in the Vertex AI client.") - else: - request_dict = _CheckQueryJobAgentEngineRequestParameters_to_vertex( - parameter_model - ) - request_url_dict = request_dict.get("_url") - if request_url_dict: - path = "{name}:checkQueryJob".format_map(request_url_dict) - else: - path = "{name}:checkQueryJob" - - query_params = request_dict.get("_query") - if query_params: - path = f"{path}?{urlencode(query_params)}" - # TODO: remove the hack that pops config. - request_dict.pop("config", None) - - http_options: Optional[types.HttpOptions] = None - if ( - parameter_model.config is not None - and parameter_model.config.http_options is not None - ): - http_options = parameter_model.config.http_options - - request_dict = _common.convert_to_dict(request_dict) - request_dict = _common.encode_unserializable_types(request_dict) - - response = self._api_client.request("post", path, request_dict, http_options) - - response_dict = {} if not response.body else json.loads(response.body) - - if self._api_client.vertexai: - response_dict = _CheckQueryJobResult_from_vertex(response_dict) - - return_value = types.CheckQueryJobResult._from_response( - response=response_dict, - kwargs=( - { - "config": { - "response_schema": getattr( - parameter_model.config, "response_schema", None - ), - "response_json_schema": getattr( - parameter_model.config, "response_json_schema", None - ), - "include_all_fields": getattr( - parameter_model.config, "include_all_fields", None - ), - } + def _cancel_query_job( + self, + *, + name: str, + operation_name: str, + config: Optional[types.CancelQueryJobAgentEngineConfigOrDict] = None, + ) -> types.CancelQueryJobResult: + """Cancels a long-running query job on an Agent Engine.""" + + parameter_model = types._CancelQueryJobAgentEngineRequestParameters( + name=name, + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _CancelQueryJobAgentEngineRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:cancelAsyncQuery".format_map(request_url_dict) + else: + path = "{name}:cancelAsyncQuery" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _CancelQueryJobResult_from_vertex(response_dict) + + return_value = types.CancelQueryJobResult._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), } - if getattr(parameter_model, "config", None) - else {} - ), - ) - - self._api_client._verify_response(return_value) - return return_value - - def _run_query_job( - self, - *, - name: str, - config: Optional[types._RunQueryJobAgentEngineConfigOrDict] = None, - ) -> types.AgentEngineOperation: - """ - Run a query job on an agent engine. - """ - - parameter_model = types._RunQueryJobAgentEngineRequestParameters( - name=name, - config=config, - ) - - request_url_dict: Optional[dict[str, str]] - if not self._api_client.vertexai: - raise ValueError("This method is only supported in the Vertex AI client.") - else: - request_dict = _RunQueryJobAgentEngineRequestParameters_to_vertex( + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _check_query_job( + self, + *, + name: str, + config: Optional[types.CheckQueryJobAgentEngineConfigOrDict] = None, + ) -> types.CheckQueryJobResult: + """Query an Agent Engine asynchronously.""" + + parameter_model = types._CheckQueryJobAgentEngineRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _CheckQueryJobAgentEngineRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:checkQueryJob".format_map(request_url_dict) + else: + path = "{name}:checkQueryJob" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _CheckQueryJobResult_from_vertex(response_dict) + + return_value = types.CheckQueryJobResult._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _run_query_job( + self, + *, + name: str, + config: Optional[types._RunQueryJobAgentEngineConfigOrDict] = None, + ) -> types.AgentEngineOperation: + """Run a query job on an agent engine.""" + + parameter_model = types._RunQueryJobAgentEngineRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _RunQueryJobAgentEngineRequestParameters_to_vertex( parameter_model ) - request_url_dict = request_dict.get("_url") - if request_url_dict: - path = "{name}:asyncQuery".format_map(request_url_dict) - else: - path = "{name}:asyncQuery" - - query_params = request_dict.get("_query") - if query_params: - path = f"{path}?{urlencode(query_params)}" - # TODO: remove the hack that pops config. - request_dict.pop("config", None) - - http_options: Optional[types.HttpOptions] = None - if ( + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:asyncQuery".format_map(request_url_dict) + else: + path = "{name}:asyncQuery" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( parameter_model.config is not None and parameter_model.config.http_options is not None ): - http_options = parameter_model.config.http_options + http_options = parameter_model.config.http_options - request_dict = _common.convert_to_dict(request_dict) - request_dict = _common.encode_unserializable_types(request_dict) + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) - response = self._api_client.request("post", path, request_dict, http_options) + response = self._api_client.request("post", path, request_dict, http_options) - response_dict = {} if not response.body else json.loads(response.body) + response_dict = {} if not response.body else json.loads(response.body) - return_value = types.AgentEngineOperation._from_response( + return_value = types.AgentEngineOperation._from_response( response=response_dict, kwargs=( { @@ -540,54 +656,54 @@ def _run_query_job( ), ) - self._api_client._verify_response(return_value) - return return_value + self._api_client._verify_response(return_value) + return return_value - def _create( + def _create( self, *, config: Optional[types.CreateAgentEngineConfigOrDict] = None ) -> types.AgentEngineOperation: - """ + """ Creates a new Agent Engine. """ - parameter_model = types._CreateAgentEngineRequestParameters( + parameter_model = types._CreateAgentEngineRequestParameters( config=config, ) - request_url_dict: Optional[dict[str, str]] - if not self._api_client.vertexai: - raise ValueError("This method is only supported in the Vertex AI client.") - else: - request_dict = _CreateAgentEngineRequestParameters_to_vertex( + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _CreateAgentEngineRequestParameters_to_vertex( parameter_model ) - request_url_dict = request_dict.get("_url") - if request_url_dict: - path = "reasoningEngines".format_map(request_url_dict) - else: - path = "reasoningEngines" - - query_params = request_dict.get("_query") - if query_params: - path = f"{path}?{urlencode(query_params)}" - # TODO: remove the hack that pops config. - request_dict.pop("config", None) - - http_options: Optional[types.HttpOptions] = None - if ( + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "reasoningEngines".format_map(request_url_dict) + else: + path = "reasoningEngines" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( parameter_model.config is not None and parameter_model.config.http_options is not None ): - http_options = parameter_model.config.http_options + http_options = parameter_model.config.http_options - request_dict = _common.convert_to_dict(request_dict) - request_dict = _common.encode_unserializable_types(request_dict) + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) - response = self._api_client.request("post", path, request_dict, http_options) + response = self._api_client.request("post", path, request_dict, http_options) - response_dict = {} if not response.body else json.loads(response.body) + response_dict = {} if not response.body else json.loads(response.body) - return_value = types.AgentEngineOperation._from_response( + return_value = types.AgentEngineOperation._from_response( response=response_dict, kwargs=( { @@ -608,17 +724,17 @@ def _create( ), ) - self._api_client._verify_response(return_value) - return return_value + self._api_client._verify_response(return_value) + return return_value - def _delete( + def _delete( self, *, name: str, force: Optional[bool] = None, config: Optional[types.DeleteAgentEngineConfigOrDict] = None, ) -> types.DeleteAgentEngineOperation: - """ + """ Delete an Agent Engine resource. Args: @@ -635,46 +751,46 @@ def _delete( """ - parameter_model = types._DeleteAgentEngineRequestParameters( + parameter_model = types._DeleteAgentEngineRequestParameters( name=name, force=force, config=config, ) - request_url_dict: Optional[dict[str, str]] - if not self._api_client.vertexai: - raise ValueError("This method is only supported in the Vertex AI client.") - else: - request_dict = _DeleteAgentEngineRequestParameters_to_vertex( + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _DeleteAgentEngineRequestParameters_to_vertex( parameter_model ) - request_url_dict = request_dict.get("_url") - if request_url_dict: - path = "{name}".format_map(request_url_dict) - else: - path = "{name}" - - query_params = request_dict.get("_query") - if query_params: - path = f"{path}?{urlencode(query_params)}" - # TODO: remove the hack that pops config. - request_dict.pop("config", None) - - http_options: Optional[types.HttpOptions] = None - if ( + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( parameter_model.config is not None and parameter_model.config.http_options is not None ): - http_options = parameter_model.config.http_options + http_options = parameter_model.config.http_options - request_dict = _common.convert_to_dict(request_dict) - request_dict = _common.encode_unserializable_types(request_dict) + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) - response = self._api_client.request("delete", path, request_dict, http_options) + response = self._api_client.request("delete", path, request_dict, http_options) - response_dict = {} if not response.body else json.loads(response.body) + response_dict = {} if not response.body else json.loads(response.body) - return_value = types.DeleteAgentEngineOperation._from_response( + return_value = types.DeleteAgentEngineOperation._from_response( response=response_dict, kwargs=( { @@ -695,53 +811,53 @@ def _delete( ), ) - self._api_client._verify_response(return_value) - return return_value + self._api_client._verify_response(return_value) + return return_value - def _get( + def _get( self, *, name: str, config: Optional[types.GetAgentEngineConfigOrDict] = None ) -> types.ReasoningEngine: - """ + """ Get an Agent Engine instance. """ - parameter_model = types._GetAgentEngineRequestParameters( + parameter_model = types._GetAgentEngineRequestParameters( name=name, config=config, ) - request_url_dict: Optional[dict[str, str]] - if not self._api_client.vertexai: - raise ValueError("This method is only supported in the Vertex AI client.") - else: - request_dict = _GetAgentEngineRequestParameters_to_vertex(parameter_model) - request_url_dict = request_dict.get("_url") - if request_url_dict: - path = "{name}".format_map(request_url_dict) - else: - path = "{name}" - - query_params = request_dict.get("_query") - if query_params: - path = f"{path}?{urlencode(query_params)}" - # TODO: remove the hack that pops config. - request_dict.pop("config", None) - - http_options: Optional[types.HttpOptions] = None - if ( + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _GetAgentEngineRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( parameter_model.config is not None and parameter_model.config.http_options is not None ): - http_options = parameter_model.config.http_options + http_options = parameter_model.config.http_options - request_dict = _common.convert_to_dict(request_dict) - request_dict = _common.encode_unserializable_types(request_dict) + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) - response = self._api_client.request("get", path, request_dict, http_options) + response = self._api_client.request("get", path, request_dict, http_options) - response_dict = {} if not response.body else json.loads(response.body) + response_dict = {} if not response.body else json.loads(response.body) - return_value = types.ReasoningEngine._from_response( + return_value = types.ReasoningEngine._from_response( response=response_dict, kwargs=( { @@ -762,52 +878,52 @@ def _get( ), ) - self._api_client._verify_response(return_value) - return return_value + self._api_client._verify_response(return_value) + return return_value - def _list( + def _list( self, *, config: Optional[types.ListAgentEngineConfigOrDict] = None ) -> types.ListReasoningEnginesResponse: - """ + """ Lists Agent Engines. """ - parameter_model = types._ListAgentEngineRequestParameters( + parameter_model = types._ListAgentEngineRequestParameters( config=config, ) - request_url_dict: Optional[dict[str, str]] - if not self._api_client.vertexai: - raise ValueError("This method is only supported in the Vertex AI client.") - else: - request_dict = _ListAgentEngineRequestParameters_to_vertex(parameter_model) - request_url_dict = request_dict.get("_url") - if request_url_dict: - path = "reasoningEngines".format_map(request_url_dict) - else: - path = "reasoningEngines" - - query_params = request_dict.get("_query") - if query_params: - path = f"{path}?{urlencode(query_params)}" - # TODO: remove the hack that pops config. - request_dict.pop("config", None) - - http_options: Optional[types.HttpOptions] = None - if ( + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _ListAgentEngineRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "reasoningEngines".format_map(request_url_dict) + else: + path = "reasoningEngines" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( parameter_model.config is not None and parameter_model.config.http_options is not None ): - http_options = parameter_model.config.http_options + http_options = parameter_model.config.http_options - request_dict = _common.convert_to_dict(request_dict) - request_dict = _common.encode_unserializable_types(request_dict) + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) - response = self._api_client.request("get", path, request_dict, http_options) + response = self._api_client.request("get", path, request_dict, http_options) - response_dict = {} if not response.body else json.loads(response.body) + response_dict = {} if not response.body else json.loads(response.body) - return_value = types.ListReasoningEnginesResponse._from_response( + return_value = types.ListReasoningEnginesResponse._from_response( response=response_dict, kwargs=( { @@ -828,52 +944,52 @@ def _list( ), ) - self._api_client._verify_response(return_value) - return return_value + self._api_client._verify_response(return_value) + return return_value - def _get_agent_operation( + def _get_agent_operation( self, *, operation_name: str, config: Optional[types.GetAgentEngineOperationConfigOrDict] = None, ) -> types.AgentEngineOperation: - parameter_model = types._GetAgentEngineOperationParameters( + parameter_model = types._GetAgentEngineOperationParameters( operation_name=operation_name, config=config, ) - request_url_dict: Optional[dict[str, str]] - if not self._api_client.vertexai: - raise ValueError("This method is only supported in the Vertex AI client.") - else: - request_dict = _GetAgentEngineOperationParameters_to_vertex(parameter_model) - request_url_dict = request_dict.get("_url") - if request_url_dict: - path = "{operationName}".format_map(request_url_dict) - else: - path = "{operationName}" - - query_params = request_dict.get("_query") - if query_params: - path = f"{path}?{urlencode(query_params)}" - # TODO: remove the hack that pops config. - request_dict.pop("config", None) - - http_options: Optional[types.HttpOptions] = None - if ( + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _GetAgentEngineOperationParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( parameter_model.config is not None and parameter_model.config.http_options is not None ): - http_options = parameter_model.config.http_options + http_options = parameter_model.config.http_options - request_dict = _common.convert_to_dict(request_dict) - request_dict = _common.encode_unserializable_types(request_dict) + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) - response = self._api_client.request("get", path, request_dict, http_options) + response = self._api_client.request("get", path, request_dict, http_options) - response_dict = {} if not response.body else json.loads(response.body) + response_dict = {} if not response.body else json.loads(response.body) - return_value = types.AgentEngineOperation._from_response( + return_value = types.AgentEngineOperation._from_response( response=response_dict, kwargs=( { @@ -894,53 +1010,53 @@ def _get_agent_operation( ), ) - self._api_client._verify_response(return_value) - return return_value + self._api_client._verify_response(return_value) + return return_value - def _query( + def _query( self, *, name: str, config: Optional[types.QueryAgentEngineConfigOrDict] = None ) -> types.QueryReasoningEngineResponse: - """ + """ Query an Agent Engine. """ - parameter_model = types._QueryAgentEngineRequestParameters( + parameter_model = types._QueryAgentEngineRequestParameters( name=name, config=config, ) - request_url_dict: Optional[dict[str, str]] - if not self._api_client.vertexai: - raise ValueError("This method is only supported in the Vertex AI client.") - else: - request_dict = _QueryAgentEngineRequestParameters_to_vertex(parameter_model) - request_url_dict = request_dict.get("_url") - if request_url_dict: - path = "{name}:query".format_map(request_url_dict) - else: - path = "{name}:query" - - query_params = request_dict.get("_query") - if query_params: - path = f"{path}?{urlencode(query_params)}" - # TODO: remove the hack that pops config. - request_dict.pop("config", None) - - http_options: Optional[types.HttpOptions] = None - if ( + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _QueryAgentEngineRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:query".format_map(request_url_dict) + else: + path = "{name}:query" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( parameter_model.config is not None and parameter_model.config.http_options is not None ): - http_options = parameter_model.config.http_options + http_options = parameter_model.config.http_options - request_dict = _common.convert_to_dict(request_dict) - request_dict = _common.encode_unserializable_types(request_dict) + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) - response = self._api_client.request("post", path, request_dict, http_options) + response = self._api_client.request("post", path, request_dict, http_options) - response_dict = {} if not response.body else json.loads(response.body) + response_dict = {} if not response.body else json.loads(response.body) - return_value = types.QueryReasoningEngineResponse._from_response( + return_value = types.QueryReasoningEngineResponse._from_response( response=response_dict, kwargs=( { @@ -961,55 +1077,55 @@ def _query( ), ) - self._api_client._verify_response(return_value) - return return_value + self._api_client._verify_response(return_value) + return return_value - def _update( + def _update( self, *, name: str, config: Optional[types.UpdateAgentEngineConfigOrDict] = None ) -> types.AgentEngineOperation: - """ + """ Updates an Agent Engine. """ - parameter_model = types._UpdateAgentEngineRequestParameters( + parameter_model = types._UpdateAgentEngineRequestParameters( name=name, config=config, ) - request_url_dict: Optional[dict[str, str]] - if not self._api_client.vertexai: - raise ValueError("This method is only supported in the Vertex AI client.") - else: - request_dict = _UpdateAgentEngineRequestParameters_to_vertex( + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _UpdateAgentEngineRequestParameters_to_vertex( parameter_model ) - request_url_dict = request_dict.get("_url") - if request_url_dict: - path = "{name}".format_map(request_url_dict) - else: - path = "{name}" - - query_params = request_dict.get("_query") - if query_params: - path = f"{path}?{urlencode(query_params)}" - # TODO: remove the hack that pops config. - request_dict.pop("config", None) - - http_options: Optional[types.HttpOptions] = None - if ( + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( parameter_model.config is not None and parameter_model.config.http_options is not None ): - http_options = parameter_model.config.http_options + http_options = parameter_model.config.http_options - request_dict = _common.convert_to_dict(request_dict) - request_dict = _common.encode_unserializable_types(request_dict) + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) - response = self._api_client.request("patch", path, request_dict, http_options) + response = self._api_client.request("patch", path, request_dict, http_options) - response_dict = {} if not response.body else json.loads(response.body) + response_dict = {} if not response.body else json.loads(response.body) - return_value = types.AgentEngineOperation._from_response( + return_value = types.AgentEngineOperation._from_response( response=response_dict, kwargs=( { @@ -1030,91 +1146,102 @@ def _update( ), ) - self._api_client._verify_response(return_value) - return return_value + self._api_client._verify_response(return_value) + return return_value - _a2a_tasks = None - _memories = None - _sandboxes = None - _sessions = None + _a2a_tasks = None + _memories = None + _sandboxes = None + _sessions = None - @property - def a2a_tasks(self) -> "a2a_tasks_module.A2aTasks": - if self._a2a_tasks is None: - try: - # We need to lazy load the a2a_tasks module to handle the - # possibility of ImportError when dependencies are not installed. - self._a2a_tasks = importlib.import_module(".a2a_tasks", __package__) - except ImportError as e: - raise ImportError( + @property + def a2a_tasks(self) -> "a2a_tasks_module.A2aTasks": + if self._a2a_tasks is None: + try: + # We need to lazy load the a2a_tasks module to handle the + # possibility of ImportError when dependencies are not installed. + self._a2a_tasks = importlib.import_module(".a2a_tasks", __package__) + except ImportError as e: + raise ImportError( "The 'agent_engines.a2a_tasks' module requires additional " "packages. Please install them using pip install " "google-cloud-aiplatform[agent_engines]" ) from e - return self._a2a_tasks.A2aTasks(self._api_client) # type: ignore[no-any-return] - - @property - def memories(self) -> "memories_module.Memories": - if self._memories is None: - try: - # We need to lazy load the memories module to handle the - # possibility of ImportError when dependencies are not installed. - self._memories = importlib.import_module(".memories", __package__) - except ImportError as e: - raise ImportError( + return self._a2a_tasks.A2aTasks(self._api_client) # type: ignore[no-any-return] + + @property + def memories(self) -> "memories_module.Memories": + if self._memories is None: + try: + # We need to lazy load the memories module to handle the + # possibility of ImportError when dependencies are not installed. + self._memories = importlib.import_module(".memories", __package__) + except ImportError as e: + raise ImportError( "The 'agent_engines.memories' module requires additional " "packages. Please install them using pip install " "google-cloud-aiplatform[agent_engines]" ) from e - return self._memories.Memories(self._api_client) # type: ignore[no-any-return] - - @property - def sandboxes(self) -> Any: - if self._sandboxes is None: - try: - # We need to lazy load the sandboxes module to handle the - # possibility of ImportError when dependencies are not installed. - self._sandboxes = importlib.import_module(".sandboxes", __package__) - except ImportError as e: - raise ImportError( + return self._memories.Memories(self._api_client) # type: ignore[no-any-return] + + @property + def sandboxes(self) -> Any: + if self._sandboxes is None: + try: + # We need to lazy load the sandboxes module to handle the + # possibility of ImportError when dependencies are not installed. + self._sandboxes = importlib.import_module(".sandboxes", __package__) + except ImportError as e: + raise ImportError( "The agent_engines.sandboxes module requires additional packages. " "Please install them using pip install " "google-cloud-aiplatform[agent_engines]" ) from e - return self._sandboxes.Sandboxes(self._api_client) - - @property - def sessions(self) -> "sessions_module.Sessions": - if self._sessions is None: - try: - # We need to lazy load the sessions module to handle the - # possibility of ImportError when dependencies are not installed. - self._sessions = importlib.import_module(".sessions", __package__) - except ImportError as e: - raise ImportError( + return self._sandboxes.Sandboxes(self._api_client) + + @property + def sessions(self) -> "sessions_module.Sessions": + if self._sessions is None: + try: + # We need to lazy load the sessions module to handle the + # possibility of ImportError when dependencies are not installed. + self._sessions = importlib.import_module(".sessions", __package__) + except ImportError as e: + raise ImportError( "The agent_engines.sessions module requires additional packages. " "Please install them using pip install " "google-cloud-aiplatform[agent_engines]" ) from e - return self._sessions.Sessions(self._api_client) # type: ignore[no-any-return] + return self._sessions.Sessions(self._api_client) # type: ignore[no-any-return] - def _list_pager( + def _list_pager( self, *, config: Optional[types.ListAgentEngineConfigOrDict] = None ) -> Pager[types.ReasoningEngine]: - return Pager( - "reasoning_engines", - self._list, - self._list(config=config), - config, - ) - - def check_query_job( - self, - *, - name: str, - config: Optional[types.CheckQueryJobAgentEngineConfigOrDict] = None, - ) -> types.CheckQueryJobResult: - """Checks a query job on an agent engine and optionally returns the results. + return Pager( + "reasoning_engines", + self._list, + self._list(config=config), + config, + ) + + def cancel_query_job( + self, + *, + name: str, + operation_name: str, + config: Optional[types.CancelQueryJobAgentEngineConfigOrDict] = None, + ) -> types.CancelQueryJobResult: + return self._cancel_query_job( + name=name, operation_name=operation_name, config=config + ) + + def check_query_job( + self, + *, + name: str, + config: Optional[types.CheckQueryJobAgentEngineConfigOrDict] = None, + ) -> types.CheckQueryJobResult: + """Checks a query job on an agent engine and optionally returns the results. Args: name (str): @@ -1125,45 +1252,45 @@ def check_query_job( the following fields: - retrieve_result: Whether to retrieve the results of the query job. """ - from google.cloud import storage # type: ignore[attr-defined] - import json + from google.cloud import storage # type: ignore[attr-defined] + import json - if config is None: - config = types.CheckQueryJobAgentEngineConfig() - elif isinstance(config, dict): - config = types.CheckQueryJobAgentEngineConfig(**config) + if config is None: + config = types.CheckQueryJobAgentEngineConfig() + elif isinstance(config, dict): + config = types.CheckQueryJobAgentEngineConfig(**config) - raw_response = self._api_client.request("get", name, {}) - if hasattr(raw_response, "body"): - operation = ( + raw_response = self._api_client.request("get", name, {}) + if hasattr(raw_response, "body"): + operation = ( json.loads(raw_response.body) if isinstance(raw_response.body, str) else raw_response.body ) - else: - operation = raw_response + else: + operation = raw_response - status = "RUNNING" - if isinstance(operation, dict): - if operation.get("done"): - status = "FAILED" if operation.get("error") else "SUCCESS" + status = "RUNNING" + if isinstance(operation, dict): + if operation.get("done"): + status = "FAILED" if operation.get("error") else "SUCCESS" - response_dict = operation.get("response", {}) - output_gcs_uri = response_dict.get("outputGcsUri") or response_dict.get( + response_dict = operation.get("response", {}) + output_gcs_uri = response_dict.get("outputGcsUri") or response_dict.get( "output_gcs_uri" ) - error = operation.get("error") - else: - if getattr(operation, "done", False): - status = "FAILED" if getattr(operation, "error", None) else "SUCCESS" - - response_obj = getattr(operation, "response", None) - if isinstance(response_obj, dict): - output_gcs_uri = response_obj.get("outputGcsUri") or response_obj.get( + error = operation.get("error") + else: + if getattr(operation, "done", False): + status = "FAILED" if getattr(operation, "error", None) else "SUCCESS" + + response_obj = getattr(operation, "response", None) + if isinstance(response_obj, dict): + output_gcs_uri = response_obj.get("outputGcsUri") or response_obj.get( "output_gcs_uri" ) - else: - output_gcs_uri = ( + else: + output_gcs_uri = ( getattr( response_obj, "output_gcs_uri", @@ -1172,55 +1299,55 @@ def check_query_job( if response_obj else None ) - error = getattr(operation, "error", None) + error = getattr(operation, "error", None) - result_str = None - if status == "SUCCESS" and config.retrieve_result and output_gcs_uri: - storage_client = storage.Client( + result_str = None + if status == "SUCCESS" and config.retrieve_result and output_gcs_uri: + storage_client = storage.Client( project=self._api_client.project, credentials=self._api_client._credentials, ) - bucket_name = output_gcs_uri.replace("gs://", "").split("/")[0] - blob_name = output_gcs_uri.replace(f"gs://{bucket_name}/", "") - bucket = storage_client.bucket(bucket_name) - blob = bucket.blob(blob_name) - if blob.exists(): - result_str = blob.download_as_string().decode("utf-8") - else: - raise ValueError( + bucket_name = output_gcs_uri.replace("gs://", "").split("/")[0] + blob_name = output_gcs_uri.replace(f"gs://{bucket_name}/", "") + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(blob_name) + if blob.exists(): + result_str = blob.download_as_string().decode("utf-8") + else: + raise ValueError( f"Failed to retrieve blob results for {output_gcs_uri}" ) - elif status == "FAILED" and error: - result_str = str(error) + elif status == "FAILED" and error: + result_str = str(error) - return types.CheckQueryJobResult( + return types.CheckQueryJobResult( operation_name=name, output_gcs_uri=output_gcs_uri, status=status, result=result_str, ) - def _is_lightweight_creation( + def _is_lightweight_creation( self, agent: Any, config: types.AgentEngineConfig ) -> bool: - if ( + if ( agent or config.source_packages or config.developer_connect_source or config.agent_config_source or config.container_spec ): - return False - return True + return False + return True - def run_query_job( + def run_query_job( self, *, name: str, config: Optional[types.RunQueryJobAgentEngineConfigOrDict] = None, ) -> types.RunQueryJobResult: - """Launches a long-running query job on an Agent Engine + """Launches a long-running query job on an Agent Engine Args: name (str): @@ -1232,95 +1359,95 @@ def run_query_job( - query: The query to send to the agent engine. - output_gcs_uri: The GCS URI to use for the output. """ - from google.cloud import storage # type: ignore[attr-defined] - from google.api_core import exceptions - import uuid - - if config is None: - config = types.RunQueryJobAgentEngineConfig() - elif isinstance(config, dict): - config = types.RunQueryJobAgentEngineConfig(**config) - - if not config.query: - raise ValueError("`query` is required in the config object.") - if not config.output_gcs_uri: - raise ValueError("`output_gcs_uri` is required in the config object.") - - output_gcs_uri = config.output_gcs_uri - is_file = False - last_part = "" - if not output_gcs_uri.endswith("/"): - last_part = output_gcs_uri.split("/")[-1] - if "." in last_part: - is_file = True - - if is_file: - path_parts = output_gcs_uri.split("/") - file_name = path_parts[-1] - base_uri = "/".join(path_parts[:-1]) - name_parts = file_name.rsplit(".", 1) - if len(name_parts) == 2: - name_part, ext = name_parts[0], "." + name_parts[1] - else: - name_part = name_parts[0] - ext = "" - input_gcs_uri = f"{base_uri}/{name_part}_input{ext}" - else: - job_uuid = uuid.uuid4().hex - gcs_path = output_gcs_uri.rstrip("/") - input_gcs_uri = f"{gcs_path}/{job_uuid}_input.json" - output_gcs_uri = f"{gcs_path}/{job_uuid}_output.json" - - storage_client = storage.Client( + from google.cloud import storage # type: ignore[attr-defined] + from google.api_core import exceptions + import uuid + + if config is None: + config = types.RunQueryJobAgentEngineConfig() + elif isinstance(config, dict): + config = types.RunQueryJobAgentEngineConfig(**config) + + if not config.query: + raise ValueError("`query` is required in the config object.") + if not config.output_gcs_uri: + raise ValueError("`output_gcs_uri` is required in the config object.") + + output_gcs_uri = config.output_gcs_uri + is_file = False + last_part = "" + if not output_gcs_uri.endswith("/"): + last_part = output_gcs_uri.split("/")[-1] + if "." in last_part: + is_file = True + + if is_file: + path_parts = output_gcs_uri.split("/") + file_name = path_parts[-1] + base_uri = "/".join(path_parts[:-1]) + name_parts = file_name.rsplit(".", 1) + if len(name_parts) == 2: + name_part, ext = name_parts[0], "." + name_parts[1] + else: + name_part = name_parts[0] + ext = "" + input_gcs_uri = f"{base_uri}/{name_part}_input{ext}" + else: + job_uuid = uuid.uuid4().hex + gcs_path = output_gcs_uri.rstrip("/") + input_gcs_uri = f"{gcs_path}/{job_uuid}_input.json" + output_gcs_uri = f"{gcs_path}/{job_uuid}_output.json" + + storage_client = storage.Client( project=self._api_client.project, credentials=self._api_client._credentials ) - # Handle creating the bucket if it does not exist - bucket_name = config.output_gcs_uri.replace("gs://", "").split("/")[0] - bucket = storage_client.bucket(bucket_name) + # Handle creating the bucket if it does not exist + bucket_name = config.output_gcs_uri.replace("gs://", "").split("/")[0] + bucket = storage_client.bucket(bucket_name) - try: - bucket_exists = bucket.exists() - except exceptions.Forbidden as e: - raise ValueError( + try: + bucket_exists = bucket.exists() + except exceptions.Forbidden as e: + raise ValueError( f"Permission denied to check existence of bucket '{bucket_name}'. " "The service account may lack 'storage.buckets.get' permission." ) from e - if not bucket_exists: - try: - bucket.create() - except exceptions.Forbidden as e: - raise ValueError( + if not bucket_exists: + try: + bucket.create() + except exceptions.Forbidden as e: + raise ValueError( f"Permission denied to create bucket '{bucket_name}'. " "The service account may lack 'storage.buckets.create' permission." ) from e - input_blob_name = input_gcs_uri.replace(f"gs://{bucket_name}/", "") - blob = bucket.blob(input_blob_name) - blob.upload_from_string(config.query) + input_blob_name = input_gcs_uri.replace(f"gs://{bucket_name}/", "") + blob = bucket.blob(input_blob_name) + blob.upload_from_string(config.query) - new_config = types._RunQueryJobAgentEngineConfig( + new_config = types._RunQueryJobAgentEngineConfig( input_gcs_uri=input_gcs_uri, output_gcs_uri=output_gcs_uri, ) - # Proceed with sending the async query via the auto-generated method - operation = self._run_query_job(name=name, config=new_config) + # Proceed with sending the async query via the auto-generated method + operation = self._run_query_job(name=name, config=new_config) - return types.RunQueryJobResult( + return types.RunQueryJobResult( job_name=operation.name, input_gcs_uri=input_gcs_uri, output_gcs_uri=output_gcs_uri, ) - def get( + def get( self, *, name: str, config: Optional[types.GetAgentEngineConfigOrDict] = None, ) -> types.AgentEngine: - """Gets an agent engine. + """Gets an agent engine. Args: name (str): @@ -1328,24 +1455,24 @@ def get( "projects/123/locations/us-central1/reasoningEngines/456" or a shortened name such as "reasoningEngines/456". """ - api_resource = self._get(name=name, config=config) - agent_engine = types.AgentEngine( + api_resource = self._get(name=name, config=config) + agent_engine = types.AgentEngine( api_client=self, api_async_client=AsyncAgentEngines(api_client_=self._api_client), api_resource=api_resource, ) - if api_resource.spec: - self._register_api_methods(agent_engine=agent_engine) - return agent_engine + if api_resource.spec: + self._register_api_methods(agent_engine=agent_engine) + return agent_engine - def delete( + def delete( self, *, name: str, force: Optional[bool] = None, config: Optional[types.DeleteAgentEngineConfigOrDict] = None, ) -> types.DeleteAgentEngineOperation: - """ + """ Delete an Agent Engine resource. Args: @@ -1361,19 +1488,19 @@ def delete( Optional. Additional configurations for deleting the Agent Engine. """ - logger.info(f"Deleting AgentEngine resource: {name}") - operation = self._delete(name=name, force=force, config=config) - logger.info(f"Started AgentEngine delete operation: {operation.name}") - return operation + logger.info(f"Deleting AgentEngine resource: {name}") + operation = self._delete(name=name, force=force, config=config) + logger.info(f"Started AgentEngine delete operation: {operation.name}") + return operation - def create( + def create( self, *, agent_engine: Any = None, agent: Any = None, config: Optional[types.AgentEngineConfigOrDict] = None, ) -> types.AgentEngine: - """Creates an agent engine. + """Creates an agent engine. The Agent Engine will be an instance of the `agent_engine` that was passed in, running remotely on Vertex AI. @@ -1444,34 +1571,34 @@ def create( IOError: If ``config.requirements` is a string that corresponds to a nonexistent file. """ - if config is None: - config = {} - if isinstance(config, dict): - config = types.AgentEngineConfig.model_validate(config) - elif not isinstance(config, types.AgentEngineConfig): - raise TypeError( + if config is None: + config = {} + if isinstance(config, dict): + config = types.AgentEngineConfig.model_validate(config) + elif not isinstance(config, types.AgentEngineConfig): + raise TypeError( f"config must be a dict or AgentEngineConfig, but got {type(config)}." ) - context_spec = config.context_spec - if context_spec is not None: - # Conversion to a dict for _create_config - context_spec = json.loads(context_spec.model_dump_json()) - developer_connect_source = config.developer_connect_source - if developer_connect_source is not None: - developer_connect_source = json.loads( + context_spec = config.context_spec + if context_spec is not None: + # Conversion to a dict for _create_config + context_spec = json.loads(context_spec.model_dump_json()) + developer_connect_source = config.developer_connect_source + if developer_connect_source is not None: + developer_connect_source = json.loads( developer_connect_source.model_dump_json() ) - agent_config_source = config.agent_config_source - if agent_config_source is not None: - agent_config_source = json.loads(agent_config_source.model_dump_json()) - if agent and agent_engine: - raise ValueError("Please specify only one of `agent` or `agent_engine`.") - elif agent_engine: - raise DeprecationWarning( + agent_config_source = config.agent_config_source + if agent_config_source is not None: + agent_config_source = json.loads(agent_config_source.model_dump_json()) + if agent and agent_engine: + raise ValueError("Please specify only one of `agent` or `agent_engine`.") + elif agent_engine: + raise DeprecationWarning( "The `agent_engine` argument is deprecated. Please use `agent` instead." ) - agent = agent or agent_engine - api_config = self._create_config( + agent = agent or agent_engine + api_config = self._create_config( mode="create", agent=agent, identity_type=config.identity_type, @@ -1505,47 +1632,47 @@ def create( agent_config_source=agent_config_source, container_spec=config.container_spec, ) - operation = self._create(config=api_config) - reasoning_engine_id = _agent_engines_utils._get_reasoning_engine_id( + operation = self._create(config=api_config) + reasoning_engine_id = _agent_engines_utils._get_reasoning_engine_id( operation_name=operation.name ) - logger.info( + logger.info( "View progress and logs at https://console.cloud.google.com/logs/query?" f"project={self._api_client.project}" "&query=resource.type%3D%22aiplatform.googleapis.com%2FReasoningEngine%22%0A" f"resource.labels.reasoning_engine_id%3D%22{reasoning_engine_id}%22." ) - if not self._is_lightweight_creation(agent, config): - poll_interval_seconds = 10 - else: - poll_interval_seconds = 1 # Lightweight agent engine resource creation. - operation = _agent_engines_utils._await_operation( + if not self._is_lightweight_creation(agent, config): + poll_interval_seconds = 10 + else: + poll_interval_seconds = 1 # Lightweight agent engine resource creation. + operation = _agent_engines_utils._await_operation( operation_name=operation.name, get_operation_fn=self._get_agent_operation, poll_interval_seconds=poll_interval_seconds, ) - agent_engine = types.AgentEngine( + agent_engine = types.AgentEngine( api_client=self, api_async_client=AsyncAgentEngines(api_client_=self._api_client), api_resource=operation.response, ) - if agent_engine.api_resource: - logger.info("Agent Engine created. To use it in another session:") - logger.info( + if agent_engine.api_resource: + logger.info("Agent Engine created. To use it in another session:") + logger.info( f"agent_engine=client.agent_engines.get(name='{agent_engine.api_resource.name}')" ) - elif operation.error: - raise RuntimeError(f"Failed to create Agent Engine: {operation.error}") - else: - logger.warning("The operation returned an empty response.") - if not self._is_lightweight_creation(agent, config): - # If the user did not provide an agent_engine (e.g. lightweight - # provisioning), it will not have any API methods registered. - agent_engine = self._register_api_methods(agent_engine=agent_engine) - return agent_engine # type: ignore[no-any-return] - - def _set_source_code_spec( + elif operation.error: + raise RuntimeError(f"Failed to create Agent Engine: {operation.error}") + else: + logger.warning("The operation returned an empty response.") + if not self._is_lightweight_creation(agent, config): + # If the user did not provide an agent_engine (e.g. lightweight + # provisioning), it will not have any API methods registered. + agent_engine = self._register_api_methods(agent_engine=agent_engine) + return agent_engine # type: ignore[no-any-return] + + def _set_source_code_spec( self, *, spec: types.ReasoningEngineSpecDict, @@ -1567,118 +1694,118 @@ def _set_source_code_spec( types.ReasoningEngineSpecSourceCodeSpecAgentConfigSourceDict ] = None, ) -> None: - """Sets source_code_spec for agent engine inside the `spec`.""" - source_code_spec = types.ReasoningEngineSpecSourceCodeSpecDict() - if source_packages and not agent_config_source: - source_packages = _agent_engines_utils._validate_packages_or_raise( + """Sets source_code_spec for agent engine inside the `spec`.""" + source_code_spec = types.ReasoningEngineSpecSourceCodeSpecDict() + if source_packages and not agent_config_source: + source_packages = _agent_engines_utils._validate_packages_or_raise( packages=source_packages, build_options=build_options, ) - update_masks.append("spec.source_code_spec.inline_source.source_archive") - source_code_spec["inline_source"] = { # type: ignore[typeddict-item] + update_masks.append("spec.source_code_spec.inline_source.source_archive") + source_code_spec["inline_source"] = { # type: ignore[typeddict-item] "source_archive": _agent_engines_utils._create_base64_encoded_tarball( source_packages=source_packages ) } - elif developer_connect_source: - update_masks.append("spec.source_code_spec.developer_connect_source") - source_code_spec["developer_connect_source"] = { + elif developer_connect_source: + update_masks.append("spec.source_code_spec.developer_connect_source") + source_code_spec["developer_connect_source"] = { "config": developer_connect_source } - elif not agent_config_source: - raise ValueError( + elif not agent_config_source: + raise ValueError( "Please specify one of `source_packages`, `developer_connect_source`, " "or `agent_config_source`." ) - if class_methods is not None: - update_masks.append("spec.class_methods") - class_methods_spec_list = ( + if class_methods is not None: + update_masks.append("spec.class_methods") + class_methods_spec_list = ( _agent_engines_utils._class_methods_to_class_methods_spec( class_methods=class_methods ) ) - spec["class_methods"] = [ + spec["class_methods"] = [ _agent_engines_utils._to_dict(class_method_spec) for class_method_spec in class_methods_spec_list ] - elif image_spec is None: - raise ValueError( + elif image_spec is None: + raise ValueError( "`class_methods` must be specified if `source_packages`, " "`developer_connect_source`, or `agent_config_source` is " "specified without a Dockerfile or `image_spec`." ) - if image_spec is not None: - if entrypoint_module or entrypoint_object or requirements_file: - raise ValueError( + if image_spec is not None: + if entrypoint_module or entrypoint_object or requirements_file: + raise ValueError( "`image_spec` cannot be specified alongside `entrypoint_module`, " "`entrypoint_object`, or `requirements_file`, as they are " "mutually exclusive." ) - if agent_config_source: - raise ValueError( + if agent_config_source: + raise ValueError( "`image_spec` cannot be specified alongside `agent_config_source`, " "as they are mutually exclusive." ) - update_masks.append("spec.source_code_spec.image_spec") - source_code_spec["image_spec"] = image_spec - spec["source_code_spec"] = source_code_spec - return + update_masks.append("spec.source_code_spec.image_spec") + source_code_spec["image_spec"] = image_spec + spec["source_code_spec"] = source_code_spec + return - update_masks.append("spec.source_code_spec.python_spec.version") - python_spec: types.ReasoningEngineSpecSourceCodeSpecPythonSpecDict = { + update_masks.append("spec.source_code_spec.python_spec.version") + python_spec: types.ReasoningEngineSpecSourceCodeSpecPythonSpecDict = { "version": sys_version, } - if agent_config_source is not None: - if entrypoint_module or entrypoint_object: - logger.warning( + if agent_config_source is not None: + if entrypoint_module or entrypoint_object: + logger.warning( "`entrypoint_module` and `entrypoint_object` are ignored when " "`agent_config_source` is specified, as they are pre-defined." ) - if source_packages: - source_packages = _agent_engines_utils._validate_packages_or_raise( + if source_packages: + source_packages = _agent_engines_utils._validate_packages_or_raise( packages=source_packages, build_options=build_options, ) - update_masks.append( + update_masks.append( "spec.source_code_spec.agent_config_source.inline_source.source_archive" ) - agent_config_source["inline_source"] = { # type: ignore[typeddict-item] + agent_config_source["inline_source"] = { # type: ignore[typeddict-item] "source_archive": _agent_engines_utils._create_base64_encoded_tarball( source_packages=source_packages ) } - update_masks.append("spec.source_code_spec.agent_config_source") - source_code_spec["agent_config_source"] = agent_config_source + update_masks.append("spec.source_code_spec.agent_config_source") + source_code_spec["agent_config_source"] = agent_config_source - if requirements_file is not None: - update_masks.append( + if requirements_file is not None: + update_masks.append( "spec.source_code_spec.python_spec.requirements_file" ) - python_spec["requirements_file"] = requirements_file - source_code_spec["python_spec"] = python_spec + python_spec["requirements_file"] = requirements_file + source_code_spec["python_spec"] = python_spec - spec["source_code_spec"] = source_code_spec - return + spec["source_code_spec"] = source_code_spec + return - if not entrypoint_module: - raise ValueError( + if not entrypoint_module: + raise ValueError( "`entrypoint_module` must be specified if `source_packages` or `developer_connect_source` is specified." ) - update_masks.append("spec.source_code_spec.python_spec.entrypoint_module") - python_spec["entrypoint_module"] = entrypoint_module - if not entrypoint_object: - raise ValueError( + update_masks.append("spec.source_code_spec.python_spec.entrypoint_module") + python_spec["entrypoint_module"] = entrypoint_module + if not entrypoint_object: + raise ValueError( "`entrypoint_object` must be specified if `source_packages` or `developer_connect_source` is specified." ) - update_masks.append("spec.source_code_spec.python_spec.entrypoint_object") - python_spec["entrypoint_object"] = entrypoint_object - if requirements_file is not None: - update_masks.append("spec.source_code_spec.python_spec.requirements_file") - python_spec["requirements_file"] = requirements_file - source_code_spec["python_spec"] = python_spec - spec["source_code_spec"] = source_code_spec - - def _set_package_spec( + update_masks.append("spec.source_code_spec.python_spec.entrypoint_object") + python_spec["entrypoint_object"] = entrypoint_object + if requirements_file is not None: + update_masks.append("spec.source_code_spec.python_spec.requirements_file") + python_spec["requirements_file"] = requirements_file + source_code_spec["python_spec"] = python_spec + spec["source_code_spec"] = source_code_spec + + def _set_package_spec( self, *, spec: types.ReasoningEngineSpecDict, @@ -1692,29 +1819,29 @@ def _set_package_spec( sys_version: str, build_options: Optional[dict[str, list[str]]] = None, ) -> None: - """Sets package spec for agent engine.""" - project = self._api_client.project - if project is None: - raise ValueError("project must be set using `vertexai.Client`.") - location = self._api_client.location - if location is None: - raise ValueError("location must be set using `vertexai.Client`.") - gcs_dir_name = gcs_dir_name or _agent_engines_utils._DEFAULT_GCS_DIR_NAME - staging_bucket = _agent_engines_utils._validate_staging_bucket_or_raise( + """Sets package spec for agent engine.""" + project = self._api_client.project + if project is None: + raise ValueError("project must be set using `vertexai.Client`.") + location = self._api_client.location + if location is None: + raise ValueError("location must be set using `vertexai.Client`.") + gcs_dir_name = gcs_dir_name or _agent_engines_utils._DEFAULT_GCS_DIR_NAME + staging_bucket = _agent_engines_utils._validate_staging_bucket_or_raise( staging_bucket=staging_bucket, ) - requirements = _agent_engines_utils._validate_requirements_or_raise( + requirements = _agent_engines_utils._validate_requirements_or_raise( agent=agent, requirements=requirements, ) - extra_packages = _agent_engines_utils._validate_packages_or_raise( + extra_packages = _agent_engines_utils._validate_packages_or_raise( packages=extra_packages, build_options=build_options, ) - # Prepares the Agent Engine for creation/update in Vertex AI. This - # involves packaging and uploading the artifacts for agent_engine, - # requirements and extra_packages to `staging_bucket/gcs_dir_name`. - _agent_engines_utils._prepare( + # Prepares the Agent Engine for creation/update in Vertex AI. This + # involves packaging and uploading the artifacts for agent_engine, + # requirements and extra_packages to `staging_bucket/gcs_dir_name`. + _agent_engines_utils._prepare( agent=agent, requirements=requirements, project=project, @@ -1724,9 +1851,9 @@ def _set_package_spec( extra_packages=extra_packages, credentials=self._api_client._credentials, ) - # Update the package spec. - update_masks.append("spec.package_spec.pickle_object_gcs_uri") - package_spec: types.ReasoningEngineSpecPackageSpecDict = { + # Update the package spec. + update_masks.append("spec.package_spec.pickle_object_gcs_uri") + package_spec: types.ReasoningEngineSpecPackageSpecDict = { "python_version": sys_version, "pickle_object_gcs_uri": "{}/{}/{}".format( staging_bucket, @@ -1734,31 +1861,31 @@ def _set_package_spec( _agent_engines_utils._BLOB_FILENAME, ), } - if extra_packages: - update_masks.append("spec.package_spec.dependency_files_gcs_uri") - package_spec["dependency_files_gcs_uri"] = "{}/{}/{}".format( + if extra_packages: + update_masks.append("spec.package_spec.dependency_files_gcs_uri") + package_spec["dependency_files_gcs_uri"] = "{}/{}/{}".format( staging_bucket, gcs_dir_name, _agent_engines_utils._EXTRA_PACKAGES_FILE, ) - if requirements: - update_masks.append("spec.package_spec.requirements_gcs_uri") - package_spec["requirements_gcs_uri"] = "{}/{}/{}".format( + if requirements: + update_masks.append("spec.package_spec.requirements_gcs_uri") + package_spec["requirements_gcs_uri"] = "{}/{}/{}".format( staging_bucket, gcs_dir_name, _agent_engines_utils._REQUIREMENTS_FILE, ) - spec["package_spec"] = package_spec + spec["package_spec"] = package_spec - update_masks.append("spec.class_methods") - if class_methods is not None: - class_methods_spec_list = ( + update_masks.append("spec.class_methods") + if class_methods is not None: + class_methods_spec_list = ( _agent_engines_utils._class_methods_to_class_methods_spec( class_methods=class_methods ) ) - else: - class_methods_spec_list = ( + else: + class_methods_spec_list = ( _agent_engines_utils._generate_class_methods_spec_or_raise( agent=agent, operations=_agent_engines_utils._get_registered_operations( @@ -1766,12 +1893,12 @@ def _set_package_spec( ), ) ) - spec["class_methods"] = [ + spec["class_methods"] = [ _agent_engines_utils._to_dict(class_method_spec) for class_method_spec in class_methods_spec_list ] - def _create_config( + def _create_config( self, *, mode: str, @@ -1813,79 +1940,79 @@ def _create_config( ] = None, container_spec: Optional[types.ReasoningEngineSpecContainerSpecDict] = None, ) -> types.UpdateAgentEngineConfigDict: - import sys - - config: types.UpdateAgentEngineConfigDict = {} - update_masks = [] - if mode not in ["create", "update"]: - raise ValueError(f"Unsupported mode: {mode}") - if agent is None: - if requirements is not None: - raise ValueError("requirements must be None if agent is None.") - if extra_packages is not None: - raise ValueError("extra_packages must be None if agent is None.") - if display_name is not None: - update_masks.append("display_name") - config["display_name"] = display_name - if description is not None: - update_masks.append("description") - config["description"] = description - if context_spec is not None: - update_masks.append("context_spec") - config["context_spec"] = context_spec - if encryption_spec is not None: - update_masks.append("encryption_spec") - config["encryption_spec"] = encryption_spec - if labels is not None: - update_masks.append("labels") - config["labels"] = labels - - if agent_framework == "google-adk": - env_vars = _agent_engines_utils._add_telemetry_enablement_env(env_vars) - - if python_version: - sys_version = python_version - else: - sys_version = f"{sys.version_info.major}.{sys.version_info.minor}" - - if agent: - if source_packages: - raise ValueError( + import sys + + config: types.UpdateAgentEngineConfigDict = {} + update_masks = [] + if mode not in ["create", "update"]: + raise ValueError(f"Unsupported mode: {mode}") + if agent is None: + if requirements is not None: + raise ValueError("requirements must be None if agent is None.") + if extra_packages is not None: + raise ValueError("extra_packages must be None if agent is None.") + if display_name is not None: + update_masks.append("display_name") + config["display_name"] = display_name + if description is not None: + update_masks.append("description") + config["description"] = description + if context_spec is not None: + update_masks.append("context_spec") + config["context_spec"] = context_spec + if encryption_spec is not None: + update_masks.append("encryption_spec") + config["encryption_spec"] = encryption_spec + if labels is not None: + update_masks.append("labels") + config["labels"] = labels + + if agent_framework == "google-adk": + env_vars = _agent_engines_utils._add_telemetry_enablement_env(env_vars) + + if python_version: + sys_version = python_version + else: + sys_version = f"{sys.version_info.major}.{sys.version_info.minor}" + + if agent: + if source_packages: + raise ValueError( "If you have provided `source_packages` in `config`, please " "do not specify `agent` in `agent_engines.create()` or " "`agent_engines.update()`." ) - if developer_connect_source: - raise ValueError( + if developer_connect_source: + raise ValueError( "If you have provided `developer_connect_source` in `config`, please " "do not specify `agent` in `agent_engines.create()` or " "`agent_engines.update()`." ) - elif source_packages and developer_connect_source: - raise ValueError( + elif source_packages and developer_connect_source: + raise ValueError( "Please specify only one of `source_packages` or `developer_connect_source` in `config`." ) - if container_spec: - if agent: - raise ValueError( + if container_spec: + if agent: + raise ValueError( "If you have provided `container_spec` in `config`, please " "do not specify `agent` in `agent_engines.create()` or " "`agent_engines.update()`." ) - if source_packages or developer_connect_source: - raise ValueError( + if source_packages or developer_connect_source: + raise ValueError( "If you have provided `container_spec` in `config`, please " "do not specify `source_packages` or `developer_connect_source` in `config`." ) - agent_engine_spec: Any = None - if agent: - agent_engine_spec = {} - agent = _agent_engines_utils._validate_agent_or_raise(agent=agent) - if _agent_engines_utils._is_adk_agent(agent): - env_vars = _agent_engines_utils._add_telemetry_enablement_env(env_vars) - self._set_package_spec( + agent_engine_spec: Any = None + if agent: + agent_engine_spec = {} + agent = _agent_engines_utils._validate_agent_or_raise(agent=agent) + if _agent_engines_utils._is_adk_agent(agent): + env_vars = _agent_engines_utils._add_telemetry_enablement_env(env_vars) + self._set_package_spec( spec=agent_engine_spec, update_masks=update_masks, agent=agent, @@ -1897,14 +2024,14 @@ def _create_config( sys_version=sys_version, build_options=build_options, ) - elif ( + elif ( source_packages or developer_connect_source or image_spec or agent_config_source ): - agent_engine_spec = {} - self._set_source_code_spec( + agent_engine_spec = {} + self._set_source_code_spec( spec=agent_engine_spec, update_masks=update_masks, source_packages=source_packages, @@ -1918,23 +2045,23 @@ def _create_config( image_spec=image_spec, agent_config_source=agent_config_source, ) - elif container_spec: - agent_engine_spec = {} - if class_methods is not None: - update_masks.append("spec.class_methods") - class_methods_spec_list = ( + elif container_spec: + agent_engine_spec = {} + if class_methods is not None: + update_masks.append("spec.class_methods") + class_methods_spec_list = ( _agent_engines_utils._class_methods_to_class_methods_spec( class_methods=class_methods ) ) - agent_engine_spec["class_methods"] = [ + agent_engine_spec["class_methods"] = [ _agent_engines_utils._to_dict(class_method_spec) for class_method_spec in class_methods_spec_list ] - update_masks.append("spec.container_spec") - agent_engine_spec["container_spec"] = container_spec + update_masks.append("spec.container_spec") + agent_engine_spec["container_spec"] = container_spec - is_deployment_spec_updated = ( + is_deployment_spec_updated = ( env_vars is not None or psc_interface_config is not None or min_instances is not None @@ -1942,8 +2069,8 @@ def _create_config( or resource_limits is not None or container_concurrency is not None ) - if agent_engine_spec is None and is_deployment_spec_updated: - raise ValueError( + if agent_engine_spec is None and is_deployment_spec_updated: + raise ValueError( "To update `env_vars`, `psc_interface_config`, `min_instances`, " "`max_instances`, `resource_limits`, or `container_concurrency`, " "you must also provide the `agent` variable or the source code " @@ -1951,9 +2078,9 @@ def _create_config( "`agent_config_source`)." ) - if agent_engine_spec is not None: - if is_deployment_spec_updated: - ( + if agent_engine_spec is not None: + if is_deployment_spec_updated: + ( deployment_spec, deployment_update_masks, ) = self._generate_deployment_spec_or_raise( @@ -1964,59 +2091,59 @@ def _create_config( resource_limits=resource_limits, container_concurrency=container_concurrency, ) - update_masks.extend(deployment_update_masks) - agent_engine_spec["deployment_spec"] = deployment_spec + update_masks.extend(deployment_update_masks) + agent_engine_spec["deployment_spec"] = deployment_spec - if agent_server_mode: - if not agent_engine_spec.get("deployment_spec"): - agent_engine_spec["deployment_spec"] = ( + if agent_server_mode: + if not agent_engine_spec.get("deployment_spec"): + agent_engine_spec["deployment_spec"] = ( types.ReasoningEngineSpecDeploymentSpecDict() ) - agent_engine_spec["deployment_spec"][ + agent_engine_spec["deployment_spec"][ "agent_server_mode" ] = agent_server_mode - agent_engine_spec["agent_framework"] = ( + agent_engine_spec["agent_framework"] = ( _agent_engines_utils._get_agent_framework( agent_framework=agent_framework, agent=agent, ) ) - if hasattr(agent, "agent_card"): - agent_card = getattr(agent, "agent_card") - if agent_card: - try: - agent_engine_spec["agent_card"] = agent_card.model_dump( + if hasattr(agent, "agent_card"): + agent_card = getattr(agent, "agent_card") + if agent_card: + try: + agent_engine_spec["agent_card"] = agent_card.model_dump( exclude_none=True ) - except TypeError as e: - raise ValueError( + except TypeError as e: + raise ValueError( f"Failed to convert agent card to dict (serialization error): {e}" ) from e - update_masks.append("spec.agent_framework") + update_masks.append("spec.agent_framework") - if identity_type is not None or service_account is not None: - if agent_engine_spec is None: - agent_engine_spec = {} + if identity_type is not None or service_account is not None: + if agent_engine_spec is None: + agent_engine_spec = {} - if identity_type is not None: - agent_engine_spec["identity_type"] = identity_type - update_masks.append("spec.identity_type") - if service_account is not None: - # Clear the field in case of empty service_account. - if service_account: - agent_engine_spec["service_account"] = service_account - update_masks.append("spec.service_account") + if identity_type is not None: + agent_engine_spec["identity_type"] = identity_type + update_masks.append("spec.identity_type") + if service_account is not None: + # Clear the field in case of empty service_account. + if service_account: + agent_engine_spec["service_account"] = service_account + update_masks.append("spec.service_account") - if agent_engine_spec is not None: - config["spec"] = agent_engine_spec + if agent_engine_spec is not None: + config["spec"] = agent_engine_spec - if update_masks and mode == "update": - config["update_mask"] = ",".join(update_masks) - return config + if update_masks and mode == "update": + config["update_mask"] = ",".join(update_masks) + return config - def _generate_deployment_spec_or_raise( + def _generate_deployment_spec_or_raise( self, *, env_vars: Optional[dict[str, Union[str, Any]]] = None, @@ -2026,83 +2153,83 @@ def _generate_deployment_spec_or_raise( resource_limits: Optional[dict[str, str]] = None, container_concurrency: Optional[int] = None, ) -> Tuple[dict[str, Any], Sequence[str]]: - deployment_spec: dict[str, Any] = {} - update_masks = [] - if env_vars: - deployment_spec["env"] = [] - deployment_spec["secret_env"] = [] - if isinstance(env_vars, dict): - self._update_deployment_spec_with_env_vars_dict_or_raise( + deployment_spec: dict[str, Any] = {} + update_masks = [] + if env_vars: + deployment_spec["env"] = [] + deployment_spec["secret_env"] = [] + if isinstance(env_vars, dict): + self._update_deployment_spec_with_env_vars_dict_or_raise( deployment_spec=deployment_spec, env_vars=env_vars, ) - else: - raise TypeError(f"env_vars must be a dict, but got {type(env_vars)}.") - if deployment_spec.get("env"): - update_masks.append("spec.deployment_spec.env") - if deployment_spec.get("secret_env"): - update_masks.append("spec.deployment_spec.secret_env") - if psc_interface_config: - deployment_spec["psc_interface_config"] = psc_interface_config - update_masks.append("spec.deployment_spec.psc_interface_config") - if min_instances is not None: - if not 0 <= min_instances <= 10: - raise ValueError( + else: + raise TypeError(f"env_vars must be a dict, but got {type(env_vars)}.") + if deployment_spec.get("env"): + update_masks.append("spec.deployment_spec.env") + if deployment_spec.get("secret_env"): + update_masks.append("spec.deployment_spec.secret_env") + if psc_interface_config: + deployment_spec["psc_interface_config"] = psc_interface_config + update_masks.append("spec.deployment_spec.psc_interface_config") + if min_instances is not None: + if not 0 <= min_instances <= 10: + raise ValueError( f"min_instances must be between 0 and 10. Got {min_instances}" ) - deployment_spec["min_instances"] = min_instances - update_masks.append("spec.deployment_spec.min_instances") - if max_instances is not None: - if psc_interface_config and not 1 <= max_instances <= 100: - raise ValueError( + deployment_spec["min_instances"] = min_instances + update_masks.append("spec.deployment_spec.min_instances") + if max_instances is not None: + if psc_interface_config and not 1 <= max_instances <= 100: + raise ValueError( f"max_instances must be between 1 and 100 when PSC-I is enabled. Got {max_instances}" ) - elif not psc_interface_config and not 1 <= max_instances <= 1000: - raise ValueError( + elif not psc_interface_config and not 1 <= max_instances <= 1000: + raise ValueError( f"max_instances must be between 1 and 1000. Got {max_instances}" ) - deployment_spec["max_instances"] = max_instances - update_masks.append("spec.deployment_spec.max_instances") - if resource_limits: - _agent_engines_utils._validate_resource_limits_or_raise( + deployment_spec["max_instances"] = max_instances + update_masks.append("spec.deployment_spec.max_instances") + if resource_limits: + _agent_engines_utils._validate_resource_limits_or_raise( resource_limits=resource_limits ) - deployment_spec["resource_limits"] = resource_limits - update_masks.append("spec.deployment_spec.resource_limits") - if container_concurrency: - deployment_spec["container_concurrency"] = container_concurrency - update_masks.append("spec.deployment_spec.container_concurrency") - return deployment_spec, update_masks - - def _update_deployment_spec_with_env_vars_dict_or_raise( + deployment_spec["resource_limits"] = resource_limits + update_masks.append("spec.deployment_spec.resource_limits") + if container_concurrency: + deployment_spec["container_concurrency"] = container_concurrency + update_masks.append("spec.deployment_spec.container_concurrency") + return deployment_spec, update_masks + + def _update_deployment_spec_with_env_vars_dict_or_raise( self, *, deployment_spec: dict[str, Any], env_vars: dict[str, Any], ) -> None: - for key, value in env_vars.items(): - if isinstance(value, dict): - if "secret_env" not in deployment_spec: - deployment_spec["secret_env"] = [] - deployment_spec["secret_env"].append({"name": key, "secret_ref": value}) - elif isinstance(value, str): - if "env" not in deployment_spec: - deployment_spec["env"] = [] - deployment_spec["env"].append({"name": key, "value": value}) - else: - raise TypeError( + for key, value in env_vars.items(): + if isinstance(value, dict): + if "secret_env" not in deployment_spec: + deployment_spec["secret_env"] = [] + deployment_spec["secret_env"].append({"name": key, "secret_ref": value}) + elif isinstance(value, str): + if "env" not in deployment_spec: + deployment_spec["env"] = [] + deployment_spec["env"].append({"name": key, "value": value}) + else: + raise TypeError( f"Unknown value type in env_vars for {key}. " f"Must be a str or SecretRef: {value}" ) - def _register_api_methods( + def _register_api_methods( self, *, agent_engine: types.AgentEngine, ) -> types.AgentEngine: - """Registers the API methods for the agent engine.""" - try: - _agent_engines_utils._register_api_methods_or_raise( + """Registers the API methods for the agent engine.""" + try: + _agent_engines_utils._register_api_methods_or_raise( agent_engine=agent_engine, wrap_operation_fn={ "": _agent_engines_utils._wrap_query_operation, # type: ignore[dict-item] @@ -2112,16 +2239,16 @@ def _register_api_methods( "a2a_extension": _agent_engines_utils._wrap_a2a_operation, }, ) - except Exception as e: - logger.warning( + except Exception as e: + logger.warning( _agent_engines_utils._FAILED_TO_REGISTER_API_METHODS_WARNING_TEMPLATE, e ) - return agent_engine + return agent_engine - def list( + def list( self, *, config: Optional[types.ListAgentEngineConfigOrDict] = None ) -> Iterator[types.AgentEngine]: - """List all instances of Agent Engine matching the filter. + """List all instances of Agent Engine matching the filter. Example Usage: @@ -2142,14 +2269,14 @@ def list( Iterable[AgentEngine]: An iterable of Agent Engines matching the filter. """ - for reasoning_engine in self._list_pager(config=config): - yield types.AgentEngine( + for reasoning_engine in self._list_pager(config=config): + yield types.AgentEngine( api_client=self, api_async_client=AsyncAgentEngines(api_client_=self._api_client), api_resource=reasoning_engine, ) - def update( + def update( self, *, name: str, @@ -2157,7 +2284,7 @@ def update( agent_engine: Any = None, config: types.AgentEngineConfigOrDict, ) -> types.AgentEngine: - """Updates an existing Agent Engine. + """Updates an existing Agent Engine. This method updates the configuration of an existing Agent Engine running remotely, which is identified by its name. @@ -2195,40 +2322,40 @@ def update( IOError: If `config.requirements` is a string that corresponds to a nonexistent file. """ - if isinstance(config, dict): - config = types.AgentEngineConfig.model_validate(config) - elif not isinstance(config, types.AgentEngineConfig): - raise TypeError( + if isinstance(config, dict): + config = types.AgentEngineConfig.model_validate(config) + elif not isinstance(config, types.AgentEngineConfig): + raise TypeError( f"config must be a dict or AgentEngineConfig, but got {type(config)}." ) - context_spec = config.context_spec - if context_spec is not None: - # Conversion to a dict for _create_config - context_spec = json.loads(context_spec.model_dump_json()) - developer_connect_source = config.developer_connect_source - if developer_connect_source is not None: - developer_connect_source = json.loads( + context_spec = config.context_spec + if context_spec is not None: + # Conversion to a dict for _create_config + context_spec = json.loads(context_spec.model_dump_json()) + developer_connect_source = config.developer_connect_source + if developer_connect_source is not None: + developer_connect_source = json.loads( developer_connect_source.model_dump_json() ) - agent_config_source = config.agent_config_source - if agent_config_source is not None: - agent_config_source = json.loads(agent_config_source.model_dump_json()) - if agent and agent_engine: - raise ValueError("Please specify only one of `agent` or `agent_engine`.") - elif agent_engine: - raise DeprecationWarning( + agent_config_source = config.agent_config_source + if agent_config_source is not None: + agent_config_source = json.loads(agent_config_source.model_dump_json()) + if agent and agent_engine: + raise ValueError("Please specify only one of `agent` or `agent_engine`.") + elif agent_engine: + raise DeprecationWarning( "The `agent_engine` argument is deprecated. Please use `agent` instead." ) - image_spec = config.image_spec - if image_spec is not None: - # Conversion to a dict for _create_config - image_spec = json.loads(image_spec.model_dump_json()) - container_spec = config.container_spec - if container_spec is not None: - # Conversion to a dict for _create_config - container_spec = json.loads(container_spec.model_dump_json()) - agent = agent or agent_engine - api_config = self._create_config( + image_spec = config.image_spec + if image_spec is not None: + # Conversion to a dict for _create_config + image_spec = json.loads(image_spec.model_dump_json()) + container_spec = config.container_spec + if container_spec is not None: + # Conversion to a dict for _create_config + container_spec = json.loads(container_spec.model_dump_json()) + agent = agent or agent_engine + api_config = self._create_config( mode="update", agent=agent, identity_type=config.identity_type, @@ -2260,109 +2387,109 @@ def update( agent_config_source=agent_config_source, container_spec=container_spec, ) - operation = self._update(name=name, config=api_config) - reasoning_engine_id = _agent_engines_utils._get_reasoning_engine_id( + operation = self._update(name=name, config=api_config) + reasoning_engine_id = _agent_engines_utils._get_reasoning_engine_id( resource_name=name ) - logger.info( + logger.info( "View progress and logs at https://console.cloud.google.com/logs/query?" f"project={self._api_client.project}" "&query=resource.type%3D%22aiplatform.googleapis.com%2FReasoningEngine%22%0A" f"resource.labels.reasoning_engine_id%3D%22{reasoning_engine_id}%22." ) - operation = _agent_engines_utils._await_operation( + operation = _agent_engines_utils._await_operation( operation_name=operation.name, get_operation_fn=self._get_agent_operation, ) - agent_engine = types.AgentEngine( + agent_engine = types.AgentEngine( api_client=self, api_async_client=AsyncAgentEngines(api_client_=self._api_client), api_resource=operation.response, ) - if agent_engine.api_resource: - logger.info("Agent Engine updated. To use it in another session:") - logger.info( + if agent_engine.api_resource: + logger.info("Agent Engine updated. To use it in another session:") + logger.info( f"agent_engine=client.agent_engines.get(name='{agent_engine.api_resource.name}')" ) - elif operation.error: - raise RuntimeError(f"Failed to update Agent Engine: {operation.error}") - if agent_engine.api_resource.spec: - self._register_api_methods(agent_engine=agent_engine) - return agent_engine # type: ignore[no-any-return] + elif operation.error: + raise RuntimeError(f"Failed to update Agent Engine: {operation.error}") + if agent_engine.api_resource.spec: + self._register_api_methods(agent_engine=agent_engine) + return agent_engine # type: ignore[no-any-return] - def _stream_query( + def _stream_query( self, *, name: str, config: Optional[types.QueryAgentEngineConfigOrDict] = None ) -> Iterator[Any]: - """Streams the response of the agent engine.""" - parameter_model = types._QueryAgentEngineRequestParameters( + """Streams the response of the agent engine.""" + parameter_model = types._QueryAgentEngineRequestParameters( name=name, config=config, ) - request_dict = _QueryAgentEngineRequestParameters_to_vertex(parameter_model) - request_url_dict = request_dict.get("_url") - if request_url_dict: - path = "{name}:streamQuery?alt=sse".format_map(request_url_dict) - else: - path = "{name}:streamQuery?alt=sse" - query_params = request_dict.get("_query") - if query_params: - path = f"{path}?{urlencode(query_params)}" - # TODO: remove the hack that pops config. - request_dict.pop("config", None) - http_options = None - if ( + request_dict = _QueryAgentEngineRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:streamQuery?alt=sse".format_map(request_url_dict) + else: + path = "{name}:streamQuery?alt=sse" + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + http_options = None + if ( parameter_model.config is not None and parameter_model.config.http_options is not None ): - http_options = parameter_model.config.http_options + http_options = parameter_model.config.http_options - request_dict = _common.convert_to_dict(request_dict) - request_dict = _common.encode_unserializable_types(request_dict) - for response in self._api_client.request_streamed( + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + for response in self._api_client.request_streamed( "post", path, request_dict, http_options ): - yield response + yield response - # TODO: b/436704146 - Replace with generated methods - # TODO: b/437129724 - Add replay test for async stream query - async def _async_stream_query( + # TODO: b/436704146 - Replace with generated methods + # TODO: b/437129724 - Add replay test for async stream query + async def _async_stream_query( self, *, name: str, config: Optional[types.QueryAgentEngineConfigOrDict] = None, ) -> AsyncIterator[Any]: - """Streams the response of the agent engine asynchronously.""" - parameter_model = types._QueryAgentEngineRequestParameters( + """Streams the response of the agent engine asynchronously.""" + parameter_model = types._QueryAgentEngineRequestParameters( name=name, config=config, ) - request_dict = _QueryAgentEngineRequestParameters_to_vertex(parameter_model) - request_url_dict = request_dict.get("_url") - if request_url_dict: - path = "{name}:streamQuery?alt=sse".format_map(request_url_dict) - else: - path = "{name}:streamQuery?alt=sse" - query_params = request_dict.get("_query") - if query_params: - path = f"{path}?{urlencode(query_params)}" - # TODO: remove the hack that pops config. - request_dict.pop("config", None) - http_options = None - if ( + request_dict = _QueryAgentEngineRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:streamQuery?alt=sse".format_map(request_url_dict) + else: + path = "{name}:streamQuery?alt=sse" + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + http_options = None + if ( parameter_model.config is not None and parameter_model.config.http_options is not None ): - http_options = parameter_model.config.http_options + http_options = parameter_model.config.http_options - request_dict = _common.convert_to_dict(request_dict) - request_dict = _common.encode_unserializable_types(request_dict) - async_iterator = await self._api_client.async_request_streamed( + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + async_iterator = await self._api_client.async_request_streamed( "post", path, request_dict, http_options ) - async for response in async_iterator: - yield response + async for response in async_iterator: + yield response - def create_memory( + def create_memory( self, *, name: str, @@ -2370,8 +2497,8 @@ def create_memory( scope: dict[str, str], config: Optional[types.AgentEngineMemoryConfigOrDict] = None, ) -> types.AgentEngineMemoryOperation: - """Deprecated. Use agent_engines.memories.create instead.""" - warnings.warn( + """Deprecated. Use agent_engines.memories.create instead.""" + warnings.warn( ( "agent_engines.create_memory is deprecated. " "Use agent_engines.memories.create instead." @@ -2379,21 +2506,21 @@ def create_memory( DeprecationWarning, stacklevel=2, ) - return self.memories.create( + return self.memories.create( name=name, fact=fact, scope=scope, config=config, ) - def delete_memory( + def delete_memory( self, *, name: str, config: Optional[types.DeleteAgentEngineMemoryConfigOrDict] = None, ) -> types.DeleteAgentEngineMemoryOperation: - """Deprecated. Use agent_engines.memories.delete instead.""" - warnings.warn( + """Deprecated. Use agent_engines.memories.delete instead.""" + warnings.warn( ( "agent_engines.delete_memory is deprecated. " "Use agent_engines.memories.delete instead." @@ -2401,9 +2528,9 @@ def delete_memory( DeprecationWarning, stacklevel=2, ) - return self.memories.delete(name=name, config=config) + return self.memories.delete(name=name, config=config) - def generate_memories( + def generate_memories( self, *, name: str, @@ -2419,8 +2546,8 @@ def generate_memories( scope: Optional[dict[str, str]] = None, config: Optional[types.GenerateAgentEngineMemoriesConfigOrDict] = None, ) -> types.AgentEngineGenerateMemoriesOperation: - """Deprecated. Use agent_engines.memories.generate instead.""" - warnings.warn( + """Deprecated. Use agent_engines.memories.generate instead.""" + warnings.warn( ( "agent_engines.generate_memories is deprecated. " "Use agent_engines.memories.generate instead." @@ -2428,7 +2555,7 @@ def generate_memories( DeprecationWarning, stacklevel=2, ) - return self.memories.generate( + return self.memories.generate( name=name, vertex_session_source=vertex_session_source, direct_contents_source=direct_contents_source, @@ -2437,14 +2564,14 @@ def generate_memories( config=config, ) - def get_memory( + def get_memory( self, *, name: str, config: Optional[types.GetAgentEngineMemoryConfigOrDict] = None, ) -> types.Memory: - """Deprecated. Use agent_engines.memories.get instead.""" - warnings.warn( + """Deprecated. Use agent_engines.memories.get instead.""" + warnings.warn( ( "agent_engines.get_memory is deprecated. " "Use agent_engines.memories.get instead." @@ -2452,16 +2579,16 @@ def get_memory( DeprecationWarning, stacklevel=2, ) - return self.memories.get(name=name, config=config) + return self.memories.get(name=name, config=config) - def list_memories( + def list_memories( self, *, name: str, config: Optional[types.ListAgentEngineMemoryConfigOrDict] = None, ) -> Iterator[types.Memory]: - """Deprecated. Use agent_engines.memories.list instead.""" - warnings.warn( + """Deprecated. Use agent_engines.memories.list instead.""" + warnings.warn( ( "agent_engines.list_memories is deprecated. " "Use agent_engines.memories.list instead." @@ -2469,9 +2596,9 @@ def list_memories( DeprecationWarning, stacklevel=2, ) - return self.memories.list(name=name, config=config) + return self.memories.list(name=name, config=config) - def retrieve_memories( + def retrieve_memories( self, *, name: str, @@ -2484,8 +2611,8 @@ def retrieve_memories( ] = None, config: Optional[types.RetrieveAgentEngineMemoriesConfigOrDict] = None, ) -> Iterator[types.RetrieveMemoriesResponseRetrievedMemory]: - """Deprecated. Use agent_engines.memories.retrieve instead.""" - warnings.warn( + """Deprecated. Use agent_engines.memories.retrieve instead.""" + warnings.warn( ( "agent_engines.retrieve_memories is deprecated. " "Use agent_engines.memories.retrieve instead." @@ -2493,7 +2620,7 @@ def retrieve_memories( DeprecationWarning, stacklevel=2, ) - return self.memories.retrieve( + return self.memories.retrieve( name=name, scope=scope, similarity_search_params=similarity_search_params, @@ -2501,15 +2628,15 @@ def retrieve_memories( config=config, ) - def create_session( + def create_session( self, *, name: str, user_id: str, config: Optional[types.CreateAgentEngineSessionConfigOrDict] = None, ) -> types.AgentEngineSessionOperation: - """Deprecated. Use agent_engines.sessions.create instead.""" - warnings.warn( + """Deprecated. Use agent_engines.sessions.create instead.""" + warnings.warn( ( "agent_engines.create_session is deprecated. " "Use agent_engines.sessions.create instead." @@ -2517,16 +2644,16 @@ def create_session( DeprecationWarning, stacklevel=2, ) - return self.sessions.create(name=name, user_id=user_id, config=config) + return self.sessions.create(name=name, user_id=user_id, config=config) - def delete_session( + def delete_session( self, *, name: str, config: Optional[types.DeleteAgentEngineSessionConfigOrDict] = None, ) -> types.DeleteAgentEngineSessionOperation: - """Deprecated. Use agent_engines.sessions.delete instead.""" - warnings.warn( + """Deprecated. Use agent_engines.sessions.delete instead.""" + warnings.warn( ( "agent_engines.delete_session is deprecated. " "Use agent_engines.sessions.delete instead." @@ -2534,16 +2661,16 @@ def delete_session( DeprecationWarning, stacklevel=2, ) - return self.sessions.delete(name=name, config=config) + return self.sessions.delete(name=name, config=config) - def get_session( + def get_session( self, *, name: str, config: Optional[types.GetAgentEngineSessionConfigOrDict] = None, ) -> types.Session: - """Deprecated. Use agent_engines.sessions.get instead.""" - warnings.warn( + """Deprecated. Use agent_engines.sessions.get instead.""" + warnings.warn( ( "agent_engines.get_session is deprecated. " "Use agent_engines.sessions.get instead." @@ -2551,16 +2678,16 @@ def get_session( DeprecationWarning, stacklevel=2, ) - return self.sessions.get(name=name, config=config) + return self.sessions.get(name=name, config=config) - def list_sessions( + def list_sessions( self, *, name: str, config: Optional[types.ListAgentEngineSessionsConfigOrDict] = None, ) -> Iterator[types.Session]: - """Deprecated. Use agent_engines.sessions.list instead.""" - warnings.warn( + """Deprecated. Use agent_engines.sessions.list instead.""" + warnings.warn( ( "agent_engines.list_sessions is deprecated. " "Use agent_engines.sessions.list instead." @@ -2568,9 +2695,9 @@ def list_sessions( DeprecationWarning, stacklevel=2, ) - return self.sessions.list(name=name, config=config) + return self.sessions.list(name=name, config=config) - def append_session_event( + def append_session_event( self, *, name: str, @@ -2579,8 +2706,8 @@ def append_session_event( timestamp: datetime.datetime, config: Optional[types.AppendAgentEngineSessionEventConfigOrDict] = None, ) -> types.AppendAgentEngineSessionEventResponse: - """Deprecated. Use agent_engines.sessions.events.append instead.""" - warnings.warn( + """Deprecated. Use agent_engines.sessions.events.append instead.""" + warnings.warn( ( "agent_engines.append_session_event is deprecated. " "Use agent_engines.sessions.events.append instead." @@ -2588,7 +2715,7 @@ def append_session_event( DeprecationWarning, stacklevel=2, ) - return self.sessions.events.append( + return self.sessions.events.append( name=name, author=author, invocation_id=invocation_id, @@ -2596,14 +2723,14 @@ def append_session_event( config=config, ) - def list_session_events( + def list_session_events( self, *, name: str, config: Optional[types.ListAgentEngineSessionEventsConfigOrDict] = None, ) -> Iterator[types.SessionEvent]: - """Deprecated. Use agent_engines.sessions.events.list instead.""" - warnings.warn( + """Deprecated. Use agent_engines.sessions.events.list instead.""" + warnings.warn( ( "agent_engines.list_session_events is deprecated. " "Use agent_engines.sessions.events.list instead." @@ -2611,7 +2738,7 @@ def list_session_events( DeprecationWarning, stacklevel=2, ) - return self.sessions.events.list(name=name, config=config) + return self.sessions.events.list(name=name, config=config) class AsyncAgentEngines(_api_module.BaseModule): diff --git a/vertexai/_genai/types/__init__.py b/vertexai/_genai/types/__init__.py index f293f22bc7..711c2b96cf 100644 --- a/vertexai/_genai/types/__init__.py +++ b/vertexai/_genai/types/__init__.py @@ -26,6 +26,7 @@ from .common import _AppendAgentEngineTaskEventRequestParameters from .common import _AssembleDatasetParameters from .common import _AssessDatasetParameters +from .common import _CancelQueryJobAgentEngineRequestParameters from .common import _CheckQueryJobAgentEngineRequestParameters from .common import _CreateAgentEngineMemoryRequestParameters from .common import _CreateAgentEngineRequestParameters @@ -192,6 +193,12 @@ from .common import BleuResults from .common import BleuResultsDict from .common import BleuResultsOrDict +from .common import CancelQueryJobAgentEngineConfig +from .common import CancelQueryJobAgentEngineConfigDict +from .common import CancelQueryJobAgentEngineConfigOrDict +from .common import CancelQueryJobResult +from .common import CancelQueryJobResultDict +from .common import CancelQueryJobResultOrDict from .common import CandidateResponse from .common import CandidateResponseDict from .common import CandidateResponseOrDict @@ -1654,6 +1661,12 @@ "CheckQueryJobResult", "CheckQueryJobResultDict", "CheckQueryJobResultOrDict", + "CancelQueryJobAgentEngineConfig", + "CancelQueryJobAgentEngineConfigDict", + "CancelQueryJobAgentEngineConfigOrDict", + "CancelQueryJobResult", + "CancelQueryJobResultDict", + "CancelQueryJobResultOrDict", "_RunQueryJobAgentEngineConfig", "_RunQueryJobAgentEngineConfigDict", "_RunQueryJobAgentEngineConfigOrDict", diff --git a/vertexai/_genai/types/common.py b/vertexai/_genai/types/common.py index d017fda1d1..a45d637cf6 100644 --- a/vertexai/_genai/types/common.py +++ b/vertexai/_genai/types/common.py @@ -6323,6 +6323,80 @@ class CheckQueryJobResultDict(TypedDict, total=False): CheckQueryJobResultOrDict = Union[CheckQueryJobResult, CheckQueryJobResultDict] +class CancelQueryJobAgentEngineConfig(_common.BaseModel): + """Config for canceling a query job on an agent engine.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class CancelQueryJobAgentEngineConfigDict(TypedDict, total=False): + """Config for canceling a query job on an agent engine.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +CancelQueryJobAgentEngineConfigOrDict = Union[ + CancelQueryJobAgentEngineConfig, CancelQueryJobAgentEngineConfigDict +] + + +class _CancelQueryJobAgentEngineRequestParameters(_common.BaseModel): + """Parameters for canceling a query job on an agent engine.""" + + name: Optional[str] = Field( + default=None, description="""Name of the reasoning engine resource.""" + ) + operation_name: Optional[str] = Field( + default=None, + description="""Name of the longrunning operation returned from run_query_job.""", + ) + config: Optional[CancelQueryJobAgentEngineConfig] = Field( + default=None, description="""""" + ) + + +class _CancelQueryJobAgentEngineRequestParametersDict(TypedDict, total=False): + """Parameters for canceling a query job on an agent engine.""" + + name: Optional[str] + """Name of the reasoning engine resource.""" + + operation_name: Optional[str] + """Name of the longrunning operation returned from run_query_job.""" + + config: Optional[CancelQueryJobAgentEngineConfigDict] + """""" + + +_CancelQueryJobAgentEngineRequestParametersOrDict = Union[ + _CancelQueryJobAgentEngineRequestParameters, + _CancelQueryJobAgentEngineRequestParametersDict, +] + + +class CancelQueryJobResult(_common.BaseModel): + """Result of canceling a query job.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class CancelQueryJobResultDict(TypedDict, total=False): + """Result of canceling a query job.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +CancelQueryJobResultOrDict = Union[ + CancelQueryJobResult, CancelQueryJobResultDict +] + + class _RunQueryJobAgentEngineConfig(_common.BaseModel): """Config for running a query job on an agent engine."""