diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index eb22a83bb9..20df92ec0a 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -15,6 +15,7 @@ from __future__ import annotations import abc +import base64 from typing import Any from typing import Optional @@ -25,6 +26,29 @@ from .session import Session from .state import State +_DEFAULT_PAGE_SIZE = 20 +_MAX_PAGE_SIZE = 100 + + +def _resolve_page_size(page_size: Optional[int]) -> int: + """Clamp *page_size* to [1, _MAX_PAGE_SIZE], defaulting to _DEFAULT_PAGE_SIZE.""" + if page_size is None: + return _DEFAULT_PAGE_SIZE + return max(1, min(page_size, _MAX_PAGE_SIZE)) + + +def _encode_page_token(offset: int) -> str: + return base64.b64encode(str(offset).encode()).decode() + + +def _decode_page_token(token: Optional[str]) -> int: + if not token: + return 0 + try: + return max(0, int(base64.b64decode(token).decode())) + except (ValueError, TypeError): + return 0 + class GetSessionConfig(BaseModel): """The configuration of getting a session.""" @@ -40,6 +64,7 @@ class ListSessionsResponse(BaseModel): """ sessions: list[Session] = Field(default_factory=list) + next_page_token: Optional[str] = None class BaseSessionService(abc.ABC): @@ -83,17 +108,27 @@ async def get_session( @abc.abstractmethod async def list_sessions( - self, *, app_name: str, user_id: Optional[str] = None + self, + *, + app_name: str, + user_id: Optional[str] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, ) -> ListSessionsResponse: - """Lists all the sessions for a user. + """Lists sessions, optionally filtered by user, with pagination. Args: app_name: The name of the app. user_id: The ID of the user. If not provided, lists all sessions for all users. + page_size: Maximum number of sessions to return per page. Defaults to 20, + maximum 100. + page_token: Token returned from a previous ``list_sessions`` call to + fetch the next page. Returns: - A ListSessionsResponse containing the sessions. + A ListSessionsResponse containing the sessions and an optional + ``next_page_token`` for fetching subsequent pages. """ @abc.abstractmethod diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index bb00aefe4f..6265f3513e 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -44,6 +44,9 @@ from .base_session_service import BaseSessionService from .base_session_service import GetSessionConfig from .base_session_service import ListSessionsResponse +from .base_session_service import _decode_page_token +from .base_session_service import _encode_page_token +from .base_session_service import _resolve_page_size from .migration import _schema_check_utils from .schemas.v0 import Base as BaseV0 from .schemas.v0 import StorageAppState as StorageAppStateV0 @@ -492,10 +495,19 @@ async def get_session( @override async def list_sessions( - self, *, app_name: str, user_id: Optional[str] = None + self, + *, + app_name: str, + user_id: Optional[str] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, ) -> ListSessionsResponse: await self._prepare_tables() schema = self._get_schema_classes() + + effective_page_size = _resolve_page_size(page_size) + offset = _decode_page_token(page_token) + async with self._rollback_on_exception_session() as sql_session: stmt = select(schema.StorageSession).filter( schema.StorageSession.app_name == app_name @@ -503,8 +515,16 @@ async def list_sessions( if user_id is not None: stmt = stmt.filter(schema.StorageSession.user_id == user_id) + stmt = stmt.order_by(schema.StorageSession.update_time.desc()) + # Fetch one extra row to determine if there is a next page. + stmt = stmt.offset(offset).limit(effective_page_size + 1) + result = await sql_session.execute(stmt) - results = result.scalars().all() + results = list(result.scalars().all()) + + has_next_page = len(results) > effective_page_size + if has_next_page: + results = results[:effective_page_size] # Fetch app state from storage storage_app_state = await sql_session.get( @@ -538,7 +558,15 @@ async def list_sessions( sessions.append( storage_session.to_session(state=merged_state, is_sqlite=is_sqlite) ) - return ListSessionsResponse(sessions=sessions) + + next_page_token = ( + _encode_page_token(offset + effective_page_size) + if has_next_page + else None + ) + return ListSessionsResponse( + sessions=sessions, next_page_token=next_page_token + ) @override async def delete_session( diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index e0f9b49ff3..e629d0898c 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -28,6 +28,9 @@ from .base_session_service import BaseSessionService from .base_session_service import GetSessionConfig from .base_session_service import ListSessionsResponse +from .base_session_service import _decode_page_token +from .base_session_service import _encode_page_token +from .base_session_service import _resolve_page_size from .session import Session from .state import State @@ -220,18 +223,43 @@ def _merge_state( @override async def list_sessions( - self, *, app_name: str, user_id: Optional[str] = None + self, + *, + app_name: str, + user_id: Optional[str] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, ) -> ListSessionsResponse: - return self._list_sessions_impl(app_name=app_name, user_id=user_id) + return self._list_sessions_impl( + app_name=app_name, + user_id=user_id, + page_size=page_size, + page_token=page_token, + ) def list_sessions_sync( - self, *, app_name: str, user_id: Optional[str] = None + self, + *, + app_name: str, + user_id: Optional[str] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, ) -> ListSessionsResponse: logger.warning('Deprecated. Please migrate to the async method.') - return self._list_sessions_impl(app_name=app_name, user_id=user_id) + return self._list_sessions_impl( + app_name=app_name, + user_id=user_id, + page_size=page_size, + page_token=page_token, + ) def _list_sessions_impl( - self, *, app_name: str, user_id: Optional[str] = None + self, + *, + app_name: str, + user_id: Optional[str] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, ) -> ListSessionsResponse: empty_response = ListSessionsResponse() if app_name not in self.sessions: @@ -239,23 +267,40 @@ def _list_sessions_impl( if user_id is not None and user_id not in self.sessions[app_name]: return empty_response - sessions_without_events = [] + all_sessions = [] if user_id is None: - for user_id in self.sessions[app_name]: - for session_id in self.sessions[app_name][user_id]: - session = self.sessions[app_name][user_id][session_id] + for uid in self.sessions[app_name]: + for session_id in self.sessions[app_name][uid]: + session = self.sessions[app_name][uid][session_id] copied_session = copy.deepcopy(session) copied_session.events = [] - copied_session = self._merge_state(app_name, user_id, copied_session) - sessions_without_events.append(copied_session) + copied_session = self._merge_state(app_name, uid, copied_session) + all_sessions.append(copied_session) else: for session in self.sessions[app_name][user_id].values(): copied_session = copy.deepcopy(session) copied_session.events = [] copied_session = self._merge_state(app_name, user_id, copied_session) - sessions_without_events.append(copied_session) - return ListSessionsResponse(sessions=sessions_without_events) + all_sessions.append(copied_session) + + # Sort by last_update_time descending (most recently updated first) + all_sessions.sort( + key=lambda s: s.last_update_time if s.last_update_time else 0, + reverse=True, + ) + + effective_page_size = _resolve_page_size(page_size) + offset = _decode_page_token(page_token) + + page = all_sessions[offset : offset + effective_page_size] + has_next_page = (offset + effective_page_size) < len(all_sessions) + next_page_token = ( + _encode_page_token(offset + effective_page_size) + if has_next_page + else None + ) + return ListSessionsResponse(sessions=page, next_page_token=next_page_token) @override async def delete_session( diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index 3ad84e9d1a..6adf62da34 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -35,6 +35,9 @@ from .base_session_service import BaseSessionService from .base_session_service import GetSessionConfig from .base_session_service import ListSessionsResponse +from .base_session_service import _decode_page_token +from .base_session_service import _encode_page_token +from .base_session_service import _resolve_page_size from .session import Session from .state import State @@ -292,24 +295,38 @@ async def get_session( @override async def list_sessions( - self, *, app_name: str, user_id: Optional[str] = None + self, + *, + app_name: str, + user_id: Optional[str] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, ) -> ListSessionsResponse: + effective_page_size = _resolve_page_size(page_size) + offset = _decode_page_token(page_token) + sessions_list = [] async with self._get_db_connection() as db: - # Fetch sessions + # Fetch sessions with ORDER BY / LIMIT / OFFSET if user_id: session_rows = await db.execute_fetchall( "SELECT id, user_id, state, update_time FROM sessions WHERE" - " app_name=? AND user_id=?", - (app_name, user_id), + " app_name=? AND user_id=?" + " ORDER BY update_time DESC LIMIT ? OFFSET ?", + (app_name, user_id, effective_page_size + 1, offset), ) else: session_rows = await db.execute_fetchall( "SELECT id, user_id, state, update_time FROM sessions WHERE" - " app_name=?", - (app_name,), + " app_name=?" + " ORDER BY update_time DESC LIMIT ? OFFSET ?", + (app_name, effective_page_size + 1, offset), ) + has_next_page = len(session_rows) > effective_page_size + if has_next_page: + session_rows = session_rows[:effective_page_size] + # Fetch app state app_state = await self._get_app_state(db, app_name) @@ -343,7 +360,15 @@ async def list_sessions( last_update_time=row["update_time"], ) ) - return ListSessionsResponse(sessions=sessions_list) + + next_page_token = ( + _encode_page_token(offset + effective_page_size) + if has_next_page + else None + ) + return ListSessionsResponse( + sessions=sessions_list, next_page_token=next_page_token + ) @override async def delete_session( diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 9e5c9bb2ec..0a83dc5057 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -216,7 +216,12 @@ async def get_session( @override async def list_sessions( - self, *, app_name: str, user_id: Optional[str] = None + self, + *, + app_name: str, + user_id: Optional[str] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, ) -> ListSessionsResponse: reasoning_engine_id = self._get_reasoning_engine_id(app_name) diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 0a7721088e..1a7b36b0f2 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -1351,3 +1351,164 @@ async def tracking_fn(**kwargs): finally: database_session_service._select_required_state = original_fn await service.close() + + +# --------------------------------------------------------------------------- +# Pagination tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_list_sessions_default_pagination(session_service): + """Without explicit page_size, the first 20 sessions are returned.""" + app_name = 'pagination_app' + user_id = 'user' + num_sessions = 25 + + for i in range(num_sessions): + await session_service.create_session( + app_name=app_name, user_id=user_id, session_id=f's{i:03d}' + ) + + response = await session_service.list_sessions( + app_name=app_name, user_id=user_id + ) + assert len(response.sessions) == 20 + assert response.next_page_token is not None + + +@pytest.mark.asyncio +async def test_list_sessions_custom_page_size(session_service): + """Explicit page_size is respected.""" + app_name = 'pagination_app2' + user_id = 'user' + + for i in range(10): + await session_service.create_session( + app_name=app_name, user_id=user_id, session_id=f's{i}' + ) + + response = await session_service.list_sessions( + app_name=app_name, user_id=user_id, page_size=3 + ) + assert len(response.sessions) == 3 + assert response.next_page_token is not None + + +@pytest.mark.asyncio +async def test_list_sessions_page_size_clamped_to_max(session_service): + """page_size > 100 is clamped to 100.""" + app_name = 'pagination_clamp' + user_id = 'user' + + for i in range(5): + await session_service.create_session( + app_name=app_name, user_id=user_id, session_id=f's{i}' + ) + + response = await session_service.list_sessions( + app_name=app_name, user_id=user_id, page_size=999 + ) + # Only 5 sessions exist, so all are returned and no next page. + assert len(response.sessions) == 5 + assert response.next_page_token is None + + +@pytest.mark.asyncio +async def test_list_sessions_iterate_all_pages(session_service): + """Iterating with page_token collects every session exactly once.""" + app_name = 'pagination_iter' + user_id = 'user' + total = 7 + page_size = 3 + + for i in range(total): + await session_service.create_session( + app_name=app_name, user_id=user_id, session_id=f's{i}' + ) + + collected_ids = [] + page_token = None + while True: + response = await session_service.list_sessions( + app_name=app_name, + user_id=user_id, + page_size=page_size, + page_token=page_token, + ) + collected_ids.extend(s.id for s in response.sessions) + if response.next_page_token is None: + break + page_token = response.next_page_token + + assert sorted(collected_ids) == sorted([f's{i}' for i in range(total)]) + + +@pytest.mark.asyncio +async def test_list_sessions_no_next_token_when_exact_fit(session_service): + """When total == page_size, next_page_token should be None.""" + app_name = 'pagination_exact' + user_id = 'user' + + for i in range(5): + await session_service.create_session( + app_name=app_name, user_id=user_id, session_id=f's{i}' + ) + + response = await session_service.list_sessions( + app_name=app_name, user_id=user_id, page_size=5 + ) + assert len(response.sessions) == 5 + assert response.next_page_token is None + + +@pytest.mark.asyncio +async def test_list_sessions_empty_result(session_service): + """Listing sessions for a non-existent app returns empty with no token.""" + response = await session_service.list_sessions( + app_name='nonexistent_app', user_id='nobody' + ) + assert len(response.sessions) == 0 + assert response.next_page_token is None + + +@pytest.mark.asyncio +async def test_list_sessions_backward_compatible_no_args(session_service): + """Calling list_sessions without pagination args still works (backward compat).""" + app_name = 'compat_app' + user_id = 'user' + + for i in range(3): + await session_service.create_session( + app_name=app_name, user_id=user_id, session_id=f's{i}' + ) + + response = await session_service.list_sessions( + app_name=app_name, user_id=user_id + ) + assert len(response.sessions) == 3 + assert response.next_page_token is None + + +@pytest.mark.asyncio +async def test_list_sessions_ordered_by_update_time_desc(session_service): + """Sessions are returned most-recently-updated first.""" + app_name = 'order_app' + user_id = 'user' + + s0 = await session_service.create_session( + app_name=app_name, user_id=user_id, session_id='s0' + ) + s1 = await session_service.create_session( + app_name=app_name, user_id=user_id, session_id='s1' + ) + s2 = await session_service.create_session( + app_name=app_name, user_id=user_id, session_id='s2' + ) + + response = await session_service.list_sessions( + app_name=app_name, user_id=user_id, page_size=100 + ) + ids = [s.id for s in response.sessions] + # Most recently created (and thus updated) should come first. + assert ids == ['s2', 's1', 's0']