Skip to content

Commit f508474

Browse files
GWealecopybara-github
authored andcommitted
fix: Per-session locking and row-level locking in DatabaseSessionService.append_event
This change introduces an in-process `asyncio.Lock` per session to serialize `append_event` calls for the same session ID within a single process. For supported database dialects (MySQL, PostgreSQL, MariaDB), it also uses `SELECT ... FOR UPDATE` to acquire row-level locks on the session, app state, and user state records, preventing race conditions across different processes or database connections. A new test case verifies that concurrent updates to stale session objects correctly merge all state changes. Close #1049 Co-authored-by: George Weale <[email protected]> PiperOrigin-RevId: 867752676
1 parent 32ee07d commit f508474

File tree

2 files changed

+295
-69
lines changed

2 files changed

+295
-69
lines changed

src/google/adk/sessions/database_session_service.py

Lines changed: 189 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@
2222
from typing import Any
2323
from typing import AsyncIterator
2424
from typing import Optional
25+
from typing import TypeAlias
26+
from typing import TypeVar
2527

2628
from sqlalchemy import delete
2729
from sqlalchemy import event
2830
from sqlalchemy import select
29-
from sqlalchemy import text
3031
from sqlalchemy.engine import make_url
3132
from sqlalchemy.exc import ArgumentError
3233
from sqlalchemy.ext.asyncio import async_sessionmaker
@@ -59,6 +60,40 @@
5960

6061
logger = logging.getLogger("google_adk." + __name__)
6162

63+
_SQLITE_DIALECT = "sqlite"
64+
_MARIADB_DIALECT = "mariadb"
65+
_MYSQL_DIALECT = "mysql"
66+
_POSTGRESQL_DIALECT = "postgresql"
67+
# Tuple key order for in-process per-session lock maps:
68+
# (app_name, user_id, session_id).
69+
_SessionLockKey: TypeAlias = tuple[str, str, str]
70+
_StorageStateT = TypeVar(
71+
"_StorageStateT",
72+
StorageAppStateV0,
73+
StorageAppStateV1,
74+
StorageUserStateV0,
75+
StorageUserStateV1,
76+
)
77+
78+
79+
async def _select_required_state(
80+
*,
81+
sql_session: DatabaseSessionFactory,
82+
state_model: type[_StorageStateT],
83+
predicates: tuple[Any, ...],
84+
use_row_level_locking: bool,
85+
missing_message: str,
86+
) -> _StorageStateT:
87+
"""Returns a state row, raising if the row is missing."""
88+
stmt = select(state_model).filter(*predicates)
89+
if use_row_level_locking:
90+
stmt = stmt.with_for_update()
91+
result = await sql_session.execute(stmt)
92+
state_row = result.scalars().one_or_none()
93+
if state_row is None:
94+
raise ValueError(missing_message)
95+
return state_row
96+
6297

6398
def _set_sqlite_pragma(dbapi_connection, connection_record):
6499
cursor = dbapi_connection.cursor()
@@ -107,16 +142,19 @@ def __init__(self, db_url: str, **kwargs: Any):
107142
try:
108143
engine_kwargs = dict(kwargs)
109144
url = make_url(db_url)
110-
if url.get_backend_name() == "sqlite" and url.database == ":memory:":
145+
if (
146+
url.get_backend_name() == _SQLITE_DIALECT
147+
and url.database == ":memory:"
148+
):
111149
engine_kwargs.setdefault("poolclass", StaticPool)
112150
connect_args = dict(engine_kwargs.get("connect_args", {}))
113151
connect_args.setdefault("check_same_thread", False)
114152
engine_kwargs["connect_args"] = connect_args
115-
elif url.get_backend_name() != "sqlite":
153+
elif url.get_backend_name() != _SQLITE_DIALECT:
116154
engine_kwargs.setdefault("pool_pre_ping", True)
117155

118156
db_engine = create_async_engine(db_url, **engine_kwargs)
119-
if db_engine.dialect.name == "sqlite":
157+
if db_engine.dialect.name == _SQLITE_DIALECT:
120158
# Set sqlite pragma to enable foreign keys constraints
121159
event.listen(db_engine.sync_engine, "connect", _set_sqlite_pragma)
122160

@@ -152,6 +190,11 @@ def __init__(self, db_url: str, **kwargs: Any):
152190
# Lock to ensure thread-safe schema version check
153191
self._db_schema_lock = asyncio.Lock()
154192

193+
# Per-session locks used to serialize append_event calls in this process.
194+
self._session_locks: dict[_SessionLockKey, asyncio.Lock] = {}
195+
self._session_lock_ref_count: dict[_SessionLockKey, int] = {}
196+
self._session_locks_guard = asyncio.Lock()
197+
155198
def _get_schema_classes(self) -> _SchemaClasses:
156199
return _SchemaClasses(self._db_schema_version)
157200

