diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index bb00aefe4f..c82a94ca0f 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -257,12 +257,17 @@ async def _with_session_lock( else: self._session_lock_ref_count[lock_key] = remaining - async def _prepare_tables(self): + async def prepare_tables(self): """Ensure database tables are ready for use. This method is called lazily before each database operation. It checks the DB schema version to use and creates the tables (including setting the schema version metadata) if needed. + + It can also be called eagerly right after construction to pay the + table-creation cost upfront (e.g. during application startup) instead of + on the first database operation. It is safe to call more than once and + is recommended for latency-sensitive applications. """ # Early return if tables are already created if self._tables_created: @@ -361,7 +366,7 @@ async def create_session( # 3. Add the object to the table # 4. Build the session object with generated id # 5. Return the session - await self._prepare_tables() + await self.prepare_tables() schema = self._get_schema_classes() async with self._rollback_on_exception_session() as sql_session: if session_id and await sql_session.get( @@ -436,7 +441,7 @@ async def get_session( session_id: str, config: Optional[GetSessionConfig] = None, ) -> Optional[Session]: - await self._prepare_tables() + await self.prepare_tables() # 1. Get the storage session entry from session table # 2. Get all the events based on session id and filtering config # 3. Convert and return the session @@ -494,7 +499,7 @@ async def get_session( async def list_sessions( self, *, app_name: str, user_id: Optional[str] = None ) -> ListSessionsResponse: - await self._prepare_tables() + await self.prepare_tables() schema = self._get_schema_classes() async with self._rollback_on_exception_session() as sql_session: stmt = select(schema.StorageSession).filter( @@ -544,7 +549,7 @@ async def list_sessions( async def delete_session( self, app_name: str, user_id: str, session_id: str ) -> None: - await self._prepare_tables() + await self.prepare_tables() schema = self._get_schema_classes() async with self._rollback_on_exception_session() as sql_session: stmt = delete(schema.StorageSession).where( @@ -557,7 +562,7 @@ async def delete_session( @override async def append_event(self, session: Session, event: Event) -> Event: - await self._prepare_tables() + await self.prepare_tables() if event.partial: return event diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 0a7721088e..e49f7a9357 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -1163,10 +1163,10 @@ def _spy_factory(): @pytest.mark.asyncio async def test_concurrent_prepare_tables_no_race_condition(): - """Verifies that concurrent calls to _prepare_tables wait for table creation. + """Verifies that concurrent calls to prepare_tables wait for table creation. Reproduces the race condition from https://github.com/google/adk-python/issues/4445: when concurrent requests - arrive at startup, _prepare_tables must not return before tables exist. + arrive at startup, prepare_tables must not return before tables exist. Previously, the early-return guard checked _db_schema_version (set during schema detection) instead of _tables_created, so a second request could slip through after schema detection but before table creation finished. @@ -1179,7 +1179,7 @@ async def test_concurrent_prepare_tables_no_race_condition(): # Launch several concurrent create_session calls, each with a unique # app_name to avoid IntegrityError on the shared app_states row. - # Each will call _prepare_tables internally. If the race condition + # Each will call prepare_tables internally. If the race condition # exists, some of these will fail because the "sessions" table doesn't # exist yet. num_concurrent = 5 @@ -1197,7 +1197,7 @@ async def test_concurrent_prepare_tables_no_race_condition(): for i, result in enumerate(results): assert not isinstance(result, BaseException), ( f'Concurrent create_session #{i} raised {result!r}; tables were' - ' likely not ready due to the _prepare_tables race condition.' + ' likely not ready due to the prepare_tables race condition.' ) # All sessions should be retrievable. @@ -1216,7 +1216,7 @@ async def test_concurrent_prepare_tables_no_race_condition(): async def test_prepare_tables_serializes_schema_detection_and_creation(): """Verifies schema detection and table creation happen atomically under one lock, so concurrent callers cannot observe a partially-initialized state. - After _prepare_tables completes, both _db_schema_version and _tables_created + After prepare_tables completes, both _db_schema_version and _tables_created must be set. """ service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') @@ -1224,9 +1224,9 @@ async def test_prepare_tables_serializes_schema_detection_and_creation(): assert not service._tables_created assert service._db_schema_version is None - await service._prepare_tables() + await service.prepare_tables() - # Both must be set after a single _prepare_tables call. + # Both must be set after a single prepare_tables call. assert service._tables_created assert service._db_schema_version is not None @@ -1242,17 +1242,17 @@ async def test_prepare_tables_serializes_schema_detection_and_creation(): @pytest.mark.asyncio async def test_prepare_tables_idempotent_after_creation(): - """Calling _prepare_tables multiple times is safe and idempotent. + """Calling prepare_tables multiple times is safe and idempotent. After tables are created, subsequent calls should return immediately via the fast path without errors. """ service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') try: - await service._prepare_tables() + await service.prepare_tables() assert service._tables_created # Call again — should be a no-op via the fast path. - await service._prepare_tables() + await service.prepare_tables() assert service._tables_created # Service should still work. @@ -1264,6 +1264,30 @@ async def test_prepare_tables_idempotent_after_creation(): await service.close() +@pytest.mark.asyncio +async def test_public_prepare_tables_eager_initialization(): + """Calling the public prepare_tables() eagerly initializes tables so that + the first real database operation does not pay the setup cost. + """ + async with DatabaseSessionService('sqlite+aiosqlite:///:memory:') as service: + # Before calling prepare_tables, tables are not created. + assert not service._tables_created + assert service._db_schema_version is None + + # Eagerly prepare tables via the public API. + await service.prepare_tables() + + # Tables should now be ready. + assert service._tables_created + assert service._db_schema_version is not None + + # Subsequent operations should work without any additional setup cost. + session = await service.create_session( + app_name='app', user_id='user', session_id='s1' + ) + assert session.id == 's1' + + @pytest.mark.asyncio @pytest.mark.parametrize( 'state_delta, expect_app_lock, expect_user_lock',