Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tests/unit/vertexai/genai/test_agent_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -3162,6 +3162,28 @@ def test_query_agent_engine_async(self):
None,
)

def test_cancel_query_job_agent_engine(self):
with mock.patch.object(
self.client.agent_engines._api_client, "request"
) as request_mock:
request_mock.return_value = genai_types.HttpResponse(body="{}")

result = self.client.agent_engines.cancel_query_job(
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
operation_name=_TEST_AGENT_ENGINE_OPERATION_NAME,
)

assert isinstance(result, _genai_types.CancelQueryJobResult)
request_mock.assert_called_with(
"post",
f"{_TEST_AGENT_ENGINE_RESOURCE_NAME}:cancelAsyncQuery",
{
"_url": {"name": _TEST_AGENT_ENGINE_RESOURCE_NAME},
"operationName": _TEST_AGENT_ENGINE_OPERATION_NAME,
},
None,
)

def test_check_query_job_agent_engine(self):
with mock.patch.object(
self.client.agent_engines._api_client, "request"
Expand Down
131 changes: 128 additions & 3 deletions vertexai/_genai/agent_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,47 @@
logger.setLevel(logging.INFO)


def _CancelQueryJobAgentEngineConfig_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}

return to_object


def _CancelQueryJobAgentEngineRequestParameters_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ["name"]) is not None:
setv(to_object, ["_url", "name"], getv(from_object, ["name"]))

if getv(from_object, ["operation_name"]) is not None:
setv(to_object, ["operationName"], getv(from_object, ["operation_name"]))

if getv(from_object, ["config"]) is not None:
setv(
to_object,
["config"],
_CancelQueryJobAgentEngineConfig_to_vertex(
getv(from_object, ["config"]), to_object
),
)

return to_object


def _CancelQueryJobResult_from_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}

return to_object


