Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,25 +106,6 @@ def test_create_dataset_from_bigquery(client):
)


def test_create_dataset_from_bigquery_without_bq_prefix(client):
dataset = client.datasets.create_from_bigquery(
multimodal_dataset={
"display_name": "test-from-bigquery",
"description": "test-description-from-bigquery",
"metadata": {
"inputConfig": {
"bigquerySource": {"uri": BIGQUERY_TABLE_NAME},
},
},
},
)
assert isinstance(dataset, types.MultimodalDataset)
assert dataset.display_name == "test-from-bigquery"
assert dataset.metadata.input_config.bigquery_source.uri == (
f"bq://{BIGQUERY_TABLE_NAME}"
)


@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes")
def test_create_dataset_from_pandas(client, is_replay_mode):
dataframe = pd.DataFrame(
Expand Down Expand Up @@ -270,26 +251,6 @@ async def test_create_dataset_from_bigquery_async_with_timeout(client):
)


@pytest.mark.asyncio
async def test_create_dataset_from_bigquery_async_without_bq_prefix(client):
dataset = await client.aio.datasets.create_from_bigquery(
multimodal_dataset={
"display_name": "test-from-bigquery",
"description": "test-description-from-bigquery",
"metadata": {
"inputConfig": {
"bigquerySource": {"uri": BIGQUERY_TABLE_NAME},
},
},
},
)
assert isinstance(dataset, types.MultimodalDataset)
assert dataset.display_name == "test-from-bigquery"
assert dataset.metadata.input_config.bigquery_source.uri == (
f"bq://{BIGQUERY_TABLE_NAME}"
)


