Skip to content

Commit e94fc99

Browse files
cleop-googlecopybara-github
authored andcommitted
chore: GenAI SDK client(multimodal) - Move to_bigframes method to MultimodalDataset class.
PiperOrigin-RevId: 889731146
1 parent da663c0 commit e94fc99

7 files changed

Lines changed: 247 additions & 184 deletions

File tree

tests/unit/vertexai/genai/__init__.py

Whitespace-only changes.

tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -106,25 +106,6 @@ def test_create_dataset_from_bigquery(client):
106106
)
107107

108108

109-
def test_create_dataset_from_bigquery_without_bq_prefix(client):
110-
dataset = client.datasets.create_from_bigquery(
111-
multimodal_dataset={
112-
"display_name": "test-from-bigquery",
113-
"description": "test-description-from-bigquery",
114-
"metadata": {
115-
"inputConfig": {
116-
"bigquerySource": {"uri": BIGQUERY_TABLE_NAME},
117-
},
118-
},
119-
},
120-
)
121-
assert isinstance(dataset, types.MultimodalDataset)
122-
assert dataset.display_name == "test-from-bigquery"
123-
assert dataset.metadata.input_config.bigquery_source.uri == (
124-
f"bq://{BIGQUERY_TABLE_NAME}"
125-
)
126-
127-
128109
@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes")
129110
def test_create_dataset_from_pandas(client, is_replay_mode):
130111
dataframe = pd.DataFrame(
@@ -270,26 +251,6 @@ async def test_create_dataset_from_bigquery_async_with_timeout(client):
270251
)
271252

272253

273-
@pytest.mark.asyncio
274-
async def test_create_dataset_from_bigquery_async_without_bq_prefix(client):
275-
dataset = await client.aio.datasets.create_from_bigquery(
276-
multimodal_dataset={
277-
"display_name": "test-from-bigquery",
278-
"description": "test-description-from-bigquery",
279-
"metadata": {
280-
"inputConfig": {
281-
"bigquerySource": {"uri": BIGQUERY_TABLE_NAME},
282-
},
283-
},
284-
},
285-
)
286-
assert isinstance(dataset, types.MultimodalDataset)
287-
assert dataset.display_name == "test-from-bigquery"
288-
assert dataset.metadata.input_config.bigquery_source.uri == (
289-
f"bq://{BIGQUERY_TABLE_NAME}"
290-
)
291-
292-
293254
@pytest.mark.asyncio
294255
@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes")
295256
async def test_create_dataset_from_pandas_async(client, is_replay_mode):

tests/unit/vertexai/genai/replays/test_get_multimodal_datasets.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,34 +15,14 @@
1515
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
1616

1717
from tests.unit.vertexai.genai.replays import pytest_helper
18-
from vertexai._genai import _datasets_utils
1918
from vertexai._genai import types
2019

21-
from unittest import mock
2220
import pytest
2321

2422
BIGQUERY_TABLE_NAME = "vertex-sdk-dev.multimodal_dataset.test-table"
2523
DATASET = "8810841321427173376"
2624

2725