def _CheckQueryJobAgentEngineConfig_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
Expand Down Expand Up @@ -396,15 +437,88 @@ def _UpdateAgentEngineRequestParameters_to_vertex(

class AgentEngines(_api_module.BaseModule):

def _cancel_query_job(
self,
*,
name: str,
operation_name: str,
config: Optional[types.CancelQueryJobAgentEngineConfigOrDict] = None,
) -> types.CancelQueryJobResult:
"""Cancels a long-running query job on an Agent Engine."""

parameter_model = types._CancelQueryJobAgentEngineRequestParameters(
name=name,
operation_name=operation_name,
config=config,
)

request_url_dict: Optional[dict[str, str]]
if not self._api_client.vertexai:
raise ValueError("This method is only supported in the Vertex AI client.")
else:
request_dict = _CancelQueryJobAgentEngineRequestParameters_to_vertex(
parameter_model
)
request_url_dict = request_dict.get("_url")
if request_url_dict:
path = "{name}:cancelAsyncQuery".format_map(request_url_dict)
else:
path = "{name}:cancelAsyncQuery"

query_params = request_dict.get("_query")
if query_params:
path = f"{path}?{urlencode(query_params)}"
# TODO: remove the hack that pops config.
request_dict.pop("config", None)

http_options: Optional[types.HttpOptions] = None
if (
parameter_model.config is not None
and parameter_model.config.http_options is not None
):
http_options = parameter_model.config.http_options

request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.encode_unserializable_types(request_dict)

response = self._api_client.request("post", path, request_dict, http_options)

response_dict = {} if not response.body else json.loads(response.body)

if self._api_client.vertexai:
response_dict = _CancelQueryJobResult_from_vertex(response_dict)

return_value = types.CancelQueryJobResult._from_response(
response=response_dict,
kwargs=(
{
"config": {
"response_schema": getattr(
parameter_model.config, "response_schema", None
),
"response_json_schema": getattr(
parameter_model.config, "response_json_schema", None
),
"include_all_fields": getattr(
parameter_model.config, "include_all_fields", None
),
}
}
if getattr(parameter_model, "config", None)
else {}
),
)

self._api_client._verify_response(return_value)
return return_value

def _check_query_job(
self,
*,
name: str,
config: Optional[types.CheckQueryJobAgentEngineConfigOrDict] = None,
) -> types.CheckQueryJobResult:
"""
Query an Agent Engine asynchronously.
"""
"""Query an Agent Engine asynchronously."""

parameter_model = types._CheckQueryJobAgentEngineRequestParameters(
name=name,
Expand Down Expand Up @@ -1108,6 +1222,17 @@ def _list_pager(
config,
)

def cancel_query_job(
self,
*,
name: str,
operation_name: str,
config: Optional[types.CancelQueryJobAgentEngineConfigOrDict] = None,
) -> types.CancelQueryJobResult:
return self._cancel_query_job(
name=name, operation_name=operation_name, config=config
)

def check_query_job(
self,
*,
Expand Down
13 changes: 13 additions & 0 deletions vertexai/_genai/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .common import _AppendAgentEngineTaskEventRequestParameters
from .common import _AssembleDatasetParameters
from .common import _AssessDatasetParameters
from .common import _CancelQueryJobAgentEngineRequestParameters
from .common import _CheckQueryJobAgentEngineRequestParameters
from .common import _CreateAgentEngineMemoryRequestParameters
from .common import _CreateAgentEngineRequestParameters
Expand Down Expand Up @@ -191,6 +192,12 @@
from .common import BleuResults
from .common import BleuResultsDict
from .common import BleuResultsOrDict
from .common import CancelQueryJobAgentEngineConfig
from .common import CancelQueryJobAgentEngineConfigDict
from .common import CancelQueryJobAgentEngineConfigOrDict
from .common import CancelQueryJobResult
from .common import CancelQueryJobResultDict
from .common import CancelQueryJobResultOrDict
from .common import CandidateResponse
from .common import CandidateResponseDict
from .common import CandidateResponseOrDict
Expand Down Expand Up @@ -1647,6 +1654,12 @@
"CheckQueryJobResult",
"CheckQueryJobResultDict",
"CheckQueryJobResultOrDict",
"CancelQueryJobAgentEngineConfig",
"CancelQueryJobAgentEngineConfigDict",
"CancelQueryJobAgentEngineConfigOrDict",
"CancelQueryJobResult",
"CancelQueryJobResultDict",
"CancelQueryJobResultOrDict",
"_RunQueryJobAgentEngineConfig",
"_RunQueryJobAgentEngineConfigDict",
"_RunQueryJobAgentEngineConfigOrDict",
Expand Down
72 changes: 72 additions & 0 deletions vertexai/_genai/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6277,6 +6277,78 @@ class CheckQueryJobResultDict(TypedDict, total=False):
CheckQueryJobResultOrDict = Union[CheckQueryJobResult, CheckQueryJobResultDict]


class CancelQueryJobAgentEngineConfig(_common.BaseModel):
"""Config for canceling a query job on an agent engine."""

http_options: Optional[genai_types.HttpOptions] = Field(
default=None, description="""Used to override HTTP request options."""
)


class CancelQueryJobAgentEngineConfigDict(TypedDict, total=False):
"""Config for canceling a query job on an agent engine."""

http_options: Optional[genai_types.HttpOptionsDict]
"""Used to override HTTP request options."""


CancelQueryJobAgentEngineConfigOrDict = Union[
CancelQueryJobAgentEngineConfig, CancelQueryJobAgentEngineConfigDict
]


class _CancelQueryJobAgentEngineRequestParameters(_common.BaseModel):
"""Parameters for canceling a query job on an agent engine."""

name: Optional[str] = Field(
default=None, description="""Name of the reasoning engine resource."""
)
operation_name: Optional[str] = Field(
default=None,
description="""Name of the longrunning operation returned from run_query_job.""",
)
config: Optional[CancelQueryJobAgentEngineConfig] = Field(
default=None, description=""""""
)


class _CancelQueryJobAgentEngineRequestParametersDict(TypedDict, total=False):
"""Parameters for canceling a query job on an agent engine."""

name: Optional[str]
"""Name of the reasoning engine resource."""

operation_name: Optional[str]
"""Name of the longrunning operation returned from run_query_job."""

config: Optional[CancelQueryJobAgentEngineConfigDict]
""""""


_CancelQueryJobAgentEngineRequestParametersOrDict = Union[
_CancelQueryJobAgentEngineRequestParameters,
_CancelQueryJobAgentEngineRequestParametersDict,
]


class CancelQueryJobResult(_common.BaseModel):
"""Result of canceling a query job."""

http_options: Optional[genai_types.HttpOptions] = Field(
default=None, description="""Used to override HTTP request options."""
)


class CancelQueryJobResultDict(TypedDict, total=False):
"""Result of canceling a query job."""

http_options: Optional[genai_types.HttpOptionsDict]
"""Used to override HTTP request options."""


CancelQueryJobResultOrDict = Union[CancelQueryJobResult, CancelQueryJobResultDict]


class _RunQueryJobAgentEngineConfig(_common.BaseModel):
"""Config for running a query job on an agent engine."""

Expand Down
Loading