Skip to content

Commit 179953e

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI Client(evals): Update to generate_conversation_scenarios from generate_user_scenarios
PiperOrigin-RevId: 885882534
1 parent 0cff2d8 commit 179953e

7 files changed

Lines changed: 232 additions & 209 deletions

File tree

tests/unit/vertexai/genai/replays/test_generate_user_scenarios.py renamed to tests/unit/vertexai/genai/replays/test_generate_conversation_scenarios.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
import pytest
2020

2121

22-
def test_gen_user_scenarios(client):
23-
"""Tests that generate_user_scenarios() correctly calls the API and parses the response."""
24-
eval_dataset = client.evals.generate_user_scenarios(
22+
def test_gen_conversation_scenarios(client):
23+
"""Tests that generate_conversation_scenarios() correctly calls the API and parses the response."""
24+
eval_dataset = client.evals.generate_conversation_scenarios(
2525
agents={
2626
"booking-agent": types.evals.AgentConfig(
2727
agent_id="booking-agent",
@@ -40,13 +40,13 @@ def test_gen_user_scenarios(client):
4040
],
4141
)
4242
},
43-
user_scenario_generation_config=types.evals.UserScenarioGenerationConfig(
44-
user_scenario_count=2,
45-
simulation_instruction=(
43+
config=types.evals.UserScenarioGenerationConfig(
44+
count=2,
45+
generation_instruction=(
4646
"Generate scenarios where the user tries to book a flight but"
4747
" changes their mind about the destination."
4848
),
49-
environment_data="Today is Monday. Flights to Paris are available.",
49+
environment_context="Today is Monday. Flights to Paris are available.",
5050
model_name="gemini-2.5-flash",
5151
),
5252
root_agent_id="booking-agent",
@@ -67,9 +67,9 @@ def test_gen_user_scenarios(client):
6767

6868

6969
@pytest.mark.asyncio
70-
async def test_gen_user_scenarios_async(client):
71-
"""Tests that generate_user_scenarios() async correctly calls the API and parses the response."""
72-
eval_dataset = await client.aio.evals.generate_user_scenarios(
70+
async def test_gen_conversation_scenarios_async(client):
71+
"""Tests that generate_conversation_scenarios() async correctly calls the API and parses the response."""
72+
eval_dataset = await client.aio.evals.generate_conversation_scenarios(
7373
agents={
7474
"booking-agent": types.evals.AgentConfig(
7575
agent_id="booking-agent",
@@ -88,13 +88,13 @@ async def test_gen_user_scenarios_async(client):
8888
],
8989
)
9090
},
91-
user_scenario_generation_config=types.evals.UserScenarioGenerationConfig(
92-
user_scenario_count=2,
93-
simulation_instruction=(
91+
config=types.evals.UserScenarioGenerationConfig(
92+
count=2,
93+
generation_instruction=(
9494
"Generate scenarios where the user tries to book a flight but"
9595
" changes their mind about the destination."
9696
),
97-
environment_data="Today is Monday. Flights to Paris are available.",
97+
environment_context="Today is Monday. Flights to Paris are available.",
9898
model_name="gemini-2.5-flash",
9999
),
100100
root_agent_id="booking-agent",
@@ -114,5 +114,5 @@ async def test_gen_user_scenarios_async(client):
114114
pytestmark = pytest_helper.setup(
115115
file=__file__,
116116
globals_for_file=globals(),
117-
test_method="evals.generate_user_scenarios",
117+
test_method="evals.generate_conversation_scenarios",
118118
)

tests/unit/vertexai/genai/test_evals.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6140,8 +6140,8 @@ def read_file_contents_side_effect(src: str) -> str:
61406140
)
61416141

61426142

6143-
class TestEvalsGenerateUserScenarios(unittest.TestCase):
6144-
"""Unit tests for the Evals generate_user_scenarios method."""
6143+
class TestEvalsGenerateConversationScenarios(unittest.TestCase):
6144+
"""Unit tests for the Evals generate_conversation_scenarios method."""
61456145

61466146
def setUp(self):
61476147
self.addCleanup(mock.patch.stopall)
@@ -6161,13 +6161,13 @@ def setUp(self):
61616161
)
61626162
self.mock_api_client.request.return_value = self.mock_response
61636163

6164-
def test_generate_user_scenarios(self):
6165-
"""Tests that generate_user_scenarios correctly calls the API and parses the response."""
6164+
def test_generate_conversation_scenarios(self):
6165+
"""Tests that generate_conversation_scenarios correctly calls the API and parses the response."""
61666166
evals_module = evals.Evals(api_client_=self.mock_api_client)
61676167

6168-
eval_dataset = evals_module.generate_user_scenarios(
6168+
eval_dataset = evals_module.generate_conversation_scenarios(
61696169
agents={"agent_1": {}},
6170-
user_scenario_generation_config={"user_scenario_count": 2},
6170+
config={"count": 2},
61716171
root_agent_id="agent_1",
61726172
)
61736173
assert isinstance(eval_dataset, vertexai_genai_types.EvaluationDataset)
@@ -6187,17 +6187,17 @@ def test_generate_user_scenarios(self):
61876187
self.mock_api_client.request.assert_called_once()
61886188

61896189
@pytest.mark.asyncio
6190-
async def test_async_generate_user_scenarios(self):
6191-
"""Tests that async generate_user_scenarios correctly calls the API and parses the response."""
6190+
async def test_async_generate_conversation_scenarios(self):
6191+
"""Tests that async generate_conversation_scenarios correctly calls the API and parses the response."""
61926192

61936193
self.mock_api_client.async_request = mock.AsyncMock(
61946194
return_value=self.mock_response
61956195
)
61966196
async_evals_module = evals.AsyncEvals(api_client_=self.mock_api_client)
61976197

6198-
eval_dataset = await async_evals_module.generate_user_scenarios(
6198+
eval_dataset = await async_evals_module.generate_conversation_scenarios(
61996199
agents={"agent_1": {}},
6200-
user_scenario_generation_config={"user_scenario_count": 2},
6200+
config={"count": 2},
62016201
root_agent_id="agent_1",
62026202
)
62036203
assert isinstance(eval_dataset, vertexai_genai_types.EvaluationDataset)

vertexai/_genai/_transformers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,27 @@ def t_metric_sources(metrics: list[Any]) -> list[dict[str, Any]]:
125125
return sources_payload
126126

127127

128+
def t_user_scenario_generation_config(
129+
config: "types.evals.UserScenarioGenerationConfigOrDict",
130+
) -> dict[str, Any]:
131+
"""Transforms UserScenarioGenerationConfig to Vertex AI format."""
132+
payload: dict[str, Any] = {}
133+
config_dict = config if isinstance(config, dict) else config.model_dump()
134+
135+
if getv(config_dict, ["count"]) is not None:
136+
payload["user_scenario_count"] = getv(config_dict, ["count"])
137+
if getv(config_dict, ["generation_instruction"]) is not None:
138+
payload["simulation_instruction"] = getv(
139+
config_dict, ["generation_instruction"]
140+
)
141+
if getv(config_dict, ["environment_context"]) is not None:
142+
payload["environment_data"] = getv(config_dict, ["environment_context"])
143+
if getv(config_dict, ["model_name"]) is not None:
144+
payload["model_name"] = getv(config_dict, ["model_name"])
145+
146+
return payload
147+
148+
128149
def t_metric_for_registry(
129150
metric: "types.Metric",
130151
) -> dict[str, Any]:

0 commit comments

Comments
 (0)