Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions apps/worker/celery_task_router.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import shared.celery_config as shared_celery_config
from database.engine import get_db_session
from database.engine import create_task_session
from database.models.core import Commit, CompareCommit, Owner, Repository
from shared.celery_router import route_tasks_based_on_user_plan
from shared.plan.constants import DEFAULT_FREE_PLAN
Expand Down Expand Up @@ -137,14 +137,24 @@ def _get_user_plan_from_task(dbsession, task_name: str, task_kwargs: dict) -> st
def route_task(name, args, kwargs, options, task=None, **kw):
"""Function to dynamically route tasks to the proper queue.
Docs: https://docs.celeryq.dev/en/stable/userguide/routing.html#routers

Note: Routing happens BEFORE task execution, so we use a temporary session
for routing lookups to avoid transaction contamination.
"""

user_plan = options.get("user_plan")
ownerid = options.get("ownerid")
if user_plan is None or ownerid is None:
db_session = get_db_session()
if user_plan is None:
user_plan = _get_user_plan_from_task(db_session, name, kwargs)
if ownerid is None:
ownerid = _get_ownerid_from_task(db_session, name, kwargs)
# Use temporary session for routing (read-only, short-lived)
routing_session = create_task_session()
try:
if user_plan is None:
user_plan = _get_user_plan_from_task(routing_session, name, kwargs)
if ownerid is None:
ownerid = _get_ownerid_from_task(routing_session, name, kwargs)
finally:
# Always cleanup routing session (read-only, so rollback is safe)
if routing_session.in_transaction():
routing_session.rollback()
routing_session.close()
return route_tasks_based_on_user_plan(name, user_plan, ownerid)
71 changes: 71 additions & 0 deletions apps/worker/database/engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
import json
from collections.abc import Callable
from decimal import Decimal

from sqlalchemy import create_engine
Expand Down Expand Up @@ -96,3 +97,73 @@ def _fix_engine(database_url: str) -> str:
session = session_factory.create_session()

get_db_session = session


class TaskSessionManager:
"""Manages per-task database sessions with support for test overrides."""

def __init__(self):
self._session_maker = None
self._test_session_factory: Callable[[], Session] | None = None

def set_test_session_factory(self, factory: Callable[[], Session] | None):
"""Override create_task_session() to return a shared test session."""
self._test_session_factory = factory

def _get_session_maker(self):
if self._session_maker is None:
if session_factory.main_engine is None:
_ = session_factory.create_session()

main_engine = session_factory.main_engine
timeseries_engine = session_factory.timeseries_engine
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this also for TA timeseries? We have 3 engines don't we?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per line 80 we only have two. 🤔


if main_engine is None:
raise RuntimeError(
"Cannot create task session: database engine not initialized"
)

if is_timeseries_enabled() and timeseries_engine is not None:
main_engine_ref = main_engine
timeseries_engine_ref = timeseries_engine

class RoutingSession(Session):
def get_bind(self, mapper=None, clause=None, **kwargs):
if mapper is not None and issubclass(
mapper.class_, TimeseriesBaseModel
):
return timeseries_engine_ref
if (
clause is not None
and hasattr(clause, "table")
and clause.table.name.startswith("timeseries_")
):
return timeseries_engine_ref
return main_engine_ref

self._session_maker = sessionmaker(class_=RoutingSession)
else:
self._session_maker = sessionmaker(bind=main_engine)

return self._session_maker

def create_task_session(self):
"""Create a new isolated session for a task. Caller must clean up (rollback/close)."""
if self._test_session_factory is not None:
return self._test_session_factory()

session_maker = self._get_session_maker()
return session_maker()


_task_session_manager = TaskSessionManager()


def set_test_session_factory(factory: Callable[[], Session] | None):
"""Override create_task_session() to return a shared test session."""
_task_session_manager.set_test_session_factory(factory)


def create_task_session():
"""Create a new isolated session for a task. Caller must clean up (rollback/close)."""
return _task_session_manager.create_task_session()
Loading
Loading