Skip to content

Commit e73fcf3

Browse files
cleop-googlecopybara-github
authored andcommitted
feat: GenAI SDK client(multimodal) - Support getting a bigframe DataFrame from a multimodal dataset
PiperOrigin-RevId: 882583435
1 parent e11ff3e commit e73fcf3

File tree

4 files changed

+98
-5
lines changed

4 files changed

+98
-5
lines changed

tests/unit/vertexai/genai/replays/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,11 @@ def replays_prefix():
123123
return "test"
124124

125125

126+
@pytest.fixture
127+
def is_replay_mode(request):
128+
return request.config.getoption("--mode") in ["replay", "tap"]
129+
130+
126131
@pytest.fixture
127132
def mock_agent_engine_create_path_exists():
128133
"""Mocks os.path.exists to return True."""

tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,6 @@
2828
BIGQUERY_TABLE_NAME = "vertex-sdk-dev.multimodal_dataset.test-table"
2929

3030

31-
@pytest.fixture
32-
def is_replay_mode(request):
33-
return request.config.getoption("--mode") in ["replay", "tap"]
34-
35-
3631
@pytest.fixture
3732
def mock_bigquery_client(is_replay_mode):
3833
if is_replay_mode:

tests/unit/vertexai/genai/replays/test_get_multimodal_datasets.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,34 @@
1515
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
1616

1717
from tests.unit.vertexai.genai.replays import pytest_helper
18+
from vertexai._genai import _datasets_utils
1819
from vertexai._genai import types
1920

21+
from unittest import mock
2022
import pytest
2123

2224
BIGQUERY_TABLE_NAME = "vertex-sdk-dev.multimodal_dataset.test-table"
2325
DATASET = "8810841321427173376"
2426

2527

28+
@pytest.fixture
29+
def mock_import_bigframes(is_replay_mode):
30+
if is_replay_mode:
31+
with mock.patch.object(
32+
_datasets_utils, "_try_import_bigframes"
33+
) as mock_import_bigframes:
34+
mock_read_gbq_table_result = mock.MagicMock()
35+
mock_read_gbq_table_result.sql = f"SLECT * FROM `{BIGQUERY_TABLE_NAME}`"
36+
37+
bigframes = mock.MagicMock()
38+
bigframes.pandas.read_gbq_table.return_value = mock_read_gbq_table_result
39+
40+
mock_import_bigframes.return_value = bigframes
41+
yield mock_import_bigframes
42+
else:
43+
yield None
44+
45+
2646
def test_get_dataset(client):
2747
dataset = client.datasets._get_multimodal_dataset(
2848
name=DATASET,
@@ -41,6 +61,15 @@ def test_get_dataset_from_public_method(client):
4161
assert dataset.display_name == "test-display-name"
4262

4363

64+
@pytest.mark.usefixtures("mock_import_bigframes")
65+
def test_to_bigframes(client):
66+
dataset = client.datasets.get_multimodal_dataset(
67+
name=DATASET,
68+
)
69+
df = client.datasets.to_bigframes(multimodal_dataset=dataset)
70+
assert BIGQUERY_TABLE_NAME in df.sql
71+
72+
4473
pytestmark = pytest_helper.setup(
4574
file=__file__,
4675
globals_for_file=globals(),
@@ -67,3 +96,13 @@ async def test_get_dataset_from_public_method_async(client):
6796
assert isinstance(dataset, types.MultimodalDataset)
6897
assert dataset.name.endswith(DATASET)
6998
assert dataset.display_name == "test-display-name"
99+
100+
101+
@pytest.mark.asyncio
102+
@pytest.mark.usefixtures("mock_import_bigframes")
103+
async def test_to_bigframes_async(client):
104+
dataset = await client.aio.datasets.get_multimodal_dataset(
105+
name=DATASET,
106+
)
107+
df = await client.aio.datasets.to_bigframes(multimodal_dataset=dataset)
108+
assert BIGQUERY_TABLE_NAME in df.sql

vertexai/_genai/datasets.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,33 @@ def create_from_pandas(
899899
config=config,
900900
)
901901

902+
def to_bigframes(
903+
self,
904+
*,
905+
multimodal_dataset: types.MultimodalDatasetOrDict,
906+
) -> "bigframes.pandas.DataFrame": # type: ignore # noqa: F821
907+
"""Converts a multimodal dataset to a BigFrames dataframe.
908+
909+
This is the preferred method to inspect the multimodal dataset in a
910+
notebook.
911+
912+
Args:
913+
multimodal_dataset:
914+
Required. A representation of a multimodal dataset.
915+
916+
Returns:
917+
A BigFrames dataframe.
918+
"""
919+
bigframes = _datasets_utils._try_import_bigframes()
920+
921+
if isinstance(multimodal_dataset, dict):
922+
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
923+
elif not multimodal_dataset:
924+
multimodal_dataset = types.MultimodalDataset()
925+
926+
uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
927+
return bigframes.pandas.read_gbq_table(uri.removeprefix("bq://"))
928+
902929
def update_multimodal_dataset(
903930
self,
904931
*,
@@ -1948,6 +1975,33 @@ async def create_from_pandas(
19481975
config=config,
19491976
)
19501977

1978+
async def to_bigframes(
1979+
self,
1980+
*,
1981+
multimodal_dataset: types.MultimodalDatasetOrDict,
1982+
) -> "bigframes.pandas.DataFrame": # type: ignore # noqa: F821
1983+
"""Converts a multimodal dataset to a BigFrames dataframe.
1984+
1985+
This is the preferred method to inspect the multimodal dataset in a
1986+
notebook.
1987+
1988+
Args:
1989+
multimodal_dataset:
1990+
Required. A representation of a multimodal dataset.
1991+
1992+
Returns:
1993+
A BigFrames dataframe.
1994+
"""
1995+
bigframes = _datasets_utils._try_import_bigframes()
1996+
1997+
if isinstance(multimodal_dataset, dict):
1998+
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
1999+
elif not multimodal_dataset:
2000+
multimodal_dataset = types.MultimodalDataset()
2001+
2002+
uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
2003+
return bigframes.pandas.read_gbq_table(uri.removeprefix("bq://"))
2004+
19512005
async def update_multimodal_dataset(
19522006
self,
19532007
*,

0 commit comments

Comments
 (0)