Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,17 @@ async def _with_session_lock(
else:
self._session_lock_ref_count[lock_key] = remaining

async def _prepare_tables(self):
async def prepare_tables(self):
"""Ensure database tables are ready for use.

This method is called lazily before each database operation. It checks the
DB schema version to use and creates the tables (including setting the
schema version metadata) if needed.

It can also be called eagerly right after construction to pay the
table-creation cost upfront (e.g. during application startup) instead of
on the first database operation. It is safe to call more than once and
is recommended for latency-sensitive applications.
"""
# Early return if tables are already created
if self._tables_created:
Expand Down Expand Up @@ -361,7 +366,7 @@ async def create_session(
# 3. Add the object to the table
# 4. Build the session object with generated id
# 5. Return the session
await self._prepare_tables()
await self.prepare_tables()
schema = self._get_schema_classes()
async with self._rollback_on_exception_session() as sql_session:
if session_id and await sql_session.get(
Expand Down Expand Up @@ -436,7 +441,7 @@ async def get_session(
session_id: str,
config: Optional[GetSessionConfig] = None,
) -> Optional[Session]:
await self._prepare_tables()
await self.prepare_tables()
# 1. Get the storage session entry from session table
# 2. Get all the events based on session id and filtering config
# 3. Convert and return the session
Expand Down Expand Up @@ -494,7 +499,7 @@ async def get_session(
async def list_sessions(
self, *, app_name: str, user_id: Optional[str] = None
) -> ListSessionsResponse:
await self._prepare_tables()
await self.prepare_tables()
schema = self._get_schema_classes()
async with self._rollback_on_exception_session() as sql_session:
stmt = select(schema.StorageSession).filter(
Expand Down Expand Up @@ -544,7 +549,7 @@ async def list_sessions(
async def delete_session(
self, app_name: str, user_id: str, session_id: str
) -> None:
await self._prepare_tables()
await self.prepare_tables()
schema = self._get_schema_classes()
async with self._rollback_on_exception_session() as sql_session:
stmt = delete(schema.StorageSession).where(
Expand All @@ -557,7 +562,7 @@ async def delete_session(

@override
async def append_event(self, session: Session, event: Event) -> Event:
await self._prepare_tables()
await self.prepare_tables()
if event.partial:
return event

Expand Down
44 changes: 34 additions & 10 deletions tests/unittests/sessions/test_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,10 +1163,10 @@ def _spy_factory():

@pytest.mark.asyncio
async def test_concurrent_prepare_tables_no_race_condition():
"""Verifies that concurrent calls to _prepare_tables wait for table creation.
"""Verifies that concurrent calls to prepare_tables wait for table creation.
Reproduces the race condition from
https://github.com/google/adk-python/issues/4445: when concurrent requests
arrive at startup, _prepare_tables must not return before tables exist.
arrive at startup, prepare_tables must not return before tables exist.
Previously, the early-return guard checked _db_schema_version (set during
schema detection) instead of _tables_created, so a second request could
slip through after schema detection but before table creation finished.
Expand All @@ -1179,7 +1179,7 @@ async def test_concurrent_prepare_tables_no_race_condition():

# Launch several concurrent create_session calls, each with a unique
# app_name to avoid IntegrityError on the shared app_states row.
# Each will call _prepare_tables internally. If the race condition
# Each will call prepare_tables internally. If the race condition
# exists, some of these will fail because the "sessions" table doesn't
# exist yet.
num_concurrent = 5
Expand All @@ -1197,7 +1197,7 @@ async def test_concurrent_prepare_tables_no_race_condition():
for i, result in enumerate(results):
assert not isinstance(result, BaseException), (
f'Concurrent create_session #{i} raised {result!r}; tables were'
' likely not ready due to the _prepare_tables race condition.'
' likely not ready due to the prepare_tables race condition.'
)

# All sessions should be retrievable.
Expand All @@ -1216,17 +1216,17 @@ async def test_concurrent_prepare_tables_no_race_condition():
async def test_prepare_tables_serializes_schema_detection_and_creation():
"""Verifies schema detection and table creation happen atomically under one
lock, so concurrent callers cannot observe a partially-initialized state.
After _prepare_tables completes, both _db_schema_version and _tables_created
After prepare_tables completes, both _db_schema_version and _tables_created
must be set.
"""
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
try:
assert not service._tables_created
assert service._db_schema_version is None

await service._prepare_tables()
await service.prepare_tables()

# Both must be set after a single _prepare_tables call.
# Both must be set after a single prepare_tables call.
assert service._tables_created
assert service._db_schema_version is not None

Expand All @@ -1242,17 +1242,17 @@ async def test_prepare_tables_serializes_schema_detection_and_creation():

@pytest.mark.asyncio
async def test_prepare_tables_idempotent_after_creation():
"""Calling _prepare_tables multiple times is safe and idempotent.
"""Calling prepare_tables multiple times is safe and idempotent.
After tables are created, subsequent calls should return immediately via
the fast path without errors.
"""
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
try:
await service._prepare_tables()
await service.prepare_tables()
assert service._tables_created

# Call again — should be a no-op via the fast path.
await service._prepare_tables()
await service.prepare_tables()
assert service._tables_created

# Service should still work.
Expand All @@ -1264,6 +1264,30 @@ async def test_prepare_tables_idempotent_after_creation():
await service.close()


@pytest.mark.asyncio
async def test_public_prepare_tables_eager_initialization():
"""Calling the public prepare_tables() eagerly initializes tables so that
the first real database operation does not pay the setup cost.
"""
async with DatabaseSessionService('sqlite+aiosqlite:///:memory:') as service:
# Before calling prepare_tables, tables are not created.
assert not service._tables_created
assert service._db_schema_version is None

# Eagerly prepare tables via the public API.
await service.prepare_tables()

# Tables should now be ready.
assert service._tables_created
assert service._db_schema_version is not None

# Subsequent operations should work without any additional setup cost.
session = await service.create_session(
app_name='app', user_id='user', session_id='s1'
)
assert session.id == 's1'


@pytest.mark.asyncio
@pytest.mark.parametrize(
'state_delta, expect_app_lock, expect_user_lock',
Expand Down
Loading