Skip to content

Commit 2f6b86a

Browse files
cleop-googlecopybara-github
authored andcommitted
feat: GenAI SDK client(multimodal) - Support getting a bigframe DataFrame from a multimodal dataset
PiperOrigin-RevId: 882583435
1 parent 3d05ffa commit 2f6b86a

5 files changed

Lines changed: 173 additions & 121 deletions

File tree

tests/unit/vertexai/genai/replays/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,11 @@ def replays_prefix():
123123
return "test"
124124

125125

126+
@pytest.fixture
127+
def is_replay_mode(request):
128+
return request.config.getoption("--mode") in ["replay", "tap"]
129+
130+
126131
@pytest.fixture
127132
def mock_agent_engine_create_path_exists():
128133
"""Mocks os.path.exists to return True."""

tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,6 @@
2828
BIGQUERY_TABLE_NAME = "vertex-sdk-dev.multimodal_dataset.test-table"
2929

3030

31-
@pytest.fixture
32-
def is_replay_mode(request):
33-
return request.config.getoption("--mode") in ["replay", "tap"]
34-
35-
3631
@pytest.fixture
3732
def mock_bigquery_client(is_replay_mode):
3833
if is_replay_mode:

tests/unit/vertexai/genai/replays/test_get_multimodal_datasets.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,34 @@
1515
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
1616

1717
from tests.unit.vertexai.genai.replays import pytest_helper
18+
from vertexai._genai import _datasets_utils
1819
from vertexai._genai import types
1920

21+
from unittest import mock
2022
import pytest
2123

2224
BIGQUERY_TABLE_NAME = "vertex-sdk-dev.multimodal_dataset.test-table"
2325
DATASET = "8810841321427173376"
2426

2527

