Skip to content

Commit bb9f6e5

Browse files
cleop-googlecopybara-github
authored andcommitted
feat: Add Vertex Dataset input and output options for batch jobs
PiperOrigin-RevId: 898998803
1 parent 2fb714b commit bb9f6e5

5 files changed

Lines changed: 289 additions & 5 deletions

File tree

google/genai/_transformers.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,7 +1012,11 @@ def t_batch_job_source(
10121012
src = types.BatchJobSource(**src)
10131013
if is_duck_type_of(src, types.BatchJobSource):
10141014
vertex_sources = sum(
1015-
[src.gcs_uri is not None, src.bigquery_uri is not None] # type: ignore[union-attr]
1015+
[
1016+
src.gcs_uri is not None, # type: ignore[union-attr]
1017+
src.bigquery_uri is not None, # type: ignore[union-attr]
1018+
src.vertex_dataset_name is not None, # type: ignore[union-attr]
1019+
]
10161020
)
10171021
mldev_sources = sum([
10181022
src.inlined_requests is not None, # type: ignore[union-attr]
@@ -1021,7 +1025,7 @@ def t_batch_job_source(
10211025
if client.vertexai:
10221026
if mldev_sources or vertex_sources != 1:
10231027
raise ValueError(
1024-
'Exactly one of `gcs_uri` or `bigquery_uri` must be set, other '
1028+
'Exactly one of `gcs_uri`, `bigquery_uri`, or `vertex_dataset_name` must be set, other '
10251029
'sources are not supported in Vertex AI.'
10261030
)
10271031
else:
@@ -1046,6 +1050,11 @@ def t_batch_job_source(
10461050
format='bigquery',
10471051
bigquery_uri=src,
10481052
)
1053+
elif re.match(r'^projects/[^/]+/locations/[^/]+/datasets/[^/]+$', src):
1054+
return types.BatchJobSource(
1055+
format='vertex-dataset',
1056+
vertex_dataset_name=src,
1057+
)
10491058
elif src.startswith('files/'):
10501059
return types.BatchJobSource(
10511060
file_name=src,

google/genai/batches.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,15 @@ def _BatchJobDestination_from_vertex(
130130
getv(from_object, ['bigqueryDestination', 'outputUri']),
131131
)
132132

133+
if getv(from_object, ['vertexMultimodalDatasetDestination']) is not None:
134+
setv(
135+
to_object,
136+
['vertex_dataset'],
137+
_VertexMultimodalDatasetDestination_from_vertex(
138+
getv(from_object, ['vertexMultimodalDatasetDestination']), to_object
139+
),
140+
)
141+
133142
return to_object
134143

135144

@@ -169,6 +178,15 @@ def _BatchJobDestination_to_vertex(
169178
' Vertex AI.'
170179
)
171180

181+
if getv(from_object, ['vertex_dataset']) is not None:
182+
setv(
183+
to_object,
184+
['vertexMultimodalDatasetDestination'],
185+
_VertexMultimodalDatasetDestination_to_vertex(
186+
getv(from_object, ['vertex_dataset']), to_object
187+
),
188+
)
189+
172190
return to_object
173191

174192

@@ -190,6 +208,16 @@ def _BatchJobSource_from_vertex(
190208
getv(from_object, ['bigquerySource', 'inputUri']),
191209
)
192210

211+
if (
212+
getv(from_object, ['vertexMultimodalDatasetSource', 'datasetName'])
213+
is not None
214+
):
215+
setv(
216+
to_object,
217+
['vertex_dataset_name'],
218+
getv(from_object, ['vertexMultimodalDatasetSource', 'datasetName']),
219+
)
220+
193221
return to_object
194222

195223

@@ -221,6 +249,11 @@ def _BatchJobSource_to_mldev(
221249
],
222250
)
223251

252+
if getv(from_object, ['vertex_dataset_name']) is not None:
253+
raise ValueError(
254+
'vertex_dataset_name parameter is not supported in Gemini API.'
255+
)
256+
224257
return to_object
225258

226259

@@ -250,6 +283,13 @@ def _BatchJobSource_to_vertex(
250283
'inlined_requests parameter is not supported in Vertex AI.'
251284
)
252285

286+
if getv(from_object, ['vertex_dataset_name']) is not None:
287+
setv(
288+
to_object,
289+
['vertexMultimodalDatasetSource', 'datasetName'],
290+
getv(from_object, ['vertex_dataset_name']),
291+
)
292+
253293
return to_object
254294

255295

@@ -1603,6 +1643,42 @@ def _Tool_to_mldev(
16031643
return to_object
16041644

16051645

1646+
def _VertexMultimodalDatasetDestination_from_vertex(
1647+
from_object: Union[dict[str, Any], object],
1648+
parent_object: Optional[dict[str, Any]] = None,
1649+
) -> dict[str, Any]:
1650+
to_object: dict[str, Any] = {}
1651+
if getv(from_object, ['bigqueryDestination', 'outputUri']) is not None:
1652+
setv(
1653+
to_object,
1654+
['bigquery_destination'],
1655+
getv(from_object, ['bigqueryDestination', 'outputUri']),
1656+
)
1657+
1658+
if getv(from_object, ['displayName']) is not None:
1659+
setv(to_object, ['display_name'], getv(from_object, ['displayName']))
1660+
1661+
return to_object
1662+
1663+
1664+
def _VertexMultimodalDatasetDestination_to_vertex(
1665+
from_object: Union[dict[str, Any], object],
1666+
parent_object: Optional[dict[str, Any]] = None,
1667+
) -> dict[str, Any]:
1668+
to_object: dict[str, Any] = {}
1669+
if getv(from_object, ['bigquery_destination']) is not None:
1670+
setv(
1671+
to_object,
1672+
['bigqueryDestination', 'outputUri'],
1673+
getv(from_object, ['bigquery_destination']),
1674+
)
1675+
1676+
if getv(from_object, ['display_name']) is not None:
1677+
setv(to_object, ['displayName'], getv(from_object, ['display_name']))
1678+
1679+
return to_object
1680+
1681+
16061682
class Batches(_api_module.BaseModule):
16071683

16081684
def _create(
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
17+
"""Tests for batches.create() with Vertex dataset source."""
18+
19+
import re
20+
21+
import pytest
22+
23+
from .. import pytest_helper
24+
from ... import types
25+
26+
27+
_GEMINI_MODEL = 'gemini-2.5-flash'
28+
_GEMINI_MODEL_FULL_NAME = 'publishers/google/models/gemini-2.5-flash'
29+
_OUTPUT_VERTEX_DATASET_DISPLAY_NAME = 'test_batch_output'
30+
_VERTEX_DATASET_INPUT_NAME = (
31+
'projects/vertex-sdk-dev/locations/us-central1/datasets/7857316250517504000'
32+
)
33+
_DISPLAY_NAME = 'test_batch'
34+
35+
_BQ_OUTPUT_PREFIX = (
36+
'bq://vertex-sdk-dev.unified_genai_tests_batches.generate_content_output'
37+
)
38+
_VERTEX_DATASET_DESTINATION = types.VertexMultimodalDatasetDestination(
39+
bigquery_destination=_BQ_OUTPUT_PREFIX,
40+
display_name=_OUTPUT_VERTEX_DATASET_DISPLAY_NAME,
41+
)
42+
43+
44+
# All tests will be run for both Vertex and MLDev.
45+
test_table: list[pytest_helper.TestTableItem] = [
46+
pytest_helper.TestTableItem(
47+
name='test_union_generate_content_with_vertex_dataset_name',
48+
parameters=types._CreateBatchJobParameters(
49+
model=_GEMINI_MODEL_FULL_NAME,
50+
src=_VERTEX_DATASET_INPUT_NAME,
51+
config={
52+
'display_name': _DISPLAY_NAME,
53+
'dest': {
54+
'vertex_dataset': _VERTEX_DATASET_DESTINATION,
55+
'format': 'vertex-dataset',
56+
},
57+
},
58+
),
59+
exception_if_mldev='not supported in Gemini API',
60+
has_union=True,
61+
),
62+
pytest_helper.TestTableItem(
63+
name='test_generate_content_with_vertex_dataset_source',
64+
parameters=types._CreateBatchJobParameters(
65+
model=_GEMINI_MODEL_FULL_NAME,
66+
src=types.BatchJobSource(
67+
vertex_dataset_name=_VERTEX_DATASET_INPUT_NAME,
68+
format='vertex-dataset',
69+
),
70+
config={
71+
'display_name': _DISPLAY_NAME,
72+
'dest': {
73+
'vertex_dataset': _VERTEX_DATASET_DESTINATION,
74+
'format': 'vertex-dataset',
75+
},
76+
},
77+
),
78+
exception_if_mldev='one of',
79+
),
80+
pytest_helper.TestTableItem(
81+
name='test_generate_content_with_vertex_dataset_source_dict',
82+
parameters=types._CreateBatchJobParameters(
83+
model=_GEMINI_MODEL_FULL_NAME,
84+
src={
85+
'vertex_dataset_name': _VERTEX_DATASET_INPUT_NAME,
86+
'format': 'vertex-dataset',
87+
},
88+
config={
89+
'display_name': _DISPLAY_NAME,
90+
'dest': {
91+
'vertex_dataset': _VERTEX_DATASET_DESTINATION,
92+
'format': 'vertex-dataset',
93+
},
94+
},
95+
),
96+
exception_if_mldev='one of',
97+
),
98+
]
99+
100+
pytestmark = [
101+
pytest.mark.usefixtures('mock_timestamped_unique_name'),
102+
pytest_helper.setup(
103+
file=__file__,
104+
globals_for_file=globals(),
105+
test_method='batches.create',
106+
test_table=test_table,
107+
),
108+
]
109+
110+
111+
@pytest.mark.asyncio
112+
async def test_async_create(client):
113+
with pytest_helper.exception_if_mldev(client, ValueError):
114+
batch_job = await client.aio.batches.create(
115+
model=_GEMINI_MODEL,
116+
src=_VERTEX_DATASET_INPUT_NAME,
117+
config={
118+
'dest': {
119+
'vertex_dataset': _VERTEX_DATASET_DESTINATION,
120+
'format': 'vertex-dataset',
121+
},
122+
},
123+
)
124+
125+
assert batch_job.name.startswith('projects/')
126+
assert (
127+
batch_job.model == _GEMINI_MODEL_FULL_NAME
128+
) # Converted to Vertex full name.
129+
assert batch_job.src.vertex_dataset_name == _VERTEX_DATASET_INPUT_NAME
130+
assert batch_job.src.format == 'vertex-dataset'

google/genai/tests/transformers/test_t_batch.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,19 +172,36 @@ def test_batch_job_source_vertexai_valid_bigquery(self, vertex_client):
172172
result = t.t_batch_job_source(vertex_client, src_obj)
173173
assert result is src_obj
174174

175-
def test_batch_job_source_vertexai_valid_both(self, vertex_client):
175+
def test_batch_job_source_vertexai_valid_all(self, vertex_client):
176176
src_obj = types.BatchJobSource(
177177
gcs_uri=['gs://vertex-bucket/data.jsonl'],
178178
bigquery_uri='bq://project.dataset.table',
179+
vertex_dataset_name='projects/123/locations/us-central1/datasets/456',
179180
)
180-
with pytest.raises(ValueError, match='`gcs_uri` or `bigquery_uri`'):
181+
with pytest.raises(ValueError, match='`gcs_uri`, `bigquery_uri`, or `vertex_dataset_name`'):
182+
t.t_batch_job_source(vertex_client, src_obj)
183+
184+
def test_batch_job_source_vertexai_valid_gcs_and_bigquery(self, vertex_client):
185+
src_obj = types.BatchJobSource(
186+
gcs_uri=['gs://vertex-bucket/data.jsonl'],
187+
bigquery_uri='bq://project.dataset.table',
188+
)
189+
with pytest.raises(ValueError, match='`gcs_uri`, `bigquery_uri`, or `vertex_dataset_name`'):
190+
t.t_batch_job_source(vertex_client, src_obj)
191+
192+
def test_batch_job_source_vertexai_valid_bigquery_and_vertex_dataset(self, vertex_client):
193+
src_obj = types.BatchJobSource(
194+
bigquery_uri='bq://project.dataset.table',
195+
vertex_dataset_name='projects/123/locations/us-central1/datasets/456',
196+
)
197+
with pytest.raises(ValueError, match='`gcs_uri`, `bigquery_uri`, or `vertex_dataset_name`'):
181198
t.t_batch_job_source(vertex_client, src_obj)
182199

183200
def test_batch_job_source_vertexai_invalid_neither_set(self, vertex_client):
184201
src_obj = types.BatchJobSource(
185202
file_name='files/data.csv'
186203
)
187-
with pytest.raises(ValueError, match='`gcs_uri` or `bigquery_uri`'):
204+
with pytest.raises(ValueError, match='`gcs_uri`, `bigquery_uri`, or `vertex_dataset_name`'):
188205
t.t_batch_job_source(vertex_client, src_obj)
189206

190207

0 commit comments

Comments
 (0)