@@ -172,6 +215,45 @@ async def _rollback_on_exception_session(
172215
await sql_session.rollback()
173216
raise
174217

218+
def _supports_row_level_locking(self) -> bool:
219+
return self.db_engine.dialect.name in (
220+
_MARIADB_DIALECT,
221+
_MYSQL_DIALECT,
222+
_POSTGRESQL_DIALECT,
223+
)
224+
225+
@asynccontextmanager
226+
async def _with_session_lock(
227+
self, *, app_name: str, user_id: str, session_id: str
228+
) -> AsyncIterator[None]:
229+
"""Serializes event appends for the same session within this process."""
230+
# Use one lock per logical ADK session to prevent concurrent append_event
231+
# writes from racing in the same process.
232+
lock_key = (app_name, user_id, session_id)
233+
async with self._session_locks_guard:
234+
lock = self._session_locks.get(lock_key)
235+
if lock is None:
236+
lock = asyncio.Lock()
237+
self._session_locks[lock_key] = lock
238+
# Reference counting keeps lock objects alive while they are in use by
239+
# concurrent tasks and allows cleanup once all waiters complete.
240+
self._session_lock_ref_count[lock_key] = (
241+
self._session_lock_ref_count.get(lock_key, 0) + 1
242+
)
243+
244+
try:
245+
async with lock:
246+
yield
247+
finally:
248+
async with self._session_locks_guard:
249+
remaining = self._session_lock_ref_count.get(lock_key, 0) - 1
250+
# Remove lock bookkeeping after the last waiter exits.
251+
if remaining <= 0 and not lock.locked():
252+
self._session_lock_ref_count.pop(lock_key, None)
253+
self._session_locks.pop(lock_key, None)
254+
else:
255+
self._session_lock_ref_count[lock_key] = remaining
256+
175257
async def _prepare_tables(self):
176258
"""Ensure database tables are ready for use.
177259
@@ -291,7 +373,7 @@ async def create_session(
291373

292374
# Store the session
293375
now = datetime.now(timezone.utc)
294-
is_sqlite = self.db_engine.dialect.name == "sqlite"
376+
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
295377
if is_sqlite:
296378
now = now.replace(tzinfo=None)
297379

@@ -372,7 +454,7 @@ async def get_session(
372454

373455
# Convert storage session to session
374456
events = [e.to_event() for e in reversed(storage_events)]
375-
is_sqlite = self.db_engine.dialect.name == "sqlite"
457+
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
376458
session = storage_session.to_session(
377459
state=merged_state, events=events, is_sqlite=is_sqlite
378460
)
@@ -418,7 +500,7 @@ async def list_sessions(
418500
user_states_map[storage_user_state.user_id] = storage_user_state.state
419501

420502
sessions = []
421-
is_sqlite = self.db_engine.dialect.name == "sqlite"
503+
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
422504
for storage_session in results:
423505
session_state = storage_session.state
424506
user_state = user_states_map.get(storage_session.user_id, {})
@@ -456,71 +538,109 @@ async def append_event(self, session: Session, event: Event) -> Event:
456538
# 2. Update session attributes based on event config
457539
# 3. Store event to table
458540
schema = self._get_schema_classes()
459-
async with self._rollback_on_exception_session() as sql_session:
460-
storage_session = await sql_session.get(
461-
schema.StorageSession, (session.app_name, session.user_id, session.id)
462-
)
463-
464-
# Fetch states from storage
465-
storage_app_state = await sql_session.get(
466-
schema.StorageAppState, (session.app_name)
467-
)
468-
storage_user_state = await sql_session.get(
469-
schema.StorageUserState, (session.app_name, session.user_id)
470-
)
471-
472-
is_sqlite = self.db_engine.dialect.name == "sqlite"
473-
if (
474-
storage_session.get_update_timestamp(is_sqlite)
475-
> session.last_update_time
476-
):
477-
# Reload the session from storage if it has been updated since it was
478-
# loaded.
479-
app_state = storage_app_state.state if storage_app_state else {}
480-
user_state = storage_user_state.state if storage_user_state else {}
481-
session_state = storage_session.state
482-
session.state = _merge_state(app_state, user_state, session_state)
483-
484-
stmt = (
485-
select(schema.StorageEvent)
486-
.filter(schema.StorageEvent.app_name == session.app_name)
487-
.filter(schema.StorageEvent.session_id == session.id)
488-
.filter(schema.StorageEvent.user_id == session.user_id)
489-
.order_by(schema.StorageEvent.timestamp.asc())
541+
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
542+
use_row_level_locking = self._supports_row_level_locking()
543+
async with self._with_session_lock(
544+
app_name=session.app_name,
545+
user_id=session.user_id,
546+
session_id=session.id,
547+
):
548+
async with self._rollback_on_exception_session() as sql_session:
549+
storage_session_stmt = (
550+
select(schema.StorageSession)
551+
.filter(schema.StorageSession.app_name == session.app_name)
552+
.filter(schema.StorageSession.user_id == session.user_id)
553+
.filter(schema.StorageSession.id == session.id)
490554
)
491-
result = await sql_session.stream_scalars(stmt)
492-
storage_events = [e async for e in result]
493-
session.events = [e.to_event() for e in storage_events]
494-
495-
# Extract state delta
496-
if event.actions and event.actions.state_delta:
497-
state_deltas = _session_util.extract_state_delta(
498-
event.actions.state_delta
555+
if use_row_level_locking:
556+
storage_session_stmt = storage_session_stmt.with_for_update()
557+
storage_session_result = await sql_session.execute(storage_session_stmt)
558+
storage_session = storage_session_result.scalars().one_or_none()
559+
if storage_session is None:
560+
raise ValueError(f"Session {session.id} not found.")
561+
562+
storage_app_state = await _select_required_state(
563+
sql_session=sql_session,
564+
state_model=schema.StorageAppState,
565+
predicates=(schema.StorageAppState.app_name == session.app_name,),
566+
use_row_level_locking=use_row_level_locking,
567+
missing_message=(
568+
"App state missing for app_name="
569+
f"{session.app_name!r}. Session state tables should be "
570+
"initialized by create_session."
571+
),
572+
)
573+
storage_user_state = await _select_required_state(
574+
sql_session=sql_session,
575+
state_model=schema.StorageUserState,
576+
predicates=(
577+
schema.StorageUserState.app_name == session.app_name,
578+
schema.StorageUserState.user_id == session.user_id,
579+
),
580+
use_row_level_locking=use_row_level_locking,
581+
missing_message=(
582+
"User state missing for app_name="
583+
f"{session.app_name!r}, user_id={session.user_id!r}. "
584+
"Session state tables should be initialized by "
585+
"create_session."
586+
),
499587
)
500-
app_state_delta = state_deltas["app"]
501-
user_state_delta = state_deltas["user"]
502-
session_state_delta = state_deltas["session"]
503-
# Merge state and update storage
504-
if app_state_delta:
505-
storage_app_state.state = storage_app_state.state | app_state_delta
506-
if user_state_delta:
507-
storage_user_state.state = storage_user_state.state | user_state_delta
508-
if session_state_delta:
509-
storage_session.state = storage_session.state | session_state_delta
510-
511-
if is_sqlite:
512-
update_time = datetime.fromtimestamp(
513-
event.timestamp, timezone.utc
514-
).replace(tzinfo=None)
515-
else:
516-
update_time = datetime.fromtimestamp(event.timestamp)
517-
storage_session.update_time = update_time
518-
sql_session.add(schema.StorageEvent.from_event(session, event))
519-
520-
await sql_session.commit()
521588

522-
# Update timestamp with commit time
523-
session.last_update_time = storage_session.get_update_timestamp(is_sqlite)
589+
if (
590+
storage_session.get_update_timestamp(is_sqlite)
591+
> session.last_update_time
592+
):
593+
# Reload the session from storage if it has been updated since it was
594+
# loaded.
595+
app_state = storage_app_state.state
596+
user_state = storage_user_state.state
597+
session_state = storage_session.state
598+
session.state = _merge_state(app_state, user_state, session_state)
599+
600+
stmt = (
601+
select(schema.StorageEvent)
602+
.filter(schema.StorageEvent.app_name == session.app_name)
603+
.filter(schema.StorageEvent.session_id == session.id)
604+
.filter(schema.StorageEvent.user_id == session.user_id)
605+
.order_by(schema.StorageEvent.timestamp.asc())
606+
)
607+
result = await sql_session.stream_scalars(stmt)
608+
storage_events = [e async for e in result]
609+
session.events = [e.to_event() for e in storage_events]
610+
611+
# Extract state delta
612+
if event.actions and event.actions.state_delta:
613+
state_deltas = _session_util.extract_state_delta(
614+
event.actions.state_delta
615+
)
616+
app_state_delta = state_deltas["app"]
617+
user_state_delta = state_deltas["user"]
618+
session_state_delta = state_deltas["session"]
619+
# Merge state and update storage
620+
if app_state_delta:
621+
storage_app_state.state = storage_app_state.state | app_state_delta
622+
if user_state_delta:
623+
storage_user_state.state = (
624+
storage_user_state.state | user_state_delta
625+
)
626+
if session_state_delta:
627+
storage_session.state = storage_session.state | session_state_delta
628+
629+
if is_sqlite:
630+
update_time = datetime.fromtimestamp(
631+
event.timestamp, timezone.utc
632+
).replace(tzinfo=None)
633+
else:
634+
update_time = datetime.fromtimestamp(event.timestamp)
635+
storage_session.update_time = update_time
636+
sql_session.add(schema.StorageEvent.from_event(session, event))
637+
638+
await sql_session.commit()
639+
640+
# Update timestamp with commit time
641+
session.last_update_time = storage_session.get_update_timestamp(
642+
is_sqlite
643+
)
524644

525645
# Also update the in-memory session
526646
await super().append_event(session=session, event=event)

0 commit comments

Comments
 (0)