diff --git a/tests/unit/vertexai/genai/test_agent_engines.py b/tests/unit/vertexai/genai/test_agent_engines.py index 2f954e709a..5db086e0ad 100644 --- a/tests/unit/vertexai/genai/test_agent_engines.py +++ b/tests/unit/vertexai/genai/test_agent_engines.py @@ -3162,6 +3162,28 @@ def test_query_agent_engine_async(self): None, ) + def test_cancel_query_job_agent_engine(self): + with mock.patch.object( + self.client.agent_engines._api_client, "request" + ) as request_mock: + request_mock.return_value = genai_types.HttpResponse(body="{}") + + result = self.client.agent_engines.cancel_query_job( + name=_TEST_AGENT_ENGINE_RESOURCE_NAME, + operation_name=_TEST_AGENT_ENGINE_OPERATION_NAME, + ) + + assert isinstance(result, _genai_types.CancelQueryJobResult) + request_mock.assert_called_with( + "post", + f"{_TEST_AGENT_ENGINE_RESOURCE_NAME}:cancelAsyncQuery", + { + "_url": {"name": _TEST_AGENT_ENGINE_RESOURCE_NAME}, + "operationName": _TEST_AGENT_ENGINE_OPERATION_NAME, + }, + None, + ) + def test_check_query_job_agent_engine(self): with mock.patch.object( self.client.agent_engines._api_client, "request" diff --git a/vertexai/_genai/agent_engines.py b/vertexai/_genai/agent_engines.py index 4272e5b34e..b546410814 100644 --- a/vertexai/_genai/agent_engines.py +++ b/vertexai/_genai/agent_engines.py @@ -49,6 +49,47 @@ logger.setLevel(logging.INFO) +def _CancelQueryJobAgentEngineConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + return to_object + + +def _CancelQueryJobAgentEngineRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["operation_name"]) is not None: + setv(to_object, ["operationName"], getv(from_object, ["operation_name"])) + + if getv(from_object, ["config"]) is not None: + setv( + to_object, + ["config"], + _CancelQueryJobAgentEngineConfig_to_vertex( + getv(from_object, ["config"]), to_object + ), + ) + + return to_object + + +def _CancelQueryJobResult_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + return to_object + + def _CheckQueryJobAgentEngineConfig_to_vertex( from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, @@ -396,15 +437,88 @@ def _UpdateAgentEngineRequestParameters_to_vertex( class AgentEngines(_api_module.BaseModule): + def _cancel_query_job( + self, + *, + name: str, + operation_name: str, + config: Optional[types.CancelQueryJobAgentEngineConfigOrDict] = None, + ) -> types.CancelQueryJobResult: + """Cancels a long-running query job on an Agent Engine.""" + + parameter_model = types._CancelQueryJobAgentEngineRequestParameters( + name=name, + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _CancelQueryJobAgentEngineRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:cancelAsyncQuery".format_map(request_url_dict) + else: + path = "{name}:cancelAsyncQuery" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _CancelQueryJobResult_from_vertex(response_dict) + + return_value = types.CancelQueryJobResult._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + def _check_query_job( self, *, name: str, config: Optional[types.CheckQueryJobAgentEngineConfigOrDict] = None, ) -> types.CheckQueryJobResult: - """ - Query an Agent Engine asynchronously. - """ + """Query an Agent Engine asynchronously.""" parameter_model = types._CheckQueryJobAgentEngineRequestParameters( name=name, @@ -1108,6 +1222,17 @@ def _list_pager( config, ) + def cancel_query_job( + self, + *, + name: str, + operation_name: str, + config: Optional[types.CancelQueryJobAgentEngineConfigOrDict] = None, + ) -> types.CancelQueryJobResult: + return self._cancel_query_job( + name=name, operation_name=operation_name, config=config + ) + def check_query_job( self, *, diff --git a/vertexai/_genai/types/__init__.py b/vertexai/_genai/types/__init__.py index 4db9b8e44e..8630b5dc4a 100644 --- a/vertexai/_genai/types/__init__.py +++ b/vertexai/_genai/types/__init__.py @@ -26,6 +26,7 @@ from .common import _AppendAgentEngineTaskEventRequestParameters from .common import _AssembleDatasetParameters from .common import _AssessDatasetParameters +from .common import _CancelQueryJobAgentEngineRequestParameters from .common import _CheckQueryJobAgentEngineRequestParameters from .common import _CreateAgentEngineMemoryRequestParameters from .common import _CreateAgentEngineRequestParameters @@ -191,6 +192,12 @@ from .common import BleuResults from .common import BleuResultsDict from .common import BleuResultsOrDict +from .common import CancelQueryJobAgentEngineConfig +from .common import CancelQueryJobAgentEngineConfigDict +from .common import CancelQueryJobAgentEngineConfigOrDict +from .common import CancelQueryJobResult +from .common import CancelQueryJobResultDict +from .common import CancelQueryJobResultOrDict from .common import CandidateResponse from .common import CandidateResponseDict from .common import CandidateResponseOrDict @@ -1647,6 +1654,12 @@ "CheckQueryJobResult", "CheckQueryJobResultDict", "CheckQueryJobResultOrDict", + "CancelQueryJobAgentEngineConfig", + "CancelQueryJobAgentEngineConfigDict", + "CancelQueryJobAgentEngineConfigOrDict", + "CancelQueryJobResult", + "CancelQueryJobResultDict", + "CancelQueryJobResultOrDict", "_RunQueryJobAgentEngineConfig", "_RunQueryJobAgentEngineConfigDict", "_RunQueryJobAgentEngineConfigOrDict", diff --git a/vertexai/_genai/types/common.py b/vertexai/_genai/types/common.py index 535fc9cb26..34fddb0998 100644 --- a/vertexai/_genai/types/common.py +++ b/vertexai/_genai/types/common.py @@ -6277,6 +6277,78 @@ class CheckQueryJobResultDict(TypedDict, total=False): CheckQueryJobResultOrDict = Union[CheckQueryJobResult, CheckQueryJobResultDict] +class CancelQueryJobAgentEngineConfig(_common.BaseModel): + """Config for canceling a query job on an agent engine.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class CancelQueryJobAgentEngineConfigDict(TypedDict, total=False): + """Config for canceling a query job on an agent engine.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +CancelQueryJobAgentEngineConfigOrDict = Union[ + CancelQueryJobAgentEngineConfig, CancelQueryJobAgentEngineConfigDict +] + + +class _CancelQueryJobAgentEngineRequestParameters(_common.BaseModel): + """Parameters for canceling a query job on an agent engine.""" + + name: Optional[str] = Field( + default=None, description="""Name of the reasoning engine resource.""" + ) + operation_name: Optional[str] = Field( + default=None, + description="""Name of the longrunning operation returned from run_query_job.""", + ) + config: Optional[CancelQueryJobAgentEngineConfig] = Field( + default=None, description="""""" + ) + + +class _CancelQueryJobAgentEngineRequestParametersDict(TypedDict, total=False): + """Parameters for canceling a query job on an agent engine.""" + + name: Optional[str] + """Name of the reasoning engine resource.""" + + operation_name: Optional[str] + """Name of the longrunning operation returned from run_query_job.""" + + config: Optional[CancelQueryJobAgentEngineConfigDict] + """""" + + +_CancelQueryJobAgentEngineRequestParametersOrDict = Union[ + _CancelQueryJobAgentEngineRequestParameters, + _CancelQueryJobAgentEngineRequestParametersDict, +] + + +class CancelQueryJobResult(_common.BaseModel): + """Result of canceling a query job.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class CancelQueryJobResultDict(TypedDict, total=False): + """Result of canceling a query job.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +CancelQueryJobResultOrDict = Union[CancelQueryJobResult, CancelQueryJobResultDict] + + class _RunQueryJobAgentEngineConfig(_common.BaseModel): """Config for running a query job on an agent engine."""