Skip to content

Commit 32ef654

Browse files
cleop-googlecopybara-github
authored andcommitted
feat: GenAI SDK client(multimodal) - Add set_read_config to MultimodalDataset.
PiperOrigin-RevId: 885611714
1 parent c4beca0 commit 32ef654

File tree

5 files changed

+207
-54
lines changed

5 files changed

+207
-54
lines changed

tests/unit/vertexai/genai/__init__.py

Whitespace-only changes.
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
"""Tests for multimodal datasets."""
16+
17+
from vertexai._genai import types
18+
19+
20+
class TestMultimodalDataset:
21+
22+
def test_get_read_config(self):
23+
dataset = types.MultimodalDataset(
24+
metadata={
25+
"gemini_request_read_config": {
26+
"assembled_request_column_name": "test_column",
27+
},
28+
},
29+
)
30+
31+
read_config = dataset.get_read_config()
32+
33+
assert isinstance(read_config, types.GeminiRequestReadConfig)
34+
assert read_config.assembled_request_column_name == "test_column"
35+
36+
def test_get_read_config_empty(self):
37+
dataset = types.MultimodalDataset()
38+
assert dataset.get_read_config() is None
39+
40+
def test_set_read_config(self):
41+
dataset = types.MultimodalDataset()
42+
43+
dataset.set_read_config(
44+
read_config={
45+
"assembled_request_column_name": "test_column",
46+
},
47+
)
48+
49+
assert isinstance(dataset, types.MultimodalDataset)
50+
assert (
51+
dataset.metadata.gemini_request_read_config.assembled_request_column_name
52+
== "test_column"
53+
)
54+
55+
def test_set_read_config_preserves_other_fields(self):
56+
dataset = types.MultimodalDataset(
57+
metadata={
58+
"inputConfig": {
59+
"bigquerySource": {"uri": "bq://test_table"},
60+
},
61+
},
62+
)
63+
64+
dataset.set_read_config(
65+
read_config={
66+
"assembled_request_column_name": "test_column",
67+
},
68+
)
69+
70+
assert isinstance(dataset, types.MultimodalDataset)
71+
assert (
72+
dataset.metadata.gemini_request_read_config.assembled_request_column_name
73+
== "test_column"
74+
)
75+
assert dataset.metadata.input_config.bigquery_source.uri == "bq://test_table"
76+
77+
def test_get_bigquery_uri(self):
78+
dataset = types.MultimodalDataset(
79+
metadata={
80+
"inputConfig": {
81+
"bigquerySource": {"uri": "bq://project.dataset.table"},
82+
},
83+
},
84+
)
85+
86+
assert dataset.get_bigquery_uri() == "bq://project.dataset.table"
87+
88+
def test_get_bigquery_uri_empty(self):
89+
dataset = types.MultimodalDataset()
90+
assert dataset.get_bigquery_uri() is None
91+
92+
def test_set_bigquery_uri(self):
93+
dataset = types.MultimodalDataset()
94+
95+
dataset.set_bigquery_uri("bq://project.dataset.table")
96+
97+
assert isinstance(dataset, types.MultimodalDataset)
98+
assert (
99+
dataset.metadata.input_config.bigquery_source.uri
100+
== "bq://project.dataset.table"
101+
)
102+
103+
def test_set_bigquery_uri_without_prefix(self):
104+
dataset = types.MultimodalDataset()
105+
106+
dataset.set_bigquery_uri("project.dataset.table")
107+
108+
assert isinstance(dataset, types.MultimodalDataset)
109+
assert (
110+
dataset.metadata.input_config.bigquery_source.uri
111+
== "bq://project.dataset.table"
112+
)
113+
114+
def test_set_bigquery_uri_preserves_other_fields(self):
115+
dataset = types.MultimodalDataset(
116+
metadata={
117+
"gemini_request_read_config": {
118+
"assembled_request_column_name": "test_column",
119+
},
120+
},
121+
)
122+
123+
dataset.set_bigquery_uri("bq://test_table")
124+
125+
assert isinstance(dataset, types.MultimodalDataset)
126+
assert dataset.metadata.input_config.bigquery_source.uri == "bq://test_table"
127+
assert (
128+
dataset.metadata.gemini_request_read_config.assembled_request_column_name
129+
== "test_column"
130+
)

vertexai/_genai/_datasets_utils.py

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ def create_from_response(model_type: Type[T], response: dict[str, Any]) -> T:
4444
return model_type(**filtered_response)
4545

