Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions tests/unit/vertexai/genai/test_multimodal_datasets_genai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from vertexai._genai import _datasets_utils
from vertexai._genai import types
from google.genai import types as genai_types
import pytest


Expand Down Expand Up @@ -155,3 +156,63 @@ def test_to_bigframes(self, mock_import_bigframes):
mock_import_bigframes.return_value.pandas.read_gbq_table.assert_called_once_with(
"project.dataset.table"
)


class TestGeminiRequestReadConfig:
def test_single_turn_template(self):
read_config = types.GeminiRequestReadConfig.single_turn_template(
model="gemini-1.5-flash",
prompt="test_prompt",
response="test_response",
system_instruction="test_system_instruction",
cached_content="test_cached_content",
tools=[{"function_declarations": [{"name": "test_tool"}]}],
tool_config={"function_calling_config": {"mode": "ANY"}},
safety_settings=[{"category": "HARM_CATEGORY_DANGEROUS_CONTENT"}],
generation_config={"temperature": 0.5},
field_mapping={"test_placeholder": "test_column"},
)

expected_read_config = types.GeminiRequestReadConfig(
template_config=types.GeminiTemplateConfig(
gemini_example=types.GeminiExample(
model="gemini-1.5-flash",
contents=[
genai_types.Content(
role="user",
parts=[genai_types.Part.from_text(text="test_prompt")],
),
genai_types.Content(
role="model",
parts=[genai_types.Part.from_text(text="test_response")],
),
],
system_instruction=genai_types.Content(
parts=[
genai_types.Part.from_text(text="test_system_instruction")
],
),
cached_content="test_cached_content",
tools=[
genai_types.Tool(
function_declarations=[
genai_types.FunctionDeclaration(name="test_tool")
]
)
],
tool_config=genai_types.ToolConfig(
function_calling_config=genai_types.FunctionCallingConfig(
mode="ANY"
)
),
safety_settings=[
genai_types.SafetySetting(
category="HARM_CATEGORY_DANGEROUS_CONTENT"
)
],
generation_config=genai_types.GenerationConfig(temperature=0.5),
),
field_mapping={"test_placeholder": "test_column"},
),
)
assert read_config == expected_read_config
86 changes: 86 additions & 0 deletions vertexai/_genai/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12583,6 +12583,92 @@ class GeminiRequestReadConfig(_common.BaseModel):
description="""Column name in the underlying BigQuery table that contains already fully assembled Gemini requests.""",
)

@classmethod
def single_turn_template(
cls,
*,
prompt: str,
response: Optional[str] = None,
system_instruction: Optional[str] = None,
model: Optional[str] = None,
cached_content: Optional[str] = None,
tools: Optional[list[Union[genai_types.Tool, dict[str, Any]]]] = None,
tool_config: Optional[Union[genai_types.ToolConfig, dict[str, Any]]] = None,
safety_settings: Optional[
list[Union[genai_types.SafetySetting, dict[str, Any]]]
] = None,
generation_config: Optional[
Union[genai_types.GenerationConfig, dict[str, Any]]
] = None,
field_mapping: Optional[dict[str, str]] = None,
) -> "GeminiRequestReadConfig":
"""Constructs a GeminiRequestReadConfig object for single-turn cases.

Example:
read_config = GeminiRequestReadConfig.single_turn_template(
prompt="Which flower is this {flower_image}?",
response="This is a {label}.",
system_instruction="You are a botanical classifier."
)

Args:
prompt: Required. User input.
response: Optional. Model response to user input.
system_instruction: Optional. System instructions for the model.
model: Optional. The model to use for the GeminiExample.
cached_content: Optional. The cached content to use for the GeminiExample.
tools: Optional. The tools to use for the GeminiExample.
tool_config: Optional. The tool config to use for the GeminiExample.
safety_settings: Optional. The safety settings to use for the GeminiExample.
generation_config: Optional. The generation config to use for the GeminiExample.
field_mapping: Optional. Mapping of placeholders to dataset columns.

Returns:
A GeminiRequestReadConfig object.
"""
contents = []
contents.append(
genai_types.Content(
role="user",
parts=[
genai_types.Part.from_text(text=prompt),
],
)
)
if response:
contents.append(
genai_types.Content(
role="model",
parts=[
genai_types.Part.from_text(text=response),
],
)
)

system_instruction_content = None
if system_instruction:
system_instruction_content = genai_types.Content(
parts=[
genai_types.Part.from_text(text=system_instruction),
],
)

return cls(
template_config=GeminiTemplateConfig(
gemini_example=GeminiExample(
model=model,
contents=contents,
system_instruction=system_instruction_content,
cached_content=cached_content,
tools=tools,
tool_config=tool_config,
safety_settings=safety_settings,
generation_config=generation_config,
),
field_mapping=field_mapping,
),
)


class GeminiRequestReadConfigDict(TypedDict, total=False):
"""Represents the config for reading Gemini requests."""
Expand Down
Loading