Skip to content

Commit 8ed44b1

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: In run_query_job, rename gcs_bucket to gcs_uri and allow the case that user sets the filename for the output.
PiperOrigin-RevId: 889620568
1 parent 6f7b12c commit 8ed44b1

3 files changed

Lines changed: 1006 additions & 912 deletions

File tree

tests/unit/vertexai/genai/test_agent_engines.py

Lines changed: 103 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2948,7 +2948,7 @@ def test_run_query_job_agent_engine(self, mock_uuid, get_mock, mock_storage_clie
29482948
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
29492949
config={
29502950
"query": _TEST_QUERY_PROMPT,
2951-
"gcs_bucket": "gs://my-input-bucket/",
2951+
"gcs_uri": "gs://my-input-bucket/",
29522952
},
29532953
)
29542954

@@ -2959,17 +2959,17 @@ def test_run_query_job_agent_engine(self, mock_uuid, get_mock, mock_storage_clie
29592959

29602960
assert result == _genai_types.RunQueryJobResult(
29612961
job_name="projects/123/locations/us-central1/reasoningEngines/456/operations/789",
2962-
input_gcs_uri="gs://my-input-bucket/input_b92b9b89-4585-4146-8ee5-22fe99802a8e.json",
2963-
output_gcs_uri="gs://my-input-bucket/output_b92b9b89-4585-4146-8ee5-22fe99802a8e.json",
2962+
input_gcs_uri="gs://my-input-bucket/b92b9b89-4585-4146-8ee5-22fe99802a8e_input.json",
2963+
output_gcs_uri="gs://my-input-bucket/b92b9b89-4585-4146-8ee5-22fe99802a8e_output.json",
29642964
)
29652965

29662966
request_mock.assert_called_with(
29672967
"post",
29682968
f"{_TEST_AGENT_ENGINE_RESOURCE_NAME}:asyncQuery",
29692969
{
29702970
"_url": {"name": _TEST_AGENT_ENGINE_RESOURCE_NAME},
2971-
"inputGcsUri": "gs://my-input-bucket/input_b92b9b89-4585-4146-8ee5-22fe99802a8e.json",
2972-
"outputGcsUri": "gs://my-input-bucket/output_b92b9b89-4585-4146-8ee5-22fe99802a8e.json",
2971+
"inputGcsUri": "gs://my-input-bucket/b92b9b89-4585-4146-8ee5-22fe99802a8e_input.json",
2972+
"outputGcsUri": "gs://my-input-bucket/b92b9b89-4585-4146-8ee5-22fe99802a8e_output.json",
29732973
},
29742974
None,
29752975
)
@@ -2980,12 +2980,12 @@ def test_run_query_job_agent_engine_missing_query(self):
29802980
):
29812981
self.client.agent_engines.run_query_job(
29822982
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
2983-
config={"gcs_bucket": "gs://my-input-bucket/"},
2983+
config={"gcs_uri": "gs://my-input-bucket/"},
29842984
)
29852985

2986-
def test_run_query_job_agent_engine_missing_bucket(self):
2986+
def test_run_query_job_agent_engine_missing_uri(self):
29872987
with pytest.raises(
2988-
ValueError, match="`gcs_bucket` is required in the config object."
2988+
ValueError, match="`gcs_uri` is required in the config object."
29892989
):
29902990
self.client.agent_engines.run_query_job(
29912991
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
@@ -3008,7 +3008,7 @@ def test_run_query_job_agent_engine_missing_cloud_run_job(self, get_mock):
30083008
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
30093009
config={
30103010
"query": _TEST_QUERY_PROMPT,
3011-
"gcs_bucket": "gs://my-input-bucket/",
3011+
"gcs_uri": "gs://my-input-bucket/",
30123012
},
30133013
)
30143014

@@ -3053,10 +3053,103 @@ def test_run_query_job_agent_engine_bucket_creation_forbidden(
30533053
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
30543054
config={
30553055
"query": _TEST_QUERY_PROMPT,
3056-
"gcs_bucket": "gs://my-input-bucket/",
3056+
"gcs_uri": "gs://my-input-bucket/",
30573057
},
30583058
)
30593059

