diff --git a/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py b/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py index 245be86f88..78582bb8fa 100644 --- a/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py +++ b/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py @@ -106,25 +106,6 @@ def test_create_dataset_from_bigquery(client): ) -def test_create_dataset_from_bigquery_without_bq_prefix(client): - dataset = client.datasets.create_from_bigquery( - multimodal_dataset={ - "display_name": "test-from-bigquery", - "description": "test-description-from-bigquery", - "metadata": { - "inputConfig": { - "bigquerySource": {"uri": BIGQUERY_TABLE_NAME}, - }, - }, - }, - ) - assert isinstance(dataset, types.MultimodalDataset) - assert dataset.display_name == "test-from-bigquery" - assert dataset.metadata.input_config.bigquery_source.uri == ( - f"bq://{BIGQUERY_TABLE_NAME}" - ) - - @pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes") def test_create_dataset_from_pandas(client, is_replay_mode): dataframe = pd.DataFrame( @@ -270,26 +251,6 @@ async def test_create_dataset_from_bigquery_async_with_timeout(client): ) -@pytest.mark.asyncio -async def test_create_dataset_from_bigquery_async_without_bq_prefix(client): - dataset = await client.aio.datasets.create_from_bigquery( - multimodal_dataset={ - "display_name": "test-from-bigquery", - "description": "test-description-from-bigquery", - "metadata": { - "inputConfig": { - "bigquerySource": {"uri": BIGQUERY_TABLE_NAME}, - }, - }, - }, - ) - assert isinstance(dataset, types.MultimodalDataset) - assert dataset.display_name == "test-from-bigquery" - assert dataset.metadata.input_config.bigquery_source.uri == ( - f"bq://{BIGQUERY_TABLE_NAME}" - ) - - @pytest.mark.asyncio @pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes") async def test_create_dataset_from_pandas_async(client, is_replay_mode): diff --git a/tests/unit/vertexai/genai/test_multimodal_datasets_genai.py b/tests/unit/vertexai/genai/test_multimodal_datasets_genai.py new file mode 100644 index 0000000000..48707f1a9c --- /dev/null +++ b/tests/unit/vertexai/genai/test_multimodal_datasets_genai.py @@ -0,0 +1,128 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests for multimodal datasets.""" + +from vertexai._genai import types + + +class TestMultimodalDataset: + + def test_read_config(self): + dataset = types.MultimodalDataset( + metadata={ + "gemini_request_read_config": { + "assembled_request_column_name": "test_column", + }, + }, + ) + + assert isinstance(dataset.read_config, types.GeminiRequestReadConfig) + assert dataset.read_config.assembled_request_column_name == "test_column" + + def test_read_config_empty(self): + dataset = types.MultimodalDataset() + assert dataset.read_config is None + + def test_set_read_config(self): + dataset = types.MultimodalDataset() + + dataset.set_read_config( + read_config={ + "assembled_request_column_name": "test_column", + }, + ) + + assert isinstance(dataset, types.MultimodalDataset) + assert ( + dataset.metadata.gemini_request_read_config.assembled_request_column_name + == "test_column" + ) + + def test_set_read_config_preserves_other_fields(self): + dataset = types.MultimodalDataset( + metadata={ + "inputConfig": { + "bigquerySource": {"uri": "bq://test_table"}, + }, + }, + ) + + dataset.set_read_config( + read_config={ + "assembled_request_column_name": "test_column", + }, + ) + + assert isinstance(dataset, types.MultimodalDataset) + assert ( + dataset.metadata.gemini_request_read_config.assembled_request_column_name + == "test_column" + ) + assert dataset.metadata.input_config.bigquery_source.uri == "bq://test_table" + + def test_bigquery_uri(self): + dataset = types.MultimodalDataset( + metadata={ + "inputConfig": { + "bigquerySource": {"uri": "bq://project.dataset.table"}, + }, + }, + ) + + assert dataset.bigquery_uri == "bq://project.dataset.table" + + def test_bigquery_uri_empty(self): + dataset = types.MultimodalDataset() + assert dataset.bigquery_uri is None + + def test_set_bigquery_uri(self): + dataset = types.MultimodalDataset() + + dataset.set_bigquery_uri("bq://project.dataset.table") + + assert isinstance(dataset, types.MultimodalDataset) + assert ( + dataset.metadata.input_config.bigquery_source.uri + == "bq://project.dataset.table" + ) + + def test_set_bigquery_uri_without_prefix(self): + dataset = types.MultimodalDataset() + + dataset.set_bigquery_uri("project.dataset.table") + + assert isinstance(dataset, types.MultimodalDataset) + assert ( + dataset.metadata.input_config.bigquery_source.uri + == "bq://project.dataset.table" + ) + + def test_set_bigquery_uri_preserves_other_fields(self): + dataset = types.MultimodalDataset( + metadata={ + "gemini_request_read_config": { + "assembled_request_column_name": "test_column", + }, + }, + ) + + dataset.set_bigquery_uri("bq://test_table") + + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.metadata.input_config.bigquery_source.uri == "bq://test_table" + assert ( + dataset.metadata.gemini_request_read_config.assembled_request_column_name + == "test_column" + ) diff --git a/vertexai/_genai/_datasets_utils.py b/vertexai/_genai/_datasets_utils.py index f225a185c7..e49280f200 100644 --- a/vertexai/_genai/_datasets_utils.py +++ b/vertexai/_genai/_datasets_utils.py @@ -44,10 +44,10 @@ def create_from_response(model_type: Type[T], response: dict[str, Any]) -> T: return model_type(**filtered_response) -def multimodal_dataset_get_bigquery_uri( +def validate_multimodal_dataset_bigquery_uri( multimodal_dataset: common.MultimodalDataset, -) -> str: - """Gets the bigquery uri from a multimodal dataset or raises ValueError.""" +) -> None: + """Validates that a multimodal dataset has a bigquery uri or raises ValueError.""" if ( not hasattr(multimodal_dataset, "metadata") or multimodal_dataset.metadata is None @@ -70,33 +70,12 @@ def multimodal_dataset_get_bigquery_uri( raise ValueError( "Multimodal dataset input config bigquery source uri is required." ) - return str(multimodal_dataset.metadata.input_config.bigquery_source.uri) - - -def multimodal_dataset_set_bigquery_uri( - multimodal_dataset: common.MultimodalDataset, - bigquery_uri: str, -) -> None: - """Sets the bigquery uri from a multimodal dataset or raises ValueError.""" - metadata = ( - common.SchemaTablesDatasetMetadata() - if multimodal_dataset.metadata is None - else multimodal_dataset.metadata - ) - input_config = ( - common.SchemaTablesDatasetMetadataInputConfig() - if metadata.input_config is None - else metadata.input_config - ) - bigquery_source = ( - common.SchemaTablesDatasetMetadataBigQuerySource() - if input_config.bigquery_source is None - else input_config.bigquery_source - ) - bigquery_source.uri = bigquery_uri - input_config.bigquery_source = bigquery_source - metadata.input_config = input_config - multimodal_dataset.metadata = metadata + if not str(multimodal_dataset.metadata.input_config.bigquery_source.uri).startswith( + "bq://" + ): + raise ValueError( + "Multimodal dataset bigquery source uri must start with 'bq://'." + ) def _try_import_bigframes() -> Any: diff --git a/vertexai/_genai/datasets.py b/vertexai/_genai/datasets.py index 0a9aa6e48e..ab4f9dc2d9 100644 --- a/vertexai/_genai/datasets.py +++ b/vertexai/_genai/datasets.py @@ -790,12 +790,8 @@ def create_from_bigquery( """ if isinstance(multimodal_dataset, dict): multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + _datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset) - uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset) - if not uri.startswith("bq://"): - _datasets_utils.multimodal_dataset_set_bigquery_uri( - multimodal_dataset, f"bq://{uri}" - ) if isinstance(config, dict): config = types.CreateMultimodalDatasetConfig(**config) elif not config: @@ -998,8 +994,11 @@ def to_bigframes( elif not multimodal_dataset: multimodal_dataset = types.MultimodalDataset() - uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset) - return bigframes.pandas.read_gbq_table(uri.removeprefix("bq://")) + if multimodal_dataset.bigquery_uri is None: + raise ValueError("Multimodal dataset bigquery source uri is not set.") + return bigframes.pandas.read_gbq_table( + multimodal_dataset.bigquery_uri.removeprefix("bq://") + ) def update_multimodal_dataset( self, @@ -1026,12 +1025,8 @@ def update_multimodal_dataset( """ if isinstance(multimodal_dataset, dict): multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + _datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset) - uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset) - if not uri.startswith("bq://"): - _datasets_utils.multimodal_dataset_set_bigquery_uri( - multimodal_dataset, f"bq://{uri}" - ) if isinstance(config, dict): config = types.CreateMultimodalDatasetConfig(**config) elif not config: @@ -1936,12 +1931,8 @@ async def create_from_bigquery( """ if isinstance(multimodal_dataset, dict): multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + _datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset) - uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset) - if not uri.startswith("bq://"): - _datasets_utils.multimodal_dataset_set_bigquery_uri( - multimodal_dataset, f"bq://{uri}" - ) if isinstance(config, dict): config = types.CreateMultimodalDatasetConfig(**config) elif not config: @@ -2148,9 +2139,11 @@ async def to_bigframes( elif not multimodal_dataset: multimodal_dataset = types.MultimodalDataset() - uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset) + if multimodal_dataset.bigquery_uri is None: + raise ValueError("Multimodal dataset bigquery source uri is missing.") return await asyncio.to_thread( - bigframes.pandas.read_gbq_table, uri.removeprefix("bq://") + bigframes.pandas.read_gbq_table, + multimodal_dataset.bigquery_uri.removeprefix("bq://"), ) async def update_multimodal_dataset( @@ -2174,12 +2167,8 @@ async def update_multimodal_dataset( """ if isinstance(multimodal_dataset, dict): multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + _datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset) - uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset) - if not uri.startswith("bq://"): - _datasets_utils.multimodal_dataset_set_bigquery_uri( - multimodal_dataset, f"bq://{uri}" - ) if isinstance(config, dict): config = types.CreateMultimodalDatasetConfig(**config) elif not config: diff --git a/vertexai/_genai/types/common.py b/vertexai/_genai/types/common.py index 437da8bbb9..e6a7ea4b3c 100644 --- a/vertexai/_genai/types/common.py +++ b/vertexai/_genai/types/common.py @@ -12366,6 +12366,64 @@ class MultimodalDataset(_common.BaseModel): default=None, description="""The description of the multimodal dataset.""" ) + @property + def read_config(self) -> Optional[GeminiRequestReadConfig]: + """Gets the read config from the dataset metadata. Returns None if it's not set.""" + if self.metadata is None or self.metadata.gemini_request_read_config is None: + return None + return self.metadata.gemini_request_read_config + + def set_read_config( + self, + *, + read_config: GeminiRequestReadConfigOrDict, + ) -> None: + """Sets the read config in the dataset metadata.""" + if isinstance(read_config, dict): + read_config = GeminiRequestReadConfig(**read_config) + + if self.metadata is None: + self.metadata = SchemaTablesDatasetMetadata() + self.metadata.gemini_request_read_config = read_config + + @property + def bigquery_uri( + self, + ) -> Optional[str]: + """Gets the bigquery uri from the dataset metadata. Returns None if it's not set.""" + if ( + self.metadata is None + or self.metadata.input_config is None + or self.metadata.input_config.bigquery_source is None + ): + return None + return str(self.metadata.input_config.bigquery_source.uri) + + def set_bigquery_uri( + self, + bigquery_uri: str, + ) -> None: + """Sets the bigquery uri in the dataset metadata. Prepends 'bq://' if it's not already present.""" + if not bigquery_uri.startswith("bq://"): + bigquery_uri = f"bq://{bigquery_uri}" + metadata = ( + SchemaTablesDatasetMetadata() if self.metadata is None else self.metadata + ) + input_config = ( + SchemaTablesDatasetMetadataInputConfig() + if metadata.input_config is None + else metadata.input_config + ) + bigquery_source = ( + SchemaTablesDatasetMetadataBigQuerySource() + if input_config.bigquery_source is None + else input_config.bigquery_source + ) + bigquery_source.uri = bigquery_uri + input_config.bigquery_source = bigquery_source + metadata.input_config = input_config + self.metadata = metadata + class MultimodalDatasetDict(TypedDict, total=False): """Represents a multimodal dataset."""