Skip to content

Commit 584cffc

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: GenAI SDK client - Add get session call to create session sdk if an immediate success is returned
PiperOrigin-RevId: 881525769
1 parent 1ecaa9b commit 584cffc

File tree

2 files changed

+24
-16
lines changed

2 files changed

+24
-16
lines changed

tests/unit/vertexai/genai/replays/test_create_agent_engine_session.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_create_session_with_ttl(client):
3232
config=types.CreateAgentEngineSessionConfig(
3333
display_name="my_session",
3434
session_state={"foo": "bar"},
35-
ttl="120s",
35+
ttl="1200000s",
3636
labels={"label_key": "label_value"},
3737
),
3838
)
@@ -42,12 +42,13 @@ def test_create_session_with_ttl(client):
4242
assert operation.response.user_id == "test-user-123"
4343
assert operation.response.labels == {"label_key": "label_value"}
4444
assert operation.response.name.startswith(agent_engine.api_resource.name)
45+
assert operation.done
4546
# Expire time is calculated by the server, so we only check that it is
4647
# within a reasonable range to avoid flakiness.
4748
assert (
48-
operation.response.create_time + datetime.timedelta(seconds=119.5)
49+
operation.response.create_time + datetime.timedelta(seconds=1199999.5)
4950
<= operation.response.expire_time
50-
<= operation.response.create_time + datetime.timedelta(seconds=120.5)
51+
<= operation.response.create_time + datetime.timedelta(seconds=1200000.5)
5152
)
5253
finally:
5354
# Clean up resources.
@@ -60,7 +61,7 @@ def test_create_session_with_expire_time(client):
6061
assert isinstance(agent_engine, types.AgentEngine)
6162
assert isinstance(agent_engine.api_resource, types.ReasoningEngine)
6263
expire_time = datetime.datetime(
63-
2026, 1, 1, 12, 30, 00, tzinfo=datetime.timezone.utc
64+
2028, 1, 1, 12, 30, 00, tzinfo=datetime.timezone.utc
6465
)
6566

6667
operation = client.agent_engines.sessions.create(
@@ -78,6 +79,7 @@ def test_create_session_with_expire_time(client):
7879
assert operation.response.user_id == "test-user-123"
7980
assert operation.response.name.startswith(agent_engine.api_resource.name)
8081
assert operation.response.expire_time == expire_time
82+
assert operation.done
8183
finally:
8284
# Clean up resources.
8385
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)

vertexai/_genai/sessions.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -651,12 +651,15 @@ def create(
651651
user_id=user_id,
652652
config=config,
653653
)
654-
if config.wait_for_completion and not operation.done:
655-
operation = _agent_engines_utils._await_operation(
656-
operation_name=operation.name,
657-
get_operation_fn=self._get_session_operation,
658-
poll_interval_seconds=0.5,
659-
)
654+
if config.wait_for_completion:
655+
if not operation.done:
656+
operation = _agent_engines_utils._await_operation(
657+
operation_name=operation.name,
658+
get_operation_fn=self._get_session_operation,
659+
poll_interval_seconds=0.5,
660+
)
661+
# We need to make a call to get the session because the operation
662+
# response might not contain the relevant fields.
660663
if operation.response:
661664
operation.response = self.get(name=operation.response.name)
662665
elif operation.error:
@@ -1133,12 +1136,15 @@ async def create(
11331136
user_id=user_id,
11341137
config=config,
11351138
)
1136-
if config.wait_for_completion and not operation.done:
1137-
operation = await _agent_engines_utils._await_async_operation(
1138-
operation_name=operation.name,
1139-
get_operation_fn=self._get_session_operation,
1140-
poll_interval_seconds=0.5,
1141-
)
1139+
if config.wait_for_completion:
1140+
if not operation.done:
1141+
operation = await _agent_engines_utils._await_async_operation(
1142+
operation_name=operation.name,
1143+
get_operation_fn=self._get_session_operation,
1144+
poll_interval_seconds=0.5,
1145+
)
1146+
# We need to make a call to get the session because the operation
1147+
# response might not contain the relevant fields.
11421148
if operation.response:
11431149
operation.response = await self.get(name=operation.response.name)
11441150
elif operation.error:

0 commit comments

Comments
 (0)