3060+
@mock.patch("google.cloud.storage.Client")
3061+
@mock.patch.object(agent_engines.AgentEngines, "_get")
3062+
@mock.patch("uuid.uuid4")
3063+
def test_run_query_job_agent_engine_file_uri(
3064+
self, mock_uuid, get_mock, mock_storage_client
3065+
):
3066+
with mock.patch.object(
3067+
self.client.agent_engines._api_client, "request"
3068+
) as request_mock:
3069+
request_mock.return_value = genai_types.HttpResponse(
3070+
body='{"name": "projects/123/locations/us-central1/reasoningEngines/456/operations/789"}'
3071+
)
3072+
3073+
mock_bucket = mock.Mock()
3074+
mock_bucket.exists.return_value = True
3075+
mock_blob = mock.Mock()
3076+
mock_bucket.blob.return_value = mock_blob
3077+
mock_storage_client.return_value.bucket.return_value = mock_bucket
3078+
3079+
get_mock.return_value = _genai_types.ReasoningEngine(
3080+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
3081+
spec=_genai_types.ReasoningEngineSpec(
3082+
deployment_spec=_genai_types.ReasoningEngineSpecDeploymentSpec(
3083+
env=[_genai_types.EnvVar(name="input_gcs_uri", value="")]
3084+
)
3085+
),
3086+
)
3087+
3088+
result = self.client.agent_engines.run_query_job(
3089+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
3090+
config={
3091+
"query": _TEST_QUERY_PROMPT,
3092+
"gcs_uri": "gs://my-input-bucket/path/output.json",
3093+
},
3094+
)
3095+
3096+
mock_blob.upload_from_string.assert_called_once_with(_TEST_QUERY_PROMPT)
3097+
mock_bucket.blob.assert_called_with("path/output_input.json")
3098+
3099+
assert result == _genai_types.RunQueryJobResult(
3100+
job_name="projects/123/locations/us-central1/reasoningEngines/456/operations/789",
3101+
input_gcs_uri="gs://my-input-bucket/path/output_input.json",
3102+
output_gcs_uri="gs://my-input-bucket/path/output.json",
3103+
)
3104+
3105+
@mock.patch("google.cloud.storage.Client")
3106+
@mock.patch.object(agent_engines.AgentEngines, "_get")
3107+
@mock.patch("uuid.uuid4")
3108+
def test_run_query_job_agent_engine_directory_no_slash(
3109+
self, mock_uuid, get_mock, mock_storage_client
3110+
):
3111+
with mock.patch.object(
3112+
self.client.agent_engines._api_client, "request"
3113+
) as request_mock:
3114+
request_mock.return_value = genai_types.HttpResponse(
3115+
body='{"name": "projects/123/locations/us-central1/reasoningEngines/456/operations/789"}'
3116+
)
3117+
3118+
mock_bucket = mock.Mock()
3119+
mock_bucket.exists.return_value = True
3120+
mock_blob = mock.Mock()
3121+
mock_bucket.blob.return_value = mock_blob
3122+
mock_storage_client.return_value.bucket.return_value = mock_bucket
3123+
3124+
mock_uuid.return_value.hex = "b92b9b89-4585-4146-8ee5-22fe99802a8e"
3125+
3126+
get_mock.return_value = _genai_types.ReasoningEngine(
3127+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
3128+
spec=_genai_types.ReasoningEngineSpec(
3129+
deployment_spec=_genai_types.ReasoningEngineSpecDeploymentSpec(
3130+
env=[_genai_types.EnvVar(name="input_gcs_uri", value="")]
3131+
)
3132+
),
3133+
)
3134+
3135+
result = self.client.agent_engines.run_query_job(
3136+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
3137+
config={
3138+
"query": _TEST_QUERY_PROMPT,
3139+
"gcs_uri": "gs://my-input-bucket/path",
3140+
},
3141+
)
3142+
3143+
mock_bucket.blob.assert_called_with(
3144+
"path/input_b92b9b89-4585-4146-8ee5-22fe99802a8e.json"
3145+
)
3146+
3147+
assert result == _genai_types.RunQueryJobResult(
3148+
job_name="projects/123/locations/us-central1/reasoningEngines/456/operations/789",
3149+
input_gcs_uri="gs://my-input-bucket/path/b92b9b89-4585-4146-8ee5-22fe99802a8e_input.json",
3150+
output_gcs_uri="gs://my-input-bucket/path/b92b9b89-4585-4146-8ee5-22fe99802a8e_output.json",
3151+
)
3152+
30603153
def test_query_agent_engine_async(self):
30613154
agent = self.client.agent_engines._register_api_methods(
30623155
agent_engine=_genai_types.AgentEngine(

0 commit comments

Comments
 (0)