4646

47-
def multimodal_dataset_get_bigquery_uri(
47+
def validate_multimodal_dataset_has_bigquery_uri(
4848
multimodal_dataset: common.MultimodalDataset,
49-
) -> str:
50-
"""Gets the bigquery uri from a multimodal dataset or raises ValueError."""
49+
) -> None:
50+
"""Validates that a multimodal dataset has a bigquery uri or raises ValueError."""
5151
if (
5252
not hasattr(multimodal_dataset, "metadata")
5353
or multimodal_dataset.metadata is None
@@ -70,33 +70,12 @@ def multimodal_dataset_get_bigquery_uri(
7070
raise ValueError(
7171
"Multimodal dataset input config bigquery source uri is required."
7272
)
73-
return str(multimodal_dataset.metadata.input_config.bigquery_source.uri)
74-
75-
76-
def multimodal_dataset_set_bigquery_uri(
77-
multimodal_dataset: common.MultimodalDataset,
78-
bigquery_uri: str,
79-
) -> None:
80-
"""Sets the bigquery uri from a multimodal dataset or raises ValueError."""
81-
metadata = (
82-
common.SchemaTablesDatasetMetadata()
83-
if multimodal_dataset.metadata is None
84-
else multimodal_dataset.metadata
85-
)
86-
input_config = (
87-
common.SchemaTablesDatasetMetadataInputConfig()
88-
if metadata.input_config is None
89-
else metadata.input_config
90-
)
91-
bigquery_source = (
92-
common.SchemaTablesDatasetMetadataBigQuerySource()
93-
if input_config.bigquery_source is None
94-
else input_config.bigquery_source
95-
)
96-
bigquery_source.uri = bigquery_uri
97-
input_config.bigquery_source = bigquery_source
98-
metadata.input_config = input_config
99-
multimodal_dataset.metadata = metadata
73+
if not str(multimodal_dataset.metadata.input_config.bigquery_source.uri).startswith(
74+
"bq://"
75+
):
76+
raise ValueError(
77+
"Multimodal dataset bigquery source uri must start with 'bq://'."
78+
)
10079

10180

10281
def _try_import_bigframes() -> Any:

vertexai/_genai/datasets.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -790,12 +790,8 @@ def create_from_bigquery(
790790
"""
791791
if isinstance(multimodal_dataset, dict):
792792
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
793+
_datasets_utils.validate_multimodal_dataset_has_bigquery_uri(multimodal_dataset)
793794

794-
uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
795-
if not uri.startswith("bq://"):
796-
_datasets_utils.multimodal_dataset_set_bigquery_uri(
797-
multimodal_dataset, f"bq://{uri}"
798-
)
799795
if isinstance(config, dict):
800796
config = types.CreateMultimodalDatasetConfig(**config)
801797
elif not config:
@@ -998,8 +994,10 @@ def to_bigframes(
998994
elif not multimodal_dataset:
999995
multimodal_dataset = types.MultimodalDataset()
1000996

1001-
uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
1002-
return bigframes.pandas.read_gbq_table(uri.removeprefix("bq://"))
997+
bigquery_uri = multimodal_dataset.get_bigquery_uri()
998+
if bigquery_uri is None:
999+
raise ValueError("Multimodal dataset bigquery source uri is not set.")
1000+
return bigframes.pandas.read_gbq_table(bigquery_uri.removeprefix("bq://"))
10031001

10041002
def update_multimodal_dataset(
10051003
self,
@@ -1026,12 +1024,8 @@ def update_multimodal_dataset(
10261024
"""
10271025
if isinstance(multimodal_dataset, dict):
10281026
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
1027+
_datasets_utils.validate_multimodal_dataset_has_bigquery_uri(multimodal_dataset)
10291028

1030-
uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
1031-
if not uri.startswith("bq://"):
1032-
_datasets_utils.multimodal_dataset_set_bigquery_uri(
1033-
multimodal_dataset, f"bq://{uri}"
1034-
)
10351029
if isinstance(config, dict):
10361030
config = types.CreateMultimodalDatasetConfig(**config)
10371031
elif not config:
@@ -1936,12 +1930,8 @@ async def create_from_bigquery(
19361930
"""
19371931
if isinstance(multimodal_dataset, dict):
19381932
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
1933+
_datasets_utils.validate_multimodal_dataset_has_bigquery_uri(multimodal_dataset)
19391934

1940-
uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
1941-
if not uri.startswith("bq://"):
1942-
_datasets_utils.multimodal_dataset_set_bigquery_uri(
1943-
multimodal_dataset, f"bq://{uri}"
1944-
)
19451935
if isinstance(config, dict):
19461936
config = types.CreateMultimodalDatasetConfig(**config)
19471937
elif not config:
@@ -2146,9 +2136,11 @@ async def to_bigframes(
21462136
elif not multimodal_dataset:
21472137
multimodal_dataset = types.MultimodalDataset()
21482138

2149-
uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
2139+
bigquery_uri = multimodal_dataset.get_bigquery_uri()
2140+
if bigquery_uri is None:
2141+
raise ValueError("Multimodal dataset bigquery source uri is missing.")
21502142
return await asyncio.to_thread(
2151-
bigframes.pandas.read_gbq_table, uri.removeprefix("bq://")
2143+
bigframes.pandas.read_gbq_table, bigquery_uri.removeprefix("bq://")
21522144
)
21532145

21542146
async def update_multimodal_dataset(
@@ -2172,12 +2164,8 @@ async def update_multimodal_dataset(
21722164
"""
21732165
if isinstance(multimodal_dataset, dict):
21742166
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
2167+
_datasets_utils.validate_multimodal_dataset_has_bigquery_uri(multimodal_dataset)
21752168

2176-
uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
2177-
if not uri.startswith("bq://"):
2178-
_datasets_utils.multimodal_dataset_set_bigquery_uri(
2179-
multimodal_dataset, f"bq://{uri}"
2180-
)
21812169
if isinstance(config, dict):
21822170
config = types.CreateMultimodalDatasetConfig(**config)
21832171
elif not config:

vertexai/_genai/types/common.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12316,6 +12316,62 @@ class MultimodalDataset(_common.BaseModel):
1231612316
default=None, description="""The description of the multimodal dataset."""
1231712317
)
1231812318

12319+
def get_read_config(self) -> Optional[GeminiRequestReadConfig]:
12320+
"""Gets the read config from a multimodal dataset."""
12321+
if self.metadata is None or self.metadata.gemini_request_read_config is None:
12322+
return None
12323+
return self.metadata.gemini_request_read_config
12324+
12325+
def set_read_config(
12326+
self,
12327+
*,
12328+
read_config: GeminiRequestReadConfigOrDict,
12329+
) -> None:
12330+
"""Sets the read config from a multimodal dataset."""
12331+
if isinstance(read_config, dict):
12332+
read_config = GeminiRequestReadConfig(**read_config)
12333+
12334+
if self.metadata is None:
12335+
self.metadata = SchemaTablesDatasetMetadata()
12336+
self.metadata.gemini_request_read_config = read_config
12337+
12338+
def get_bigquery_uri(
12339+
self,
12340+
) -> Optional[str]:
12341+
"""Gets the bigquery uri from a multimodal dataset or returns None."""
12342+
if (
12343+
self.metadata is None
12344+
or self.metadata.input_config is None
12345+
or self.metadata.input_config.bigquery_source is None
12346+
):
12347+
return None
12348+
return str(self.metadata.input_config.bigquery_source.uri)
12349+
12350+
def set_bigquery_uri(
12351+
self,
12352+
bigquery_uri: str,
12353+
) -> None:
12354+
"""Sets the bigquery uri from a multimodal dataset or raises ValueError."""
12355+
if not bigquery_uri.startswith("bq://"):
12356+
bigquery_uri = f"bq://{bigquery_uri}"
12357+
metadata = (
12358+
SchemaTablesDatasetMetadata() if self.metadata is None else self.metadata
12359+
)
12360+
input_config = (
12361+
SchemaTablesDatasetMetadataInputConfig()
12362+
if metadata.input_config is None
12363+
else metadata.input_config
12364+
)
12365+
bigquery_source = (
12366+
SchemaTablesDatasetMetadataBigQuerySource()
12367+
if input_config.bigquery_source is None
12368+
else input_config.bigquery_source
12369+
)
12370+
bigquery_source.uri = bigquery_uri
12371+
input_config.bigquery_source = bigquery_source
12372+
metadata.input_config = input_config
12373+
self.metadata = metadata
12374+
1231912375

1232012376
class MultimodalDatasetDict(TypedDict, total=False):
1232112377
"""Represents a multimodal dataset."""

0 commit comments

Comments
 (0)