Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions src/google/adk/sessions/base_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import abc
import base64
from typing import Any
from typing import Optional

Expand All @@ -25,6 +26,29 @@
from .session import Session
from .state import State

_DEFAULT_PAGE_SIZE = 20
_MAX_PAGE_SIZE = 100


def _resolve_page_size(page_size: Optional[int]) -> int:
"""Clamp *page_size* to [1, _MAX_PAGE_SIZE], defaulting to _DEFAULT_PAGE_SIZE."""
if page_size is None:
return _DEFAULT_PAGE_SIZE
return max(1, min(page_size, _MAX_PAGE_SIZE))


def _encode_page_token(offset: int) -> str:
return base64.b64encode(str(offset).encode()).decode()


def _decode_page_token(token: Optional[str]) -> int:
if not token:
return 0
try:
return max(0, int(base64.b64decode(token).decode()))
except (ValueError, TypeError):
return 0


class GetSessionConfig(BaseModel):
"""The configuration of getting a session."""
Expand All @@ -40,6 +64,7 @@ class ListSessionsResponse(BaseModel):
"""

sessions: list[Session] = Field(default_factory=list)
next_page_token: Optional[str] = None


class BaseSessionService(abc.ABC):
Expand Down Expand Up @@ -83,17 +108,27 @@ async def get_session(

@abc.abstractmethod
async def list_sessions(
self, *, app_name: str, user_id: Optional[str] = None
self,
*,
app_name: str,
user_id: Optional[str] = None,
page_size: Optional[int] = None,
page_token: Optional[str] = None,
) -> ListSessionsResponse:
"""Lists all the sessions for a user.
"""Lists sessions, optionally filtered by user, with pagination.

Args:
app_name: The name of the app.
user_id: The ID of the user. If not provided, lists all sessions for all
users.
page_size: Maximum number of sessions to return per page. Defaults to 20,
maximum 100.
page_token: Token returned from a previous ``list_sessions`` call to
fetch the next page.

Returns:
A ListSessionsResponse containing the sessions.
A ListSessionsResponse containing the sessions and an optional
``next_page_token`` for fetching subsequent pages.
"""