28-
@pytest.fixture
29-
def mock_import_bigframes(is_replay_mode):
30-
if is_replay_mode:
31-
with mock.patch.object(
32-
_datasets_utils, "_try_import_bigframes"
33-
) as mock_import_bigframes:
34-
mock_read_gbq_table_result = mock.MagicMock()
35-
mock_read_gbq_table_result.sql = f"SLECT * FROM `{BIGQUERY_TABLE_NAME}`"
36-
37-
bigframes = mock.MagicMock()
38-
bigframes.pandas.read_gbq_table.return_value = mock_read_gbq_table_result
39-
40-
mock_import_bigframes.return_value = bigframes
41-
yield mock_import_bigframes
42-
else:
43-
yield None
44-
45-
4626
def test_get_dataset(client):
4727
dataset = client.datasets._get_multimodal_dataset(
4828
name=DATASET,
@@ -61,15 +41,6 @@ def test_get_dataset_from_public_method(client):
6141
assert dataset.display_name == "test-display-name"
6242

6343

64-
@pytest.mark.usefixtures("mock_import_bigframes")
65-
def test_to_bigframes(client):
66-
dataset = client.datasets.get_multimodal_dataset(
67-
name=DATASET,
68-
)
69-
df = client.datasets.to_bigframes(multimodal_dataset=dataset)
70-
assert BIGQUERY_TABLE_NAME in df.sql
71-
72-
7344
pytestmark = pytest_helper.setup(
7445
file=__file__,
7546
globals_for_file=globals(),
@@ -96,13 +67,3 @@ async def test_get_dataset_from_public_method_async(client):
9667
assert isinstance(dataset, types.MultimodalDataset)
9768
assert dataset.name.endswith(DATASET)
9869
assert dataset.display_name == "test-display-name"
99-
100-
101-
@pytest.mark.asyncio
102-
@pytest.mark.usefixtures("mock_import_bigframes")
103-
async def test_to_bigframes_async(client):
104-
dataset = await client.aio.datasets.get_multimodal_dataset(
105-
name=DATASET,
106-
)
107-
df = await client.aio.datasets.to_bigframes(multimodal_dataset=dataset)
108-
assert BIGQUERY_TABLE_NAME in df.sql
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
"""Tests for multimodal datasets."""
16+
from unittest import mock
17+
18+
from vertexai._genai import _datasets_utils
19+
from vertexai._genai import types
20+
import pytest
21+
22+
23+
@pytest.fixture
24+
def mock_import_bigframes():
25+
with mock.patch.object(
26+
_datasets_utils, "_try_import_bigframes"
27+
) as mock_import_bigframes:
28+
mock_read_gbq_table_result = mock.MagicMock()
29+
mock_read_gbq_table_result.sql = "SELECT * FROM `project.dataset.table`"
30+
31+
bigframes = mock.MagicMock()
32+
bigframes.pandas.read_gbq_table.return_value = mock_read_gbq_table_result
33+
34+
mock_import_bigframes.return_value = bigframes
35+
yield mock_import_bigframes
36+
37+
38+
class TestMultimodalDataset:
39+
40+
def test_get_read_config(self):
41+
dataset = types.MultimodalDataset(
42+
metadata={
43+
"gemini_request_read_config": {
44+
"assembled_request_column_name": "test_column",
45+
},
46+
},
47+
)
48+
49+
read_config = dataset.get_read_config()
50+
51+
assert isinstance(read_config, types.GeminiRequestReadConfig)
52+
assert read_config.assembled_request_column_name == "test_column"
53+
54+
def test_get_read_config_empty(self):
55+
dataset = types.MultimodalDataset()
56+
assert dataset.get_read_config() is None
57+
58+
def test_set_read_config(self):
59+
dataset = types.MultimodalDataset()
60+
61+
dataset.set_read_config(
62+
read_config={
63+
"assembled_request_column_name": "test_column",
64+
},
65+
)
66+
67+
assert isinstance(dataset, types.MultimodalDataset)
68+
assert (
69+
dataset.metadata.gemini_request_read_config.assembled_request_column_name
70+
== "test_column"
71+
)
72+
73+
def test_set_read_config_preserves_other_fields(self):
74+
dataset = types.MultimodalDataset(
75+
metadata={
76+
"inputConfig": {
77+
"bigquerySource": {"uri": "bq://test_table"},
78+
},
79+
},
80+
)
81+
82+
dataset.set_read_config(
83+
read_config={
84+
"assembled_request_column_name": "test_column",
85+
},
86+
)
87+
88+
assert isinstance(dataset, types.MultimodalDataset)
89+
assert (
90+
dataset.metadata.gemini_request_read_config.assembled_request_column_name
91+
== "test_column"
92+
)
93+
assert dataset.metadata.input_config.bigquery_source.uri == "bq://test_table"
94+
95+
def test_get_bigquery_uri(self):
96+
dataset = types.MultimodalDataset(
97+
metadata={
98+
"inputConfig": {
99+
"bigquerySource": {"uri": "bq://project.dataset.table"},
100+
},
101+
},
102+
)
103+
104+
assert dataset.get_bigquery_uri() == "bq://project.dataset.table"
105+
106+
def test_get_bigquery_uri_empty(self):
107+
dataset = types.MultimodalDataset()
108+
assert dataset.get_bigquery_uri() is None
109+
110+
def test_set_bigquery_uri(self):
111+
dataset = types.MultimodalDataset()
112+
113+
dataset.set_bigquery_uri("bq://project.dataset.table")
114+
115+
assert isinstance(dataset, types.MultimodalDataset)
116+
assert (
117+
dataset.metadata.input_config.bigquery_source.uri
118+
== "bq://project.dataset.table"
119+
)
120+
121+
def test_set_bigquery_uri_without_prefix(self):
122+
dataset = types.MultimodalDataset()
123+
124+
dataset.set_bigquery_uri("project.dataset.table")
125+
126+
assert isinstance(dataset, types.MultimodalDataset)
127+
assert (
128+
dataset.metadata.input_config.bigquery_source.uri
129+
== "bq://project.dataset.table"
130+
)
131+
132+
def test_set_bigquery_uri_preserves_other_fields(self):
133+
dataset = types.MultimodalDataset(
134+
metadata={
135+
"gemini_request_read_config": {
136+
"assembled_request_column_name": "test_column",
137+
},
138+
},
139+
)
140+
141+
dataset.set_bigquery_uri("bq://test_table")
142+
143+
assert isinstance(dataset, types.MultimodalDataset)
144+
assert dataset.metadata.input_config.bigquery_source.uri == "bq://test_table"
145+
assert (
146+
dataset.metadata.gemini_request_read_config.assembled_request_column_name
147+
== "test_column"
148+
)
149+
150+
def test_to_bigframes(self, mock_import_bigframes):
151+
dataset = types.MultimodalDataset()
152+
dataset.set_bigquery_uri("bq://project.dataset.table")
153+
154+
df = dataset.to_bigframes()
155+
156+
assert "project.dataset.table" in df.sql
157+
mock_import_bigframes.return_value.pandas.read_gbq_table.assert_called_once_with(
158+
"project.dataset.table"
159+
)

vertexai/_genai/_datasets_utils.py

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ def create_from_response(model_type: Type[T], response: dict[str, Any]) -> T:
4444
return model_type(**filtered_response)
4545

4646

47-
def multimodal_dataset_get_bigquery_uri(
47+
def validate_multimodal_dataset_bigquery_uri(
4848
multimodal_dataset: common.MultimodalDataset,
49-
) -> str:
50-
"""Gets the bigquery uri from a multimodal dataset or raises ValueError."""
49+
) -> None:
50+
"""Validates that a multimodal dataset has a bigquery uri or raises ValueError."""
5151
if (
5252
not hasattr(multimodal_dataset, "metadata")
5353
or multimodal_dataset.metadata is None
@@ -70,33 +70,12 @@ def multimodal_dataset_get_bigquery_uri(
7070
raise ValueError(
7171
"Multimodal dataset input config bigquery source uri is required."
7272
)
73-
return str(multimodal_dataset.metadata.input_config.bigquery_source.uri)
74-
75-
76-
def multimodal_dataset_set_bigquery_uri(
77-
multimodal_dataset: common.MultimodalDataset,
78-
bigquery_uri: str,
79-
) -> None:
80-
"""Sets the bigquery uri from a multimodal dataset or raises ValueError."""
81-
metadata = (
82-
common.SchemaTablesDatasetMetadata()
83-
if multimodal_dataset.metadata is None
84-
else multimodal_dataset.metadata
85-
)
86-
input_config = (
87-
common.SchemaTablesDatasetMetadataInputConfig()
88-
if metadata.input_config is None
89-
else metadata.input_config
90-
)
91-
bigquery_source = (
92-
common.SchemaTablesDatasetMetadataBigQuerySource()
93-
if input_config.bigquery_source is None
94-
else input_config.bigquery_source
95-
)
96-
bigquery_source.uri = bigquery_uri
97-
input_config.bigquery_source = bigquery_source
98-
metadata.input_config = input_config
99-
multimodal_dataset.metadata = metadata
73+
if not str(multimodal_dataset.metadata.input_config.bigquery_source.uri).startswith(
74+
"bq://"
75+
):
76+
raise ValueError(
77+
"Multimodal dataset bigquery source uri must start with 'bq://'."
78+
)
10079

10180

10281
def _try_import_bigframes() -> Any:

0 commit comments

Comments
 (0)