diff --git a/apps/worker/celery_task_router.py b/apps/worker/celery_task_router.py index 1c99f579bc..106dbfa59d 100644 --- a/apps/worker/celery_task_router.py +++ b/apps/worker/celery_task_router.py @@ -1,5 +1,5 @@ import shared.celery_config as shared_celery_config -from database.engine import get_db_session +from database.engine import create_task_session from database.models.core import Commit, CompareCommit, Owner, Repository from shared.celery_router import route_tasks_based_on_user_plan from shared.plan.constants import DEFAULT_FREE_PLAN @@ -137,14 +137,24 @@ def _get_user_plan_from_task(dbsession, task_name: str, task_kwargs: dict) -> st def route_task(name, args, kwargs, options, task=None, **kw): """Function to dynamically route tasks to the proper queue. Docs: https://docs.celeryq.dev/en/stable/userguide/routing.html#routers + + Note: Routing happens BEFORE task execution, so we use a temporary session + for routing lookups to avoid transaction contamination. """ user_plan = options.get("user_plan") ownerid = options.get("ownerid") if user_plan is None or ownerid is None: - db_session = get_db_session() - if user_plan is None: - user_plan = _get_user_plan_from_task(db_session, name, kwargs) - if ownerid is None: - ownerid = _get_ownerid_from_task(db_session, name, kwargs) + # Use temporary session for routing (read-only, short-lived) + routing_session = create_task_session() + try: + if user_plan is None: + user_plan = _get_user_plan_from_task(routing_session, name, kwargs) + if ownerid is None: + ownerid = _get_ownerid_from_task(routing_session, name, kwargs) + finally: + # Always cleanup routing session (read-only, so rollback is safe) + if routing_session.in_transaction(): + routing_session.rollback() + routing_session.close() return route_tasks_based_on_user_plan(name, user_plan, ownerid) diff --git a/apps/worker/database/engine.py b/apps/worker/database/engine.py index 3091da25e2..e1687f4f9e 100644 --- a/apps/worker/database/engine.py +++ b/apps/worker/database/engine.py @@ -1,5 +1,6 @@ import dataclasses import json +from collections.abc import Callable from decimal import Decimal from sqlalchemy import create_engine @@ -96,3 +97,73 @@ def _fix_engine(database_url: str) -> str: session = session_factory.create_session() get_db_session = session + + +class TaskSessionManager: + """Manages per-task database sessions with support for test overrides.""" + + def __init__(self): + self._session_maker = None + self._test_session_factory: Callable[[], Session] | None = None + + def set_test_session_factory(self, factory: Callable[[], Session] | None): + """Override create_task_session() to return a shared test session.""" + self._test_session_factory = factory + + def _get_session_maker(self): + if self._session_maker is None: + if session_factory.main_engine is None: + _ = session_factory.create_session() + + main_engine = session_factory.main_engine + timeseries_engine = session_factory.timeseries_engine + + if main_engine is None: + raise RuntimeError( + "Cannot create task session: database engine not initialized" + ) + + if is_timeseries_enabled() and timeseries_engine is not None: + main_engine_ref = main_engine + timeseries_engine_ref = timeseries_engine + + class RoutingSession(Session): + def get_bind(self, mapper=None, clause=None, **kwargs): + if mapper is not None and issubclass( + mapper.class_, TimeseriesBaseModel + ): + return timeseries_engine_ref + if ( + clause is not None + and hasattr(clause, "table") + and clause.table.name.startswith("timeseries_") + ): + return timeseries_engine_ref + return main_engine_ref + + self._session_maker = sessionmaker(class_=RoutingSession) + else: + self._session_maker = sessionmaker(bind=main_engine) + + return self._session_maker + + def create_task_session(self): + """Create a new isolated session for a task. Caller must clean up (rollback/close).""" + if self._test_session_factory is not None: + return self._test_session_factory() + + session_maker = self._get_session_maker() + return session_maker() + + +_task_session_manager = TaskSessionManager() + + +def set_test_session_factory(factory: Callable[[], Session] | None): + """Override create_task_session() to return a shared test session.""" + _task_session_manager.set_test_session_factory(factory) + + +def create_task_session(): + """Create a new isolated session for a task. Caller must clean up (rollback/close).""" + return _task_session_manager.create_task_session() diff --git a/apps/worker/database/tests/unit/test_engine.py b/apps/worker/database/tests/unit/test_engine.py index 50760a1b40..a8caee5c2f 100644 --- a/apps/worker/database/tests/unit/test_engine.py +++ b/apps/worker/database/tests/unit/test_engine.py @@ -1,12 +1,193 @@ +import dataclasses +from decimal import Decimal +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy import create_engine from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.orm import Session from sqlalchemy_utils import get_mapper -from database.engine import SessionFactory +from database.base import Base +from database.engine import ( + DatabaseEncoder, + SessionFactory, + TaskSessionManager, + _fix_engine, + create_all, + create_task_session, + json_dumps, + set_test_session_factory, +) from database.models import Commit from database.models.timeseries import Measurement -class TestDatabaseEngine: +class TestFixEngine: + def test_fix_engine_converts_postgres_to_postgresql(self): + """Test that postgres:// URLs are converted to postgresql://""" + assert ( + _fix_engine("postgres://user:pass@host:5432/db") + == "postgresql://user:pass@host:5432/db" + ) + + def test_fix_engine_leaves_postgresql_unchanged(self): + """Test that postgresql:// URLs are left unchanged""" + url = "postgresql://user:pass@host:5432/db" + assert _fix_engine(url) == url + + def test_fix_engine_handles_multiple_occurrences(self): + """Test that all occurrences are replaced (str.replace replaces all)""" + # This shouldn't happen in practice, but test the behavior + # Note: str.replace() replaces ALL occurrences, not just the first + url = "postgres://user:pass@host:5432/postgres://db" + result = _fix_engine(url) + assert result == "postgresql://user:pass@host:5432/postgresql://db" + assert "postgres://" not in result + + +class TestDatabaseEncoder: + def test_database_encoder_dataclass(self): + """Test that dataclasses are serialized using astuple""" + + @dataclasses.dataclass + class TestDataClass: + x: int + y: str + + obj = TestDataClass(x=1, y="test") + encoder = DatabaseEncoder() + result = encoder.default(obj) + assert result == (1, "test") + + def test_database_encoder_decimal(self): + """Test that Decimal objects are converted to strings""" + encoder = DatabaseEncoder() + result = encoder.default(Decimal("123.45")) + assert result == "123.45" + + def test_database_encoder_fallback_to_parent(self): + """Test that other types fall back to parent class""" + encoder = DatabaseEncoder() + # Should raise TypeError for unsupported types (parent behavior) + with pytest.raises(TypeError): + encoder.default(object()) + + +class TestJsonDumps: + def test_json_dumps_with_dataclass(self): + """Test json_dumps with a dataclass""" + + @dataclasses.dataclass + class TestDataClass: + x: int + y: str + + obj = TestDataClass(x=1, y="test") + data = {"key": obj} + result = json_dumps(data) + assert '"key": [1, "test"]' in result + + def test_json_dumps_with_decimal(self): + """Test json_dumps with a Decimal""" + data = {"value": Decimal("123.45")} + result = json_dumps(data) + assert '"value": "123.45"' in result + + def test_json_dumps_with_regular_types(self): + """Test json_dumps with regular JSON-serializable types""" + data = {"string": "test", "number": 42, "bool": True} + result = json_dumps(data) + assert '"string": "test"' in result + assert '"number": 42' in result + assert '"bool": true' in result + + +class TestCreateAll: + def test_create_all_calls_base_metadata_create_all(self, mocker): + """Test that create_all calls Base.metadata.create_all""" + mock_engine = Mock() + mock_create_all = mocker.patch.object(Base.metadata, "create_all") + + create_all(mock_engine) + + mock_create_all.assert_called_once_with(mock_engine) + + +class TestSessionFactory: + def test_session_factory_init_calls_fix_engine( + self, mocker, sqlalchemy_connect_url + ): + """Test that __init__ calls _fix_engine on both URLs""" + mock_fix_engine = mocker.patch( + "database.engine._fix_engine", side_effect=lambda x: x + ) + + SessionFactory( + database_url=sqlalchemy_connect_url, + timeseries_database_url=sqlalchemy_connect_url, + ) + + assert mock_fix_engine.call_count == 2 + + def test_session_factory_init_with_none_timeseries_url( + self, sqlalchemy_connect_url + ): + """Test SessionFactory initialization with None timeseries_database_url raises AttributeError""" + # _fix_engine(None) will raise AttributeError since None doesn't have replace method + with pytest.raises(AttributeError): + SessionFactory( + database_url=sqlalchemy_connect_url, + timeseries_database_url=None, + ) + + def test_session_factory_create_session_sets_json_serializer( + self, sqlalchemy_connect_url, mocker + ): + """Test that create_session sets json_serializer correctly""" + mock_create_engine = mocker.patch("database.engine.create_engine") + mocker.patch("database.engine.is_timeseries_enabled", return_value=False) + mocker.patch("database.engine.scoped_session") + # Mock _fix_engine to handle None (though it won't be used when timeseries is disabled) + mocker.patch( + "database.engine._fix_engine", + side_effect=lambda x: x if x is not None else None, + ) + + factory = SessionFactory( + database_url=sqlalchemy_connect_url, + timeseries_database_url=sqlalchemy_connect_url, # Use valid URL instead of None + ) + factory.create_session() + + # Check that create_engine was called with json_serializer + call_args = mock_create_engine.call_args + assert "json_serializer" in call_args.kwargs + assert call_args.kwargs["json_serializer"] == json_dumps + assert call_args.kwargs["pool_pre_ping"] is True + + def test_session_factory_create_session_timeseries_url_none( + self, sqlalchemy_connect_url, mocker + ): + """Test create_session when timeseries_database_url is None and timeseries is enabled""" + mocker.patch("database.engine.is_timeseries_enabled", return_value=True) + # Mock _fix_engine to return None for None input (simulating the error case) + mocker.patch( + "database.engine._fix_engine", + side_effect=lambda x: x if x is not None else None, + ) + + factory = SessionFactory( + database_url=sqlalchemy_connect_url, + timeseries_database_url=None, + ) + # create_engine will fail with None URL - SQLAlchemy raises AttributeError + # when trying to instantiate plugins with None + with pytest.raises( + (TypeError, ValueError, AttributeError) + ): # create_engine will fail with None URL + factory.create_session() + def test_session_get_bind_timeseries_disabled(self, sqlalchemy_connect_url, mocker): mocker.patch("database.engine.is_timeseries_enabled", return_value=False) @@ -56,3 +237,251 @@ def test_session_get_bind_timeseries_enabled(self, sqlalchemy_connect_url, mocke clause = insert(Measurement.__table__) engine = session.get_bind(clause=clause) assert engine == session_factory.timeseries_engine + + def test_routing_session_get_bind_no_mapper_no_clause( + self, sqlalchemy_connect_url, mocker + ): + """Test RoutingSession.get_bind when both mapper and clause are None""" + mocker.patch("database.engine.is_timeseries_enabled", return_value=True) + + session_factory = SessionFactory( + database_url=sqlalchemy_connect_url, + timeseries_database_url=sqlalchemy_connect_url, + ) + session = session_factory.create_session() + + # When both are None, should return main_engine + engine = session.get_bind() + assert engine == session_factory.main_engine + + def test_routing_session_get_bind_clause_without_table_attribute( + self, sqlalchemy_connect_url, mocker + ): + """Test RoutingSession.get_bind when clause doesn't have table attribute""" + mocker.patch("database.engine.is_timeseries_enabled", return_value=True) + + session_factory = SessionFactory( + database_url=sqlalchemy_connect_url, + timeseries_database_url=sqlalchemy_connect_url, + ) + session = session_factory.create_session() + + # Create a mock clause without table attribute + mock_clause = Mock() + del mock_clause.table # Ensure it doesn't have table attribute + + engine = session.get_bind(clause=mock_clause) + assert engine == session_factory.main_engine + + def test_routing_session_get_bind_clause_table_not_timeseries( + self, sqlalchemy_connect_url, mocker + ): + """Test RoutingSession.get_bind when clause.table.name doesn't start with timeseries_""" + mocker.patch("database.engine.is_timeseries_enabled", return_value=True) + + session_factory = SessionFactory( + database_url=sqlalchemy_connect_url, + timeseries_database_url=sqlalchemy_connect_url, + ) + session = session_factory.create_session() + + # Create a mock clause with table that doesn't start with timeseries_ + mock_clause = Mock() + mock_table = Mock() + mock_table.name = "regular_table" + mock_clause.table = mock_table + + engine = session.get_bind(clause=mock_clause) + assert engine == session_factory.main_engine + + def test_routing_session_get_bind_mapper_not_timeseries_subclass( + self, sqlalchemy_connect_url, mocker + ): + """Test RoutingSession.get_bind when mapper is not a TimeseriesBaseModel subclass""" + mocker.patch("database.engine.is_timeseries_enabled", return_value=True) + + session_factory = SessionFactory( + database_url=sqlalchemy_connect_url, + timeseries_database_url=sqlalchemy_connect_url, + ) + session = session_factory.create_session() + + # Commit is not a TimeseriesBaseModel subclass + engine = session.get_bind(mapper=get_mapper(Commit)) + assert engine == session_factory.main_engine + + +class TestTaskSessionManager: + def test_task_session_manager_init(self): + """Test TaskSessionManager initialization""" + manager = TaskSessionManager() + assert manager._session_maker is None + assert manager._test_session_factory is None + + def test_set_test_session_factory(self): + """Test setting a test session factory""" + manager = TaskSessionManager() + mock_factory = Mock() + + manager.set_test_session_factory(mock_factory) + + assert manager._test_session_factory == mock_factory + + def test_set_test_session_factory_none(self): + """Test clearing test session factory by setting to None""" + manager = TaskSessionManager() + manager.set_test_session_factory(Mock()) + + manager.set_test_session_factory(None) + + assert manager._test_session_factory is None + + def test_create_task_session_with_test_factory(self): + """Test create_task_session uses test factory when set""" + manager = TaskSessionManager() + mock_session = Mock() + mock_factory = Mock(return_value=mock_session) + + manager.set_test_session_factory(mock_factory) + result = manager.create_task_session() + + assert result == mock_session + mock_factory.assert_called_once() + + def test_get_session_maker_creates_session_if_needed( + self, sqlalchemy_connect_url, mocker + ): + """Test _get_session_maker creates session if main_engine is None""" + mocker.patch("database.engine.is_timeseries_enabled", return_value=False) + # Mock the global session_factory + with patch("database.engine.session_factory") as mock_session_factory: + mock_session_factory.main_engine = None + mock_session_factory.timeseries_engine = None + mock_session_factory.create_session = Mock() + mock_engine = create_engine(sqlalchemy_connect_url) + mock_session_factory.create_session.return_value = None + + # After create_session is called, set the engines + def set_engines(): + mock_session_factory.main_engine = mock_engine + + mock_session_factory.create_session.side_effect = set_engines + + manager = TaskSessionManager() + session_maker = manager._get_session_maker() + + assert session_maker is not None + mock_session_factory.create_session.assert_called_once() + + def test_get_session_maker_raises_when_engine_none_after_create( + self, sqlalchemy_connect_url, mocker + ): + """Test _get_session_maker raises RuntimeError when main_engine is None after create_session""" + mocker.patch("database.engine.is_timeseries_enabled", return_value=False) + with patch("database.engine.session_factory") as mock_session_factory: + mock_session_factory.main_engine = None + mock_session_factory.timeseries_engine = None + mock_session_factory.create_session = Mock() + # Even after create_session, main_engine remains None + mock_session_factory.create_session.return_value = None + + manager = TaskSessionManager() + with pytest.raises( + RuntimeError, + match="Cannot create task session: database engine not initialized", + ): + manager._get_session_maker() + + def test_get_session_maker_with_timeseries_enabled( + self, sqlalchemy_connect_url, mocker + ): + """Test _get_session_maker when timeseries is enabled""" + mocker.patch("database.engine.is_timeseries_enabled", return_value=True) + with patch("database.engine.session_factory") as mock_session_factory: + main_engine = create_engine(sqlalchemy_connect_url) + timeseries_engine = create_engine(sqlalchemy_connect_url) + mock_session_factory.main_engine = main_engine + mock_session_factory.timeseries_engine = timeseries_engine + + manager = TaskSessionManager() + session_maker = manager._get_session_maker() + + assert session_maker is not None + # Create a session and verify it's a RoutingSession + session = session_maker() + assert isinstance(session, Session) + + def test_get_session_maker_with_timeseries_disabled( + self, sqlalchemy_connect_url, mocker + ): + """Test _get_session_maker when timeseries is disabled""" + mocker.patch("database.engine.is_timeseries_enabled", return_value=False) + with patch("database.engine.session_factory") as mock_session_factory: + main_engine = create_engine(sqlalchemy_connect_url) + mock_session_factory.main_engine = main_engine + mock_session_factory.timeseries_engine = None + + manager = TaskSessionManager() + session_maker = manager._get_session_maker() + + assert session_maker is not None + session = session_maker() + assert isinstance(session, Session) + + def test_get_session_maker_caches_result(self, sqlalchemy_connect_url, mocker): + """Test that _get_session_maker caches the session maker""" + mocker.patch("database.engine.is_timeseries_enabled", return_value=False) + with patch("database.engine.session_factory") as mock_session_factory: + main_engine = create_engine(sqlalchemy_connect_url) + mock_session_factory.main_engine = main_engine + mock_session_factory.timeseries_engine = None + + manager = TaskSessionManager() + session_maker1 = manager._get_session_maker() + session_maker2 = manager._get_session_maker() + + assert session_maker1 is session_maker2 + + def test_create_task_session_without_test_factory( + self, sqlalchemy_connect_url, mocker + ): + """Test create_task_session creates session via session maker when no test factory""" + mocker.patch("database.engine.is_timeseries_enabled", return_value=False) + with patch("database.engine.session_factory") as mock_session_factory: + main_engine = create_engine(sqlalchemy_connect_url) + mock_session_factory.main_engine = main_engine + mock_session_factory.timeseries_engine = None + + manager = TaskSessionManager() + session = manager.create_task_session() + + assert isinstance(session, Session) + + +class TestModuleLevelFunctions: + def test_set_test_session_factory_module_level(self): + """Test the module-level set_test_session_factory function""" + mock_factory = Mock() + set_test_session_factory(mock_factory) + + # Verify it was set by checking create_task_session uses it + mock_session = Mock() + mock_factory.return_value = mock_session + result = create_task_session() + + assert result == mock_session + + # Clean up + set_test_session_factory(None) + + def test_create_task_session_module_level(self, sqlalchemy_connect_url, mocker): + """Test the module-level create_task_session function""" + mocker.patch("database.engine.is_timeseries_enabled", return_value=False) + with patch("database.engine.session_factory") as mock_session_factory: + main_engine = create_engine(sqlalchemy_connect_url) + mock_session_factory.main_engine = main_engine + mock_session_factory.timeseries_engine = None + + session = create_task_session() + + assert isinstance(session, Session) diff --git a/apps/worker/helpers/log_context.py b/apps/worker/helpers/log_context.py index 673dcb5b8c..beb53e872f 100644 --- a/apps/worker/helpers/log_context.py +++ b/apps/worker/helpers/log_context.py @@ -80,22 +80,24 @@ def populate_from_sqlalchemy(self, dbsession): .filter(Repository.repoid == self.repo_id) ) - ( - self.repo_name, - self.owner_id, - self.owner_username, - self.owner_service, - self.owner_plan, - ) = query.first() + result = query.first() + if result: + ( + self.repo_name, + self.owner_id, + self.owner_username, + self.owner_service, + self.owner_plan, + ) = result elif self.owner_id: query = dbsession.query( Owner.username, Owner.service, Owner.plan ).filter(Owner.ownerid == self.owner_id) - (self.owner_username, self.owner_service, self.owner_plan) = ( - query.first() - ) + result = query.first() + if result: + (self.owner_username, self.owner_service, self.owner_plan) = result except Exception: log.exception("Failed to populate log context") diff --git a/apps/worker/services/bundle_analysis/report.py b/apps/worker/services/bundle_analysis/report.py index 05283dad77..ec0590d979 100644 --- a/apps/worker/services/bundle_analysis/report.py +++ b/apps/worker/services/bundle_analysis/report.py @@ -7,6 +7,7 @@ import sentry_sdk from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Session +from sqlalchemy.orm import Session as DbSession from database.enums import ReportType from database.models.core import Commit @@ -68,12 +69,15 @@ def as_dict(self): "error": self.error.as_dict() if self.error else None, } - def update_upload(self, carriedforward: bool | None = False) -> None: + def update_upload( + self, carriedforward: bool | None = False, db_session: DbSession | None = None + ) -> None: """ Updates this result's `Upload` record with information from this result. """ - db_session = self.upload.get_db_session() + if db_session is None: + db_session = self.upload.get_db_session() if self.error: self.commit.state = "error" @@ -106,8 +110,11 @@ def update_upload(self, carriedforward: bool | None = False) -> None: class BundleAnalysisReportService(BaseReportService): - def initialize_and_save_report(self, commit: Commit) -> CommitReport: - db_session = commit.get_db_session() + def initialize_and_save_report( + self, commit: Commit, db_session: DbSession | None = None + ) -> CommitReport: + if db_session is None: + db_session = commit.get_db_session() commit_report = ( db_session.query(CommitReport) diff --git a/apps/worker/services/report/__init__.py b/apps/worker/services/report/__init__.py index bab2f4f5ac..4482728772 100644 --- a/apps/worker/services/report/__init__.py +++ b/apps/worker/services/report/__init__.py @@ -158,7 +158,10 @@ def initialize_and_save_report(self, commit: Commit) -> CommitReport: raise NotImplementedError() def create_report_upload( - self, arguments: UploadArguments, commit_report: CommitReport + self, + arguments: UploadArguments, + commit_report: CommitReport, + db_session: DbSession | None = None, ) -> Upload: """ Creates an `Upload` from the user-given arguments to a job @@ -166,7 +169,8 @@ def create_report_upload( The end goal here is that the `Upload` should have all the information needed to hypothetically redo the job later. """ - db_session = commit_report.get_db_session() + if db_session is None: + db_session = commit_report.get_db_session() name = arguments.get("name") upload = Upload( report_id=commit_report.id_, @@ -208,7 +212,9 @@ def has_initialized_report(self, commit: Commit) -> bool: ) @sentry_sdk.trace - def initialize_and_save_report(self, commit: Commit) -> CommitReport: + def initialize_and_save_report( + self, commit: Commit, db_session: DbSession | None = None + ) -> CommitReport: """ Initializes the commit report @@ -223,11 +229,13 @@ def initialize_and_save_report(self, commit: Commit) -> CommitReport: Args: commit (Commit): The commit we want to initialize + db_session: Optional database session to use Returns: CommitReport: The CommitReport for that commit """ - db_session = commit.get_db_session() + if db_session is None: + db_session = commit.get_db_session() current_report_row = ( db_session.query(CommitReport) .filter_by(commit_id=commit.id_, code=None) @@ -263,7 +271,7 @@ def initialize_and_save_report(self, commit: Commit) -> CommitReport: ) # This case means the report exists in our system, it was just not saved # yet into the new models therefore it needs backfilling - self.save_full_report(commit, actual_report) + self.save_full_report(commit, actual_report, db_session=db_session) if not self.has_initialized_report(commit): report = self.create_new_report_for_commit(commit) @@ -277,18 +285,21 @@ def initialize_and_save_report(self, commit: Commit) -> CommitReport: "files_count": report.totals.files, }, ) - self.save_full_report(commit, report) + self.save_full_report(commit, report, db_session=db_session) return current_report_row - def _attach_flags_to_upload(self, upload: Upload, flag_names: list[str]): + def _attach_flags_to_upload( + self, upload: Upload, flag_names: list[str], db_session: DbSession | None = None + ): """ Internal function that manages creating the proper `RepositoryFlag`s, and attach them to the `Upload` """ all_flags = [] - db_session = upload.get_db_session() + if db_session is None: + db_session = upload.get_db_session() repoid = upload.report.commit.repoid flag_dict = self.fetch_repo_flags(db_session, repoid) @@ -711,7 +722,9 @@ def build_report_from_raw_content( return result @sentry_sdk.trace - def save_report(self, commit: Commit, report: Report): + def save_report( + self, commit: Commit, report: Report, db_session: DbSession | None = None + ): archive_service = ArchiveService(commit.repository) report_json, chunks, _totals = report.serialize() @@ -742,16 +755,19 @@ def save_report(self, commit: Commit, report: Report): ) # `report_json` is an `ArchiveField`, so this will trigger an upload # FIXME: we do an unnecessary `loads` roundtrip because of this abstraction, - # and we should just save the `report_json` to archive storage directly instead. commit.report_json = orjson.loads(report_json) - # `report` is an accessor which implicitly queries `CommitReport` if commit_report := commit.report: - db_session = commit.get_db_session() + if db_session is None: + db_session = commit.get_db_session() - report_totals = commit_report.totals + report_totals = ( + db_session.query(ReportLevelTotals) + .filter_by(report_id=commit_report.id_) + .first() + ) if report_totals is None: - report_totals = ReportLevelTotals(report_id=commit_report.id) + report_totals = ReportLevelTotals(report_id=commit_report.id_) db_session.add(report_totals) rounding: str = read_yaml_field( @@ -779,7 +795,9 @@ def save_report(self, commit: Commit, report: Report): return {"url": chunks_url} @sentry_sdk.trace - def save_full_report(self, commit: Commit, report: Report) -> dict: + def save_full_report( + self, commit: Commit, report: Report, db_session: DbSession | None = None + ) -> dict: """ Saves the report (into database and storage) AND takes care of backfilling its sessions like they were never in the database (useful for backfilling and carryforward cases) @@ -790,8 +808,9 @@ def save_full_report(self, commit: Commit, report: Report) -> dict: precision: int = read_yaml_field( self.current_yaml, ("coverage", "precision"), 2 ) - res = self.save_report(commit, report) - db_session = commit.get_db_session() + res = self.save_report(commit, report, db_session=db_session) + if db_session is None: + db_session = commit.get_db_session() for sess_id, session in report.sessions.items(): upload = Upload( build_code=session.build, @@ -814,7 +833,9 @@ def save_full_report(self, commit: Commit, report: Report) -> dict: ) db_session.add(upload) db_session.flush() - self._attach_flags_to_upload(upload, session.flags if session.flags else []) + self._attach_flags_to_upload( + upload, session.flags if session.flags else [], db_session=db_session + ) if session.totals is not None: upload_totals = UploadLevelTotals(upload_id=upload.id_) db_session.add(upload_totals) diff --git a/apps/worker/services/test_results.py b/apps/worker/services/test_results.py index 37c506cf99..c5cd786658 100644 --- a/apps/worker/services/test_results.py +++ b/apps/worker/services/test_results.py @@ -7,6 +7,7 @@ from urllib.parse import quote_plus import sentry_sdk +from sqlalchemy.orm import Session as DbSession from database.enums import ReportType from database.models import ( @@ -47,8 +48,11 @@ def __init__(self, current_yaml: UserYaml): super().__init__(current_yaml) self.flag_dict = None - def initialize_and_save_report(self, commit: Commit) -> CommitReport: - db_session = commit.get_db_session() + def initialize_and_save_report( + self, commit: Commit, db_session: DbSession | None = None + ) -> CommitReport: + if db_session is None: + db_session = commit.get_db_session() current_report_row = ( db_session.query(CommitReport) .filter_by( @@ -73,13 +77,23 @@ def initialize_and_save_report(self, commit: Commit) -> CommitReport: # support flags in test results def create_report_upload( - self, arguments: UploadArguments, commit_report: CommitReport + self, + arguments: UploadArguments, + commit_report: CommitReport, + db_session: DbSession | None = None, ) -> Upload: - upload = super().create_report_upload(arguments, commit_report) - self._attach_flags_to_upload(upload, arguments["flags"]) + upload = super().create_report_upload( + arguments, commit_report, db_session=db_session + ) + self._attach_flags_to_upload(upload, arguments["flags"], db_session=db_session) return upload - def _attach_flags_to_upload(self, upload: Upload, flag_names: Sequence[str]): + def _attach_flags_to_upload( + self, + upload: Upload, + flag_names: Sequence[str], + db_session: DbSession | None = None, + ): """Internal function that manages creating the proper `RepositoryFlag`s and attach the sessions to them Args: @@ -90,7 +104,8 @@ def _attach_flags_to_upload(self, upload: Upload, flag_names: Sequence[str]): TYPE: Description """ all_flags = [] - db_session = upload.get_db_session() + if db_session is None: + db_session = upload.get_db_session() repoid = upload.report.commit.repoid if self.flag_dict is None: diff --git a/apps/worker/services/tests/test_timeseries.py b/apps/worker/services/tests/test_timeseries.py index d4cc8bf594..ace858c9b1 100644 --- a/apps/worker/services/tests/test_timeseries.py +++ b/apps/worker/services/tests/test_timeseries.py @@ -1,12 +1,16 @@ +import logging +import types from datetime import UTC, datetime import pytest from celery import group +from sqlalchemy.orm import Session as SQLAlchemySession from database.models.timeseries import Dataset, Measurement, MeasurementName from database.tests.factories import CommitFactory, RepositoryFactory from database.tests.factories.reports import RepositoryFlagFactory from database.tests.factories.timeseries import DatasetFactory, MeasurementFactory +from services.report import ReportService from services.timeseries import ( backfill_batch_size, delete_repository_data, @@ -21,6 +25,9 @@ from shared.utils.sessions import Session from shared.yaml import UserYaml from tasks.save_commit_measurements import save_commit_measurements +from tasks.tests.utils import hook_session + +logger = logging.getLogger(__name__) @pytest.fixture @@ -587,10 +594,11 @@ def test_commit_measurement_update_component_parallel( dataset_names, mocker, mock_repo_provider, + dbsession, + request, ): - dbsession = repository.get_db_session() - mocker.patch.object(dbsession, "close") - mocker.patch("tasks.base.get_db_session", return_value=dbsession) + # Use dbsession fixture directly - repository and commit are bound to it + hook_session(mocker, dbsession, request=request) mocker.patch.object(group, "apply_async", group.apply) mocker.patch( @@ -598,17 +606,52 @@ def test_commit_measurement_update_component_parallel( return_value=True, ) - mocker.patch( - "services.report.ReportService.get_existing_report_for_commit", - return_value=ReadOnlyReport.create_from_report( - sample_report_for_components - ), + # Mock ReportService.get_existing_report_for_commit for both save_commit_measurements and upsert_component + mock_report = ReadOnlyReport.create_from_report(sample_report_for_components) + # Mock the class method (for when ReportService is instantiated) + mocker.patch.object( + ReportService, + "get_existing_report_for_commit", + return_value=mock_report, + ) + # Also mock has_initialized_report so get_existing_report_for_commit doesn't return None + mocker.patch.object( + ReportService, + "has_initialized_report", + return_value=True, ) commit = CommitFactory.create(branch="foo", repository=repository) dbsession.add(commit) dbsession.flush() + # Ensure commit and repository are properly bound and visible + dbsession.refresh(commit) + dbsession.refresh(repository) + + # Verify commit is bound before patching + actual_session = SQLAlchemySession.object_session(commit) + assert actual_session is dbsession, ( + f"Commit should be bound to dbsession, got {actual_session}" + ) + + # Patch get_db_session directly on the commit instance using MethodType + # This properly binds the function as an instance method + def mock_get_db_session(self): + logger.debug( + f"DEBUG mock_get_db_session: Returning dbsession for commit {self.id_}" + ) + return dbsession + + commit.get_db_session = types.MethodType(mock_get_db_session, commit) + + # Verify patch worked + patched_result = commit.get_db_session() + assert patched_result is dbsession, ( + f"Patch should return dbsession, got {patched_result}" + ) + logger.debug(f"DEBUG: commit.get_db_session()={patched_result}") + get_repo_yaml = mocker.patch("tasks.save_commit_measurements.get_repo_yaml") get_current_yaml = mocker.patch("tasks.upsert_component.get_repo_yaml") yaml_dict = { @@ -626,11 +669,15 @@ def test_commit_measurement_update_component_parallel( get_repo_yaml.return_value = UserYaml(yaml_dict) get_current_yaml.return_value = UserYaml(yaml_dict) - save_commit_measurements(commit, dataset_names=dataset_names) + logger.debug( + f"DEBUG: About to call save_commit_measurements, commit.get_db_session()={commit.get_db_session()}" + ) + save_commit_measurements( + commit, dataset_names=dataset_names, db_session=dbsession + ) - # Want to commit here to have the results persisted properly. - # Otherwise the results aren't going to be reflected in the select below. - # dbsession.commit() + # Flush to ensure measurements created by parallel tasks are visible + dbsession.flush() measurements = ( dbsession.query(Measurement) @@ -938,10 +985,14 @@ def test_delete_repository_data_side_effects( other_commit = CommitFactory.create(branch="foo", repository=other_repository) dbsession.add(other_commit) dbsession.flush() + # Mock other_commit.get_db_session() directly to return the test session + mocker.patch.object(other_commit, "get_db_session", return_value=dbsession) save_commit_measurements(other_commit, dataset_names=dataset_names) other_commit = CommitFactory.create(branch="bar", repository=other_repository) dbsession.add(other_commit) dbsession.flush() + # Mock other_commit.get_db_session() directly to return the test session + mocker.patch.object(other_commit, "get_db_session", return_value=dbsession) save_commit_measurements(other_commit, dataset_names=dataset_names) assert ( @@ -991,6 +1042,7 @@ def test_delete_repository_data_measurements_only( dataset_names, mocker, mock_repo_provider, + request, ): def validate_invariants(repository, other_repository): assert ( @@ -1017,16 +1069,23 @@ def validate_invariants(repository, other_repository): "tasks.save_commit_measurements.PARALLEL_COMPONENT_COMPARISON.check_value", return_value=True, ) - dbsession = repository.get_db_session() - mocker.patch.object(dbsession, "close") - mocker.patch("tasks.base.get_db_session", return_value=dbsession) + # Use dbsession fixture directly - repository and commit are bound to it + hook_session(mocker, dbsession, request=request) mocker.patch.object(group, "apply_async", group.apply) - mocker.patch( - "services.report.ReportService.get_existing_report_for_commit", - return_value=ReadOnlyReport.create_from_report( - sample_report_for_components - ), + # Mock ReportService.get_existing_report_for_commit for both save_commit_measurements and upsert_component + mock_report = ReadOnlyReport.create_from_report(sample_report_for_components) + # Mock the class method (for when ReportService is instantiated) + mocker.patch.object( + ReportService, + "get_existing_report_for_commit", + return_value=mock_report, + ) + # Also mock has_initialized_report so get_existing_report_for_commit doesn't return None + mocker.patch.object( + ReportService, + "has_initialized_report", + return_value=True, ) get_repo_yaml = mocker.patch("tasks.save_commit_measurements.get_repo_yaml") @@ -1059,22 +1118,76 @@ def validate_invariants(repository, other_repository): commit = CommitFactory.create(branch="foo", repository=repository) dbsession.add(commit) dbsession.flush() - save_commit_measurements(commit, dataset_names=dataset_names) + dbsession.refresh(commit) # Ensure commit is properly bound + + # Patch get_db_session directly on this commit instance using MethodType + def mock_get_db_session(self): + return dbsession + + commit.get_db_session = types.MethodType(mock_get_db_session, commit) + assert commit.get_db_session() is dbsession, "Patch should return dbsession" + logger.debug( + f"DEBUG: commit.id_={commit.id_}, commit.get_db_session()={commit.get_db_session()}" + ) + logger.debug( + f"DEBUG: About to call save_commit_measurements for commit {commit.id_}" + ) + save_commit_measurements( + commit, dataset_names=dataset_names, db_session=dbsession + ) commit = CommitFactory.create(branch="bar", repository=repository) dbsession.add(commit) dbsession.flush() - save_commit_measurements(commit, dataset_names=dataset_names) + dbsession.refresh(commit) # Ensure commit is properly bound + + # Patch get_db_session directly on this commit instance using MethodType + def mock_get_db_session(self): + return dbsession + + commit.get_db_session = types.MethodType(mock_get_db_session, commit) + assert commit.get_db_session() is dbsession, "Patch should return dbsession" + logger.debug( + f"DEBUG: commit.id_={commit.id_}, commit.get_db_session()={commit.get_db_session()}" + ) + logger.debug( + f"DEBUG: About to call save_commit_measurements for commit {commit.id_}" + ) + save_commit_measurements( + commit, dataset_names=dataset_names, db_session=dbsession + ) # Another unrelated repository, make sure that this one isn't deleted as a side effect other_repository = _create_repository(dbsession) other_commit = CommitFactory.create(branch="foo", repository=other_repository) dbsession.add(other_commit) dbsession.flush() - save_commit_measurements(other_commit, dataset_names=dataset_names) + dbsession.refresh(other_commit) # Ensure commit is properly bound + + # Patch get_db_session directly on this commit instance using MethodType + def mock_get_db_session(self): + return dbsession + + other_commit.get_db_session = types.MethodType( + mock_get_db_session, other_commit + ) + save_commit_measurements( + other_commit, dataset_names=dataset_names, db_session=dbsession + ) other_commit = CommitFactory.create(branch="bar", repository=other_repository) dbsession.add(other_commit) dbsession.flush() - save_commit_measurements(other_commit, dataset_names=dataset_names) + dbsession.refresh(other_commit) # Ensure commit is properly bound + + # Patch get_db_session directly on this commit instance using MethodType + def mock_get_db_session(self): + return dbsession + + other_commit.get_db_session = types.MethodType( + mock_get_db_session, other_commit + ) + save_commit_measurements( + other_commit, dataset_names=dataset_names, db_session=dbsession + ) flag_ids = { flag.measurable_id diff --git a/apps/worker/services/timeseries.py b/apps/worker/services/timeseries.py index c7e9140d7b..465914019d 100644 --- a/apps/worker/services/timeseries.py +++ b/apps/worker/services/timeseries.py @@ -33,7 +33,7 @@ def maybe_upsert_coverage_measurement(commit, dataset_names, db_session, report) def maybe_upsert_flag_measurements(commit, dataset_names, db_session, report): if MeasurementName.flag_coverage.value in dataset_names: - flag_ids = repository_flag_ids(commit.repository) + flag_ids = repository_flag_ids(commit.repository, db_session) measurements = [] for flag_name, flag in report.flags.items(): @@ -98,7 +98,10 @@ def get_relevant_components( def upsert_components_measurements( - commit: Commit, report: Report, components: list[ComponentForMeasurement] + commit: Commit, + report: Report, + components: list[ComponentForMeasurement], + db_session: Session | None = None, ): measurements = [] for component in components: @@ -114,7 +117,8 @@ def upsert_components_measurements( ) if len(measurements) > 0: - db_session = commit.get_db_session() + if db_session is None: + db_session = commit.get_db_session() upsert_measurements(db_session, measurements) log.info( "Upserted component coverage measurements", @@ -196,8 +200,11 @@ def repository_datasets_query( return datasets -def repository_flag_ids(repository: Repository) -> Mapping[str, int]: - db_session = repository.get_db_session() +def repository_flag_ids( + repository: Repository, db_session: Session | None = None +) -> Mapping[str, int]: + if db_session is None: + db_session = repository.get_db_session() repo_flags = ( db_session.query(RepositoryFlag).filter_by(repository=repository).yield_per(100) ) diff --git a/apps/worker/tasks/base.py b/apps/worker/tasks/base.py index e539775454..b733de201f 100644 --- a/apps/worker/tasks/base.py +++ b/apps/worker/tasks/base.py @@ -14,10 +14,11 @@ InvalidRequestError, SQLAlchemyError, ) +from sqlalchemy.orm import Session from app import celery_app from celery_task_router import _get_ownerid_from_task, _get_user_plan_from_task -from database.engine import get_db_session +from database.engine import create_task_session, get_db_session from database.enums import CommitErrorTypes from database.models.core import ( GITHUB_APP_INSTALLATION_DEFAULT_NAME, @@ -193,9 +194,30 @@ def get_lock_timeout(self, default_timeout: int) -> int: @sentry_sdk.trace def apply_async(self, args=None, kwargs=None, **options): - db_session = get_db_session() - user_plan = _get_user_plan_from_task(db_session, self.name, kwargs) - ownerid = _get_ownerid_from_task(db_session, self.name, kwargs) + # Use temporary session for routing (read-only, short-lived) + routing_session = create_task_session() + log.debug( + "Created routing session for apply_async", + extra={"task": self.name, "session_id": id(routing_session)}, + ) + try: + user_plan = _get_user_plan_from_task(routing_session, self.name, kwargs) + ownerid = _get_ownerid_from_task(routing_session, self.name, kwargs) + finally: + # Always cleanup routing session + had_transaction = routing_session.in_transaction() + if had_transaction: + routing_session.rollback() + routing_session.close() + log.debug( + "Closed routing session for apply_async", + extra={ + "task": self.name, + "session_id": id(routing_session), + "had_transaction": had_transaction, + }, + ) + route_with_extra_config = route_tasks_based_on_user_plan( self.name, user_plan, ownerid ) @@ -216,7 +238,9 @@ def apply_async(self, args=None, kwargs=None, **options): headers = { **opt_headers, "created_timestamp": current_time.isoformat(), - "attempts": 1, + # Preserve existing attempts if present (e.g., from retry or re-delivery) + # Only set to 1 if this is a new task creation + "attempts": opt_headers.get("attempts", 1), } return super().apply_async(args=args, kwargs=kwargs, headers=headers, **options) @@ -384,118 +408,216 @@ def _analyse_error(self, exception: SQLAlchemyError, *args, **kwargs): @sentry_sdk.trace def run(self, *args, **kwargs): - with self.task_full_runtime.time(): - db_session = get_db_session() - - log_context = LogContext( - repo_id=kwargs.get("repoid") or kwargs.get("repo_id"), - owner_id=kwargs.get("ownerid"), - commit_sha=kwargs.get("commitid") or kwargs.get("commit_id"), - ) + with self.task_full_runtime.time(): # Timer isn't tested + # Wrap entire section in try-finally to ensure session cleanup + # even if exceptions occur during setup or session creation + db_session = None + try: + # Create NEW session for this task (per-task session) + db_session = create_task_session() + task_id = getattr(self.request, "id", None) + log.debug( + "Created task session", + extra={ + "task": self.name, + "task_id": task_id, + "session_id": id(db_session), + }, + ) - task = get_current_task() - if task and task.request: - log_context.task_name = task.name - task_id = getattr(task.request, "id", None) - if task_id: - log_context.task_id = task_id - - log_context.populate_from_sqlalchemy(db_session) - set_log_context(log_context) - load_checkpoints_from_kwargs([UploadFlow, TestResultsFlow], kwargs) - - self.task_run_counter.inc() - if ( - hasattr(self, "request") - and self.request is not None - and hasattr(self.request, "get") - ): - created_timestamp = self.request.get("created_timestamp", None) - if created_timestamp: - enqueued_time = datetime.fromisoformat(created_timestamp) - now = datetime.now() - delta = now - enqueued_time - - delivery_info = self.request.get("delivery_info", {}) - queue_name = ( - delivery_info.get("routing_key", None) - if isinstance(delivery_info, dict) - else None - ) - time_in_queue_timer = TASK_TIME_IN_QUEUE.labels( - task=self.name, queue=queue_name + # Validate session is clean (no inherited transaction) + if db_session.in_transaction(): + log.warning( + "New task session has open transaction, rolling back", + extra={ + "task": self.name, + "task_id": task_id, + "session_id": id(db_session), + }, ) - time_in_queue_timer.observe(delta.total_seconds()) + db_session.rollback() - close_old_connections() - - try: - with self.task_core_runtime.time(): - return self.run_impl(db_session, *args, **kwargs) - except InterfaceError as ex: - sentry_sdk.capture_exception( - ex, + log_context = LogContext( + repo_id=kwargs.get("repoid") or kwargs.get("repo_id"), + owner_id=kwargs.get("ownerid"), + commit_sha=kwargs.get("commitid") or kwargs.get("commit_id"), ) - except (DataError, IntegrityError): - log.exception( - "Errors related to the constraints of database happened", - extra={"task_args": args, "task_kwargs": kwargs}, - ) - db_session.rollback() - retry_count = getattr(self.request, "retries", 0) - countdown = TASK_RETRY_BACKOFF_BASE_SECONDS * (2**retry_count) - # Use safe_retry to handle max retries exceeded gracefully - # Returns False if max retries exceeded, otherwise raises Retry - # Wrap in try-except to catch MaxRetriesExceededError if it escapes safe_retry - # (exceptions raised inside except blocks aren't caught by sibling except clauses) + + task = get_current_task() + if task and task.request: + log_context.task_name = task.name + task_id = getattr(task.request, "id", None) + if task_id: + log_context.task_id = task_id + + log_context.populate_from_sqlalchemy(db_session) + set_log_context(log_context) + load_checkpoints_from_kwargs([UploadFlow, TestResultsFlow], kwargs) + + self.task_run_counter.inc() + if ( + hasattr(self, "request") + and self.request is not None + and hasattr(self.request, "get") + ): + created_timestamp = self.request.get("created_timestamp", None) + if created_timestamp: + enqueued_time = datetime.fromisoformat(created_timestamp) + now = datetime.now() + delta = now - enqueued_time + + delivery_info = self.request.get("delivery_info", {}) + queue_name = ( + delivery_info.get("routing_key", None) + if isinstance(delivery_info, dict) + else None + ) + time_in_queue_timer = TASK_TIME_IN_QUEUE.labels( + task=self.name, queue=queue_name + ) + time_in_queue_timer.observe(delta.total_seconds()) + + close_old_connections() + try: - if not self.safe_retry(countdown=countdown): - # Max retries exceeded - return None to match old behavior + with self.task_core_runtime.time(): # Timer isn't tested + result = self.run_impl(db_session, *args, **kwargs) + # If we got here, task succeeded - commit if transaction is open + if db_session.in_transaction(): + db_session.commit() + return result + except InterfaceError as ex: + sentry_sdk.capture_exception( + ex, + ) + db_session.rollback() + except (DataError, IntegrityError): + log.exception( + "Errors related to the constraints of database happened", + extra={"task_args": args, "task_kwargs": kwargs}, + ) + db_session.rollback() + retry_count = getattr(self.request, "retries", 0) + countdown = TASK_RETRY_BACKOFF_BASE_SECONDS * (2**retry_count) + # Use safe_retry to handle max retries exceeded gracefully + # Returns False if max retries exceeded, otherwise raises Retry + # Wrap in try-except to catch MaxRetriesExceededError if it escapes safe_retry + # (exceptions raised inside except blocks aren't caught by sibling except clauses) + try: + if not self.safe_retry(countdown=countdown): + # Max retries exceeded - return None to match old behavior + return None + except MaxRetriesExceededError: + # Handle MaxRetriesExceededError if it escapes safe_retry + if UploadFlow.has_begun(): + UploadFlow.log(UploadFlow.UNCAUGHT_RETRY_EXCEPTION) + if TestResultsFlow.has_begun(): + TestResultsFlow.log( + TestResultsFlow.UNCAUGHT_RETRY_EXCEPTION + ) + # Return None to match old behavior return None - except MaxRetriesExceededError: - # Handle MaxRetriesExceededError if it escapes safe_retry - if UploadFlow.has_begun(): - UploadFlow.log(UploadFlow.UNCAUGHT_RETRY_EXCEPTION) - if TestResultsFlow.has_begun(): - TestResultsFlow.log(TestResultsFlow.UNCAUGHT_RETRY_EXCEPTION) - # Return None to match old behavior - return None - except SQLAlchemyError as ex: - self._analyse_error(ex, args, kwargs) - db_session.rollback() - retry_count = getattr(self.request, "retries", 0) - countdown = TASK_RETRY_BACKOFF_BASE_SECONDS * (2**retry_count) - # Use safe_retry to handle max retries exceeded gracefully - # Returns False if max retries exceeded, otherwise raises Retry - # Wrap in try-except to catch MaxRetriesExceededError if it escapes safe_retry - # (exceptions raised inside except blocks aren't caught by sibling except clauses) - try: - if not self.safe_retry(countdown=countdown): - # Max retries exceeded - return None to match old behavior + except SQLAlchemyError as ex: + self._analyse_error(ex, args, kwargs) + db_session.rollback() + retry_count = getattr(self.request, "retries", 0) + countdown = TASK_RETRY_BACKOFF_BASE_SECONDS * (2**retry_count) + # Use safe_retry to handle max retries exceeded gracefully + # Returns False if max retries exceeded, otherwise raises Retry + # Wrap in try-except to catch MaxRetriesExceededError if it escapes safe_retry + # (exceptions raised inside except blocks aren't caught by sibling except clauses) + try: + if not self.safe_retry(countdown=countdown): + # Max retries exceeded - return None to match old behavior + return None + except MaxRetriesExceededError: + # Handle MaxRetriesExceededError if it escapes safe_retry + if UploadFlow.has_begun(): + UploadFlow.log(UploadFlow.UNCAUGHT_RETRY_EXCEPTION) + if TestResultsFlow.has_begun(): + TestResultsFlow.log( + TestResultsFlow.UNCAUGHT_RETRY_EXCEPTION + ) + # Return None to match old behavior return None - except MaxRetriesExceededError: - # Handle MaxRetriesExceededError if it escapes safe_retry + except MaxRetriesExceededError as ex: if UploadFlow.has_begun(): UploadFlow.log(UploadFlow.UNCAUGHT_RETRY_EXCEPTION) if TestResultsFlow.has_begun(): TestResultsFlow.log(TestResultsFlow.UNCAUGHT_RETRY_EXCEPTION) - # Return None to match old behavior - return None - except MaxRetriesExceededError as ex: - if UploadFlow.has_begun(): - UploadFlow.log(UploadFlow.UNCAUGHT_RETRY_EXCEPTION) - if TestResultsFlow.has_begun(): - TestResultsFlow.log(TestResultsFlow.UNCAUGHT_RETRY_EXCEPTION) + db_session.rollback() + except Exception: + # Catch-all: rollback on any exception + db_session.rollback() + raise finally: - self.wrap_up_dbsession(db_session) + # Always cleanup task session + # This ensures cleanup even if exceptions occur during setup or session creation + if db_session is not None: + self.wrap_up_task_session(db_session) + + def wrap_up_task_session(self, db_session: Session): + """ + Clean up task-specific session. + + Ensures session is properly closed and any remaining transaction is rolled back. + This is called for per-task sessions created via create_task_session(). + + Args: + db_session: The task session to clean up + """ + session_id = id(db_session) + had_transaction = False + try: + # Rollback any remaining transaction + had_transaction = db_session.in_transaction() + if had_transaction: + db_session.rollback() + except Exception as e: + log.warning( + "Error rolling back task session", + extra={"error": str(e), "task": self.name, "session_id": session_id}, + exc_info=True, + ) + finally: + try: + db_session.close() + log.debug( + "Closed task session", + extra={ + "task": self.name, + "session_id": session_id, + "had_transaction": had_transaction, + }, + ) + except Exception as e: + log.warning( + "Error closing task session", + extra={ + "error": str(e), + "task": self.name, + "session_id": session_id, + }, + exc_info=True, + ) def wrap_up_dbsession(self, db_session): - """Commit and close database session, handling timeout edge cases. + """ + Wraps up dbsession, commits what is relevant and closes the session. + + LEGACY METHOD: This is kept for backward compatibility (routing, etc.). + For task sessions, use wrap_up_task_session() instead. + + This function deals with the very corner case of when a `SoftTimeLimitExceeded` + is raised during the execution of `db_session.commit()`. When it happens, + the dbsession gets into a bad state, which disallows further operations in it. + + And because we reuse dbsessions, this would mean future tasks happening inside the + same process would also lose access to db. - Handles the corner case where `SoftTimeLimitExceeded` is raised during - `db_session.commit()`, which can leave the session in an unusable state. - Since we reuse sessions across tasks, this would break future tasks in - the same process, so we catch both timeout and invalid state exceptions. + So we need to do two ugly exception-catching: + 1) For if `SoftTimeLimitExceeded` was raised while committing + 2) For if the exception left `db_session` in an unusable state """ try: db_session.commit() diff --git a/apps/worker/tasks/bundle_analysis_processor.py b/apps/worker/tasks/bundle_analysis_processor.py index e30a4485a2..0b37a67eff 100644 --- a/apps/worker/tasks/bundle_analysis_processor.py +++ b/apps/worker/tasks/bundle_analysis_processor.py @@ -174,9 +174,13 @@ def process_impl_within_lock( return processing_results else: # If the commit report does not exist, we will create a new one - commit_report = report_service.initialize_and_save_report(commit) + commit_report = report_service.initialize_and_save_report( + commit, db_session=db_session + ) - upload = report_service.create_report_upload({"url": ""}, commit_report) + upload = report_service.create_report_upload( + {"url": ""}, commit_report, db_session=db_session + ) carriedforward = True assert upload is not None @@ -260,7 +264,7 @@ def process_impl_within_lock( max_retries=self.max_retries, countdown=30 * (2**self.request.retries), ) - result.update_upload(carriedforward=carriedforward) + result.update_upload(carriedforward=carriedforward, db_session=db_session) db_session.commit() processing_results.append(result.as_dict()) diff --git a/apps/worker/tasks/preprocess_upload.py b/apps/worker/tasks/preprocess_upload.py index 05795fb5b7..ac8e737a65 100644 --- a/apps/worker/tasks/preprocess_upload.py +++ b/apps/worker/tasks/preprocess_upload.py @@ -123,7 +123,9 @@ def process_impl_within_lock(self, db_session, repoid, commitid): report_service = ReportService( commit_yaml, gh_app_installation_name=installation_name_to_use ) - commit_report = report_service.initialize_and_save_report(commit) + commit_report = report_service.initialize_and_save_report( + commit, db_session=db_session + ) # Persist changes from within the lock db_session.commit() self._call_upload_breadcrumb_task( diff --git a/apps/worker/tasks/save_commit_measurements.py b/apps/worker/tasks/save_commit_measurements.py index aa87277085..c02bf61f79 100644 --- a/apps/worker/tasks/save_commit_measurements.py +++ b/apps/worker/tasks/save_commit_measurements.py @@ -24,8 +24,16 @@ log = logging.getLogger(__name__) -def save_commit_measurements(commit: Commit, dataset_names: Sequence[str]) -> None: - db_session = commit.get_db_session() +def save_commit_measurements( + commit: Commit, dataset_names: Sequence[str], db_session: Session | None = None +) -> None: + if db_session is None: + db_session = commit.get_db_session() + if db_session is None: + raise ValueError( + f"commit.get_db_session() returned None for commit {commit.id_}. " + "Pass db_session parameter explicitly." + ) current_yaml = get_repo_yaml(commit.repository) report_service = ReportService(current_yaml) @@ -59,7 +67,7 @@ def save_commit_measurements(commit: Commit, dataset_names: Sequence[str]) -> No ] group(task_signatures).apply_async() else: - upsert_components_measurements(commit, report, components) + upsert_components_measurements(commit, report, components, db_session) maybe_upsert_flag_measurements(commit, dataset_names, db_session, report) @@ -104,7 +112,9 @@ def run_impl( try: # TODO: We should improve on the error handling/logs inside this fn - save_commit_measurements(commit=commit, dataset_names=dataset_names) + save_commit_measurements( + commit=commit, dataset_names=dataset_names, db_session=db_session + ) return {"successful": True} except Exception: log.exception( diff --git a/apps/worker/tasks/tests/integration/test_upload_e2e.py b/apps/worker/tasks/tests/integration/test_upload_e2e.py index 485d81553b..483eb9f890 100644 --- a/apps/worker/tasks/tests/integration/test_upload_e2e.py +++ b/apps/worker/tasks/tests/integration/test_upload_e2e.py @@ -129,11 +129,12 @@ def setup_mocks( mock_configuration, mock_repo_provider, user_yaml=None, + request=None, ): # patch various `get_db_session` imports - hook_session(mocker, dbsession) + hook_session(mocker, dbsession, request=request) # to not close the session after each task - mocker.patch("tasks.base.BaseCodecovTask.wrap_up_dbsession") + mocker.patch("tasks.base.BaseCodecovTask.wrap_up_task_session") # patch various `get_repo_provider_service` imports hook_repo_provider(mocker, mock_repo_provider) # avoid some calls reaching out to git providers @@ -172,9 +173,12 @@ def test_full_upload( mock_repo_provider, mock_storage, mock_configuration, + request, ): mock_all_plans_and_tiers() - setup_mocks(mocker, dbsession, mock_configuration, mock_repo_provider) + setup_mocks( + mocker, dbsession, mock_configuration, mock_repo_provider, request=request + ) repository = RepositoryFactory.create() dbsession.add(repository) @@ -200,6 +204,10 @@ def test_full_upload( dbsession.add(commit) dbsession.flush() + # Ensure commit and repository are properly bound and visible + dbsession.refresh(commit) + dbsession.refresh(repository) + setup_mock_get_compare(base_commit, commit, mock_repo_provider) archive_service = ArchiveService(repository) @@ -379,11 +387,17 @@ def test_full_carryforward( mock_repo_provider, mock_storage, mock_configuration, + request, ): mock_all_plans_and_tiers() user_yaml = {"flag_management": {"default_rules": {"carryforward": True}}} setup_mocks( - mocker, dbsession, mock_configuration, mock_repo_provider, user_yaml=user_yaml + mocker, + dbsession, + mock_configuration, + mock_repo_provider, + user_yaml=user_yaml, + request=request, ) mocker.patch("tasks.compute_comparison.ComputeComparisonTask.run_impl") @@ -397,6 +411,10 @@ def test_full_carryforward( dbsession.add(base_commit) dbsession.flush() + # Ensure commit and repository are properly bound and visible + dbsession.refresh(base_commit) + dbsession.refresh(repository) + archive_service = ArchiveService(repository) do_upload = partial( write_raw_upload, diff --git a/apps/worker/tasks/tests/unit/test_base.py b/apps/worker/tasks/tests/unit/test_base.py index 237ba9b51a..38083894c3 100644 --- a/apps/worker/tasks/tests/unit/test_base.py +++ b/apps/worker/tasks/tests/unit/test_base.py @@ -124,7 +124,7 @@ def test_hard_time_limit_task_from_default_app(self, mocker): @patch("tasks.base.datetime", MockDateTime) def test_sample_run(self, mocker, dbsession): - mocked_get_db_session = mocker.patch("tasks.base.get_db_session") + mocked_create_task_session = mocker.patch("tasks.base.create_task_session") mock_task_request = mocker.patch("tasks.base.BaseCodecovTask.request") fake_request_values = { "created_timestamp": "2023-06-13 10:00:00.000000", @@ -133,7 +133,7 @@ def test_sample_run(self, mocker, dbsession): mock_task_request.get.side_effect = ( lambda key, default: fake_request_values.get(key, default) ) - mocked_get_db_session.return_value = dbsession + mocked_create_task_session.return_value = dbsession task_instance = SampleTask() result = task_instance.run() assert result == {"unusual": "return", "value": ["There"]} @@ -146,8 +146,8 @@ def test_sample_run(self, mocker, dbsession): ) def test_sample_run_db_exception(self, mocker, dbsession): - mocked_get_db_session = mocker.patch("tasks.base.get_db_session") - mocked_get_db_session.return_value = dbsession + mocked_create_task_session = mocker.patch("tasks.base.create_task_session") + mocked_create_task_session.return_value = dbsession # Mock request to skip queue metrics mocker.patch("tasks.base.BaseCodecovTask.request", None) with pytest.raises(Retry): @@ -156,8 +156,8 @@ def test_sample_run_db_exception(self, mocker, dbsession): ).run() def test_sample_run_integrity_error(self, mocker, dbsession): - mocked_get_db_session = mocker.patch("tasks.base.get_db_session") - mocked_get_db_session.return_value = dbsession + mocked_create_task_session = mocker.patch("tasks.base.create_task_session") + mocked_create_task_session.return_value = dbsession # Mock request to skip queue metrics mocker.patch("tasks.base.BaseCodecovTask.request", None) with pytest.raises(Retry): @@ -166,8 +166,8 @@ def test_sample_run_integrity_error(self, mocker, dbsession): ).run() def test_sample_run_deadlock_exception(self, mocker, dbsession): - mocked_get_db_session = mocker.patch("tasks.base.get_db_session") - mocked_get_db_session.return_value = dbsession + mocked_create_task_session = mocker.patch("tasks.base.create_task_session") + mocked_create_task_session.return_value = dbsession # Mock request to skip queue metrics mocker.patch("tasks.base.BaseCodecovTask.request", None) with pytest.raises(Retry): @@ -176,16 +176,16 @@ def test_sample_run_deadlock_exception(self, mocker, dbsession): ).run() def test_sample_run_operationalerror_exception(self, mocker, dbsession): - mocked_get_db_session = mocker.patch("tasks.base.get_db_session") - mocked_get_db_session.return_value = dbsession + mocked_create_task_session = mocker.patch("tasks.base.create_task_session") + mocked_create_task_session.return_value = dbsession # Mock request to skip queue metrics mocker.patch("tasks.base.BaseCodecovTask.request", None) with pytest.raises(Retry): SampleTaskWithArbitraryPostgresError(psycopg2.OperationalError()).run() def test_sample_run_softimeout(self, mocker, dbsession): - mocked_get_db_session = mocker.patch("tasks.base.get_db_session") - mocked_get_db_session.return_value = dbsession + mocked_create_task_session = mocker.patch("tasks.base.create_task_session") + mocked_create_task_session.return_value = dbsession # Mock request to skip queue metrics mocker.patch("tasks.base.BaseCodecovTask.request", None) with pytest.raises(SoftTimeLimitExceeded): @@ -232,10 +232,18 @@ def test_wrap_up_dbsession_invalid_nothing_works(self, mocker): assert mocked_get_db_session.remove.call_count == 1 def test_run_success_commits_sqlalchemy(self, mocker, dbsession): - mock_wrap_up = mocker.patch("tasks.base.BaseCodecovTask.wrap_up_dbsession") + mock_wrap_up = mocker.patch("tasks.base.BaseCodecovTask.wrap_up_task_session") mock_dbsession_rollback = mocker.patch.object(dbsession, "rollback") - mock_get_db_session = mocker.patch( - "tasks.base.get_db_session", return_value=dbsession + mock_dbsession_commit = mocker.patch.object(dbsession, "commit") + # in_transaction is called: + # 1. At start to validate session (line 353) - return False (clean session) + # 2. After run_impl to check if commit needed (line 387) - return True (transaction open) + # 3. In wrap_up_task_session to check if rollback needed (line 432) - return False (after commit) + mock_dbsession_in_transaction = mocker.patch.object( + dbsession, "in_transaction", side_effect=[False, True, False] + ) + mock_create_task_session = mocker.patch( + "tasks.base.create_task_session", return_value=dbsession ) # Mock request to skip queue metrics mocker.patch("tasks.base.BaseCodecovTask.request", None) @@ -244,14 +252,16 @@ def test_run_success_commits_sqlalchemy(self, mocker, dbsession): task.run() assert mock_wrap_up.call_args_list == [call(dbsession)] - + # Should commit on success if transaction is open + assert mock_dbsession_commit.call_count == 1 + # wrap_up_task_session will check in_transaction again (returns False), so no rollback assert mock_dbsession_rollback.call_count == 0 def test_run_db_errors_rollback(self, mocker, dbsession, celery_app): mock_dbsession_rollback = mocker.patch.object(dbsession, "rollback") - mock_wrap_up = mocker.patch("tasks.base.BaseCodecovTask.wrap_up_dbsession") - mock_get_db_session = mocker.patch( - "tasks.base.get_db_session", return_value=dbsession + mock_wrap_up = mocker.patch("tasks.base.BaseCodecovTask.wrap_up_task_session") + mock_create_task_session = mocker.patch( + "tasks.base.create_task_session", return_value=dbsession ) # IntegrityError and DataError are subclasses of SQLAlchemyError that @@ -263,15 +273,16 @@ def test_run_db_errors_rollback(self, mocker, dbsession, celery_app): task = celery_app.tasks[registered_task.name] task.apply() - assert mock_dbsession_rollback.call_args_list == [call()] + # Should rollback in exception handler + assert mock_dbsession_rollback.call_count >= 1 assert mock_wrap_up.call_args_list == [call(dbsession)] def test_run_sqlalchemy_error_rollback(self, mocker, dbsession, celery_app): mock_dbsession_rollback = mocker.patch.object(dbsession, "rollback") - mock_wrap_up = mocker.patch("tasks.base.BaseCodecovTask.wrap_up_dbsession") - mock_get_db_session = mocker.patch( - "tasks.base.get_db_session", return_value=dbsession + mock_wrap_up = mocker.patch("tasks.base.BaseCodecovTask.wrap_up_task_session") + mock_create_task_session = mocker.patch( + "tasks.base.create_task_session", return_value=dbsession ) # Mock request to skip queue metrics mocker.patch("tasks.base.BaseCodecovTask.request", None) @@ -284,7 +295,8 @@ def test_run_sqlalchemy_error_rollback(self, mocker, dbsession, celery_app): task = celery_app.tasks[registered_task.name] task.apply() - assert mock_dbsession_rollback.call_args_list == [call()] + # Should rollback in exception handler + assert mock_dbsession_rollback.call_count >= 1 assert mock_wrap_up.call_args_list == [call(dbsession)] @@ -906,7 +918,11 @@ def fake_repos(self, dbsession, fake_owners): @pytest.mark.freeze_time("2023-06-13T10:01:01.000123") def test_apply_async_override(self, mocker): - mock_get_db_session = mocker.patch("tasks.base.get_db_session") + mock_db_session = mocker.MagicMock() + mock_create_task_session = mocker.patch( + "tasks.base.create_task_session", + return_value=mock_db_session, + ) mock_celery_task_router = mocker.patch("tasks.base._get_user_plan_from_task") mock_route_tasks = mocker.patch( "tasks.base.route_tasks_based_on_user_plan", @@ -922,7 +938,7 @@ def test_apply_async_override(self, mocker): kwargs = {"n": 10} task.apply_async(kwargs=kwargs) - assert mock_get_db_session.call_count == 1 + assert mock_create_task_session.call_count == 1 assert mock_celery_task_router.call_count == 1 assert mock_route_tasks.call_count == 1 call_kwargs = mocked_apply_async.call_args[1] @@ -938,7 +954,11 @@ def test_apply_async_override(self, mocker): @pytest.mark.freeze_time("2023-06-13T10:01:01.000123") def test_apply_async_override_with_chain(self, mocker): - mock_get_db_session = mocker.patch("tasks.base.get_db_session") + mock_db_session = mocker.MagicMock() + mock_create_task_session = mocker.patch( + "tasks.base.create_task_session", + return_value=mock_db_session, + ) mock_celery_task_router = mocker.patch("tasks.base._get_user_plan_from_task") mock_route_tasks = mocker.patch( "tasks.base.route_tasks_based_on_user_plan", @@ -955,7 +975,7 @@ def test_apply_async_override_with_chain(self, mocker): chain( [task.signature(kwargs={"n": 1}), task.signature(kwargs={"n": 10})] ).apply_async() - assert mock_get_db_session.call_count == 1 + assert mock_create_task_session.call_count == 1 assert mock_celery_task_router.call_count == 1 assert mock_route_tasks.call_count == 1 assert mocked_apply_async.call_count == 1 @@ -993,8 +1013,9 @@ def test_real_example_no_override( } } ) - mock_get_db_session = mocker.patch( - "tasks.base.get_db_session", return_value=dbsession + mock_create_task_session = mocker.patch( + "tasks.base.create_task_session", + return_value=dbsession, ) task = BaseCodecovTask() mocker.patch.object(task, "run", return_value="success") @@ -1007,7 +1028,7 @@ def test_real_example_no_override( kwargs = {"ownerid": repo.ownerid} task.apply_async(kwargs=kwargs) - assert mock_get_db_session.call_count == 1 + assert mock_create_task_session.call_count == 1 call_kwargs = mocked_super_apply_async.call_args[1] assert call_kwargs["args"] is None assert call_kwargs["kwargs"] == kwargs @@ -1042,8 +1063,9 @@ def test_real_example_override_from_celery( } } ) - mock_get_db_session = mocker.patch( - "tasks.base.get_db_session", return_value=dbsession + mock_create_task_session = mocker.patch( + "tasks.base.create_task_session", + return_value=dbsession, ) task = BaseCodecovTask() mocker.patch.object(task, "run", return_value="success") @@ -1056,7 +1078,7 @@ def test_real_example_override_from_celery( kwargs = {"ownerid": repo_enterprise_cloud.ownerid} task.apply_async(kwargs=kwargs) - assert mock_get_db_session.call_count == 1 + assert mock_create_task_session.call_count == 1 call_kwargs = mocked_super_apply_async.call_args[1] assert call_kwargs["args"] is None assert call_kwargs["kwargs"] == kwargs @@ -1091,8 +1113,9 @@ def test_real_example_override_from_upload( } } ) - mock_get_db_session = mocker.patch( - "tasks.base.get_db_session", return_value=dbsession + mock_create_task_session = mocker.patch( + "tasks.base.create_task_session", + return_value=dbsession, ) task = BaseCodecovTask() mocker.patch.object(task, "run", return_value="success") @@ -1105,7 +1128,7 @@ def test_real_example_override_from_upload( kwargs = {"repoid": repo_enterprise_cloud.repoid} task.apply_async(kwargs=kwargs) - assert mock_get_db_session.call_count == 1 + assert mock_create_task_session.call_count == 1 call_kwargs = mocked_super_apply_async.call_args[1] assert call_kwargs["args"] is None assert call_kwargs["kwargs"] == kwargs diff --git a/apps/worker/tasks/tests/unit/test_compute_comparison.py b/apps/worker/tasks/tests/unit/test_compute_comparison.py index 0a811c9739..e87d5d6efa 100644 --- a/apps/worker/tasks/tests/unit/test_compute_comparison.py +++ b/apps/worker/tasks/tests/unit/test_compute_comparison.py @@ -560,7 +560,7 @@ def test_compute_component_comparisons( def test_compute_component_comparisons_parallel( self, dbsession, mocker, mock_repo_provider, mock_storage, sample_report ): - mocker.patch("tasks.base.get_db_session", return_value=dbsession) + mocker.patch("tasks.base.create_task_session", return_value=dbsession) mocker.patch.object(group, "apply_async", group.apply) mocker.patch.object( diff --git a/apps/worker/tasks/tests/unit/test_save_commit_measurements.py b/apps/worker/tasks/tests/unit/test_save_commit_measurements.py index 651ee4a703..2eef111607 100644 --- a/apps/worker/tasks/tests/unit/test_save_commit_measurements.py +++ b/apps/worker/tasks/tests/unit/test_save_commit_measurements.py @@ -36,6 +36,7 @@ def test_save_commit_measurements_success(self, dbsession, mocker): MeasurementName.coverage.value, MeasurementName.flag_coverage.value, ], + db_session=dbsession, ) def test_save_commit_measurements_no_commit(self, dbsession): diff --git a/apps/worker/tasks/tests/utils.py b/apps/worker/tasks/tests/utils.py index fdf0d66c86..cc2ae740a4 100644 --- a/apps/worker/tasks/tests/utils.py +++ b/apps/worker/tasks/tests/utils.py @@ -4,6 +4,7 @@ from sqlalchemy.orm import Session from app import celery_app +from database.engine import set_test_session_factory @contextlib.contextmanager @@ -17,20 +18,44 @@ def run_tasks() -> Generator[None]: GLOBALS_USING_SESSION = [ - "celery_task_router.get_db_session", "database.engine.get_db_session", "tasks.base.get_db_session", ] +GLOBALS_USING_TASK_SESSION = [ + "tasks.base.create_task_session", + "database.engine.create_task_session", + "celery_task_router.create_task_session", +] + + +def hook_session(mocker, dbsession: Session, request=None): + """Configure all tasks to use the shared test session.""" -def hook_session(mocker, dbsession: Session): - """ - This patches various module-local imports related to `get_db_session`. - """ mocker.patch("shared.metrics") for path in GLOBALS_USING_SESSION: mocker.patch(path, return_value=dbsession) + mocker.patch("tasks.base.close_old_connections") + mocker.patch.object(dbsession, "close", lambda: None) + mocker.patch.object(dbsession, "in_transaction", lambda: False) + + original_commit = dbsession.commit + + def flush_instead_of_commit(): + dbsession.flush() + + mocker.patch.object(dbsession, "commit", flush_instead_of_commit) + + set_test_session_factory(lambda: dbsession) + + def cleanup(): + set_test_session_factory(None) + dbsession.commit = original_commit + + if request is not None: + request.addfinalizer(cleanup) + GLOBALS_USING_REPO_PROVIDER = [ "services.comparison.get_repo_provider_service", @@ -42,14 +67,7 @@ def hook_session(mocker, dbsession: Session): def hook_repo_provider(mocker, mock_repo_provider): - """ - Hooks / mocks various `get_repo_provider_service` locals. - Due to how import resolution works in python, we have to patch this - *everywhere* that is *imported* into, instead of patching the function where - it is defined. - The reason is that imports are resolved at import time, and overriding the - function definition after the fact does not work. - """ + """Patch get_repo_provider_service in all modules that import it.""" for path in GLOBALS_USING_REPO_PROVIDER: mocker.patch(path, return_value=mock_repo_provider) @@ -57,32 +75,14 @@ def hook_repo_provider(mocker, mock_repo_provider): def ensure_hard_time_limit_task_is_numeric( mocker, task_instance, default_value: int = 720 ): - """ - Ensures that hard_time_limit_task returns a numeric value for testing. - - This helper patches hard_time_limit_task to return a proper integer value, - preventing issues where MagicMock objects might be returned when app.conf - is mocked. - - Use this helper when testing code that calls get_lock_timeout() to ensure - hard_time_limit_task returns a proper numeric value. - - Args: - mocker: The pytest mocker fixture - task_instance: The task instance to patch - default_value: The default value to return (default: 720) - """ - # Get the original property getter + """Patch hard_time_limit_task to return a numeric value instead of MagicMock.""" original_getter = task_instance.__class__.hard_time_limit_task.fget def safe_hard_time_limit_task(self): try: value = original_getter(self) - if isinstance(value, int | float): - # Use the original value if it's valid and greater than 0 - # Otherwise fall back to default_value - if value > 0: - return int(value) + if isinstance(value, int | float) and value > 0: + return int(value) except AttributeError | TypeError: pass return default_value diff --git a/apps/worker/tasks/upload.py b/apps/worker/tasks/upload.py index 92a70b8693..0a8fc74753 100644 --- a/apps/worker/tasks/upload.py +++ b/apps/worker/tasks/upload.py @@ -530,7 +530,9 @@ def run_impl_within_lock( try: log.info("Initializing and saving report", extra=upload_context.log_extra()) - commit_report = report_service.initialize_and_save_report(commit) + commit_report = report_service.initialize_and_save_report( + commit, db_session=db_session + ) except NotReadyToBuildReportYetError: log.warning( "Commit not yet ready to build its initial report. Retrying in 60s.", @@ -692,7 +694,9 @@ def _possibly_create_uploads_to_insert( normalize_flags(arguments) if "upload_id" not in arguments: - upload = report_service.create_report_upload(arguments, commit_report) + upload = report_service.create_report_upload( + arguments, commit_report, db_session=db_session + ) arguments["upload_id"] = upload.id_ # Adds objects to insert later in bulk upload_flag_map[upload] = arguments.get("flags", []) diff --git a/apps/worker/tasks/upload_finisher.py b/apps/worker/tasks/upload_finisher.py index 0da4dc2094..304e2701ef 100644 --- a/apps/worker/tasks/upload_finisher.py +++ b/apps/worker/tasks/upload_finisher.py @@ -445,7 +445,7 @@ def _process_reports_with_lock( report.apply_diff(diff) log.info("run_impl: Saving report") - report_service.save_report(commit, report) + report_service.save_report(commit, report, db_session=db_session) db_session.commit() diff --git a/apps/worker/tasks/upsert_component.py b/apps/worker/tasks/upsert_component.py index 828681a5a9..a3c3db83f6 100644 --- a/apps/worker/tasks/upsert_component.py +++ b/apps/worker/tasks/upsert_component.py @@ -34,6 +34,16 @@ def run_impl( .first() ) + if commit is None: + raise ValueError( + f"Commit not found: repoid={repoid}, commitid={commitid}. " + "This may indicate a session visibility issue." + ) + + # Ensure repository relationship is loaded + if commit.repository is None: + db_session.refresh(commit, ["repository"]) + current_yaml = get_repo_yaml(commit.repository) report_service = ReportService(current_yaml) report = report_service.get_existing_report_for_commit( @@ -42,7 +52,10 @@ def run_impl( assert report, "expected a `Report` to exist" upsert_components_measurements( - commit, report, [ComponentForMeasurement(component_id, flags, paths)] + commit, + report, + [ComponentForMeasurement(component_id, flags, paths)], + db_session, ) diff --git a/apps/worker/tests/unit/test_task_router.py b/apps/worker/tests/unit/test_task_router.py index e17e60092e..9723637dfa 100644 --- a/apps/worker/tests/unit/test_task_router.py +++ b/apps/worker/tests/unit/test_task_router.py @@ -289,11 +289,11 @@ def test_get_ownerid_from_task( def test_route_task(mocker, dbsession, fake_repos): - mock_get_db_session = mocker.patch("celery_task_router.get_db_session") + mock_create_task_session = mocker.patch("celery_task_router.create_task_session") mock_route_tasks_shared = mocker.patch( "celery_task_router.route_tasks_based_on_user_plan" ) - mock_get_db_session.return_value = dbsession + mock_create_task_session.return_value = dbsession mock_route_tasks_shared.return_value = {"queue": "correct queue"} repo = fake_repos[0] task_kwargs = { @@ -304,7 +304,7 @@ def test_route_task(mocker, dbsession, fake_repos): } response = route_task(shared_celery_config.upload_task_name, [], task_kwargs, {}) assert response == {"queue": "correct queue"} - mock_get_db_session.assert_called() + mock_create_task_session.assert_called() mock_route_tasks_shared.assert_called_with( shared_celery_config.upload_task_name, PlanName.CODECOV_PRO_MONTHLY.value,