@abc.abstractmethod
Expand Down
34 changes: 31 additions & 3 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListSessionsResponse
from .base_session_service import _decode_page_token
from .base_session_service import _encode_page_token
from .base_session_service import _resolve_page_size
from .migration import _schema_check_utils
from .schemas.v0 import Base as BaseV0
from .schemas.v0 import StorageAppState as StorageAppStateV0
Expand Down Expand Up @@ -492,19 +495,36 @@ async def get_session(

@override
async def list_sessions(
self, *, app_name: str, user_id: Optional[str] = None
self,
*,
app_name: str,
user_id: Optional[str] = None,
page_size: Optional[int] = None,
page_token: Optional[str] = None,
) -> ListSessionsResponse:
await self._prepare_tables()
schema = self._get_schema_classes()

effective_page_size = _resolve_page_size(page_size)
offset = _decode_page_token(page_token)

async with self._rollback_on_exception_session() as sql_session:
stmt = select(schema.StorageSession).filter(
schema.StorageSession.app_name == app_name
)
if user_id is not None:
stmt = stmt.filter(schema.StorageSession.user_id == user_id)

stmt = stmt.order_by(schema.StorageSession.update_time.desc())
# Fetch one extra row to determine if there is a next page.
stmt = stmt.offset(offset).limit(effective_page_size + 1)

result = await sql_session.execute(stmt)
results = result.scalars().all()
results = list(result.scalars().all())

has_next_page = len(results) > effective_page_size
if has_next_page:
results = results[:effective_page_size]

# Fetch app state from storage
storage_app_state = await sql_session.get(
Expand Down Expand Up @@ -538,7 +558,15 @@ async def list_sessions(
sessions.append(
storage_session.to_session(state=merged_state, is_sqlite=is_sqlite)
)
return ListSessionsResponse(sessions=sessions)

next_page_token = (
_encode_page_token(offset + effective_page_size)
if has_next_page
else None
)
return ListSessionsResponse(
sessions=sessions, next_page_token=next_page_token
)

@override
async def delete_session(
Expand Down
71 changes: 58 additions & 13 deletions src/google/adk/sessions/in_memory_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListSessionsResponse
from .base_session_service import _decode_page_token
from .base_session_service import _encode_page_token
from .base_session_service import _resolve_page_size
from .session import Session
from .state import State

Expand Down Expand Up @@ -220,42 +223,84 @@ def _merge_state(

@override
async def list_sessions(
self, *, app_name: str, user_id: Optional[str] = None
self,
*,
app_name: str,
user_id: Optional[str] = None,
page_size: Optional[int] = None,
page_token: Optional[str] = None,
) -> ListSessionsResponse:
return self._list_sessions_impl(app_name=app_name, user_id=user_id)
return self._list_sessions_impl(
app_name=app_name,
user_id=user_id,
page_size=page_size,
page_token=page_token,
)

def list_sessions_sync(
self, *, app_name: str, user_id: Optional[str] = None
self,
*,
app_name: str,
user_id: Optional[str] = None,
page_size: Optional[int] = None,
page_token: Optional[str] = None,
) -> ListSessionsResponse:
logger.warning('Deprecated. Please migrate to the async method.')
return self._list_sessions_impl(app_name=app_name, user_id=user_id)
return self._list_sessions_impl(
app_name=app_name,
user_id=user_id,
page_size=page_size,
page_token=page_token,
)

def _list_sessions_impl(
self, *, app_name: str, user_id: Optional[str] = None
self,
*,
app_name: str,
user_id: Optional[str] = None,
page_size: Optional[int] = None,
page_token: Optional[str] = None,
) -> ListSessionsResponse:
empty_response = ListSessionsResponse()
if app_name not in self.sessions:
return empty_response
if user_id is not None and user_id not in self.sessions[app_name]:
return empty_response

sessions_without_events = []
all_sessions = []

if user_id is None:
for user_id in self.sessions[app_name]:
for session_id in self.sessions[app_name][user_id]:
session = self.sessions[app_name][user_id][session_id]
for uid in self.sessions[app_name]:
for session_id in self.sessions[app_name][uid]:
session = self.sessions[app_name][uid][session_id]
copied_session = copy.deepcopy(session)
copied_session.events = []
copied_session = self._merge_state(app_name, user_id, copied_session)
sessions_without_events.append(copied_session)
copied_session = self._merge_state(app_name, uid, copied_session)
all_sessions.append(copied_session)
else:
for session in self.sessions[app_name][user_id].values():
copied_session = copy.deepcopy(session)
copied_session.events = []
copied_session = self._merge_state(app_name, user_id, copied_session)
sessions_without_events.append(copied_session)
return ListSessionsResponse(sessions=sessions_without_events)
all_sessions.append(copied_session)

# Sort by last_update_time descending (most recently updated first)
all_sessions.sort(
key=lambda s: s.last_update_time if s.last_update_time else 0,
reverse=True,
)

effective_page_size = _resolve_page_size(page_size)
offset = _decode_page_token(page_token)

page = all_sessions[offset : offset + effective_page_size]
has_next_page = (offset + effective_page_size) < len(all_sessions)
next_page_token = (
_encode_page_token(offset + effective_page_size)
if has_next_page
else None
)
return ListSessionsResponse(sessions=page, next_page_token=next_page_token)

@override
async def delete_session(
Expand Down
39 changes: 32 additions & 7 deletions src/google/adk/sessions/sqlite_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListSessionsResponse
from .base_session_service import _decode_page_token
from .base_session_service import _encode_page_token
from .base_session_service import _resolve_page_size
from .session import Session
from .state import State

Expand Down Expand Up @@ -292,24 +295,38 @@ async def get_session(

@override
async def list_sessions(
self, *, app_name: str, user_id: Optional[str] = None
self,
*,
app_name: str,
user_id: Optional[str] = None,
page_size: Optional[int] = None,
page_token: Optional[str] = None,
) -> ListSessionsResponse:
effective_page_size = _resolve_page_size(page_size)
offset = _decode_page_token(page_token)

sessions_list = []
async with self._get_db_connection() as db:
# Fetch sessions
# Fetch sessions with ORDER BY / LIMIT / OFFSET
if user_id:
session_rows = await db.execute_fetchall(
"SELECT id, user_id, state, update_time FROM sessions WHERE"
" app_name=? AND user_id=?",
(app_name, user_id),
" app_name=? AND user_id=?"
" ORDER BY update_time DESC LIMIT ? OFFSET ?",
(app_name, user_id, effective_page_size + 1, offset),
)
else:
session_rows = await db.execute_fetchall(
"SELECT id, user_id, state, update_time FROM sessions WHERE"
" app_name=?",
(app_name,),
" app_name=?"
" ORDER BY update_time DESC LIMIT ? OFFSET ?",
(app_name, effective_page_size + 1, offset),
)

has_next_page = len(session_rows) > effective_page_size
if has_next_page:
session_rows = session_rows[:effective_page_size]

# Fetch app state
app_state = await self._get_app_state(db, app_name)

Expand Down Expand Up @@ -343,7 +360,15 @@ async def list_sessions(
last_update_time=row["update_time"],
)
)
return ListSessionsResponse(sessions=sessions_list)

next_page_token = (
_encode_page_token(offset + effective_page_size)
if has_next_page
else None
)
return ListSessionsResponse(
sessions=sessions_list, next_page_token=next_page_token
)

@override
async def delete_session(
Expand Down
7 changes: 6 additions & 1 deletion src/google/adk/sessions/vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,12 @@ async def get_session(

@override
async def list_sessions(
self, *, app_name: str, user_id: Optional[str] = None
self,
*,
app_name: str,
user_id: Optional[str] = None,
page_size: Optional[int] = None,
page_token: Optional[str] = None,
) -> ListSessionsResponse:
reasoning_engine_id = self._get_reasoning_engine_id(app_name)

Expand Down
Loading