2222from typing import Any
2323from typing import AsyncIterator
2424from typing import Optional
25+ from typing import TypeAlias
26+ from typing import TypeVar
2527
2628from sqlalchemy import delete
2729from sqlalchemy import event
2830from sqlalchemy import select
29- from sqlalchemy import text
3031from sqlalchemy .engine import make_url
3132from sqlalchemy .exc import ArgumentError
3233from sqlalchemy .ext .asyncio import async_sessionmaker
5960
6061logger = logging .getLogger ("google_adk." + __name__ )
6162
63+ _SQLITE_DIALECT = "sqlite"
64+ _MARIADB_DIALECT = "mariadb"
65+ _MYSQL_DIALECT = "mysql"
66+ _POSTGRESQL_DIALECT = "postgresql"
67+ # Tuple key order for in-process per-session lock maps:
68+ # (app_name, user_id, session_id).
69+ _SessionLockKey : TypeAlias = tuple [str , str , str ]
70+ _StorageStateT = TypeVar (
71+ "_StorageStateT" ,
72+ StorageAppStateV0 ,
73+ StorageAppStateV1 ,
74+ StorageUserStateV0 ,
75+ StorageUserStateV1 ,
76+ )
77+
78+
79+ async def _select_required_state (
80+ * ,
81+ sql_session : DatabaseSessionFactory ,
82+ state_model : type [_StorageStateT ],
83+ predicates : tuple [Any , ...],
84+ use_row_level_locking : bool ,
85+ missing_message : str ,
86+ ) -> _StorageStateT :
87+ """Returns a state row, raising if the row is missing."""
88+ stmt = select (state_model ).filter (* predicates )
89+ if use_row_level_locking :
90+ stmt = stmt .with_for_update ()
91+ result = await sql_session .execute (stmt )
92+ state_row = result .scalars ().one_or_none ()
93+ if state_row is None :
94+ raise ValueError (missing_message )
95+ return state_row
96+
6297
6398def _set_sqlite_pragma (dbapi_connection , connection_record ):
6499 cursor = dbapi_connection .cursor ()
@@ -107,16 +142,19 @@ def __init__(self, db_url: str, **kwargs: Any):
107142 try :
108143 engine_kwargs = dict (kwargs )
109144 url = make_url (db_url )
110- if url .get_backend_name () == "sqlite" and url .database == ":memory:" :
145+ if (
146+ url .get_backend_name () == _SQLITE_DIALECT
147+ and url .database == ":memory:"
148+ ):
111149 engine_kwargs .setdefault ("poolclass" , StaticPool )
112150 connect_args = dict (engine_kwargs .get ("connect_args" , {}))
113151 connect_args .setdefault ("check_same_thread" , False )
114152 engine_kwargs ["connect_args" ] = connect_args
115- elif url .get_backend_name () != "sqlite" :
153+ elif url .get_backend_name () != _SQLITE_DIALECT :
116154 engine_kwargs .setdefault ("pool_pre_ping" , True )
117155
118156 db_engine = create_async_engine (db_url , ** engine_kwargs )
119- if db_engine .dialect .name == "sqlite" :
157+ if db_engine .dialect .name == _SQLITE_DIALECT :
120158 # Set sqlite pragma to enable foreign keys constraints
121159 event .listen (db_engine .sync_engine , "connect" , _set_sqlite_pragma )
122160
@@ -152,6 +190,11 @@ def __init__(self, db_url: str, **kwargs: Any):
152190 # Lock to ensure thread-safe schema version check
153191 self ._db_schema_lock = asyncio .Lock ()
154192
193+ # Per-session locks used to serialize append_event calls in this process.
194+ self ._session_locks : dict [_SessionLockKey , asyncio .Lock ] = {}
195+ self ._session_lock_ref_count : dict [_SessionLockKey , int ] = {}
196+ self ._session_locks_guard = asyncio .Lock ()
197+
155198 def _get_schema_classes (self ) -> _SchemaClasses :
156199 return _SchemaClasses (self ._db_schema_version )
157200
@@ -172,6 +215,45 @@ async def _rollback_on_exception_session(
172215 await sql_session .rollback ()
173216 raise
174217
218+ def _supports_row_level_locking (self ) -> bool :
219+ return self .db_engine .dialect .name in (
220+ _MARIADB_DIALECT ,
221+ _MYSQL_DIALECT ,
222+ _POSTGRESQL_DIALECT ,
223+ )
224+
225+ @asynccontextmanager
226+ async def _with_session_lock (
227+ self , * , app_name : str , user_id : str , session_id : str
228+ ) -> AsyncIterator [None ]:
229+ """Serializes event appends for the same session within this process."""
230+ # Use one lock per logical ADK session to prevent concurrent append_event
231+ # writes from racing in the same process.
232+ lock_key = (app_name , user_id , session_id )
233+ async with self ._session_locks_guard :
234+ lock = self ._session_locks .get (lock_key )
235+ if lock is None :
236+ lock = asyncio .Lock ()
237+ self ._session_locks [lock_key ] = lock
238+ # Reference counting keeps lock objects alive while they are in use by
239+ # concurrent tasks and allows cleanup once all waiters complete.
240+ self ._session_lock_ref_count [lock_key ] = (
241+ self ._session_lock_ref_count .get (lock_key , 0 ) + 1
242+ )
243+
244+ try :
245+ async with lock :
246+ yield
247+ finally :
248+ async with self ._session_locks_guard :
249+ remaining = self ._session_lock_ref_count .get (lock_key , 0 ) - 1
250+ # Remove lock bookkeeping after the last waiter exits.
251+ if remaining <= 0 and not lock .locked ():
252+ self ._session_lock_ref_count .pop (lock_key , None )
253+ self ._session_locks .pop (lock_key , None )
254+ else :
255+ self ._session_lock_ref_count [lock_key ] = remaining
256+
175257 async def _prepare_tables (self ):
176258 """Ensure database tables are ready for use.
177259
@@ -291,7 +373,7 @@ async def create_session(
291373
292374 # Store the session
293375 now = datetime .now (timezone .utc )
294- is_sqlite = self .db_engine .dialect .name == "sqlite"
376+ is_sqlite = self .db_engine .dialect .name == _SQLITE_DIALECT
295377 if is_sqlite :
296378 now = now .replace (tzinfo = None )
297379
@@ -372,7 +454,7 @@ async def get_session(
372454
373455 # Convert storage session to session
374456 events = [e .to_event () for e in reversed (storage_events )]
375- is_sqlite = self .db_engine .dialect .name == "sqlite"
457+ is_sqlite = self .db_engine .dialect .name == _SQLITE_DIALECT
376458 session = storage_session .to_session (
377459 state = merged_state , events = events , is_sqlite = is_sqlite
378460 )
@@ -418,7 +500,7 @@ async def list_sessions(
418500 user_states_map [storage_user_state .user_id ] = storage_user_state .state
419501
420502 sessions = []
421- is_sqlite = self .db_engine .dialect .name == "sqlite"
503+ is_sqlite = self .db_engine .dialect .name == _SQLITE_DIALECT
422504 for storage_session in results :
423505 session_state = storage_session .state
424506 user_state = user_states_map .get (storage_session .user_id , {})
@@ -456,71 +538,109 @@ async def append_event(self, session: Session, event: Event) -> Event:
456538 # 2. Update session attributes based on event config
457539 # 3. Store event to table
458540 schema = self ._get_schema_classes ()
459- async with self ._rollback_on_exception_session () as sql_session :
460- storage_session = await sql_session .get (
461- schema .StorageSession , (session .app_name , session .user_id , session .id )
462- )
463-
464- # Fetch states from storage
465- storage_app_state = await sql_session .get (
466- schema .StorageAppState , (session .app_name )
467- )
468- storage_user_state = await sql_session .get (
469- schema .StorageUserState , (session .app_name , session .user_id )
470- )
471-
472- is_sqlite = self .db_engine .dialect .name == "sqlite"
473- if (
474- storage_session .get_update_timestamp (is_sqlite )
475- > session .last_update_time
476- ):
477- # Reload the session from storage if it has been updated since it was
478- # loaded.
479- app_state = storage_app_state .state if storage_app_state else {}
480- user_state = storage_user_state .state if storage_user_state else {}
481- session_state = storage_session .state
482- session .state = _merge_state (app_state , user_state , session_state )
483-
484- stmt = (
485- select (schema .StorageEvent )
486- .filter (schema .StorageEvent .app_name == session .app_name )
487- .filter (schema .StorageEvent .session_id == session .id )
488- .filter (schema .StorageEvent .user_id == session .user_id )
489- .order_by (schema .StorageEvent .timestamp .asc ())
541+ is_sqlite = self .db_engine .dialect .name == _SQLITE_DIALECT
542+ use_row_level_locking = self ._supports_row_level_locking ()
543+ async with self ._with_session_lock (
544+ app_name = session .app_name ,
545+ user_id = session .user_id ,
546+ session_id = session .id ,
547+ ):
548+ async with self ._rollback_on_exception_session () as sql_session :
549+ storage_session_stmt = (
550+ select (schema .StorageSession )
551+ .filter (schema .StorageSession .app_name == session .app_name )
552+ .filter (schema .StorageSession .user_id == session .user_id )
553+ .filter (schema .StorageSession .id == session .id )
490554 )
491- result = await sql_session .stream_scalars (stmt )
492- storage_events = [e async for e in result ]
493- session .events = [e .to_event () for e in storage_events ]
494-
495- # Extract state delta
496- if event .actions and event .actions .state_delta :
497- state_deltas = _session_util .extract_state_delta (
498- event .actions .state_delta
555+ if use_row_level_locking :
556+ storage_session_stmt = storage_session_stmt .with_for_update ()
557+ storage_session_result = await sql_session .execute (storage_session_stmt )
558+ storage_session = storage_session_result .scalars ().one_or_none ()
559+ if storage_session is None :
560+ raise ValueError (f"Session { session .id } not found." )
561+
562+ storage_app_state = await _select_required_state (
563+ sql_session = sql_session ,
564+ state_model = schema .StorageAppState ,
565+ predicates = (schema .StorageAppState .app_name == session .app_name ,),
566+ use_row_level_locking = use_row_level_locking ,
567+ missing_message = (
568+ "App state missing for app_name="
569+ f"{ session .app_name !r} . Session state tables should be "
570+ "initialized by create_session."
571+ ),
572+ )
573+ storage_user_state = await _select_required_state (
574+ sql_session = sql_session ,
575+ state_model = schema .StorageUserState ,
576+ predicates = (
577+ schema .StorageUserState .app_name == session .app_name ,
578+ schema .StorageUserState .user_id == session .user_id ,
579+ ),
580+ use_row_level_locking = use_row_level_locking ,
581+ missing_message = (
582+ "User state missing for app_name="
583+ f"{ session .app_name !r} , user_id={ session .user_id !r} . "
584+ "Session state tables should be initialized by "
585+ "create_session."
586+ ),
499587 )
500- app_state_delta = state_deltas ["app" ]
501- user_state_delta = state_deltas ["user" ]
502- session_state_delta = state_deltas ["session" ]
503- # Merge state and update storage
504- if app_state_delta :
505- storage_app_state .state = storage_app_state .state | app_state_delta
506- if user_state_delta :
507- storage_user_state .state = storage_user_state .state | user_state_delta
508- if session_state_delta :
509- storage_session .state = storage_session .state | session_state_delta
510-
511- if is_sqlite :
512- update_time = datetime .fromtimestamp (
513- event .timestamp , timezone .utc
514- ).replace (tzinfo = None )
515- else :
516- update_time = datetime .fromtimestamp (event .timestamp )
517- storage_session .update_time = update_time
518- sql_session .add (schema .StorageEvent .from_event (session , event ))
519-
520- await sql_session .commit ()
521588
522- # Update timestamp with commit time
523- session .last_update_time = storage_session .get_update_timestamp (is_sqlite )
589+ if (
590+ storage_session .get_update_timestamp (is_sqlite )
591+ > session .last_update_time
592+ ):
593+ # Reload the session from storage if it has been updated since it was
594+ # loaded.
595+ app_state = storage_app_state .state
596+ user_state = storage_user_state .state
597+ session_state = storage_session .state
598+ session .state = _merge_state (app_state , user_state , session_state )
599+
600+ stmt = (
601+ select (schema .StorageEvent )
602+ .filter (schema .StorageEvent .app_name == session .app_name )
603+ .filter (schema .StorageEvent .session_id == session .id )
604+ .filter (schema .StorageEvent .user_id == session .user_id )
605+ .order_by (schema .StorageEvent .timestamp .asc ())
606+ )
607+ result = await sql_session .stream_scalars (stmt )
608+ storage_events = [e async for e in result ]
609+ session .events = [e .to_event () for e in storage_events ]
610+
611+ # Extract state delta
612+ if event .actions and event .actions .state_delta :
613+ state_deltas = _session_util .extract_state_delta (
614+ event .actions .state_delta
615+ )
616+ app_state_delta = state_deltas ["app" ]
617+ user_state_delta = state_deltas ["user" ]
618+ session_state_delta = state_deltas ["session" ]
619+ # Merge state and update storage
620+ if app_state_delta :
621+ storage_app_state .state = storage_app_state .state | app_state_delta
622+ if user_state_delta :
623+ storage_user_state .state = (
624+ storage_user_state .state | user_state_delta
625+ )
626+ if session_state_delta :
627+ storage_session .state = storage_session .state | session_state_delta
628+
629+ if is_sqlite :
630+ update_time = datetime .fromtimestamp (
631+ event .timestamp , timezone .utc
632+ ).replace (tzinfo = None )
633+ else :
634+ update_time = datetime .fromtimestamp (event .timestamp )
635+ storage_session .update_time = update_time
636+ sql_session .add (schema .StorageEvent .from_event (session , event ))
637+
638+ await sql_session .commit ()
639+
640+ # Update timestamp with commit time
641+ session .last_update_time = storage_session .get_update_timestamp (
642+ is_sqlite
643+ )
524644
525645 # Also update the in-memory session
526646 await super ().append_event (session = session , event = event )
0 commit comments