28+
@pytest.fixture
29+
def mock_import_bigframes(is_replay_mode):
30+
if is_replay_mode:
31+
with mock.patch.object(
32+
_datasets_utils, "_try_import_bigframes"
33+
) as mock_import_bigframes:
34+
mock_read_gbq_table_result = mock.MagicMock()
35+
mock_read_gbq_table_result.sql = f"SLECT * FROM `{BIGQUERY_TABLE_NAME}`"
36+
37+
bigframes = mock.MagicMock()
38+
bigframes.pandas.read_gbq_table.return_value = mock_read_gbq_table_result
39+
40+
mock_import_bigframes.return_value = bigframes
41+
yield mock_import_bigframes
42+
else:
43+
yield None
44+
45+
2646
def test_get_dataset(client):
2747
dataset = client.datasets._get_multimodal_dataset(
2848
name=DATASET,
@@ -41,6 +61,15 @@ def test_get_dataset_from_public_method(client):
4161
assert dataset.display_name == "test-display-name"
4262

4363

64+
@pytest.mark.usefixtures("mock_import_bigframes")
65+
def test_to_bigframes(client):
66+
dataset = client.datasets.get_multimodal_dataset(
67+
name=DATASET,
68+
)
69+
df = client.datasets.to_bigframes(multimodal_dataset=dataset)
70+
assert BIGQUERY_TABLE_NAME in df.sql
71+
72+
4473
pytestmark = pytest_helper.setup(
4574
file=__file__,
4675
globals_for_file=globals(),
@@ -67,3 +96,13 @@ async def test_get_dataset_from_public_method_async(client):
6796
assert isinstance(dataset, types.MultimodalDataset)
6897
assert dataset.name.endswith(DATASET)
6998
assert dataset.display_name == "test-display-name"
99+
100+
101+
@pytest.mark.asyncio
102+
@pytest.mark.usefixtures("mock_import_bigframes")
103+
async def test_to_bigframes_async(client):
104+
dataset = await client.aio.datasets.get_multimodal_dataset(
105+
name=DATASET,
106+
)
107+
df = await client.aio.datasets.to_bigframes(multimodal_dataset=dataset)
108+
assert BIGQUERY_TABLE_NAME in df.sql

vertexai/_genai/_datasets_utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,61 @@ def create_from_response(model_type: Type[T], response: dict[str, Any]) -> T:
4242
return model_type(**filtered_response)
4343

4444

45+
def multimodal_dataset_get_bigquery_uri(
46+
multimodal_dataset: common.MultimodalDataset,
47+
) -> str:
48+
"""Gets the bigquery uri from a multimodal dataset or raises ValueError."""
49+
if (
50+
not hasattr(multimodal_dataset, "metadata")
51+
or multimodal_dataset.metadata is None
52+
):
53+
raise ValueError("Multimodal dataset metadata is required.")
54+
if (
55+
not hasattr(multimodal_dataset.metadata, "input_config")
56+
or multimodal_dataset.metadata.input_config is None
57+
):
58+
raise ValueError("Multimodal dataset input config is required.")
59+
if (
60+
not hasattr(multimodal_dataset.metadata.input_config, "bigquery_source")
61+
or multimodal_dataset.metadata.input_config.bigquery_source is None
62+
):
63+
raise ValueError("Multimodal dataset input config bigquery source is required.")
64+
if (
65+
not hasattr(multimodal_dataset.metadata.input_config.bigquery_source, "uri")
66+
or multimodal_dataset.metadata.input_config.bigquery_source.uri is None
67+
):
68+
raise ValueError(
69+
"Multimodal dataset input config bigquery source uri is required."
70+
)
71+
return str(multimodal_dataset.metadata.input_config.bigquery_source.uri)
72+
73+
74+
def multimodal_dataset_set_bigquery_uri(
75+
multimodal_dataset: common.MultimodalDataset,
76+
bigquery_uri: str,
77+
) -> None:
78+
"""Sets the bigquery uri from a multimodal dataset or raises ValueError."""
79+
metadata = (
80+
common.SchemaTablesDatasetMetadata()
81+
if multimodal_dataset.metadata is None
82+
else multimodal_dataset.metadata
83+
)
84+
input_config = (
85+
common.SchemaTablesDatasetMetadataInputConfig()
86+
if metadata.input_config is None
87+
else metadata.input_config
88+
)
89+
bigquery_source = (
90+
common.SchemaTablesDatasetMetadataBigQuerySource()
91+
if input_config.bigquery_source is None
92+
else input_config.bigquery_source
93+
)
94+
bigquery_source.uri = bigquery_uri
95+
input_config.bigquery_source = bigquery_source
96+
metadata.input_config = input_config
97+
multimodal_dataset.metadata = metadata
98+
99+
45100
def _try_import_bigframes() -> Any:
46101
"""Tries to import `bigframes`."""
47102
try:

vertexai/_genai/datasets.py

Lines changed: 74 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -790,35 +790,11 @@ def create_from_bigquery(
790790
"""
791791
if isinstance(multimodal_dataset, dict):
792792
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
793-
if (
794-
not hasattr(multimodal_dataset, "metadata")
795-
or multimodal_dataset.metadata is None
796-
):
797-
raise ValueError("Multimodal dataset metadata is required.")
798-
if (
799-
not hasattr(multimodal_dataset.metadata, "input_config")
800-
or multimodal_dataset.metadata.input_config is None
801-
):
802-
raise ValueError("Multimodal dataset input config is required.")
803-
if (
804-
not hasattr(multimodal_dataset.metadata.input_config, "bigquery_source")
805-
or multimodal_dataset.metadata.input_config.bigquery_source is None
806-
):
807-
raise ValueError(
808-
"Multimodal dataset input config bigquery source is required."
809-
)
810-
if (
811-
not hasattr(multimodal_dataset.metadata.input_config.bigquery_source, "uri")
812-
or multimodal_dataset.metadata.input_config.bigquery_source.uri is None
813-
):
814-
raise ValueError(
815-
"Multimodal dataset input config bigquery source uri is required."
816-
)
817-
if not multimodal_dataset.metadata.input_config.bigquery_source.uri.startswith(
818-
"bq://"
819-
):
820-
multimodal_dataset.metadata.input_config.bigquery_source.uri = (
821-
f"bq://{multimodal_dataset.metadata.input_config.bigquery_source.uri}"
793+
794+
uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
795+
if not uri.startswith("bq://"):
796+
_datasets_utils.multimodal_dataset_set_bigquery_uri(
797+
multimodal_dataset, f"bq://{uri}"
822798
)
823799
if isinstance(config, dict):
824800
config = types.CreateMultimodalDatasetConfig(**config)
@@ -923,6 +899,33 @@ def create_from_pandas(
923899
config=config,
924900
)
925901

902+
def to_bigframes(
903+
self,
904+
*,
905+
multimodal_dataset: types.MultimodalDatasetOrDict,
906+
) -> "bigframes.pandas.DataFrame": # type: ignore # noqa: F821
907+
"""Converts a multimodal dataset to a BigFrames dataframe.
908+
909+
This is the preferred method to inspect the multimodal dataset in a
910+
notebook.
911+
912+
Args:
913+
multimodal_dataset:
914+
Required. A representation of a multimodal dataset.
915+
916+
Returns:
917+
A BigFrames dataframe.
918+
"""
919+
bigframes = _datasets_utils._try_import_bigframes()
920+
921+
if isinstance(multimodal_dataset, dict):
922+
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
923+
elif not multimodal_dataset:
924+
multimodal_dataset = types.MultimodalDataset()
925+
926+
uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
927+
return bigframes.pandas.read_gbq_table(uri.lstrip("bq://"))
928+
926929
def update_multimodal_dataset(
927930
self,
928931
*,
@@ -948,35 +951,11 @@ def update_multimodal_dataset(
948951
"""
949952
if isinstance(multimodal_dataset, dict):
950953
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
951-
if (
952-
not hasattr(multimodal_dataset, "metadata")
953-
or multimodal_dataset.metadata is None
954-
):
955-
raise ValueError("Multimodal dataset metadata is required.")
956-
if (
957-
not hasattr(multimodal_dataset.metadata, "input_config")
958-
or multimodal_dataset.metadata.input_config is None
959-
):
960-
raise ValueError("Multimodal dataset input config is required.")
961-
if (
962-
not hasattr(multimodal_dataset.metadata.input_config, "bigquery_source")
963-
or multimodal_dataset.metadata.input_config.bigquery_source is None
964-
):
965-
raise ValueError(
966-
"Multimodal dataset input config bigquery source is required."
967-
)
968-
if (
969-
not hasattr(multimodal_dataset.metadata.input_config.bigquery_source, "uri")
970-
or multimodal_dataset.metadata.input_config.bigquery_source.uri is None
971-
):
972-
raise ValueError(
973-
"Multimodal dataset input config bigquery source uri is required."
974-
)
975-
if not multimodal_dataset.metadata.input_config.bigquery_source.uri.startswith(
976-
"bq://"
977-
):
978-
multimodal_dataset.metadata.input_config.bigquery_source.uri = (
979-
f"bq://{multimodal_dataset.metadata.input_config.bigquery_source.uri}"
954+
955+
uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
956+
if not uri.startswith("bq://"):
957+
_datasets_utils.multimodal_dataset_set_bigquery_uri(
958+
multimodal_dataset, f"bq://{uri}"
980959
)
981960
if isinstance(config, dict):
982961
config = types.CreateMultimodalDatasetConfig(**config)
@@ -1887,35 +1866,11 @@ async def create_from_bigquery(
18871866
"""
18881867
if isinstance(multimodal_dataset, dict):
18891868
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
1890-
if (
1891-
not hasattr(multimodal_dataset, "metadata")
1892-
or multimodal_dataset.metadata is None
1893-
):
1894-
raise ValueError("Multimodal dataset metadata is required.")
1895-
if (
1896-
not hasattr(multimodal_dataset.metadata, "input_config")
1897-
or multimodal_dataset.metadata.input_config is None
1898-
):
1899-
raise ValueError("Multimodal dataset input config is required.")
1900-
if (
1901-
not hasattr(multimodal_dataset.metadata.input_config, "bigquery_source")
1902-
or multimodal_dataset.metadata.input_config.bigquery_source is None
1903-
):
1904-
raise ValueError(
1905-
"Multimodal dataset input config bigquery source is required."
1906-
)
1907-
if (
1908-
not hasattr(multimodal_dataset.metadata.input_config.bigquery_source, "uri")
1909-
or multimodal_dataset.metadata.input_config.bigquery_source.uri is None
1910-
):
1911-
raise ValueError(
1912-
"Multimodal dataset input config bigquery source uri is required."
1913-
)
1914-
if not multimodal_dataset.metadata.input_config.bigquery_source.uri.startswith(
1915-
"bq://"
1916-
):
1917-
multimodal_dataset.metadata.input_config.bigquery_source.uri = (
1918-
f"bq://{multimodal_dataset.metadata.input_config.bigquery_source.uri}"
1869+
1870+
uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
1871+
if not uri.startswith("bq://"):
1872+
_datasets_utils.multimodal_dataset_set_bigquery_uri(
1873+
multimodal_dataset, f"bq://{uri}"
19191874
)
19201875
if isinstance(config, dict):
19211876
config = types.CreateMultimodalDatasetConfig(**config)
@@ -2020,6 +1975,33 @@ async def create_from_pandas(
20201975
config=config,
20211976
)
20221977

1978+
async def to_bigframes(
1979+
self,
1980+
*,
1981+
multimodal_dataset: types.MultimodalDatasetOrDict,
1982+
) -> "bigframes.pandas.DataFrame": # type: ignore # noqa: F821
1983+
"""Converts a multimodal dataset to a BigFrames dataframe.
1984+
1985+
This is the preferred method to inspect the multimodal dataset in a
1986+
notebook.
1987+
1988+
Args:
1989+
multimodal_dataset:
1990+
Required. A representation of a multimodal dataset.
1991+
1992+
Returns:
1993+
A BigFrames dataframe.
1994+
"""
1995+
bigframes = _datasets_utils._try_import_bigframes()
1996+
1997+
if isinstance(multimodal_dataset, dict):
1998+
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
1999+
elif not multimodal_dataset:
2000+
multimodal_dataset = types.MultimodalDataset()
2001+
2002+
uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
2003+
return bigframes.pandas.read_gbq_table(uri.lstrip("bq://"))
2004+
20232005
async def update_multimodal_dataset(
20242006
self,
20252007
*,
@@ -2041,35 +2023,11 @@ async def update_multimodal_dataset(
20412023
"""
20422024
if isinstance(multimodal_dataset, dict):
20432025
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
2044-
if (
2045-
not hasattr(multimodal_dataset, "metadata")
2046-
or multimodal_dataset.metadata is None
2047-
):
2048-
raise ValueError("Multimodal dataset metadata is required.")
2049-
if (
2050-
not hasattr(multimodal_dataset.metadata, "input_config")
2051-
or multimodal_dataset.metadata.input_config is None
2052-
):
2053-
raise ValueError("Multimodal dataset input config is required.")
2054-
if (
2055-
not hasattr(multimodal_dataset.metadata.input_config, "bigquery_source")
2056-
or multimodal_dataset.metadata.input_config.bigquery_source is None
2057-
):
2058-
raise ValueError(
2059-
"Multimodal dataset input config bigquery source is required."
2060-
)
2061-
if (
2062-
not hasattr(multimodal_dataset.metadata.input_config.bigquery_source, "uri")
2063-
or multimodal_dataset.metadata.input_config.bigquery_source.uri is None
2064-
):
2065-
raise ValueError(
2066-
"Multimodal dataset input config bigquery source uri is required."
2067-
)
2068-
if not multimodal_dataset.metadata.input_config.bigquery_source.uri.startswith(
2069-
"bq://"
2070-
):
2071-
multimodal_dataset.metadata.input_config.bigquery_source.uri = (
2072-
f"bq://{multimodal_dataset.metadata.input_config.bigquery_source.uri}"
2026+
2027+
uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
2028+
if not uri.startswith("bq://"):
2029+
_datasets_utils.multimodal_dataset_set_bigquery_uri(
2030+
multimodal_dataset, f"bq://{uri}"
20732031
)
20742032
if isinstance(config, dict):
20752033
config = types.CreateMultimodalDatasetConfig(**config)

0 commit comments

Comments
 (0)