Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions tests/unit/vertex_adk/test_agent_engine_templates_adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,23 +546,23 @@ async def test_streaming_agent_run_with_events_force_flush_otel(
async def test_async_create_session(self, get_project_id_mock: mock.Mock):
app = agent_engines.AdkApp(agent=_TEST_AGENT)
session1 = await app.async_create_session(user_id=_TEST_USER_ID)
assert session1.user_id == _TEST_USER_ID
assert session1["user_id"] == _TEST_USER_ID
session2 = await app.async_create_session(
user_id=_TEST_USER_ID, session_id="test_session_id"
)
assert session2.user_id == _TEST_USER_ID
assert session2.id == "test_session_id"
assert session2["user_id"] == _TEST_USER_ID
assert session2["id"] == "test_session_id"

@pytest.mark.asyncio
async def test_async_get_session(self, get_project_id_mock: mock.Mock):
app = agent_engines.AdkApp(agent=_TEST_AGENT)
session1 = await app.async_create_session(user_id=_TEST_USER_ID)
session2 = await app.async_get_session(
user_id=_TEST_USER_ID,
session_id=session1.id,
session_id=session1["id"],
)
assert session2.user_id == _TEST_USER_ID
assert session1.id == session2.id
assert session2["user_id"] == _TEST_USER_ID
assert session1["id"] == session2["id"]

@pytest.mark.asyncio
async def test_async_list_sessions(self, get_project_id_mock: mock.Mock):
Expand All @@ -572,12 +572,12 @@ async def test_async_list_sessions(self, get_project_id_mock: mock.Mock):
session = await app.async_create_session(user_id=_TEST_USER_ID)
response1 = await app.async_list_sessions(user_id=_TEST_USER_ID)
assert len(response1.sessions) == 1
assert response1.sessions[0].id == session.id
assert response1.sessions[0].id == session["id"]
session2 = await app.async_create_session(user_id=_TEST_USER_ID)
response2 = await app.async_list_sessions(user_id=_TEST_USER_ID)
assert len(response2.sessions) == 2
assert response2.sessions[0].id == session.id
assert response2.sessions[1].id == session2.id
assert response2.sessions[0].id == session["id"]
assert response2.sessions[1].id == session2["id"]

@pytest.mark.asyncio
async def test_async_delete_session(self, get_project_id_mock: mock.Mock):
Expand All @@ -592,30 +592,30 @@ async def test_async_delete_session(self, get_project_id_mock: mock.Mock):
assert len(response1.sessions) == 1
await app.async_delete_session(
user_id=_TEST_USER_ID,
session_id=session.id,
session_id=session["id"],
)
response0 = await app.async_list_sessions(user_id=_TEST_USER_ID)
assert not response0.sessions

def test_create_session(self, get_project_id_mock: mock.Mock):
app = agent_engines.AdkApp(agent=_TEST_AGENT)
session1 = app.create_session(user_id=_TEST_USER_ID)
assert session1.user_id == _TEST_USER_ID
assert session1["user_id"] == _TEST_USER_ID
session2 = app.create_session(
user_id=_TEST_USER_ID, session_id="test_session_id"
)
assert session2.user_id == _TEST_USER_ID
assert session2.id == "test_session_id"
assert session2["user_id"] == _TEST_USER_ID
assert session2["id"] == "test_session_id"

def test_get_session(self, get_project_id_mock: mock.Mock):
app = agent_engines.AdkApp(agent=_TEST_AGENT)
session1 = app.create_session(user_id=_TEST_USER_ID)
session2 = app.get_session(
user_id=_TEST_USER_ID,
session_id=session1.id,
session_id=session1["id"],
)
assert session2.user_id == _TEST_USER_ID
assert session1.id == session2.id
assert session2["user_id"] == _TEST_USER_ID
assert session1["id"] == session2["id"]

def test_list_sessions(self, get_project_id_mock: mock.Mock):
app = agent_engines.AdkApp(agent=_TEST_AGENT)
Expand All @@ -624,12 +624,12 @@ def test_list_sessions(self, get_project_id_mock: mock.Mock):
session = app.create_session(user_id=_TEST_USER_ID)
response1 = app.list_sessions(user_id=_TEST_USER_ID)
assert len(response1.sessions) == 1
assert response1.sessions[0].id == session.id
assert response1.sessions[0].id == session["id"]
session2 = app.create_session(user_id=_TEST_USER_ID)
response2 = app.list_sessions(user_id=_TEST_USER_ID)
assert len(response2.sessions) == 2
assert response2.sessions[0].id == session.id
assert response2.sessions[1].id == session2.id
assert response2.sessions[0].id == session["id"]
assert response2.sessions[1].id == session2["id"]

def test_delete_session(self, get_project_id_mock: mock.Mock):
app = agent_engines.AdkApp(agent=_TEST_AGENT)
Expand All @@ -638,7 +638,7 @@ def test_delete_session(self, get_project_id_mock: mock.Mock):
session = app.create_session(user_id=_TEST_USER_ID)
response1 = app.list_sessions(user_id=_TEST_USER_ID)
assert len(response1.sessions) == 1
app.delete_session(user_id=_TEST_USER_ID, session_id=session.id)
app.delete_session(user_id=_TEST_USER_ID, session_id=session["id"])
response0 = app.list_sessions(user_id=_TEST_USER_ID)
assert not response0.sessions

Expand Down
8 changes: 6 additions & 2 deletions vertexai/agent_engines/templates/adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -1318,7 +1318,9 @@ async def async_get_session(
raise RuntimeError(
"Session not found. Please create it using .create_session()"
)
return session
if hasattr(session, "model_dump"):
return session.model_dump(mode="json")
return session.dict() if hasattr(session, "dict") else session

def get_session(
self,
Expand Down Expand Up @@ -1464,7 +1466,9 @@ async def async_create_session(
state=state,
**kwargs,
)
return session
if hasattr(session, "model_dump"):
return session.model_dump(mode="json")
return session.dict() if hasattr(session, "dict") else session

def create_session(
self,
Expand Down