@pytest.mark.asyncio
@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes")
async def test_create_dataset_from_pandas_async(client, is_replay_mode):
Expand Down
128 changes: 128 additions & 0 deletions tests/unit/vertexai/genai/test_multimodal_datasets_genai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Tests for multimodal datasets."""

from vertexai._genai import types


class TestMultimodalDataset:

def test_read_config(self):
dataset = types.MultimodalDataset(
metadata={
"gemini_request_read_config": {
"assembled_request_column_name": "test_column",
},
},
)

assert isinstance(dataset.read_config, types.GeminiRequestReadConfig)
assert dataset.read_config.assembled_request_column_name == "test_column"

def test_read_config_empty(self):
dataset = types.MultimodalDataset()
assert dataset.read_config is None

def test_set_read_config(self):
dataset = types.MultimodalDataset()

dataset.set_read_config(
read_config={
"assembled_request_column_name": "test_column",
},
)

assert isinstance(dataset, types.MultimodalDataset)
assert (
dataset.metadata.gemini_request_read_config.assembled_request_column_name
== "test_column"
)

def test_set_read_config_preserves_other_fields(self):
dataset = types.MultimodalDataset(
metadata={
"inputConfig": {
"bigquerySource": {"uri": "bq://test_table"},
},
},
)

dataset.set_read_config(
read_config={
"assembled_request_column_name": "test_column",
},
)

assert isinstance(dataset, types.MultimodalDataset)
assert (
dataset.metadata.gemini_request_read_config.assembled_request_column_name
== "test_column"
)
assert dataset.metadata.input_config.bigquery_source.uri == "bq://test_table"

def test_bigquery_uri(self):
dataset = types.MultimodalDataset(
metadata={
"inputConfig": {
"bigquerySource": {"uri": "bq://project.dataset.table"},
},
},
)

assert dataset.bigquery_uri == "bq://project.dataset.table"

def test_bigquery_uri_empty(self):
dataset = types.MultimodalDataset()
assert dataset.bigquery_uri is None

def test_set_bigquery_uri(self):
dataset = types.MultimodalDataset()

dataset.set_bigquery_uri("bq://project.dataset.table")

assert isinstance(dataset, types.MultimodalDataset)
assert (
dataset.metadata.input_config.bigquery_source.uri
== "bq://project.dataset.table"
)

def test_set_bigquery_uri_without_prefix(self):
dataset = types.MultimodalDataset()

dataset.set_bigquery_uri("project.dataset.table")

assert isinstance(dataset, types.MultimodalDataset)
assert (
dataset.metadata.input_config.bigquery_source.uri
== "bq://project.dataset.table"
)

def test_set_bigquery_uri_preserves_other_fields(self):
dataset = types.MultimodalDataset(
metadata={
"gemini_request_read_config": {
"assembled_request_column_name": "test_column",
},
},
)

dataset.set_bigquery_uri("bq://test_table")

assert isinstance(dataset, types.MultimodalDataset)
assert dataset.metadata.input_config.bigquery_source.uri == "bq://test_table"
assert (
dataset.metadata.gemini_request_read_config.assembled_request_column_name
== "test_column"
)
39 changes: 9 additions & 30 deletions vertexai/_genai/_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def create_from_response(model_type: Type[T], response: dict[str, Any]) -> T:
return model_type(**filtered_response)


def multimodal_dataset_get_bigquery_uri(
def validate_multimodal_dataset_bigquery_uri(
multimodal_dataset: common.MultimodalDataset,
) -> str:
"""Gets the bigquery uri from a multimodal dataset or raises ValueError."""
) -> None:
"""Validates that a multimodal dataset has a bigquery uri or raises ValueError."""
if (
not hasattr(multimodal_dataset, "metadata")
or multimodal_dataset.metadata is None
Expand All @@ -70,33 +70,12 @@ def multimodal_dataset_get_bigquery_uri(
raise ValueError(
"Multimodal dataset input config bigquery source uri is required."
)
return str(multimodal_dataset.metadata.input_config.bigquery_source.uri)


def multimodal_dataset_set_bigquery_uri(
multimodal_dataset: common.MultimodalDataset,
bigquery_uri: str,
) -> None:
"""Sets the bigquery uri from a multimodal dataset or raises ValueError."""
metadata = (
common.SchemaTablesDatasetMetadata()
if multimodal_dataset.metadata is None
else multimodal_dataset.metadata
)
input_config = (
common.SchemaTablesDatasetMetadataInputConfig()
if metadata.input_config is None
else metadata.input_config
)
bigquery_source = (
common.SchemaTablesDatasetMetadataBigQuerySource()
if input_config.bigquery_source is None
else input_config.bigquery_source
)
bigquery_source.uri = bigquery_uri
input_config.bigquery_source = bigquery_source
metadata.input_config = input_config
multimodal_dataset.metadata = metadata
if not str(multimodal_dataset.metadata.input_config.bigquery_source.uri).startswith(
"bq://"
):
raise ValueError(
"Multimodal dataset bigquery source uri must start with 'bq://'."
)


def _try_import_bigframes() -> Any:
Expand Down
37 changes: 13 additions & 24 deletions vertexai/_genai/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,12 +790,8 @@ def create_from_bigquery(
"""
if isinstance(multimodal_dataset, dict):
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
_datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset)

uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
if not uri.startswith("bq://"):
_datasets_utils.multimodal_dataset_set_bigquery_uri(
multimodal_dataset, f"bq://{uri}"
)
if isinstance(config, dict):
config = types.CreateMultimodalDatasetConfig(**config)
elif not config:
Expand Down Expand Up @@ -998,8 +994,11 @@ def to_bigframes(
elif not multimodal_dataset:
multimodal_dataset = types.MultimodalDataset()

uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
return bigframes.pandas.read_gbq_table(uri.removeprefix("bq://"))
if multimodal_dataset.bigquery_uri is None:
raise ValueError("Multimodal dataset bigquery source uri is not set.")
return bigframes.pandas.read_gbq_table(
multimodal_dataset.bigquery_uri.removeprefix("bq://")
)

def update_multimodal_dataset(
self,
Expand All @@ -1026,12 +1025,8 @@ def update_multimodal_dataset(
"""
if isinstance(multimodal_dataset, dict):
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
_datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset)

uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
if not uri.startswith("bq://"):
_datasets_utils.multimodal_dataset_set_bigquery_uri(
multimodal_dataset, f"bq://{uri}"
)
if isinstance(config, dict):
config = types.CreateMultimodalDatasetConfig(**config)
elif not config:
Expand Down Expand Up @@ -1936,12 +1931,8 @@ async def create_from_bigquery(
"""
if isinstance(multimodal_dataset, dict):
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
_datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset)

uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
if not uri.startswith("bq://"):
_datasets_utils.multimodal_dataset_set_bigquery_uri(
multimodal_dataset, f"bq://{uri}"
)
if isinstance(config, dict):
config = types.CreateMultimodalDatasetConfig(**config)
elif not config:
Expand Down Expand Up @@ -2148,9 +2139,11 @@ async def to_bigframes(
elif not multimodal_dataset:
multimodal_dataset = types.MultimodalDataset()

uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
if multimodal_dataset.bigquery_uri is None:
raise ValueError("Multimodal dataset bigquery source uri is missing.")
return await asyncio.to_thread(
bigframes.pandas.read_gbq_table, uri.removeprefix("bq://")
bigframes.pandas.read_gbq_table,
multimodal_dataset.bigquery_uri.removeprefix("bq://"),
)

async def update_multimodal_dataset(
Expand All @@ -2174,12 +2167,8 @@ async def update_multimodal_dataset(
"""
if isinstance(multimodal_dataset, dict):
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
_datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset)

uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
if not uri.startswith("bq://"):
_datasets_utils.multimodal_dataset_set_bigquery_uri(
multimodal_dataset, f"bq://{uri}"
)
if isinstance(config, dict):
config = types.CreateMultimodalDatasetConfig(**config)
elif not config:
Expand Down
58 changes: 58 additions & 0 deletions vertexai/_genai/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12366,6 +12366,64 @@ class MultimodalDataset(_common.BaseModel):
default=None, description="""The description of the multimodal dataset."""
)

@property
def read_config(self) -> Optional[GeminiRequestReadConfig]:
"""Gets the read config from the dataset metadata. Returns None if it's not set."""
if self.metadata is None or self.metadata.gemini_request_read_config is None:
return None
return self.metadata.gemini_request_read_config

def set_read_config(
self,
*,
read_config: GeminiRequestReadConfigOrDict,
) -> None:
"""Sets the read config in the dataset metadata."""
if isinstance(read_config, dict):
read_config = GeminiRequestReadConfig(**read_config)

if self.metadata is None:
self.metadata = SchemaTablesDatasetMetadata()
self.metadata.gemini_request_read_config = read_config

@property
def bigquery_uri(
self,
) -> Optional[str]:
"""Gets the bigquery uri from the dataset metadata. Returns None if it's not set."""
if (
self.metadata is None
or self.metadata.input_config is None
or self.metadata.input_config.bigquery_source is None
):
return None
return str(self.metadata.input_config.bigquery_source.uri)

def set_bigquery_uri(
self,
bigquery_uri: str,
) -> None:
"""Sets the bigquery uri in the dataset metadata. Prepends 'bq://' if it's not already present."""
if not bigquery_uri.startswith("bq://"):
bigquery_uri = f"bq://{bigquery_uri}"
metadata = (
SchemaTablesDatasetMetadata() if self.metadata is None else self.metadata
)
input_config = (
SchemaTablesDatasetMetadataInputConfig()
if metadata.input_config is None
else metadata.input_config
)
bigquery_source = (
SchemaTablesDatasetMetadataBigQuerySource()
if input_config.bigquery_source is None
else input_config.bigquery_source
)
bigquery_source.uri = bigquery_uri
input_config.bigquery_source = bigquery_source
metadata.input_config = input_config
self.metadata = metadata


class MultimodalDatasetDict(TypedDict, total=False):
"""Represents a multimodal dataset."""
Expand Down
Loading