diff --git a/.gitignore b/.gitignore index afb5191aa..30406afe6 100644 --- a/.gitignore +++ b/.gitignore @@ -176,3 +176,6 @@ burr/tracking/server/build examples/*/statemachine examples/*/*/statemachine .vscode + +# Agent-authored planning artifacts (never commit) +docs/superpowers/ diff --git a/burr/core/__init__.py b/burr/core/__init__.py index c4da5a48e..abf1c779d 100644 --- a/burr/core/__init__.py +++ b/burr/core/__init__.py @@ -22,7 +22,9 @@ ApplicationContext, ApplicationGraph, ) +from burr.core.durable import DeterminismError, JournalEntry, SuspensionRecord from burr.core.graph import Graph, GraphBuilder +from burr.core.resume import aresume, resume from burr.core.state import State __all__ = [ @@ -32,11 +34,16 @@ "ApplicationBuilder", "ApplicationGraph", "ApplicationContext", + "aresume", "Condition", "default", + "DeterminismError", "expr", + "JournalEntry", "Result", + "resume", "State", + "SuspensionRecord", "when", "Graph", "GraphBuilder", diff --git a/burr/core/application.py b/burr/core/application.py index dc8067c4b..92718a4db 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -60,6 +60,7 @@ StreamingAction, StreamingResultContainer, ) +from burr.core.durable import _Suspended from burr.core.graph import Graph, GraphBuilder from burr.core.persistence import ( AsyncBaseStateLoader, @@ -584,6 +585,125 @@ def my_action(state: State, __context: ApplicationContext) -> State: state_persister: Optional[BaseStateSaver] action_name: Optional[str] # Store just the action name + # --- Durable execution runtime state (populated per-step) --- + _resume_signals: Dict[str, Any] = dataclasses.field(default_factory=dict) + _loaded_journal: list = dataclasses.field(default_factory=list) + _journal_sink: list = dataclasses.field(default_factory=list) + _journal_call_index: int = 0 + + def suspend( + self, + channel: str, + *, + schema: Optional[type] = None, + metadata: Optional[dict] = None, + ) -> Any: + """Pause this run until an external event arrives on ``channel``. + + First execution: raises the internal suspend signal; the run loop + persists a suspension record and stops. On resume: returns the payload + delivered to that channel. When ``schema`` is supplied and the payload + is a dict, the payload is coerced into the schema type via + ``schema(**payload)``; a non-dict payload is returned unchanged. + + Pure control flow, no IO. Do not call from inside a ``durable()`` fn. + """ + if channel in self._resume_signals: + payload = self._resume_signals[channel] + if schema is not None and isinstance(payload, dict): + return schema(**payload) + return payload + schema_json = ( + schema.model_json_schema() + if schema is not None and hasattr(schema, "model_json_schema") + else None + ) + from burr.core.durable import _Suspended + + raise _Suspended(channel, schema_json, metadata) + + def durable(self, key: str, fn: Callable, *args, **kwargs) -> Any: + """Memoize a sub-step. First run: execute ``fn`` and journal its result. + On resume: replay the journaled result without executing ``fn`` again. + + ``key`` must be stable and called in the same order across re-runs of + the same action invocation (see the determinism contract). Do not call + ``suspend()`` from inside ``fn``, and do not wrap ``durable()`` in a + try/except that varies the key on the exception branch: the call index + is consumed even when ``fn`` raises, so a retry with a different key + triggers a ``DeterminismError``. + """ + from burr.core.durable import DeterminismError, JournalEntry, supports_durable_storage + + idx = self._journal_call_index + self._journal_call_index += 1 + + if idx < len(self._loaded_journal): + recorded = self._loaded_journal[idx] + if recorded.step_key != key: + raise DeterminismError( + f"Durable sub-step #{idx} replayed as key {key!r} but the " + f"journal recorded key {recorded.step_key!r}. The action's " + f"durable() calls must occur in the same order with the same " + f"keys on every re-run." + ) + return recorded.result + + result = fn(*args, **kwargs) + entry = JournalEntry( + partition_key=self.partition_key, + app_id=self.app_id, + sequence_id=self.sequence_id, + step_key=key, + call_index=idx, + result=result, + ) + self._journal_sink.append(entry) + if self.state_persister is not None and supports_durable_storage( + self.state_persister + ): + # First-party storage: persist immediately for crash resilience. + self.state_persister.save_journal_entry(entry) + return result + + async def adurable(self, key: str, fn: Callable, *args, **kwargs) -> Any: + """Async variant of durable(): ``fn`` is a coroutine function.""" + from burr.core.durable import ( + DeterminismError, + JournalEntry, + supports_durable_storage, + ) + + idx = self._journal_call_index + self._journal_call_index += 1 + + if idx < len(self._loaded_journal): + recorded = self._loaded_journal[idx] + if recorded.step_key != key: + raise DeterminismError( + f"Durable sub-step #{idx} replayed as key {key!r} but the " + f"journal recorded key {recorded.step_key!r}." + ) + return recorded.result + + result = await fn(*args, **kwargs) + entry = JournalEntry( + partition_key=self.partition_key, + app_id=self.app_id, + sequence_id=self.sequence_id, + step_key=key, + call_index=idx, + result=result, + ) + self._journal_sink.append(entry) + if self.state_persister is not None and supports_durable_storage( + self.state_persister + ): + saver = self.state_persister.save_journal_entry(entry) + if self.state_persister.is_async(): + await saver + return result + @staticmethod def get() -> Optional["ApplicationContext"]: """Provides the context-local application context. @@ -841,6 +961,10 @@ def __init__( self._spawning_parent_pointer = spawning_parent_pointer self._state_initializer = state_initializer self._state_persister = state_persister + self._resume_signals: Dict[str, Any] = {} + self._loaded_journal: list = [] + self._journal_sink: list = [] + self._suspended = None self._adapter_set.call_all_lifecycle_hooks_sync( "post_application_create", state=self._state, @@ -870,6 +994,8 @@ def step(self, inputs: Optional[Dict[str, Any]] = None) -> Optional[Tuple[Action # we need to increment the sequence before we start computing # that way if we're replaying from state, we don't get stuck self.validate_correct_async_use() + self._journal_sink = [] + self._suspended = None self._increment_sequence_id() out = self._step(inputs=inputs, _run_hooks=True) return out @@ -885,6 +1011,10 @@ def _context_factory(self, action: Action, sequence_id: int) -> ApplicationConte state_initializer=self._state_initializer, state_persister=self._state_persister, action_name=action.name if action else None, # Pass just the action name + # _journal_call_index is intentionally not forwarded; it starts at 0 per action context. + _resume_signals=self._resume_signals, + _loaded_journal=self._loaded_journal, + _journal_sink=self._journal_sink, ) def _step( @@ -912,6 +1042,7 @@ def _step( exc = None result = None new_state = self._state + suspended_signal = None try: if next_action.single_step: result, new_state = _run_single_step_action( @@ -924,25 +1055,164 @@ def _step( new_state = _run_reducer(next_action, self._state, result, next_action.name) new_state = self._update_internal_state_value(new_state, next_action) + if self._journal_sink and self._state_persister is not None: + from burr.core.durable import ( + read_journal_from_state, + supports_durable_storage, + write_journal_into_state, + ) + + if not supports_durable_storage(self._state_persister): + merged = read_journal_from_state(self._state) + self._journal_sink + new_state = write_journal_into_state(new_state, merged) self._set_state(new_state) + except _Suspended as suspended: + suspended_signal = suspended + try: + self._handle_suspension(next_action, action_inputs, suspended) + except Exception as handler_exc: + exc = handler_exc + suspended_signal = None + logger.exception(_format_BASE_ERROR_MESSAGE(next_action, self._state, inputs)) + raise except Exception as e: exc = e logger.exception(_format_BASE_ERROR_MESSAGE(next_action, self._state, inputs)) raise e finally: if _run_hooks: - self._adapter_set.call_all_lifecycle_hooks_sync( - "post_run_step", - app_id=self._uid, - partition_key=self._partition_key, - action=next_action, - state=new_state, - result=result, - sequence_id=self.sequence_id, - exception=exc, - ) + if suspended_signal is not None: + self._adapter_set.call_all_lifecycle_hooks_sync( + "post_run_step", + app_id=self._uid, + partition_key=self._partition_key, + action=next_action, + state=self._state, + result=None, + sequence_id=self.sequence_id, + exception=None, + ) + else: + self._adapter_set.call_all_lifecycle_hooks_sync( + "post_run_step", + app_id=self._uid, + partition_key=self._partition_key, + action=next_action, + state=new_state, + result=result, + sequence_id=self.sequence_id, + exception=exc, + ) + if suspended_signal is not None: + return next_action, None, self._state return next_action, result, new_state + def _handle_suspension(self, action, action_inputs, suspended): + """Build and persist a SuspensionRecord, then set self._suspended.""" + from burr.core.durable import ( + SuspensionRecord, + supports_durable_storage, + write_journal_into_state, + write_suspension_into_state, + ) + + record = SuspensionRecord( + suspension_id=str(uuid.uuid4()), + partition_key=self._partition_key, + app_id=self._uid, + sequence_id=self.sequence_id, + position=action.name, + channel=suspended.channel, + schema_json=suspended.schema_json, + metadata=suspended.metadata, + inputs=action_inputs, + state=dict(self._state.get_all()), + created_at=system.now().isoformat(), + resolved=False, + ) + persister = self._state_persister + if persister is not None and supports_durable_storage(persister): + persister.save_suspension(record) + for entry in self._journal_sink: + persister.save_journal_entry(entry) + elif persister is not None: + # In-state fallback: embed the record + journal into State only. + # We deliberately do NOT call persister.save here. The post_run_step + # lifecycle hook fires for this suspended step and PersisterHook.save + # persists the embedded State once. Saving here too would write the + # same (partition_key, app_id, sequence_id, position) row twice and + # break persisters with a UNIQUE constraint (e.g. SQLitePersister). + state = write_suspension_into_state(self._state, record) + state = write_journal_into_state(state, self._journal_sink) + self._set_state(state) + self._adapter_set.call_all_lifecycle_hooks_sync( + "post_action_suspend", + app_id=self._uid, + partition_key=self._partition_key, + action=action, + sequence_id=self.sequence_id, + suspension=record, + ) + self._suspended = record + + async def _ahandle_suspension(self, action, action_inputs, suspended): + """Async sibling of _handle_suspension. Called from _astep when the action suspends.""" + from burr.core.durable import ( + SuspensionRecord, + supports_durable_storage, + write_journal_into_state, + write_suspension_into_state, + ) + + record = SuspensionRecord( + suspension_id=str(uuid.uuid4()), + partition_key=self._partition_key, + app_id=self._uid, + sequence_id=self.sequence_id, + position=action.name, + channel=suspended.channel, + schema_json=suspended.schema_json, + metadata=suspended.metadata, + inputs=action_inputs, + state=dict(self._state.get_all()), + created_at=system.now().isoformat(), + resolved=False, + ) + persister = self._state_persister + if persister is not None and supports_durable_storage(persister): + if persister.is_async(): + await persister.save_suspension(record) + for entry in self._journal_sink: + await persister.save_journal_entry(entry) + else: + persister.save_suspension(record) + for entry in self._journal_sink: + persister.save_journal_entry(entry) + elif persister is not None: + # In-state fallback: embed the record + journal into State only. + # We deliberately do NOT call persister.save here. The post_run_step + # lifecycle hook fires for this suspended step and PersisterHook.save + # persists the embedded State once. Saving here too would write the + # same (partition_key, app_id, sequence_id, position) row twice and + # break persisters with a UNIQUE constraint (e.g. SQLitePersister). + state = write_suspension_into_state(self._state, record) + state = write_journal_into_state(state, self._journal_sink) + self._set_state(state) + await self._adapter_set.call_all_lifecycle_hooks_sync_and_async( + "post_action_suspend", + app_id=self._uid, + partition_key=self._partition_key, + action=action, + sequence_id=self.sequence_id, + suspension=record, + ) + self._suspended = record + + @property + def suspended(self): + """The SuspensionRecord if the last run() suspended, else None.""" + return self._suspended + def reset_to_entrypoint(self) -> None: """Resets the state machine to the entrypoint action -- you probably want to consider having a loop in your graph, but this will do the trick if you need it!""" @@ -1025,6 +1295,8 @@ async def astep( :return: Tuple[Function, dict, State] -- the action that was just ran, the result of running it, and the new state """ + self._journal_sink = [] + self._suspended = None self._increment_sequence_id() out = await self._astep(inputs=inputs, _run_hooks=True) return out @@ -1050,6 +1322,7 @@ async def _astep(self, inputs: Optional[Dict[str, Any]], _run_hooks: bool = True exc = None result = None new_state = self._state + suspended_signal = None try: if not next_action.is_async(): # we can just delegate to the synchronous version, it will block the event loop, @@ -1076,24 +1349,56 @@ async def _astep(self, inputs: Optional[Dict[str, Any]], _run_hooks: bool = True ) new_state = _run_reducer(next_action, self._state, result, next_action.name) new_state = self._update_internal_state_value(new_state, next_action) + if self._journal_sink and self._state_persister is not None: + from burr.core.durable import ( + read_journal_from_state, + supports_durable_storage, + write_journal_into_state, + ) + + if not supports_durable_storage(self._state_persister): + merged = read_journal_from_state(self._state) + self._journal_sink + new_state = write_journal_into_state(new_state, merged) self._set_state(new_state) + except _Suspended as suspended: + suspended_signal = suspended + try: + await self._ahandle_suspension(next_action, action_inputs, suspended) + except Exception as handler_exc: + exc = handler_exc + suspended_signal = None + logger.exception(_format_BASE_ERROR_MESSAGE(next_action, self._state, inputs)) + raise except Exception as e: exc = e logger.exception(_format_BASE_ERROR_MESSAGE(next_action, self._state, inputs)) raise e finally: if _run_hooks: - await self._adapter_set.call_all_lifecycle_hooks_sync_and_async( - "post_run_step", - action=next_action, - state=new_state, - result=result, - sequence_id=self.sequence_id, - exception=exc, - app_id=self._uid, - partition_key=self._partition_key, - ) - + if suspended_signal is not None: + await self._adapter_set.call_all_lifecycle_hooks_sync_and_async( + "post_run_step", + action=next_action, + state=self._state, + result=None, + sequence_id=self.sequence_id, + exception=None, + app_id=self._uid, + partition_key=self._partition_key, + ) + else: + await self._adapter_set.call_all_lifecycle_hooks_sync_and_async( + "post_run_step", + action=next_action, + state=self._state, + result=result, + sequence_id=self.sequence_id, + exception=exc, + app_id=self._uid, + partition_key=self._partition_key, + ) + if suspended_signal is not None: + return next_action, None, self._state return next_action, result, new_state def _parse_action_list(self, action_list: list[str]) -> Tuple[List[str], List[str]]: @@ -1237,7 +1542,7 @@ def iterate( result = None prior_action: Optional[Action] = None - while self.has_next_action(): + while self.has_next_action() and self._suspended is None: # self.step will only return None if there is no next action, so we can rely on tuple unpacking prior_action, result, state = self.step(inputs=inputs) yield prior_action, result, state @@ -1270,7 +1575,7 @@ async def aiterate( halt_before, halt_after, inputs ) self._validate_halt_conditions(halt_before, halt_after) - while self.has_next_action(): + while self.has_next_action() and self._suspended is None: # self.step will only return None if there is no next action, so we can rely on tuple unpacking prior_action, result, state = await self.astep(inputs=inputs) yield prior_action, result, state @@ -1521,6 +1826,7 @@ def callback( if not next_action.streaming: # In this case we are halting at a non-streaming condition # This is allowed as we want to maintain a more consistent API + self._journal_sink = [] action, result, state = self._step(inputs=inputs, _run_hooks=False) self._adapter_set.call_all_lifecycle_hooks_sync( "post_run_step", @@ -1773,6 +2079,7 @@ async def callback( if not next_action.streaming: # In this case we are halting at a non-streaming condition # This is allowed as we want to maintain a more consistent API + self._journal_sink = [] action, result, state = await self._astep(inputs=inputs, _run_hooks=False) await self._adapter_set.call_all_lifecycle_hooks_sync_and_async( "post_run_step", diff --git a/burr/core/durable.py b/burr/core/durable.py new file mode 100644 index 000000000..adedbb823 --- /dev/null +++ b/burr/core/durable.py @@ -0,0 +1,154 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Primitives for durable execution: the suspend control-flow signal, the +determinism error, and the records persisted to support resume.""" + +import dataclasses +from typing import Any, Dict, Optional + + +class _Suspended(BaseException): + """Internal control-flow signal raised by ``ApplicationContext.suspend()``. + + Subclasses ``BaseException`` (not ``Exception``) on purpose: a user + ``try/except Exception`` wrapping an LLM/IO call inside an action must NOT + swallow it. The run loop catches it explicitly. It is never an error and is + never logged or persisted as a failure. + """ + + def __init__( + self, + channel: str, + schema_json: Optional[dict] = None, + metadata: Optional[dict] = None, + ): + self.channel = channel + self.schema_json = schema_json + self.metadata = metadata + super().__init__(f"Execution suspended on channel '{channel}'") + + +class DeterminismError(Exception): + """Raised on resume when ``ctx.durable()`` calls do not replay in the same + order, or with the same keys, as the recorded journal. This converts a + silent footgun (lost re-execution or stale cache) into a loud failure.""" + + +@dataclasses.dataclass +class SuspensionRecord: + """Everything needed to resume a suspended run. Persisted when an action + calls ``suspend()``. ``metadata``, ``inputs`` and ``state`` are serialized + through ``burr.core.serde``.""" + + suspension_id: str + partition_key: Optional[str] + app_id: str + sequence_id: int + position: str # name of the suspended action + channel: str + schema_json: Optional[dict] + metadata: Optional[dict] + inputs: Dict[str, Any] + state: Dict[str, Any] # entry state of the suspended action + created_at: str + resolved: bool = False + + +@dataclasses.dataclass +class JournalEntry: + """One memoized ``ctx.durable()`` sub-step. ``result`` is serialized through + ``burr.core.serde``. + + .. note:: + The in-state codec (``write_journal_into_state`` / ``read_journal_from_state``) + serializes entries via ``dataclasses.asdict``, which recurses into nested + dataclasses and converts them to plain dicts. Those dicts are NOT + reconstructed back into their original types on read. Callers must keep + ``result`` (and any nested fields) to plain JSON-friendly types, or accept + that nested dataclasses come back as plain dicts after a round-trip through + the in-state codec. + """ + + partition_key: Optional[str] + app_id: str + sequence_id: int + step_key: str + call_index: int + result: Any + + +def supports_durable_storage(persister) -> bool: + """True if the persister overrides the durable-storage methods. When False, + the Application stores suspensions and journal entries inside the State. + + .. note:: + All-or-nothing override contract: a persister is considered to support + durable storage only when it overrides ALL five durable-storage methods + (``save_suspension``, ``load_suspension``, ``save_journal_entry``, + ``load_journal``, ``mark_suspension_resolved``). Detection is based solely + on ``save_suspension``; partial overrides are not detected and will raise + ``NotImplementedError`` at call time. + """ + from burr.core.persistence import ( + AsyncBaseStatePersister, + BaseStatePersister, + ) + + base = AsyncBaseStatePersister if persister.is_async() else BaseStatePersister + return type(persister).save_suspension is not base.save_suspension + + +# --- In-state fallback codec -------------------------------------------------- +# When the persister has no dedicated storage, suspensions and journal entries +# ride inside a reserved State namespace, which the existing PersisterHook saves. + +DURABLE_STATE_KEY = "__burr_durable__" + + +def write_suspension_into_state(state, record: "SuspensionRecord"): + """Return a new State with the suspension record embedded.""" + bucket = dict(state.get(DURABLE_STATE_KEY, {}) or {}) + bucket["suspension"] = dataclasses.asdict(record) + return state.update(**{DURABLE_STATE_KEY: bucket}) + + +def read_suspension_from_state(state, channel: str) -> "Optional[SuspensionRecord]": + bucket = state.get(DURABLE_STATE_KEY, {}) or {} + raw = bucket.get("suspension") + if raw is None or raw.get("channel") != channel: + return None + return SuspensionRecord(**raw) + + +def write_journal_into_state(state, entries: "list"): + """Return a new State with the journal entries embedded. + + .. warning:: + Serializes via ``dataclasses.asdict``, which recursively converts nested + dataclasses to plain dicts. They are NOT reconstructed to their original + types when read back via ``read_journal_from_state``. Keep ``JournalEntry.result`` + and any nested fields as plain JSON-friendly types to avoid type loss. + """ + bucket = dict(state.get(DURABLE_STATE_KEY, {}) or {}) + bucket["journal"] = [dataclasses.asdict(e) for e in entries] + return state.update(**{DURABLE_STATE_KEY: bucket}) + + +def read_journal_from_state(state) -> "list": + bucket = state.get(DURABLE_STATE_KEY, {}) or {} + return [JournalEntry(**raw) for raw in bucket.get("journal", [])] diff --git a/burr/core/persistence.py b/burr/core/persistence.py index c32bf8e96..4dfd311d9 100644 --- a/burr/core/persistence.py +++ b/burr/core/persistence.py @@ -24,7 +24,8 @@ from typing import Any, Dict, Literal, Optional, TypedDict from burr.common.types import BaseCopyable -from burr.core import Action +from burr.core import Action, serde +from burr.core.durable import JournalEntry, SuspensionRecord from burr.core.state import State, logger from burr.lifecycle import PostRunStepHook, PostRunStepHookAsync @@ -130,7 +131,7 @@ def save( sequence_id: int, position: str, state: State, - status: Literal["completed", "failed"], + status: Literal["completed", "failed", "suspended"], **kwargs, ): """Saves the state for a given app_id, sequence_id, position @@ -145,8 +146,8 @@ def save( :param sequence_id: Sequence ID of the last executed step :param position: The action name that was implemented :param state: The current state of the application - :param status: The status of this state, either "completed" or "failed". If "failed" the state is what it was - before the action was applied. + :param status: The status of this state: "completed", "failed", or "suspended". If "failed" the state is what + it was before the action was applied. """ pass @@ -175,7 +176,7 @@ async def save( sequence_id: int, position: str, state: State, - status: Literal["completed", "failed"], + status: Literal["completed", "failed", "suspended"], **kwargs, ): """Saves the state for a given app_id, sequence_id, position @@ -190,8 +191,8 @@ async def save( :param sequence_id: Sequence ID of the last executed step :param position: The action name that was implemented :param state: The current state of the application - :param status: The status of this state, either "completed" or "failed". If "failed" the state is what it was - before the action was applied. + :param status: The status of this state: "completed", "failed", or "suspended". If "failed" the state is what + it was before the action was applied. """ pass @@ -204,12 +205,85 @@ class BaseStatePersister(BaseStateLoader, BaseStateSaver, metaclass=ABCMeta): Extend this class if you want an easy way to implement custom state storage. """ + # --- Durable execution: optional. Default raises; the Application falls + # --- back to storing this data inside the State (see burr.core.durable). + + def save_suspension(self, record: SuspensionRecord) -> None: + """Persist a suspension record. Override for dedicated storage.""" + raise NotImplementedError + + def load_suspension( + self, partition_key: Optional[str], app_id: str, channel: str + ) -> Optional[SuspensionRecord]: + """Load the suspension record for (partition_key, app_id, channel). + + Returns the record whether or not it is resolved; callers check + ``record.resolved`` for resume-once idempotency. Returns ``None`` + when no record exists for this combination. + """ + raise NotImplementedError + + def save_journal_entry(self, entry: JournalEntry) -> None: + """Persist one memoized sub-step. Override for dedicated storage.""" + raise NotImplementedError + + def load_journal( + self, partition_key: Optional[str], app_id: str, sequence_id: int + ) -> list[JournalEntry]: + """Load journal entries for a suspended action, ordered by call_index.""" + raise NotImplementedError + + def mark_suspension_resolved(self, suspension_id: str) -> bool: + """Marks the suspension as resolved. + + Returns True if a previously-unresolved row was flipped, False otherwise + (already resolved, or unknown id). Callers use this for resume-once idempotency. + """ + raise NotImplementedError + class AsyncBaseStatePersister(AsyncBaseStateLoader, AsyncBaseStateSaver, metaclass=ABCMeta): """Utility interface for an asynchronous state reader/writer. This both persists and initializes state. Extend this class if you want an easy way to implement custom state storage. """ + # --- Durable execution: optional. Default raises; the Application falls + # --- back to storing this data inside the State (see burr.core.durable). + + async def save_suspension(self, record: SuspensionRecord) -> None: + """Persist a suspension record. Override for dedicated storage.""" + raise NotImplementedError + + async def load_suspension( + self, partition_key: Optional[str], app_id: str, channel: str + ) -> Optional[SuspensionRecord]: + """Load the suspension record for (partition_key, app_id, channel). + + Returns the record whether or not it is resolved; callers check + ``record.resolved`` for resume-once idempotency. Returns ``None`` + when no record exists for this combination. + """ + raise NotImplementedError + + async def save_journal_entry(self, entry: JournalEntry) -> None: + """Persist one memoized sub-step. Override for dedicated storage.""" + raise NotImplementedError + + async def load_journal( + self, partition_key: Optional[str], app_id: str, sequence_id: int + ) -> list[JournalEntry]: + """Load journal entries for a suspended action, ordered by call_index.""" + raise NotImplementedError + + async def mark_suspension_resolved(self, suspension_id: str) -> bool: + """Marks the suspension as resolved. + + Returns True if a previously-unresolved row was flipped, False otherwise + (already resolved, or unknown id). Callers use this for resume-once idempotency. + First-party SQL persisters do this with a conditional UPDATE. + """ + raise NotImplementedError + class PersisterHook(PostRunStepHook): """Wrapper class for bridging the persistence interface with lifecycle hooks. This is used internally.""" @@ -279,7 +353,7 @@ def save( sequence_id: int, position: str, state: State, - status: Literal["completed", "failed"], + status: Literal["completed", "failed", "suspended"], **kwargs, ): return @@ -303,7 +377,7 @@ async def save( sequence_id: int, position: str, state: State, - status: Literal["completed", "failed"], + status: Literal["completed", "failed", "suspended"], **kwargs, ): return @@ -429,6 +503,7 @@ def initialize(self): """Creates the table if it doesn't exist""" # Usage self.create_table_if_not_exists(self.table_name) + self.create_durable_tables_if_not_exist() self._initialized = True def is_initialized(self) -> bool: @@ -538,7 +613,7 @@ def save( sequence_id: int, position: str, state: State, - status: Literal["completed", "failed"], + status: Literal["completed", "failed", "suspended"], **kwargs, ): """ @@ -554,8 +629,8 @@ def save( :param sequence_id: The state corresponding to a specific point in time. :param position: The position in the sequence of states. :param state: The state to be saved, an instance of the State class. - :param status: The status of this state, either "completed" or "failed". If "failed" the state is what it was - before the action was applied. + :param status: The status of this state: "completed", "failed", or "suspended". If "failed" the state is what + it was before the action was applied. :return: None """ logger.debug( @@ -586,6 +661,162 @@ def save( raise self.connection.commit() + def create_durable_tables_if_not_exist(self): + """Creates the durable-execution tables (suspensions + journal) if they don't exist.""" + cursor = self.connection.cursor() + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS burr_suspensions ( + suspension_id TEXT PRIMARY KEY, + partition_key TEXT, + app_id TEXT NOT NULL, + sequence_id INTEGER NOT NULL, + position TEXT NOT NULL, + channel TEXT NOT NULL, + schema_json TEXT, + metadata_json TEXT, + inputs_json TEXT, + state_json TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + resolved INTEGER NOT NULL DEFAULT 0 + )""" + ) + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS burr_journal ( + partition_key TEXT, + app_id TEXT NOT NULL, + sequence_id INTEGER NOT NULL, + step_key TEXT NOT NULL, + call_index INTEGER NOT NULL, + result_json TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (partition_key, app_id, sequence_id, step_key) + )""" + ) + self.connection.commit() + + def save_suspension(self, record: SuspensionRecord) -> None: + """Persist a suspension record into the burr_suspensions table.""" + cursor = self.connection.cursor() + cursor.execute( + """INSERT OR REPLACE INTO burr_suspensions + (suspension_id, partition_key, app_id, sequence_id, position, + channel, schema_json, metadata_json, inputs_json, state_json, + created_at, resolved) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + record.suspension_id, + record.partition_key, + record.app_id, + record.sequence_id, + record.position, + record.channel, + json.dumps(record.schema_json), + json.dumps(serde.serialize(record.metadata, **self.serde_kwargs)), + json.dumps(serde.serialize(record.inputs, **self.serde_kwargs)), + json.dumps(serde.serialize(record.state, **self.serde_kwargs)), + record.created_at, + 1 if record.resolved else 0, + ), + ) + self.connection.commit() + + def load_suspension( + self, partition_key: Optional[str], app_id: str, channel: str + ) -> Optional[SuspensionRecord]: + """Load the most recent suspension record for (partition_key, app_id, channel). + + Returns the record whether or not it is resolved; callers check + ``record.resolved`` for resume-once idempotency. Returns ``None`` + when no record exists for this combination. + """ + cursor = self.connection.cursor() + cursor.execute( + """SELECT suspension_id, partition_key, app_id, sequence_id, position, + channel, schema_json, metadata_json, inputs_json, state_json, + created_at, resolved + FROM burr_suspensions + WHERE partition_key IS ? AND app_id = ? AND channel = ? + ORDER BY created_at DESC LIMIT 1""", + (partition_key, app_id, channel), + ) + row = cursor.fetchone() + if row is None: + return None + return SuspensionRecord( + suspension_id=row[0], + partition_key=row[1], + app_id=row[2], + sequence_id=row[3], + position=row[4], + channel=row[5], + schema_json=json.loads(row[6]) if row[6] is not None else None, + metadata=serde.deserialize(json.loads(row[7]), **self.serde_kwargs) if row[7] is not None else None, + inputs=serde.deserialize(json.loads(row[8]), **self.serde_kwargs), + state=serde.deserialize(json.loads(row[9]), **self.serde_kwargs), + created_at=row[10], + resolved=bool(row[11]), + ) + + def mark_suspension_resolved(self, suspension_id: str) -> bool: + """Mark a suspension consumed. Conditional UPDATE for resume-once idempotency. + + :return: True if a row was updated (first call), False if already resolved (no-op). + """ + cursor = self.connection.cursor() + cursor.execute( + "UPDATE burr_suspensions SET resolved = 1 " + "WHERE suspension_id = ? AND resolved = 0", + (suspension_id,), + ) + self.connection.commit() + return cursor.rowcount > 0 + + def save_journal_entry(self, entry: JournalEntry) -> None: + """Persist one memoized sub-step into the burr_journal table.""" + cursor = self.connection.cursor() + cursor.execute( + """INSERT OR REPLACE INTO burr_journal + (partition_key, app_id, sequence_id, step_key, call_index, + result_json) + VALUES (?, ?, ?, ?, ?, ?)""", + ( + entry.partition_key, + entry.app_id, + entry.sequence_id, + entry.step_key, + entry.call_index, + json.dumps(serde.serialize(entry.result, **self.serde_kwargs)), + ), + ) + self.connection.commit() + + def load_journal( + self, partition_key: Optional[str], app_id: str, sequence_id: int + ) -> list[JournalEntry]: + """Load journal entries for a suspended action, ordered by call_index.""" + cursor = self.connection.cursor() + cursor.execute( + """SELECT partition_key, app_id, sequence_id, step_key, call_index, + result_json + FROM burr_journal + WHERE partition_key IS ? AND app_id = ? AND sequence_id = ? + ORDER BY call_index ASC""", + (partition_key, app_id, sequence_id), + ) + return [ + JournalEntry( + partition_key=row[0], + app_id=row[1], + sequence_id=row[2], + step_key=row[3], + call_index=row[4], + result=serde.deserialize(json.loads(row[5]), **self.serde_kwargs), + ) + for row in cursor.fetchall() + ] + def cleanup(self): """Closes the connection to the database.""" self.connection.close() @@ -614,6 +845,8 @@ class InMemoryPersister(BaseStatePersister): def __init__(self): self._storage = defaultdict(lambda: defaultdict(list)) + self._suspensions = {} + self._journal = {} def load( self, partition_key: str, app_id: Optional[str], sequence_id: Optional[int] = None, **kwargs @@ -644,7 +877,7 @@ def save( sequence_id: int, position: str, state: State, - status: Literal["completed", "failed"], + status: Literal["completed", "failed", "suspended"], **kwargs, ): # Create a PersistedStateData entry @@ -661,12 +894,39 @@ def save( # Store the state self._storage[partition_key][app_id].append(persisted_state) + def save_suspension(self, record: SuspensionRecord) -> None: + self._suspensions[(record.partition_key, record.app_id, record.channel)] = record + + def load_suspension(self, partition_key: Optional[str], app_id: str, channel: str) -> Optional[SuspensionRecord]: + return self._suspensions.get((partition_key, app_id, channel)) + + def mark_suspension_resolved(self, suspension_id: str) -> bool: + for key, record in self._suspensions.items(): + if record.suspension_id == suspension_id: + if record.resolved: + return False + record.resolved = True + return True + return False + + def save_journal_entry(self, entry: JournalEntry) -> None: + bucket = self._journal.setdefault( + (entry.partition_key, entry.app_id, entry.sequence_id), [] + ) + bucket.append(entry) + + def load_journal(self, partition_key: Optional[str], app_id: str, sequence_id: int) -> list[JournalEntry]: + bucket = self._journal.get((partition_key, app_id, sequence_id), []) + return sorted(bucket, key=lambda e: e.call_index) + class AsyncInMemoryPersister(AsyncBaseStatePersister): - """Sync in-memory persister for testing purposes. This is not recommended for production use.""" + """Async in-memory persister for testing purposes. This is not recommended for production use.""" def __init__(self): self._storage = defaultdict(lambda: defaultdict(list)) + self._suspensions = {} + self._journal = {} async def load( self, partition_key: str, app_id: Optional[str], sequence_id: Optional[int] = None, **kwargs @@ -697,7 +957,7 @@ async def save( sequence_id: int, position: str, state: State, - status: Literal["completed", "failed"], + status: Literal["completed", "failed", "suspended"], **kwargs, ): # Create a PersistedStateData entry @@ -714,6 +974,35 @@ async def save( # Store the state self._storage[partition_key][app_id].append(persisted_state) + async def save_suspension(self, record: SuspensionRecord) -> None: + self._suspensions[(record.partition_key, record.app_id, record.channel)] = record + + async def load_suspension( + self, partition_key: Optional[str], app_id: str, channel: str + ) -> Optional[SuspensionRecord]: + return self._suspensions.get((partition_key, app_id, channel)) + + async def mark_suspension_resolved(self, suspension_id: str) -> bool: + for key, record in self._suspensions.items(): + if record.suspension_id == suspension_id: + if record.resolved: + return False + record.resolved = True + return True + return False + + async def save_journal_entry(self, entry: JournalEntry) -> None: + bucket = self._journal.setdefault( + (entry.partition_key, entry.app_id, entry.sequence_id), [] + ) + bucket.append(entry) + + async def load_journal( + self, partition_key: Optional[str], app_id: str, sequence_id: int + ) -> list[JournalEntry]: + bucket = self._journal.get((partition_key, app_id, sequence_id), []) + return sorted(bucket, key=lambda e: e.call_index) + SQLLitePersister = SQLitePersister diff --git a/burr/core/resume.py b/burr/core/resume.py new file mode 100644 index 000000000..78bb41fd2 --- /dev/null +++ b/burr/core/resume.py @@ -0,0 +1,265 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Top-level resume helpers for durable execution.""" + +import warnings +from typing import Any, List, Optional + +from burr.core.durable import ( + read_journal_from_state, + read_suspension_from_state, + supports_durable_storage, +) +from burr.core.state import State +from burr.lifecycle.base import LifecycleAdapter + + +def _load_suspension(persister, partition_key, app_id, channel): + if supports_durable_storage(persister): + return persister.load_suspension(partition_key, app_id, channel) + loaded = persister.load(partition_key, app_id) + if loaded is None: + return None + return read_suspension_from_state(loaded["state"], channel) + + +def _load_journal(persister, partition_key, app_id, sequence_id, state): + if supports_durable_storage(persister): + return persister.load_journal(partition_key, app_id, sequence_id) + return read_journal_from_state(state) + + +async def _aload_suspension(persister, partition_key, app_id, channel): + if supports_durable_storage(persister): + if persister.is_async(): + return await persister.load_suspension(partition_key, app_id, channel) + return persister.load_suspension(partition_key, app_id, channel) + if persister.is_async(): + loaded = await persister.load(partition_key, app_id) + else: + loaded = persister.load(partition_key, app_id) + if loaded is None: + return None + return read_suspension_from_state(loaded["state"], channel) + + +async def _aload_journal(persister, partition_key, app_id, sequence_id, state): + if supports_durable_storage(persister): + if persister.is_async(): + return await persister.load_journal(partition_key, app_id, sequence_id) + return persister.load_journal(partition_key, app_id, sequence_id) + return read_journal_from_state(state) + + +async def _arebuild(persister, graph, app_id, partition_key, record, hooks: Optional[List[LifecycleAdapter]] = None): + from burr.core.application import ApplicationBuilder + + builder = ( + ApplicationBuilder() + .with_graph(graph) + .with_identifiers(app_id=app_id, partition_key=partition_key) + .with_entrypoint(record.position) + .with_state(State(record.state)) + .with_state_persister(persister) + ) + for hook in (hooks or []): + builder = builder.with_hooks(hook) + if persister.is_async(): + return await builder.abuild() + return builder.build() + + +def _validate_payload(schema_json, payload): + """Validate *payload* against *schema_json* using jsonschema. + + Schema validation requires the optional ``jsonschema`` package. When it is + absent, validation is skipped and a warning is emitted. + """ + if schema_json is None: + return + try: + import jsonschema + except ImportError: + warnings.warn( + "jsonschema is not installed; skipping resume payload schema validation. " + "Install jsonschema to enable validation.", + stacklevel=3, + ) + return + jsonschema.validate(instance=payload, schema=schema_json) + + +def _rebuild(persister, graph, app_id, partition_key, record, hooks: Optional[List[LifecycleAdapter]] = None): + from burr.core.application import ApplicationBuilder + + builder = ( + ApplicationBuilder() + .with_graph(graph) + .with_identifiers(app_id=app_id, partition_key=partition_key) + .with_entrypoint(record.position) + .with_state(State(record.state)) + .with_state_persister(persister) + ) + for hook in (hooks or []): + builder = builder.with_hooks(hook) + return builder.build() + + +def resume( + *, + persister, + graph, + app_id: str, + partition_key: Optional[str], + channel: str, + payload: Any, + hooks: Optional[List[LifecycleAdapter]] = None, +): + """Resume a suspended run by delivering ``payload`` to ``channel``. + + Reloads the suspension, rebuilds the Application from ``graph`` + ``persister``, + re-runs the suspended action from the top (durable sub-steps replay from the + journal, ``suspend(channel)`` returns ``payload``), and runs to the next halt, + suspend, or completion. + + Idempotency: resuming an already-resolved suspension is an idempotent no-op for + persisters with durable storage (those implementing ``save_suspension`` / + ``load_suspension`` / ``mark_suspension_resolved``). For persisters without + durable storage, the suspension lives in ``state['__burr_durable__']`` and is + overwritten as the resumed run progresses; a second ``resume()`` call after the + first completes raises ``ValueError``. + + :param hooks: Optional lifecycle adapters to register on the rebuilt application + before firing ``pre_action_resume``. Pass any adapter implementing + :class:`~burr.lifecycle.PreActionResumeHook` (or its async variant) here. + """ + record = _load_suspension(persister, partition_key, app_id, channel) + if record is None: + raise ValueError( + f"No suspension found for app_id={app_id!r} " + f"(never suspended, or already resolved on a persister without durable storage)." + ) + if record.resolved: + loaded = persister.load(partition_key, app_id) + return loaded["state"] if loaded else State(record.state) + + _validate_payload(record.schema_json, payload) + + app = _rebuild(persister, graph, app_id, partition_key, record, hooks=hooks) + app._resume_signals = {channel: payload} + app._loaded_journal = _load_journal( + persister, partition_key, app_id, record.sequence_id, record.state + ) + app._suspended = None + + app._adapter_set.call_all_lifecycle_hooks_sync( + "pre_action_resume", + app_id=app_id, + partition_key=partition_key, + action=graph.get_action(record.position), + sequence_id=record.sequence_id, + channel=channel, + ) + + app.run(halt_after=[]) # run to completion or the next suspend + + # In-state fallback does not durably mark suspensions resolved; a second + # resume will raise (see docstring). + if supports_durable_storage(persister): + persister.mark_suspension_resolved(record.suspension_id) + + return app.state + + +async def aresume( + *, + persister, + graph, + app_id: str, + partition_key: Optional[str], + channel: str, + payload: Any, + hooks: Optional[List[LifecycleAdapter]] = None, +): + """Resume a suspended run by delivering ``payload`` to ``channel``. + + Runs the async action loop (``await app.arun(...)``), so async actions are + fully supported. Both sync and async persisters are supported (durable or + non-durable). + + :param persister: A state persister (sync or async, durable or non-durable). + :param graph: The :class:`~burr.core.graph.Graph` to rebuild the application. + :param app_id: Identifier of the application run to resume. + :param partition_key: Partition key used when the run was persisted. + :param channel: Name of the suspension channel to deliver ``payload`` to. + :param payload: Value returned by ``suspend(channel)`` inside the action. + :param hooks: Optional lifecycle adapters to register on the rebuilt application + before firing ``pre_action_resume``. Pass any adapter implementing + :class:`~burr.lifecycle.PreActionResumeHook` (or its async variant) here. + + **Idempotency:** + + * *Durable persisters* -- resuming an already-resolved suspension is an + idempotent no-op: the call returns the latest persisted state unchanged. + * *Non-durable persisters* -- the suspension lives in + ``state['__burr_durable__']`` and is overwritten as the resumed run + progresses. A second ``aresume()`` call after the first completes raises + :exc:`ValueError`. + """ + record = await _aload_suspension(persister, partition_key, app_id, channel) + if record is None: + raise ValueError( + f"No suspension found for app_id={app_id!r} " + f"(never suspended, or already resolved on a persister without durable storage)." + ) + if record.resolved: + if persister.is_async(): + loaded = await persister.load(partition_key, app_id) + else: + loaded = persister.load(partition_key, app_id) + return loaded["state"] if loaded else State(record.state) + + _validate_payload(record.schema_json, payload) + + app = await _arebuild(persister, graph, app_id, partition_key, record, hooks=hooks) + app._resume_signals = {channel: payload} + app._loaded_journal = await _aload_journal( + persister, partition_key, app_id, record.sequence_id, record.state + ) + app._suspended = None + + await app._adapter_set.call_all_lifecycle_hooks_sync_and_async( + "pre_action_resume", + app_id=app_id, + partition_key=partition_key, + action=graph.get_action(record.position), + sequence_id=record.sequence_id, + channel=channel, + ) + + await app.arun(halt_after=[]) + + # In-state fallback does not durably mark suspensions resolved; a second + # resume will raise (see docstring). + if supports_durable_storage(persister): + if persister.is_async(): + await persister.mark_suspension_resolved(record.suspension_id) + else: + persister.mark_suspension_resolved(record.suspension_id) + + return app.state diff --git a/burr/integrations/persisters/b_aiosqlite.py b/burr/integrations/persisters/b_aiosqlite.py index 9ce3c4a5d..6cb0b0ac9 100644 --- a/burr/integrations/persisters/b_aiosqlite.py +++ b/burr/integrations/persisters/b_aiosqlite.py @@ -22,7 +22,8 @@ import aiosqlite from burr.common.types import BaseCopyable -from burr.core import State +from burr.core import State, serde +from burr.core.durable import JournalEntry, SuspensionRecord from burr.core.persistence import AsyncBaseStatePersister, PersistedStateData logger = logging.getLogger() @@ -147,10 +148,45 @@ async def create_table_if_not_exists(self, table_name: str): ) await self.connection.commit() + async def create_durable_tables_if_not_exist(self): + """Creates the durable-execution tables (suspensions + journal) if they don't exist.""" + await self.connection.execute( + """ + CREATE TABLE IF NOT EXISTS burr_suspensions ( + suspension_id TEXT PRIMARY KEY, + partition_key TEXT, + app_id TEXT NOT NULL, + sequence_id INTEGER NOT NULL, + position TEXT NOT NULL, + channel TEXT NOT NULL, + schema_json TEXT, + metadata_json TEXT, + inputs_json TEXT, + state_json TEXT NOT NULL, + created_at TEXT, + resolved INTEGER NOT NULL DEFAULT 0 + )""" + ) + await self.connection.execute( + """ + CREATE TABLE IF NOT EXISTS burr_journal ( + partition_key TEXT, + app_id TEXT NOT NULL, + sequence_id INTEGER NOT NULL, + step_key TEXT NOT NULL, + call_index INTEGER NOT NULL, + result_json TEXT NOT NULL, + created_at TEXT, + PRIMARY KEY (partition_key, app_id, sequence_id, step_key) + )""" + ) + await self.connection.commit() + async def initialize(self): """Asynchronously creates the table if it doesn't exist""" # Usage await self.create_table_if_not_exists(self.table_name) + await self.create_durable_tables_if_not_exist() self._initialized = True async def is_initialized(self) -> bool: @@ -294,6 +330,125 @@ async def save( ) await self.connection.commit() + async def save_suspension(self, record: SuspensionRecord) -> None: + """Persist a suspension record into the burr_suspensions table.""" + await self.connection.execute( + """INSERT OR REPLACE INTO burr_suspensions + (suspension_id, partition_key, app_id, sequence_id, position, + channel, schema_json, metadata_json, inputs_json, state_json, + created_at, resolved) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + record.suspension_id, + record.partition_key, + record.app_id, + record.sequence_id, + record.position, + record.channel, + json.dumps(record.schema_json), + json.dumps(serde.serialize(record.metadata, **self.serde_kwargs)), + json.dumps(serde.serialize(record.inputs, **self.serde_kwargs)), + json.dumps(serde.serialize(record.state, **self.serde_kwargs)), + record.created_at, + 1 if record.resolved else 0, + ), + ) + await self.connection.commit() + + async def load_suspension( + self, partition_key: Optional[str], app_id: str, channel: str + ) -> Optional[SuspensionRecord]: + """Load the most recent suspension record for (partition_key, app_id, channel). + + Returns the record whether or not it is resolved; callers check + ``record.resolved`` for resume-once idempotency. Returns ``None`` + when no record exists for this combination. + """ + cursor = await self.connection.execute( + """SELECT suspension_id, partition_key, app_id, sequence_id, position, + channel, schema_json, metadata_json, inputs_json, state_json, + created_at, resolved + FROM burr_suspensions + WHERE partition_key IS ? AND app_id = ? AND channel = ? + ORDER BY created_at DESC LIMIT 1""", + (partition_key, app_id, channel), + ) + row = await cursor.fetchone() + if row is None: + return None + return SuspensionRecord( + suspension_id=row[0], + partition_key=row[1], + app_id=row[2], + sequence_id=row[3], + position=row[4], + channel=row[5], + schema_json=json.loads(row[6]) if row[6] is not None else None, + metadata=serde.deserialize(json.loads(row[7]), **self.serde_kwargs) + if row[7] is not None + else None, + inputs=serde.deserialize(json.loads(row[8]), **self.serde_kwargs), + state=serde.deserialize(json.loads(row[9]), **self.serde_kwargs), + created_at=row[10], + resolved=bool(row[11]), + ) + + async def mark_suspension_resolved(self, suspension_id: str) -> bool: + """Mark a suspension consumed. Conditional UPDATE for resume-once idempotency. + + :return: True if a row was updated (first call), False if already resolved (no-op). + """ + cursor = await self.connection.execute( + "UPDATE burr_suspensions SET resolved = 1 " + "WHERE suspension_id = ? AND resolved = 0", + (suspension_id,), + ) + await self.connection.commit() + return cursor.rowcount > 0 + + async def save_journal_entry(self, entry: JournalEntry) -> None: + """Persist one memoized sub-step into the burr_journal table.""" + await self.connection.execute( + """INSERT OR REPLACE INTO burr_journal + (partition_key, app_id, sequence_id, step_key, call_index, + result_json) + VALUES (?, ?, ?, ?, ?, ?)""", + ( + entry.partition_key, + entry.app_id, + entry.sequence_id, + entry.step_key, + entry.call_index, + json.dumps(serde.serialize(entry.result, **self.serde_kwargs)), + ), + ) + await self.connection.commit() + + async def load_journal( + self, partition_key: Optional[str], app_id: str, sequence_id: int + ) -> list[JournalEntry]: + """Load journal entries for a suspended action, ordered by call_index.""" + cursor = await self.connection.execute( + """SELECT partition_key, app_id, sequence_id, step_key, call_index, + result_json + FROM burr_journal + WHERE partition_key IS ? AND app_id = ? AND sequence_id = ? + ORDER BY call_index ASC""", + (partition_key, app_id, sequence_id), + ) + rows = await cursor.fetchall() + return [ + JournalEntry( + partition_key=row[0], + app_id=row[1], + sequence_id=row[2], + step_key=row[3], + call_index=row[4], + result=serde.deserialize(json.loads(row[5]), **self.serde_kwargs), + ) + for row in rows + ] + async def cleanup(self): """Closes the connection to the database.""" await self.connection.close() diff --git a/burr/integrations/persisters/b_asyncpg.py b/burr/integrations/persisters/b_asyncpg.py index 66f91f206..1c30a8468 100644 --- a/burr/integrations/persisters/b_asyncpg.py +++ b/burr/integrations/persisters/b_asyncpg.py @@ -17,10 +17,12 @@ import json import logging +from datetime import datetime from typing import Any, ClassVar, Literal, Optional from burr.common.types import BaseCopyable -from burr.core import persistence, state +from burr.core import persistence, serde, state +from burr.core.durable import JournalEntry, SuspensionRecord from burr.integrations import base try: @@ -244,9 +246,48 @@ async def create_table(self, table_name: str): finally: await self._release_connection(conn, acquired) + async def create_durable_tables_if_not_exist(self): + """Creates the durable-execution tables (suspensions + journal) if they don't exist.""" + conn, acquired = await self._get_connection() + try: + async with conn.transaction(): + await conn.execute( + """ + CREATE TABLE IF NOT EXISTS burr_suspensions ( + suspension_id TEXT PRIMARY KEY, + partition_key TEXT, + app_id TEXT NOT NULL, + sequence_id INTEGER NOT NULL, + position TEXT NOT NULL, + channel TEXT NOT NULL, + schema_json JSONB, + metadata_json JSONB, + inputs_json JSONB, + state_json JSONB NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + resolved BOOLEAN NOT NULL DEFAULT false + )""" + ) + await conn.execute( + """ + CREATE TABLE IF NOT EXISTS burr_journal ( + partition_key TEXT, + app_id TEXT NOT NULL, + sequence_id INTEGER NOT NULL, + step_key TEXT NOT NULL, + call_index INTEGER NOT NULL, + result_json JSONB NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (partition_key, app_id, sequence_id, step_key) + )""" + ) + finally: + await self._release_connection(conn, acquired) + async def initialize(self): """Creates the table""" await self.create_table(self.table_name) + await self.create_durable_tables_if_not_exist() self._initialized = True async def is_initialized(self) -> bool: @@ -399,6 +440,169 @@ async def save( finally: await self._release_connection(conn, acquired) + async def save_suspension(self, record: SuspensionRecord) -> None: + """Persist a suspension record into the burr_suspensions table.""" + conn, acquired = await self._get_connection() + try: + # asyncpg requires datetime objects for TIMESTAMP columns; + # SuspensionRecord.created_at is typed as str so we parse it when needed. + created_at = record.created_at + if isinstance(created_at, str): + created_at = datetime.fromisoformat(created_at) + await conn.execute( + """INSERT INTO burr_suspensions + (suspension_id, partition_key, app_id, sequence_id, position, + channel, schema_json, metadata_json, inputs_json, state_json, + created_at, resolved) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + ON CONFLICT (suspension_id) DO UPDATE SET + partition_key = EXCLUDED.partition_key, + app_id = EXCLUDED.app_id, + sequence_id = EXCLUDED.sequence_id, + position = EXCLUDED.position, + channel = EXCLUDED.channel, + schema_json = EXCLUDED.schema_json, + metadata_json = EXCLUDED.metadata_json, + inputs_json = EXCLUDED.inputs_json, + state_json = EXCLUDED.state_json, + created_at = EXCLUDED.created_at, + resolved = EXCLUDED.resolved""", + record.suspension_id, + record.partition_key, + record.app_id, + record.sequence_id, + record.position, + record.channel, + json.dumps(record.schema_json), + json.dumps(serde.serialize(record.metadata, **self.serde_kwargs)), + json.dumps(serde.serialize(record.inputs, **self.serde_kwargs)), + json.dumps(serde.serialize(record.state, **self.serde_kwargs)), + created_at, + record.resolved, + ) + finally: + await self._release_connection(conn, acquired) + + async def load_suspension( + self, partition_key: Optional[str], app_id: str, channel: str + ) -> Optional[SuspensionRecord]: + """Load the most recent suspension record for (partition_key, app_id, channel). + + Returns the record whether or not it is resolved; callers check + ``record.resolved`` for resume-once idempotency. Returns ``None`` + when no record exists for this combination. + """ + conn, acquired = await self._get_connection() + try: + row = await conn.fetchrow( + """SELECT suspension_id, partition_key, app_id, sequence_id, position, + channel, schema_json, metadata_json, inputs_json, state_json, + created_at, resolved + FROM burr_suspensions + WHERE partition_key IS NOT DISTINCT FROM $1 AND app_id = $2 AND channel = $3 + ORDER BY created_at DESC LIMIT 1""", + partition_key, + app_id, + channel, + ) + if row is None: + return None + # asyncpg returns JSONB columns as strings — must json.loads() explicitly + # (unlike psycopg2 which auto-parses JSONB to Python objects). + schema = json.loads(row[6]) if row[6] is not None else None + metadata_raw = json.loads(row[7]) if row[7] is not None else None + inputs_raw = json.loads(row[8]) # asyncpg: JSONB is a string, must deserialize + state_raw = json.loads(row[9]) # asyncpg: JSONB is a string, must deserialize + return SuspensionRecord( + suspension_id=row[0], + partition_key=row[1], + app_id=row[2], + sequence_id=row[3], + position=row[4], + channel=row[5], + schema_json=schema, + metadata=serde.deserialize(metadata_raw, **self.serde_kwargs) + if metadata_raw is not None + else None, + inputs=serde.deserialize(inputs_raw, **self.serde_kwargs), + state=serde.deserialize(state_raw, **self.serde_kwargs), + # asyncpg returns TIMESTAMP as datetime; SuspensionRecord.created_at is str. + created_at=row[10].isoformat() if isinstance(row[10], datetime) else row[10], + resolved=bool(row[11]), + ) + finally: + await self._release_connection(conn, acquired) + + async def mark_suspension_resolved(self, suspension_id: str) -> bool: + """Mark a suspension consumed. Conditional UPDATE for resume-once idempotency. + + :return: True if a row was updated (first call), False if already resolved (no-op). + """ + conn, acquired = await self._get_connection() + try: + status = await conn.execute( + "UPDATE burr_suspensions SET resolved = true " + "WHERE suspension_id = $1 AND resolved = false", + suspension_id, + ) + # asyncpg returns status string like 'UPDATE 1' or 'UPDATE 0' + return int(status.split()[-1]) > 0 + finally: + await self._release_connection(conn, acquired) + + async def save_journal_entry(self, entry: JournalEntry) -> None: + """Persist one memoized sub-step into the burr_journal table.""" + conn, acquired = await self._get_connection() + try: + await conn.execute( + """INSERT INTO burr_journal + (partition_key, app_id, sequence_id, step_key, call_index, result_json) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (partition_key, app_id, sequence_id, step_key) DO UPDATE SET + call_index = EXCLUDED.call_index, + result_json = EXCLUDED.result_json""", + entry.partition_key, + entry.app_id, + entry.sequence_id, + entry.step_key, + entry.call_index, + json.dumps(serde.serialize(entry.result, **self.serde_kwargs)), + ) + finally: + await self._release_connection(conn, acquired) + + async def load_journal( + self, partition_key: Optional[str], app_id: str, sequence_id: int + ) -> list[JournalEntry]: + """Load journal entries for a suspended action, ordered by call_index.""" + conn, acquired = await self._get_connection() + try: + rows = await conn.fetch( + """SELECT partition_key, app_id, sequence_id, step_key, call_index, + result_json + FROM burr_journal + WHERE partition_key IS NOT DISTINCT FROM $1 AND app_id = $2 AND sequence_id = $3 + ORDER BY call_index ASC""", + partition_key, + app_id, + sequence_id, + ) + # asyncpg returns JSONB columns as strings — must json.loads() explicitly + # (unlike psycopg2 which auto-parses JSONB to Python objects). + return [ + JournalEntry( + partition_key=row[0], + app_id=row[1], + sequence_id=row[2], + step_key=row[3], + call_index=row[4], + result=serde.deserialize(json.loads(row[5]), **self.serde_kwargs), + ) + for row in rows + ] + finally: + await self._release_connection(conn, acquired) + async def cleanup(self): """Closes the connection to the database.""" if self.connection is not None: diff --git a/burr/integrations/persisters/b_psycopg2.py b/burr/integrations/persisters/b_psycopg2.py index 26425f805..b676f1cc2 100644 --- a/burr/integrations/persisters/b_psycopg2.py +++ b/burr/integrations/persisters/b_psycopg2.py @@ -26,7 +26,8 @@ import logging from typing import Literal, Optional -from burr.core import persistence, state +from burr.core import persistence, serde, state +from burr.core.durable import JournalEntry, SuspensionRecord logger = logging.getLogger(__name__) @@ -130,9 +131,45 @@ def create_table(self, table_name: str): ) self.connection.commit() + def create_durable_tables_if_not_exist(self): + """Creates the durable-execution tables (suspensions + journal) if they don't exist.""" + cursor = self.connection.cursor() + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS burr_suspensions ( + suspension_id TEXT PRIMARY KEY, + partition_key TEXT, + app_id TEXT NOT NULL, + sequence_id INTEGER NOT NULL, + position TEXT NOT NULL, + channel TEXT NOT NULL, + schema_json JSONB, + metadata_json JSONB, + inputs_json JSONB, + state_json JSONB NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + resolved BOOLEAN NOT NULL DEFAULT false + )""" + ) + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS burr_journal ( + partition_key TEXT, + app_id TEXT NOT NULL, + sequence_id INTEGER NOT NULL, + step_key TEXT NOT NULL, + call_index INTEGER NOT NULL, + result_json JSONB NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (partition_key, app_id, sequence_id, step_key) + )""" + ) + self.connection.commit() + def initialize(self): """Creates the table""" self.create_table(self.table_name) + self.create_durable_tables_if_not_exist() self._initialized = True def is_initialized(self) -> bool: @@ -259,6 +296,143 @@ def save( ) self.connection.commit() + def save_suspension(self, record: SuspensionRecord) -> None: + """Persist a suspension record into the burr_suspensions table.""" + cursor = self.connection.cursor() + cursor.execute( + """INSERT INTO burr_suspensions + (suspension_id, partition_key, app_id, sequence_id, position, + channel, schema_json, metadata_json, inputs_json, state_json, + created_at, resolved) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (suspension_id) DO UPDATE SET + partition_key = EXCLUDED.partition_key, + app_id = EXCLUDED.app_id, + sequence_id = EXCLUDED.sequence_id, + position = EXCLUDED.position, + channel = EXCLUDED.channel, + schema_json = EXCLUDED.schema_json, + metadata_json = EXCLUDED.metadata_json, + inputs_json = EXCLUDED.inputs_json, + state_json = EXCLUDED.state_json, + created_at = EXCLUDED.created_at, + resolved = EXCLUDED.resolved""", + ( + record.suspension_id, + record.partition_key, + record.app_id, + record.sequence_id, + record.position, + record.channel, + json.dumps(record.schema_json), + json.dumps(serde.serialize(record.metadata, **self.serde_kwargs)), + json.dumps(serde.serialize(record.inputs, **self.serde_kwargs)), + json.dumps(serde.serialize(record.state, **self.serde_kwargs)), + record.created_at, + record.resolved, + ), + ) + self.connection.commit() + + def load_suspension( + self, partition_key: Optional[str], app_id: str, channel: str + ) -> Optional[SuspensionRecord]: + """Load the most recent suspension record for (partition_key, app_id, channel). + + Returns the record whether or not it is resolved; callers check + ``record.resolved`` for resume-once idempotency. Returns ``None`` + when no record exists for this combination. + """ + cursor = self.connection.cursor() + cursor.execute( + """SELECT suspension_id, partition_key, app_id, sequence_id, position, + channel, schema_json, metadata_json, inputs_json, state_json, + created_at, resolved + FROM burr_suspensions + WHERE partition_key IS NOT DISTINCT FROM %s AND app_id = %s AND channel = %s + ORDER BY created_at DESC LIMIT 1""", + (partition_key, app_id, channel), + ) + row = cursor.fetchone() + if row is None: + return None + # psycopg2 auto-parses JSONB columns to Python objects — no json.loads needed. + return SuspensionRecord( + suspension_id=row[0], + partition_key=row[1], + app_id=row[2], + sequence_id=row[3], + position=row[4], + channel=row[5], + schema_json=row[6] if row[6] is not None else None, + metadata=serde.deserialize(row[7], **self.serde_kwargs) if row[7] is not None else None, + inputs=serde.deserialize(row[8], **self.serde_kwargs), + state=serde.deserialize(row[9], **self.serde_kwargs), + created_at=row[10], + resolved=bool(row[11]), + ) + + def mark_suspension_resolved(self, suspension_id: str) -> bool: + """Mark a suspension consumed. Conditional UPDATE for resume-once idempotency. + + :return: True if a row was updated (first call), False if already resolved (no-op). + """ + cursor = self.connection.cursor() + cursor.execute( + "UPDATE burr_suspensions SET resolved = true " + "WHERE suspension_id = %s AND resolved = false", + (suspension_id,), + ) + self.connection.commit() + return cursor.rowcount > 0 + + def save_journal_entry(self, entry: JournalEntry) -> None: + """Persist one memoized sub-step into the burr_journal table.""" + cursor = self.connection.cursor() + cursor.execute( + """INSERT INTO burr_journal + (partition_key, app_id, sequence_id, step_key, call_index, result_json) + VALUES (%s, %s, %s, %s, %s, %s) + ON CONFLICT (partition_key, app_id, sequence_id, step_key) DO UPDATE SET + call_index = EXCLUDED.call_index, + result_json = EXCLUDED.result_json""", + ( + entry.partition_key, + entry.app_id, + entry.sequence_id, + entry.step_key, + entry.call_index, + json.dumps(serde.serialize(entry.result, **self.serde_kwargs)), + ), + ) + self.connection.commit() + + def load_journal( + self, partition_key: Optional[str], app_id: str, sequence_id: int + ) -> list[JournalEntry]: + """Load journal entries for a suspended action, ordered by call_index.""" + cursor = self.connection.cursor() + cursor.execute( + """SELECT partition_key, app_id, sequence_id, step_key, call_index, + result_json + FROM burr_journal + WHERE partition_key IS NOT DISTINCT FROM %s AND app_id = %s AND sequence_id = %s + ORDER BY call_index ASC""", + (partition_key, app_id, sequence_id), + ) + # psycopg2 auto-parses JSONB columns to Python objects — no json.loads needed. + return [ + JournalEntry( + partition_key=row[0], + app_id=row[1], + sequence_id=row[2], + step_key=row[3], + call_index=row[4], + result=serde.deserialize(row[5], **self.serde_kwargs), + ) + for row in cursor.fetchall() + ] + def cleanup(self): """Closes the connection to the database.""" self.connection.close() diff --git a/burr/integrations/persisters/b_pymongo.py b/burr/integrations/persisters/b_pymongo.py index ba02438e9..9a8dc6163 100644 --- a/burr/integrations/persisters/b_pymongo.py +++ b/burr/integrations/persisters/b_pymongo.py @@ -20,9 +20,10 @@ from datetime import datetime, timezone from typing import Literal, Optional -from pymongo import MongoClient +from pymongo import ASCENDING, DESCENDING, MongoClient -from burr.core import persistence, state +from burr.core import persistence, serde, state +from burr.core.durable import JournalEntry, SuspensionRecord logger = logging.getLogger(__name__) @@ -96,6 +97,139 @@ def __init__( self.collection = self.db[collection_name] self.serde_kwargs = serde_kwargs or {} + def initialize(self): + """Creates indexes for the state collection and the two durable-execution + collections (``burr_suspensions`` and ``burr_journal``). + + Index creation in MongoDB is idempotent — calling this multiple times + is safe. + """ + self.db["burr_suspensions"].create_index( + [ + ("partition_key", ASCENDING), + ("app_id", ASCENDING), + ("channel", ASCENDING), + ("created_at", DESCENDING), + ] + ) + self.db["burr_journal"].create_index( + [ + ("partition_key", ASCENDING), + ("app_id", ASCENDING), + ("sequence_id", ASCENDING), + ("step_key", ASCENDING), + ], + unique=True, + ) + + def save_suspension(self, record: SuspensionRecord) -> None: + """Persist a suspension record into the ``burr_suspensions`` collection.""" + doc = { + "_id": record.suspension_id, + "suspension_id": record.suspension_id, + "partition_key": record.partition_key, + "app_id": record.app_id, + "sequence_id": record.sequence_id, + "position": record.position, + "channel": record.channel, + "schema_json": record.schema_json, + "metadata": serde.serialize(record.metadata, **self.serde_kwargs) + if record.metadata is not None + else None, + "inputs": serde.serialize(record.inputs, **self.serde_kwargs), + "state": serde.serialize(record.state, **self.serde_kwargs), + "created_at": record.created_at, + "resolved": record.resolved, + } + self.db["burr_suspensions"].update_one( + {"_id": record.suspension_id}, + {"$set": doc}, + upsert=True, + ) + + def load_suspension( + self, partition_key: Optional[str], app_id: str, channel: str + ) -> Optional[SuspensionRecord]: + """Load the most recent suspension record for (partition_key, app_id, channel). + + Returns the record whether or not it is resolved; callers check + ``record.resolved`` for resume-once idempotency. Returns ``None`` + when no record exists for this combination. + """ + doc = self.db["burr_suspensions"].find_one( + {"partition_key": partition_key, "app_id": app_id, "channel": channel}, + sort=[("created_at", DESCENDING)], + ) + if doc is None: + return None + return SuspensionRecord( + suspension_id=doc["suspension_id"], + partition_key=doc["partition_key"], + app_id=doc["app_id"], + sequence_id=doc["sequence_id"], + position=doc["position"], + channel=doc["channel"], + schema_json=doc.get("schema_json"), + metadata=serde.deserialize(doc["metadata"], **self.serde_kwargs) + if doc.get("metadata") is not None + else None, + inputs=serde.deserialize(doc["inputs"], **self.serde_kwargs), + state=serde.deserialize(doc["state"], **self.serde_kwargs), + created_at=doc["created_at"], + resolved=bool(doc["resolved"]), + ) + + def mark_suspension_resolved(self, suspension_id: str) -> bool: + """Mark a suspension consumed. Conditional update for resume-once idempotency. + + :return: ``True`` if a document was updated (first call), ``False`` if + already resolved or not found (no-op). + """ + result = self.db["burr_suspensions"].update_one( + {"_id": suspension_id, "resolved": False}, + {"$set": {"resolved": True}}, + ) + return result.modified_count == 1 + + def save_journal_entry(self, entry: JournalEntry) -> None: + """Persist one memoized sub-step into the ``burr_journal`` collection.""" + filter_doc = { + "partition_key": entry.partition_key, + "app_id": entry.app_id, + "sequence_id": entry.sequence_id, + "step_key": entry.step_key, + } + update_doc = { + "$set": { + "partition_key": entry.partition_key, + "app_id": entry.app_id, + "sequence_id": entry.sequence_id, + "step_key": entry.step_key, + "call_index": entry.call_index, + "result": serde.serialize(entry.result, **self.serde_kwargs), + } + } + self.db["burr_journal"].update_one(filter_doc, update_doc, upsert=True) + + def load_journal( + self, partition_key: Optional[str], app_id: str, sequence_id: int + ) -> list[JournalEntry]: + """Load journal entries for a suspended action, ordered by call_index.""" + cursor = self.db["burr_journal"].find( + {"partition_key": partition_key, "app_id": app_id, "sequence_id": sequence_id} + ).sort("call_index", ASCENDING) + return [ + JournalEntry( + partition_key=doc["partition_key"], + app_id=doc["app_id"], + sequence_id=doc["sequence_id"], + step_key=doc["step_key"], + call_index=doc["call_index"], + result=serde.deserialize(doc["result"], **self.serde_kwargs), + ) + for doc in cursor + ] + def __enter__(self): return self diff --git a/burr/integrations/persisters/b_redis.py b/burr/integrations/persisters/b_redis.py index 5f091dc43..517e8c227 100644 --- a/burr/integrations/persisters/b_redis.py +++ b/burr/integrations/persisters/b_redis.py @@ -29,7 +29,8 @@ from datetime import datetime, timezone from typing import Literal, Optional -from burr.core import persistence, state +from burr.core import persistence, serde, state +from burr.core.durable import JournalEntry, SuspensionRecord logger = logging.getLogger(__name__) @@ -191,6 +192,171 @@ def save( namespaced_partition_key = add_namespace_to_partition_key(partition_key, self.namespace) self.connection.zadd(namespaced_partition_key, {app_id: sequence_id}) + # ------------------------------------------------------------------ + # Durable-execution helpers + # ------------------------------------------------------------------ + + def _partition_key_safe(self, partition_key: Optional[str]) -> str: + """Return a Redis-key-safe representation of partition_key.""" + return "__none__" if partition_key is None else partition_key + + def _suspension_hash_key(self, partition_key: Optional[str], app_id: str, channel: str) -> str: + pk = self._partition_key_safe(partition_key) + return f"burr:suspension:{pk}:{app_id}:{channel}" + + def _journal_list_key( + self, partition_key: Optional[str], app_id: str, sequence_id: int + ) -> str: + pk = self._partition_key_safe(partition_key) + return f"burr:journal:{pk}:{app_id}:{sequence_id}" + + # ------------------------------------------------------------------ + # Durable-execution methods + # ------------------------------------------------------------------ + + def save_suspension(self, record: SuspensionRecord) -> None: + """Persist a SuspensionRecord to a Redis HASH. + + Also writes a secondary index key so ``mark_suspension_resolved`` + can locate the hash by ``suspension_id`` alone. + + The hash ``resolved`` field stores a literal string and is updated + by ``mark_suspension_resolved``; callers must use ``load_suspension`` + to get the authoritative ``resolved`` state (backed by the SETNX key). + """ + hash_key = self._suspension_hash_key( + record.partition_key, record.app_id, record.channel + ) + self.connection.hset( + hash_key, + mapping={ + "suspension_id": record.suspension_id, + "partition_key": json.dumps(record.partition_key), + "app_id": record.app_id, + "sequence_id": str(record.sequence_id), + "position": record.position, + "channel": record.channel, + "schema_json": json.dumps(record.schema_json), + "metadata_json": json.dumps( + serde.serialize(record.metadata, **self.serde_kwargs) + ), + "inputs_json": json.dumps( + serde.serialize(record.inputs, **self.serde_kwargs) + ), + "state_json": json.dumps( + serde.serialize(record.state, **self.serde_kwargs) + ), + "created_at": record.created_at, + "resolved": "true" if record.resolved else "false", + }, + ) + # Secondary index: suspension_id -> hash key, for mark_suspension_resolved + self.connection.set(f"burr:suspension_id_idx:{record.suspension_id}", hash_key) + + def load_suspension( + self, partition_key: Optional[str], app_id: str, channel: str + ) -> Optional[SuspensionRecord]: + """Load the suspension record for (partition_key, app_id, channel). + + Returns the record whether or not it is resolved; callers check + ``record.resolved`` for resume-once idempotency. Returns ``None`` + when no record exists. + + The ``resolved`` flag is determined by the existence of the SETNX + key ``burr:resolved:{suspension_id}`` rather than the hash field. + """ + hash_key = self._suspension_hash_key(partition_key, app_id, channel) + data = self.connection.hgetall(hash_key) + if not data: + return None + suspension_id = data[b"suspension_id"].decode() + resolved = bool(self.connection.exists(f"burr:resolved:{suspension_id}")) + return SuspensionRecord( + suspension_id=suspension_id, + partition_key=json.loads(data[b"partition_key"].decode()), + app_id=data[b"app_id"].decode(), + sequence_id=int(data[b"sequence_id"].decode()), + position=data[b"position"].decode(), + channel=data[b"channel"].decode(), + schema_json=json.loads(data[b"schema_json"].decode()), + metadata=serde.deserialize( + json.loads(data[b"metadata_json"].decode()), **self.serde_kwargs + ), + inputs=serde.deserialize( + json.loads(data[b"inputs_json"].decode()), **self.serde_kwargs + ), + state=serde.deserialize( + json.loads(data[b"state_json"].decode()), **self.serde_kwargs + ), + created_at=data[b"created_at"].decode(), + resolved=resolved, + ) + + def mark_suspension_resolved(self, suspension_id: str) -> bool: + """Mark a suspension consumed. Atomic SETNX for resume-once idempotency. + + :return: True if this call performed the first flip, False if already + resolved or the suspension_id is unknown. + """ + if self.connection.setnx(f"burr:resolved:{suspension_id}", 1): + # Update the hash field so load_suspension reflects the resolved state + # without requiring an EXISTS check for callers who read the hash directly. + hash_key_bytes = self.connection.get(f"burr:suspension_id_idx:{suspension_id}") + if hash_key_bytes is not None: + self.connection.hset(hash_key_bytes.decode(), "resolved", "true") + return True + return False + + def save_journal_entry(self, entry: JournalEntry) -> None: + """Persist one memoized sub-step to a Redis LIST. + + Upserts by step_key: scans for an existing entry with the same + step_key and replaces it via LSET if found; otherwise appends with + RPUSH. Journals are short so the linear scan is acceptable. + """ + list_key = self._journal_list_key(entry.partition_key, entry.app_id, entry.sequence_id) + serialized = json.dumps( + { + "partition_key": json.dumps(entry.partition_key), + "app_id": entry.app_id, + "sequence_id": entry.sequence_id, + "step_key": entry.step_key, + "call_index": entry.call_index, + "result_json": json.dumps(serde.serialize(entry.result, **self.serde_kwargs)), + } + ) + existing = self.connection.lrange(list_key, 0, -1) + for idx, raw in enumerate(existing): + item = json.loads(raw.decode()) + if item.get("step_key") == entry.step_key: + self.connection.lset(list_key, idx, serialized) + return + self.connection.rpush(list_key, serialized) + + def load_journal( + self, partition_key: Optional[str], app_id: str, sequence_id: int + ) -> list[JournalEntry]: + """Load journal entries for a suspended action, sorted by call_index.""" + list_key = self._journal_list_key(partition_key, app_id, sequence_id) + raw_entries = self.connection.lrange(list_key, 0, -1) + entries = [] + for raw in raw_entries: + item = json.loads(raw.decode()) + entries.append( + JournalEntry( + partition_key=json.loads(item["partition_key"]), + app_id=item["app_id"], + sequence_id=item["sequence_id"], + step_key=item["step_key"], + call_index=item["call_index"], + result=serde.deserialize( + json.loads(item["result_json"]), **self.serde_kwargs + ), + ) + ) + entries.sort(key=lambda e: e.call_index) + return entries + def cleanup(self): """Closes the connection to the database.""" self.connection.close() @@ -372,6 +538,163 @@ async def save( namespaced_partition_key = add_namespace_to_partition_key(partition_key, self.namespace) await self.connection.zadd(namespaced_partition_key, {app_id: sequence_id}) + # ------------------------------------------------------------------ + # Durable-execution helpers (async) + # ------------------------------------------------------------------ + + def _partition_key_safe(self, partition_key: Optional[str]) -> str: + """Return a Redis-key-safe representation of partition_key.""" + return "__none__" if partition_key is None else partition_key + + def _suspension_hash_key(self, partition_key: Optional[str], app_id: str, channel: str) -> str: + pk = self._partition_key_safe(partition_key) + return f"burr:suspension:{pk}:{app_id}:{channel}" + + def _journal_list_key( + self, partition_key: Optional[str], app_id: str, sequence_id: int + ) -> str: + pk = self._partition_key_safe(partition_key) + return f"burr:journal:{pk}:{app_id}:{sequence_id}" + + # ------------------------------------------------------------------ + # Durable-execution methods (async) + # ------------------------------------------------------------------ + + async def save_suspension(self, record: SuspensionRecord) -> None: + """Persist a SuspensionRecord to a Redis HASH (async). + + Also writes a secondary index key so ``mark_suspension_resolved`` + can locate the hash by ``suspension_id`` alone. + """ + hash_key = self._suspension_hash_key( + record.partition_key, record.app_id, record.channel + ) + await self.connection.hset( + hash_key, + mapping={ + "suspension_id": record.suspension_id, + "partition_key": json.dumps(record.partition_key), + "app_id": record.app_id, + "sequence_id": str(record.sequence_id), + "position": record.position, + "channel": record.channel, + "schema_json": json.dumps(record.schema_json), + "metadata_json": json.dumps( + serde.serialize(record.metadata, **self.serde_kwargs) + ), + "inputs_json": json.dumps( + serde.serialize(record.inputs, **self.serde_kwargs) + ), + "state_json": json.dumps( + serde.serialize(record.state, **self.serde_kwargs) + ), + "created_at": record.created_at, + "resolved": "true" if record.resolved else "false", + }, + ) + await self.connection.set( + f"burr:suspension_id_idx:{record.suspension_id}", hash_key + ) + + async def load_suspension( + self, partition_key: Optional[str], app_id: str, channel: str + ) -> Optional[SuspensionRecord]: + """Load the suspension record for (partition_key, app_id, channel) (async). + + Returns the record whether or not it is resolved; callers check + ``record.resolved`` for resume-once idempotency. Returns ``None`` + when no record exists. + """ + hash_key = self._suspension_hash_key(partition_key, app_id, channel) + data = await self.connection.hgetall(hash_key) + if not data: + return None + suspension_id = data[b"suspension_id"].decode() + resolved = bool(await self.connection.exists(f"burr:resolved:{suspension_id}")) + return SuspensionRecord( + suspension_id=suspension_id, + partition_key=json.loads(data[b"partition_key"].decode()), + app_id=data[b"app_id"].decode(), + sequence_id=int(data[b"sequence_id"].decode()), + position=data[b"position"].decode(), + channel=data[b"channel"].decode(), + schema_json=json.loads(data[b"schema_json"].decode()), + metadata=serde.deserialize( + json.loads(data[b"metadata_json"].decode()), **self.serde_kwargs + ), + inputs=serde.deserialize( + json.loads(data[b"inputs_json"].decode()), **self.serde_kwargs + ), + state=serde.deserialize( + json.loads(data[b"state_json"].decode()), **self.serde_kwargs + ), + created_at=data[b"created_at"].decode(), + resolved=resolved, + ) + + async def mark_suspension_resolved(self, suspension_id: str) -> bool: + """Mark a suspension consumed. Atomic SETNX for resume-once idempotency (async). + + :return: True if this call performed the first flip, False if already resolved. + """ + if await self.connection.setnx(f"burr:resolved:{suspension_id}", 1): + hash_key_bytes = await self.connection.get( + f"burr:suspension_id_idx:{suspension_id}" + ) + if hash_key_bytes is not None: + await self.connection.hset(hash_key_bytes.decode(), "resolved", "true") + return True + return False + + async def save_journal_entry(self, entry: JournalEntry) -> None: + """Persist one memoized sub-step to a Redis LIST (async). + + Upserts by step_key: scans for an existing entry with the same + step_key and replaces it via LSET if found; otherwise appends. + """ + list_key = self._journal_list_key(entry.partition_key, entry.app_id, entry.sequence_id) + serialized = json.dumps( + { + "partition_key": json.dumps(entry.partition_key), + "app_id": entry.app_id, + "sequence_id": entry.sequence_id, + "step_key": entry.step_key, + "call_index": entry.call_index, + "result_json": json.dumps(serde.serialize(entry.result, **self.serde_kwargs)), + } + ) + existing = await self.connection.lrange(list_key, 0, -1) + for idx, raw in enumerate(existing): + item = json.loads(raw.decode()) + if item.get("step_key") == entry.step_key: + await self.connection.lset(list_key, idx, serialized) + return + await self.connection.rpush(list_key, serialized) + + async def load_journal( + self, partition_key: Optional[str], app_id: str, sequence_id: int + ) -> list[JournalEntry]: + """Load journal entries for a suspended action, sorted by call_index (async).""" + list_key = self._journal_list_key(partition_key, app_id, sequence_id) + raw_entries = await self.connection.lrange(list_key, 0, -1) + entries = [] + for raw in raw_entries: + item = json.loads(raw.decode()) + entries.append( + JournalEntry( + partition_key=json.loads(item["partition_key"]), + app_id=item["app_id"], + sequence_id=item["sequence_id"], + step_key=item["step_key"], + call_index=item["call_index"], + result=serde.deserialize( + json.loads(item["result_json"]), **self.serde_kwargs + ), + ) + ) + entries.sort(key=lambda e: e.call_index) + return entries + async def cleanup(self): """Closes the connection to the database.""" await self.connection.aclose() diff --git a/burr/lifecycle/__init__.py b/burr/lifecycle/__init__.py index 4ae24073a..bd8f29537 100644 --- a/burr/lifecycle/__init__.py +++ b/burr/lifecycle/__init__.py @@ -17,12 +17,16 @@ from burr.lifecycle.base import ( LifecycleAdapter, + PostActionSuspendHook, + PostActionSuspendHookAsync, PostApplicationCreateHook, PostApplicationExecuteCallHook, PostApplicationExecuteCallHookAsync, PostEndSpanHook, PostRunStepHook, PostRunStepHookAsync, + PreActionResumeHook, + PreActionResumeHookAsync, PreApplicationExecuteCallHook, PreApplicationExecuteCallHookAsync, PreRunStepHook, @@ -32,17 +36,21 @@ from burr.lifecycle.default import StateAndResultsFullLogger __all__ = [ - "PreRunStepHook", - "PreRunStepHookAsync", + "LifecycleAdapter", + "PostActionSuspendHook", + "PostActionSuspendHookAsync", + "PostApplicationCreateHook", + "PostApplicationExecuteCallHook", + "PostApplicationExecuteCallHookAsync", + "PostEndSpanHook", "PostRunStepHook", "PostRunStepHookAsync", + "PreActionResumeHook", + "PreActionResumeHookAsync", "PreApplicationExecuteCallHook", "PreApplicationExecuteCallHookAsync", - "PostApplicationExecuteCallHook", - "PostApplicationExecuteCallHookAsync", - "LifecycleAdapter", - "StateAndResultsFullLogger", - "PostApplicationCreateHook", - "PostEndSpanHook", + "PreRunStepHook", + "PreRunStepHookAsync", "PreStartSpanHook", + "StateAndResultsFullLogger", ] diff --git a/burr/lifecycle/base.py b/burr/lifecycle/base.py index 66d8bd7e6..e2a754698 100644 --- a/burr/lifecycle/base.py +++ b/burr/lifecycle/base.py @@ -25,6 +25,7 @@ if TYPE_CHECKING: # type-checking-only for a circular import from burr.core import State, Action, ApplicationGraph + from burr.core.durable import SuspensionRecord from burr.visibility import ActionSpan from burr.lifecycle.internal import lifecycle @@ -142,6 +143,96 @@ async def post_run_step( pass +@lifecycle.base_hook("post_action_suspend") +class PostActionSuspendHook(abc.ABC): + """Hook that runs after an action suspends the run and the suspension is persisted.""" + + @abc.abstractmethod + def post_action_suspend( + self, + *, + app_id: str, + partition_key: Optional[str], + action: "Action", + sequence_id: int, + suspension: "SuspensionRecord", + **future_kwargs: Any, + ): + """Run after a step suspends. + + :param app_id: Application ID + :param partition_key: Partition key of the run (may be None) + :param action: Action that suspended + :param sequence_id: Sequence ID of the suspended step + :param suspension: SuspensionRecord that was persisted + :param future_kwargs: Future keyword arguments + """ + pass + + +@lifecycle.base_hook("post_action_suspend") +class PostActionSuspendHookAsync(abc.ABC): + """Async hook that runs after an action suspends the run and the suspension is persisted.""" + + @abc.abstractmethod + async def post_action_suspend( + self, + *, + app_id: str, + partition_key: Optional[str], + action: "Action", + sequence_id: int, + suspension: "SuspensionRecord", + **future_kwargs: Any, + ): + pass + + +@lifecycle.base_hook("pre_action_resume") +class PreActionResumeHook(abc.ABC): + """Hook that runs just before a suspended action is re-executed on resume.""" + + @abc.abstractmethod + def pre_action_resume( + self, + *, + app_id: str, + partition_key: Optional[str], + action: "Action", + sequence_id: int, + channel: str, + **future_kwargs: Any, + ): + """Run just before a suspended action re-executes. + + :param app_id: Application ID + :param partition_key: Partition key (may be None) + :param action: Action being re-entered + :param sequence_id: Sequence ID of the suspended step + :param channel: Suspension channel being resumed + :param future_kwargs: Future keyword arguments + """ + pass + + +@lifecycle.base_hook("pre_action_resume") +class PreActionResumeHookAsync(abc.ABC): + """Async hook that runs just before a suspended action is re-executed on resume.""" + + @abc.abstractmethod + async def pre_action_resume( + self, + *, + app_id: str, + partition_key: Optional[str], + action: "Action", + sequence_id: int, + channel: str, + **future_kwargs: Any, + ): + pass + + @lifecycle.base_hook("post_application_create") class PostApplicationCreateHook(abc.ABC): """Synchronous hook that runs post instantiation of an ``Application`` @@ -500,6 +591,10 @@ async def post_end_stream( PreRunStepHookAsync, PostRunStepHook, PostRunStepHookAsync, + PostActionSuspendHook, + PostActionSuspendHookAsync, + PreActionResumeHook, + PreActionResumeHookAsync, PreApplicationExecuteCallHook, PreApplicationExecuteCallHookAsync, PostApplicationExecuteCallHook, diff --git a/burr/tracking/client.py b/burr/tracking/client.py index 44919aed5..f70f0bb27 100644 --- a/burr/tracking/client.py +++ b/burr/tracking/client.py @@ -49,19 +49,23 @@ def flock(*args, **kwargs): import re import traceback from abc import ABC -from typing import Any, Dict, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple try: from typing import Self except ImportError: Self = "Self" +if TYPE_CHECKING: + from burr.core.durable import SuspensionRecord + from burr import system from burr.common import types as burr_types from burr.core import Action, ApplicationGraph, State, serde from burr.core.persistence import BaseStateLoader, PersistedStateData from burr.integrations.base import require_plugin from burr.lifecycle import ( + PostActionSuspendHook, PostApplicationCreateHook, PostEndSpanHook, PostRunStepHook, @@ -81,6 +85,7 @@ def flock(*args, **kwargs): FirstItemStreamModel, InitializeStreamModel, PointerModel, + SuspendEntryModel, ) from burr.visibility import ActionSpan @@ -131,6 +136,7 @@ class SyncTrackingClient( PostApplicationCreateHook, PreRunStepHook, PostRunStepHook, + PostActionSuspendHook, PreStartSpanHook, PostEndSpanHook, DoLogAttributeHook, @@ -478,6 +484,26 @@ def post_run_step( ) self._append_write_line(post_run_entry) + def post_action_suspend( + self, + *, + app_id: str, + partition_key: Optional[str], + action: Action, + sequence_id: int, + suspension: "SuspensionRecord", + **future_kwargs: Any, + ): + suspend_entry = SuspendEntryModel( + suspend_time=datetime.datetime.now(), + action=action.name, + sequence_id=sequence_id, + channel=suspension.channel, + metadata=suspension.metadata if suspension.metadata is not None else {}, + suspension_id=suspension.suspension_id, + ) + self._append_write_line(suspend_entry) + def pre_start_span( self, *, diff --git a/burr/tracking/common/models.py b/burr/tracking/common/models.py index 5980bf9df..9cc0c0193 100644 --- a/burr/tracking/common/models.py +++ b/burr/tracking/common/models.py @@ -180,6 +180,18 @@ class EndEntryModel(IdentifyingModel): type: str = "end_entry" +class SuspendEntryModel(IdentifyingModel): + """Pydantic model that represents a step that suspended the run.""" + + suspend_time: datetime.datetime + action: str + sequence_id: int + channel: str + metadata: Dict[str, Any] + suspension_id: str + type: str = "suspend_entry" + + class BeginSpanModel(IdentifyingModel): """Pydantic model that represents an entry for the beginning of a span""" diff --git a/docs/concepts/durable-execution.rst b/docs/concepts/durable-execution.rst new file mode 100644 index 000000000..bca371718 --- /dev/null +++ b/docs/concepts/durable-execution.rst @@ -0,0 +1,321 @@ +.. + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +================= +Durable Execution +================= + +.. _durable-execution: + +.. note:: + + Durable execution lets an action pause mid-run (``__context.suspend()``), + survive process restarts, and resume exactly where it stopped when an + external event arrives. Sub-steps memoized with ``__context.durable()`` + are replayed from a journal on re-run so they never execute twice. + +What is durable execution? +-------------------------- + +Some workflows cannot finish in a single continuous execution: + +* **Human-in-the-loop.** An action drafts content and waits for a human to + approve it before the workflow proceeds. The process may be restarted many + times while waiting. +* **External-event wait.** An action triggers a webhook and must wait for the + callback, which could arrive seconds or days later. +* **Crash resilience.** Long-running inference or IO inside an action can be + checkpointed so that a restart does not redo expensive work already completed + before the crash. + +Burr addresses all three scenarios with two primitives on +:py:class:`ApplicationContext `: +``suspend()`` for pausing the run and ``durable()`` / ``adurable()`` for +memoizing sub-steps so they survive a suspend/resume boundary. + +Suspending a run: ``suspend()`` +-------------------------------- + +Call ``__context.suspend(channel)`` inside an action to pause the entire run +and wait for an external payload on the named *channel*. The call raises an +internal control-flow signal that the run loop catches; the run stops and a +:py:class:`SuspensionRecord ` is persisted. + +On resume the action **re-runs from the top**. When execution reaches the same +``suspend(channel)`` call again, the already-delivered payload is returned +instead of raising the signal. + +.. code-block:: python + + from burr.core import action, State + + @action(reads=["draft"], writes=["review_decision"]) + def review(state: State, __context) -> State: + # Optional: memoize an expensive step before suspending (see below). + summary = __context.durable( + "summarize", + lambda d: f"SUMMARY: {d[:50]}...", + state["draft"], + ) + # Suspend until a human posts a payload to "human_approval". + payload = __context.suspend( + "human_approval", + metadata={"summary": summary}, + ) + return state.update(review_decision=payload) + +Signature:: + + __context.suspend( + channel: str, + *, + schema: Optional[type] = None, + metadata: Optional[dict] = None, + ) -> Any + +* ``channel`` -- a stable name for this suspension point. +* ``schema`` -- optional Pydantic model or dataclass; when supplied, a dict + payload is coerced via ``schema(**payload)`` before being returned. +* ``metadata`` -- free-form dict stored with the suspension record; useful for + surfacing context to the UI or a webhook handler. +* **Return value** -- the ``payload`` delivered by :py:func:`resume() + ` when the run is resumed. + +.. warning:: + + ``_Suspended`` inherits from ``BaseException``, not ``Exception``. Do not + wrap ``__context.suspend()`` calls inside ``asyncio.shield()``, + ``try/except BaseException``, or any other guard that catches + ``BaseException`` -- doing so will swallow the signal and prevent the run + from suspending correctly. + + +Memoizing sub-steps: ``durable()`` and ``adurable()`` +------------------------------------------------------ + +Because the action re-runs from the top on resume, any side-effectful or +expensive work executed *before* the ``suspend()`` call will execute again. +Use ``__context.durable(key, fn, *args, **kwargs)`` to memoize a sub-step so +it runs exactly once regardless of how many times the action is re-executed. + +On the first run ``fn`` is called and its result is written to an append-only +*journal*. On re-run the same ``key`` is looked up in the journal and the +cached result is returned without calling ``fn`` again. + +.. code-block:: python + + @action(reads=["content"], writes=["result"]) + def process(state: State, __context) -> State: + # Expensive LLM call -- runs once, replayed on resume. + summary = __context.durable( + "llm_summarize", + call_llm, # fn + state["content"], # *args + ) + + # Async variant -- use inside async actions. + # embedding = await __context.adurable("embed", fetch_embedding, summary) + + payload = __context.suspend("approval", metadata={"summary": summary}) + return state.update(result=payload) + +Signatures:: + + __context.durable(key: str, fn: Callable, *args, **kwargs) -> Any + await __context.adurable(key: str, fn: Callable, *args, **kwargs) -> Any + +* ``key`` -- a stable, unique identifier for this sub-step (see determinism + contract below). +* ``fn`` -- a callable (or coroutine function for ``adurable``) whose result + should be memoized. +* ``*args, **kwargs`` -- forwarded to ``fn`` on first execution only. + +.. note:: + + Do not call ``suspend()`` from inside a ``durable()`` fn. The fn must be + a pure computation that returns a value. + + +The determinism contract +------------------------- + +The run loop identifies journal entries by position (call index) and key. For +replay to work correctly, every re-run of the same action invocation must call +``durable()`` / ``adurable()`` in **exactly the same order** with **exactly +the same keys**. Violations raise :py:exc:`DeterminismError +` immediately (fail-loud). + +Rules: + +1. **Stable key per call site.** Use a string literal, not a runtime value + that may change (e.g., a timestamp or UUID). + +2. **Stable call order.** The set and order of ``durable()`` calls must be + identical on every re-run of the same invocation. + +3. **No non-deterministic branching.** Do not gate a ``durable()`` call on + a condition that may differ between the first run and the re-run: + + .. code-block:: python + + # BAD -- the branch may not be taken on resume. + if random.random() > 0.5: + ctx.durable("step", fn) + + # GOOD -- key is unconditional. + result = ctx.durable("step", fn) + +4. **No ``suspend()`` inside ``durable()`` fn.** The fn must return a plain + value; calling ``suspend()`` inside it raises ``_Suspended`` before the + result is recorded and corrupts the journal. + +5. **Mismatch raises ``DeterminismError``.** If ``key`` or call order differs + between runs, a :py:exc:`DeterminismError ` + is raised, converting a silent footgun into a loud failure. + + +Resuming a suspended run: ``resume()`` +--------------------------------------- + +When the external event arrives (webhook, form POST, timer, etc.), call +:py:func:`resume() ` (sync) or +:py:func:`aresume() ` (async). Both helpers +reload the suspension from the persister, rebuild the Application, set the +resume payload, and run the graph to the next halt, suspend, or completion. + +.. code-block:: python + + from burr.core import resume + + # Synchronous resume (e.g., inside a Flask route handler): + final_state = resume( + persister=persister, + graph=graph, + app_id="my-app-run-001", + partition_key=None, + channel="human_approval", + payload={"approved": True}, + ) + + # Asynchronous resume (e.g., inside a FastAPI route handler): + from burr.core.resume import aresume + + final_state = await aresume( + persister=persister, + graph=graph, + app_id="my-app-run-001", + partition_key=None, + channel="human_approval", + payload={"approved": True}, + ) + +Both functions return the final :py:class:`State ` after +the resumed run completes or reaches the next suspension. + +**Idempotency.** For persisters with durable storage (see below), resuming an +already-resolved suspension is an idempotent no-op: the call returns the +latest persisted state unchanged. The ``resolved`` flag on the +:py:class:`SuspensionRecord ` prevents +double-execution. For custom persisters without durable storage, a second +``resume()`` call after the first completes raises ``ValueError``. + + +Persister support +----------------- + +First-party persisters ship with dedicated storage tables or collections for +suspension records and journal entries, providing strong resume-once semantics: + +.. list-table:: + :header-rows: 1 + :widths: 25 20 55 + + * - Backend + - Driver + - Class + * - SQLite (sync) + - sqlite3 + - :ref:`SQLitePersister ` + * - SQLite (async) + - aiosqlite + - :ref:`AsyncSQLitePersister ` + * - PostgreSQL (sync) + - psycopg2 + - :ref:`PostgreSQLPersister ` + * - PostgreSQL (async) + - asyncpg + - :ref:`AsyncPostgreSQLPersister ` + * - Redis (sync) + - redis + - :ref:`RedisBasePersister ` + * - Redis (async) + - redis.asyncio + - :ref:`AsyncRedisBasePersister ` + * - MongoDB + - pymongo + - :ref:`MongoDBBasePersister ` + +**Custom persisters** work transparently through an in-state fallback: the +:py:class:`SuspensionRecord ` and journal +entries are embedded inside the reserved ``__burr_durable__`` key in +:py:class:`State `, which the existing persister hook +saves automatically. This is correct and requires no code changes, but it does +not provide the idempotency guarantees of the dedicated durable-storage +methods. + +To opt in to durable storage for a custom persister, override all five methods +on :py:class:`BaseStatePersister `: +``save_suspension``, ``load_suspension``, ``save_journal_entry``, +``load_journal``, and ``mark_suspension_resolved``. + + +``_Suspended`` and ``BaseException`` +-------------------------------------- + +The internal control-flow signal ``_Suspended`` inherits from ``BaseException`` +so that a user ``try/except Exception`` block inside an action does not +accidentally catch it. The run loop catches it explicitly. It is never logged +as a failure. + +This means you must not wrap ``__context.suspend()`` calls in constructs that +catch ``BaseException``: + +.. code-block:: python + + # BAD -- asyncio.shield catches BaseException and re-raises CancelledError; + # _Suspended will be swallowed or mishandled. + result = await asyncio.shield(__context.suspend("ch")) + + # GOOD -- call suspend directly. + result = __context.suspend("ch") + + +Example +------- + +A complete human-in-the-loop draft-review-finalize workflow is available in +the ``examples/durable-execution/`` directory of the repository. It +demonstrates: + +* Using ``durable()`` to memoize an LLM summary before suspending. +* Calling ``suspend("human_approval")`` to pause the workflow. +* Using ``resume()`` to deliver the human's decision and finish the run. + +See :ref:`available persisters here ` for the full list of +backends that support the durable-storage APIs. diff --git a/docs/concepts/index.rst b/docs/concepts/index.rst index 3a58bd012..301698d79 100644 --- a/docs/concepts/index.rst +++ b/docs/concepts/index.rst @@ -36,6 +36,7 @@ Overview of the concepts -- read these to get a mental model for how Burr works. transitions tracking state-persistence + durable-execution serde streaming-actions state-typing diff --git a/examples/durable-execution/.gitignore b/examples/durable-execution/.gitignore new file mode 100644 index 000000000..141da33f0 --- /dev/null +++ b/examples/durable-execution/.gitignore @@ -0,0 +1,3 @@ +*.db +__pycache__/ +*.py[cod] diff --git a/examples/durable-execution/README.md b/examples/durable-execution/README.md new file mode 100644 index 000000000..2e6564d2e --- /dev/null +++ b/examples/durable-execution/README.md @@ -0,0 +1,51 @@ +# Durable Execution: Human-in-the-Loop + +This example demonstrates Burr's suspend/resume primitives through a three-step +draft-review-finalize workflow. The `review` action suspends the workflow and +waits for a human to approve or reject a draft. While suspended, the process can +die and restart without losing progress — the `durable()` call memoizes expensive +sub-steps (like an LLM summary) in a journal, so they are not re-executed on +resume. + +The same pattern covers three production use-cases: human-in-the-loop approval +gates, waiting for an external event (webhook, queue message, IoT sensor), and +crash resilience where a long-running action is interrupted mid-flight. + +## How to run + +```bash +pip install burr +python application.py +``` + +The script runs the workflow to the `review` suspension, prints the suspended +channel and metadata, then immediately simulates the human responding with +`{"approved": True}` via `resume()` and prints the final state. + +## The `human_approval` channel + +In production you would expose the `resume()` call through a webhook or UI +button. When the workflow suspends, store the `app_id` and `partition_key` +alongside the suspension metadata (returned by `app.suspended.metadata`). Your +webhook handler then calls: + +```python +from burr.core import resume +from burr.core.persistence import SQLitePersister + +persister = SQLitePersister.from_values("durable.db") +persister.initialize() +final_state = resume( + persister=persister, + graph=graph, # same Graph object (or rebuild it) + app_id=app_id, + partition_key=partition_key, + channel="human_approval", + payload={"approved": True}, +) +``` + +## Further reading + +- [Durable Execution concepts](../../docs/concepts/durable-execution.rst) (landing in Task 6.2) +- [Burr documentation](https://burr.dagworks.io) diff --git a/examples/durable-execution/__init__.py b/examples/durable-execution/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/durable-execution/application.py b/examples/durable-execution/application.py new file mode 100644 index 000000000..ce6360df5 --- /dev/null +++ b/examples/durable-execution/application.py @@ -0,0 +1,160 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Human-in-the-loop durable execution example. + +Demonstrates a draft -> review -> finalize workflow where the "review" step: +1. Uses ``durable()`` to memoize an expensive summary so it is computed once + even if the process is restarted before the human responds. +2. Calls ``suspend("human_approval")`` to pause the workflow and wait for an + external signal (webhook, UI button, etc.). +3. Continues to finalization once the human payload arrives via ``resume()``. +""" + +import pathlib +from typing import Optional, Tuple + +from burr.core import ApplicationBuilder, GraphBuilder, State, action, resume +from burr.core.application import Application +from burr.core.graph import Graph +from burr.core.persistence import SQLitePersister + +# --------------------------------------------------------------------------- +# Actions +# --------------------------------------------------------------------------- + + +@action(reads=[], writes=["draft"]) +def draft(state: State) -> State: + """Produce the initial draft content.""" + return state.update(draft="This is the initial draft content for review.") + + +@action(reads=["draft"], writes=["review_decision"]) +def review(state: State, __context) -> State: + """Memoize a summary, then suspend waiting for human approval. + + The ``durable()`` call ensures the summarizer runs exactly once across + the suspend/resume boundary — the result is replayed from the journal on + resume instead of being recomputed. + """ + summary = __context.durable( + "summarize", + lambda d: f"SUMMARY: {d[:20]}...", + state["draft"], + ) + # Suspend until a human delivers a payload over the "human_approval" channel. + # The payload is expected to be a dict with key "approved" (bool). + payload = __context.suspend( + "human_approval", + metadata={"summary": summary}, + ) + return state.update(review_decision=payload) + + +@action(reads=["review_decision"], writes=["approved"]) +def finalize(state: State) -> State: + """Record the human's decision.""" + decision = state["review_decision"] + return state.update(approved=decision.get("approved", False)) + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + + +def build_application( + app_id: str, + partition_key: Optional[str] = None, + db_path: Optional[str] = None, +) -> Tuple[Application, Graph, SQLitePersister]: + """Build and return the application, its graph, and the persister. + + :param app_id: Unique identifier for this run. + :param partition_key: Optional partition key (e.g. tenant / user id). + :param db_path: Path for the SQLite database. Defaults to a file next to + this script so re-runs pick up where they left off. + :return: Tuple of (application, graph, persister). + """ + if db_path is None: + db_path = str(pathlib.Path(__file__).parent / "durable.db") + + persister = SQLitePersister.from_values(db_path) + persister.initialize() + + graph = ( + GraphBuilder() + .with_actions(draft=draft, review=review, finalize=finalize) + .with_transitions(("draft", "review"), ("review", "finalize")) + .build() + ) + + app = ( + ApplicationBuilder() + .with_graph(graph) + .with_entrypoint("draft") + .with_state(State({})) + .with_identifiers(app_id=app_id, partition_key=partition_key) + .with_state_persister(persister) + .build() + ) + + return app, graph, persister + + +# --------------------------------------------------------------------------- +# Main: run the full suspend/resume cycle for demonstration +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + import os + + _DB_PATH = str(pathlib.Path(__file__).parent / "durable.db") + # Remove stale DB so the demo always starts fresh. + if os.path.exists(_DB_PATH): + os.remove(_DB_PATH) + + # --- First half: run until the workflow suspends at "review" --- + app, graph, persister = build_application(app_id="demo-1", db_path=_DB_PATH) + + # Generate the state-machine diagram (graphviz binary required). + try: + app.visualize( + output_file_path=str(pathlib.Path(__file__).parent / "statemachine"), + include_conditions=False, + view=False, + format="png", + ) + print("State machine saved to statemachine.png") + except Exception as exc: + print(f"visualize skipped: {exc}") + + app.run(halt_after=["review"]) + print("Suspended on channel:", app.suspended.channel) + print("Suspension metadata:", app.suspended.metadata) + + # --- Second half: simulate the human approving the draft --- + final_state = resume( + persister=persister, + graph=graph, + app_id="demo-1", + partition_key=None, + channel="human_approval", + payload={"approved": True}, + ) + print("Final approved:", final_state["approved"]) diff --git a/examples/durable-execution/notebook.ipynb b/examples/durable-execution/notebook.ipynb new file mode 100644 index 000000000..0b29710ee --- /dev/null +++ b/examples/durable-execution/notebook.ipynb @@ -0,0 +1,80 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "a1b2c3d4e5f60001", + "metadata": {}, + "outputs": [], + "source": "# Licensed to the Apache Software Foundation (ASF) under one\n# or more contributor license agreements. See the NOTICE file\n# distributed with this work for additional information\n# regarding copyright ownership. The ASF licenses this file\n# to you under the Apache License, Version 2.0 (the\n# \"License\"); you may not use this file except in compliance\n# with the License. You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied. See the License for the\n# specific language governing permissions and limitations\n# under the License." + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a1b2c3d4e5f60002", + "metadata": {}, + "outputs": [], + "source": "!pip install burr" + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a1b2c3d4e5f60003", + "metadata": {}, + "outputs": [], + "source": [ + "# Build the durable-execution application and run until it suspends.\n", + "import sys, pathlib\n", + "sys.path.insert(0, str(pathlib.Path(\".\").resolve()))\n", + "\n", + "from application import build_application\n", + "from burr.core import resume\n", + "\n", + "app, graph, persister = build_application(app_id=\"notebook-demo-1\", db_path=\":memory:\")\n", + "app.run(halt_after=[\"review\"])\n", + "print(\"Suspended channel:\", app.suspended.channel)\n", + "print(\"Metadata:\", app.suspended.metadata)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a1b2c3d4e5f60004", + "metadata": {}, + "outputs": [], + "source": [ + "# Simulate the human approving and resume the workflow.\n", + "final_state = resume(\n", + " persister=persister,\n", + " graph=graph,\n", + " app_id=\"notebook-demo-1\",\n", + " partition_key=None,\n", + " channel=\"human_approval\",\n", + " payload={\"approved\": True},\n", + ")\n", + "print(\"Approved:\", final_state[\"approved\"])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/durable-execution/requirements.txt b/examples/durable-execution/requirements.txt new file mode 100644 index 000000000..a78cac9dd --- /dev/null +++ b/examples/durable-execution/requirements.txt @@ -0,0 +1 @@ +burr diff --git a/examples/durable-execution/statemachine.png b/examples/durable-execution/statemachine.png new file mode 100644 index 000000000..f03f9e03e Binary files /dev/null and b/examples/durable-execution/statemachine.png differ diff --git a/tests/core/test_durable.py b/tests/core/test_durable.py new file mode 100644 index 000000000..76881e6a2 --- /dev/null +++ b/tests/core/test_durable.py @@ -0,0 +1,1007 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import dataclasses +import datetime +from collections import defaultdict +from typing import Literal, Optional + +import pytest + +from burr.core.durable import ( + DeterminismError, + JournalEntry, + SuspensionRecord, + _Suspended, +) +from burr.core.persistence import BaseStatePersister +from burr.core.state import State + + +class NonDurablePersister(BaseStatePersister): + """Dict-backed persister that does NOT override any durable-storage methods. + + ``supports_durable_storage(NonDurablePersister())`` returns False because + ``save_suspension`` is inherited unchanged from ``BaseStatePersister``. + The Application therefore stores suspensions and journal entries inside + the State blob (in-state fallback path). + """ + + def __init__(self): + self._storage = defaultdict(lambda: defaultdict(list)) + + def save( + self, + partition_key: Optional[str], + app_id: str, + sequence_id: int, + position: str, + state: "State", + status: Literal["completed", "failed", "suspended"], + **kwargs, + ): + record = { + "partition_key": partition_key or "", + "app_id": app_id, + "sequence_id": sequence_id, + "position": position, + "state": state, + "created_at": datetime.datetime.now().isoformat(), + "status": status, + } + self._storage[partition_key][app_id].append(record) + + def load( + self, + partition_key: str, + app_id: Optional[str], + sequence_id: Optional[int] = None, + **kwargs, + ): + if app_id is None: + return None + states = self._storage[partition_key][app_id] + if not states: + return None + if sequence_id is None: + return states[-1] + matching = [s for s in states if s["sequence_id"] == sequence_id] + return matching[-1] if matching else None + + def list_app_ids(self, partition_key: str, **kwargs): + return list(self._storage[partition_key].keys()) + + +def test_suspended_is_base_exception_not_exception(): + assert issubclass(_Suspended, BaseException) + assert not issubclass(_Suspended, Exception) + + +def test_suspended_carries_channel_schema_metadata(): + sig = _Suspended(channel="approval", schema_json={"type": "object"}, metadata={"k": "v"}) + assert sig.channel == "approval" + assert sig.schema_json == {"type": "object"} + assert sig.metadata == {"k": "v"} + + +def test_determinism_error_is_exception(): + assert issubclass(DeterminismError, Exception) + + +def test_suspension_record_fields(): + record = SuspensionRecord( + suspension_id="s1", + partition_key="p", + app_id="a", + sequence_id=3, + position="review", + channel="approval", + schema_json=None, + metadata={"summary": "hi"}, + inputs={"x": 1}, + state={"draft": "text"}, + created_at="2026-05-22T00:00:00", + resolved=False, + ) + assert dataclasses.is_dataclass(record) + assert record.resolved is False + assert record.state == {"draft": "text"} + + +def test_journal_entry_fields(): + entry = JournalEntry( + partition_key="p", + app_id="a", + sequence_id=3, + step_key="summarize", + call_index=0, + result="cached value", + ) + assert dataclasses.is_dataclass(entry) + assert entry.call_index == 0 + assert entry.result == "cached value" + + +def test_save_status_literal_includes_suspended(): + import typing + + from burr.core.persistence import BaseStateSaver + + hints = typing.get_type_hints(BaseStateSaver.save) + status_arg = hints["status"] + assert "suspended" in typing.get_args(status_arg) + + +def test_durable_symbols_exported_from_burr_core(): + import burr.core as core + + assert hasattr(core, "DeterminismError") + assert hasattr(core, "SuspensionRecord") + + +def test_base_persister_durable_methods_raise_not_implemented(): + from burr.core.durable import JournalEntry, SuspensionRecord + from burr.core.persistence import BaseStatePersister + + # Use DevNullPersister which satisfies the abstract methods but does not + # override the durable methods, so all five should raise NotImplementedError. + from burr.core.persistence import DevNullPersister + + p = DevNullPersister() + + dummy_record = SuspensionRecord( + suspension_id="s1", + partition_key="p", + app_id="a", + sequence_id=1, + position="action", + channel="ch", + schema_json=None, + metadata=None, + inputs={}, + state={}, + created_at="2026-05-22T00:00:00", + resolved=False, + ) + dummy_entry = JournalEntry( + partition_key="p", + app_id="a", + sequence_id=1, + step_key="k", + call_index=0, + result=None, + ) + + with pytest.raises(NotImplementedError): + p.save_suspension(dummy_record) + + with pytest.raises(NotImplementedError): + p.load_suspension("p", "a", "ch") + + with pytest.raises(NotImplementedError): + p.save_journal_entry(dummy_entry) + + with pytest.raises(NotImplementedError): + p.load_journal("p", "a", 1) + + with pytest.raises(NotImplementedError): + p.mark_suspension_resolved("s1") + + +def test_supports_durable_storage_false_for_non_durable_persister(): + from burr.core.durable import supports_durable_storage + + persister = NonDurablePersister() + assert supports_durable_storage(persister) is False + + +def test_supports_durable_storage_true_for_in_memory(): + from burr.core.durable import supports_durable_storage + from burr.core.persistence import InMemoryPersister + + assert supports_durable_storage(InMemoryPersister()) is True + + +def test_in_memory_persister_suspension_round_trip(): + from burr.core.durable import SuspensionRecord + from burr.core.persistence import InMemoryPersister + + persister = InMemoryPersister() + record = SuspensionRecord( + suspension_id="s1", partition_key="p", app_id="a", sequence_id=2, + position="review", channel="approval", schema_json=None, + metadata=None, inputs={}, state={"draft": "d"}, + created_at="2026-05-22T00:00:00", resolved=False, + ) + persister.save_suspension(record) + loaded = persister.load_suspension("p", "a", "approval") + assert loaded.suspension_id == "s1" + assert loaded.state == {"draft": "d"} + assert loaded.resolved is False + + +def test_in_memory_persister_journal_round_trip(): + from burr.core.durable import JournalEntry + from burr.core.persistence import InMemoryPersister + + persister = InMemoryPersister() + entry = JournalEntry( + partition_key="p", app_id="a", sequence_id=2, + step_key="summarize", call_index=0, result="cached", + ) + persister.save_journal_entry(entry) + journal = persister.load_journal("p", "a", 2) + assert len(journal) == 1 + assert journal[0].result == "cached" + + +# --- In-state fallback codec tests ------------------------------------------- + + +def test_suspension_codec_round_trip(): + from burr.core.durable import ( + SuspensionRecord, + read_suspension_from_state, + write_suspension_into_state, + ) + from burr.core.state import State + + record = SuspensionRecord( + suspension_id="s42", + partition_key="p", + app_id="a", + sequence_id=5, + position="review", + channel="approval", + schema_json=None, + metadata={"note": "hi"}, + inputs={"x": 1}, + state={"draft": "text"}, + created_at="2026-05-22T00:00:00", + resolved=False, + ) + state = State() + new_state = write_suspension_into_state(state, record) + result = read_suspension_from_state(new_state, "approval") + + assert result is not None + assert result.suspension_id == record.suspension_id + assert result.channel == record.channel + assert result.state == record.state + assert result.resolved == record.resolved + + +def test_read_suspension_from_state_channel_mismatch(): + from burr.core.durable import ( + SuspensionRecord, + read_suspension_from_state, + write_suspension_into_state, + ) + from burr.core.state import State + + record = SuspensionRecord( + suspension_id="s1", + partition_key="p", + app_id="a", + sequence_id=1, + position="act", + channel="approval", + schema_json=None, + metadata=None, + inputs={}, + state={}, + created_at="2026-05-22T00:00:00", + resolved=False, + ) + state = write_suspension_into_state(State(), record) + + assert read_suspension_from_state(state, "other_channel") is None + assert read_suspension_from_state(State(), "approval") is None + + +def test_journal_codec_round_trip(): + from burr.core.durable import ( + JournalEntry, + read_journal_from_state, + write_journal_into_state, + ) + from burr.core.state import State + + entries = [ + JournalEntry( + partition_key="p", app_id="a", sequence_id=3, + step_key="step_a", call_index=0, result="first", + ), + JournalEntry( + partition_key="p", app_id="a", sequence_id=3, + step_key="step_b", call_index=1, result="second", + ), + ] + state = write_journal_into_state(State(), entries) + loaded = read_journal_from_state(state) + + assert len(loaded) == 2 + call_indices = {e.call_index for e in loaded} + assert call_indices == {0, 1} + results = {e.call_index: e.result for e in loaded} + assert results[0] == "first" + assert results[1] == "second" + + +def test_journal_codec_preserves_json_friendly_result(): + from burr.core.durable import ( + JournalEntry, + read_journal_from_state, + write_journal_into_state, + ) + from burr.core.state import State + + original_result = {"k": [1, 2]} + entry = JournalEntry( + partition_key="p", app_id="a", sequence_id=7, + step_key="fetch", call_index=0, result=original_result, + ) + state = write_journal_into_state(State(), [entry]) + loaded = read_journal_from_state(state) + + assert len(loaded) == 1 + assert loaded[0].result == original_result + + +# --- InMemoryPersister: mark_suspension_resolved tests ----------------------- + + +def test_in_memory_persister_mark_suspension_resolved_flips_flag(): + from burr.core.durable import SuspensionRecord + from burr.core.persistence import InMemoryPersister + + persister = InMemoryPersister() + record = SuspensionRecord( + suspension_id="s99", + partition_key="p", + app_id="a", + sequence_id=1, + position="review", + channel="approval", + schema_json=None, + metadata=None, + inputs={}, + state={}, + created_at="2026-05-22T00:00:00", + resolved=False, + ) + persister.save_suspension(record) + persister.mark_suspension_resolved("s99") + loaded = persister.load_suspension("p", "a", "approval") + assert loaded is not None + assert loaded.resolved is True + + +def test_in_memory_persister_mark_suspension_resolved_unknown_id_is_noop(): + from burr.core.persistence import InMemoryPersister + + persister = InMemoryPersister() + # Must not raise for an id that was never stored. + persister.mark_suspension_resolved("does-not-exist") + + +def test_in_memory_persister_mark_suspension_resolved_is_conditional(): + """mark_suspension_resolved must match SQLite semantics: True on first call, + False on second call (already resolved), and False for an unknown id.""" + from burr.core.durable import SuspensionRecord + from burr.core.persistence import InMemoryPersister + + persister = InMemoryPersister() + record = SuspensionRecord( + suspension_id="s-cond", + partition_key="p", + app_id="a", + sequence_id=1, + position="review", + channel="approval", + schema_json=None, + metadata=None, + inputs={}, + state={}, + created_at="2026-05-22T00:00:00", + resolved=False, + ) + persister.save_suspension(record) + + # First call: row was unresolved, should flip and return True. + first = persister.mark_suspension_resolved("s-cond") + assert first is True + + # Second call: already resolved, should be a no-op and return False. + second = persister.mark_suspension_resolved("s-cond") + assert second is False + + # Unknown id: nothing to flip, must return False. + unknown = persister.mark_suspension_resolved("does-not-exist") + assert unknown is False + + +# --- InMemoryPersister: load_journal ordering test --------------------------- + + +def test_in_memory_persister_journal_ordered_by_call_index(): + from burr.core.durable import JournalEntry + from burr.core.persistence import InMemoryPersister + + persister = InMemoryPersister() + # Insert out of order: 2, 0, 1 + for idx in (2, 0, 1): + persister.save_journal_entry( + JournalEntry( + partition_key="p", + app_id="a", + sequence_id=5, + step_key=f"step_{idx}", + call_index=idx, + result=f"result_{idx}", + ) + ) + journal = persister.load_journal("p", "a", 5) + assert [e.call_index for e in journal] == [0, 1, 2] + + +# --- ApplicationContext.suspend() tests --------------------------------------- + + +def _make_context(resume_signals=None, state_persister=None): + from burr.core.application import ApplicationContext + + return ApplicationContext( + app_id="a", partition_key="p", sequence_id=1, tracker=None, + parallel_executor_factory=lambda: None, state_initializer=None, + state_persister=state_persister, action_name="review", + _resume_signals=resume_signals or {}, + _loaded_journal=[], _journal_sink=[], + ) + + +def test_suspend_raises_on_first_call(): + from burr.core.durable import _Suspended + + ctx = _make_context() + with pytest.raises(_Suspended) as excinfo: + ctx.suspend("approval", metadata={"summary": "hi"}) + assert excinfo.value.channel == "approval" + assert excinfo.value.metadata == {"summary": "hi"} + + +def test_suspend_returns_payload_when_signal_present(): + ctx = _make_context(resume_signals={"approval": {"approved": True}}) + result = ctx.suspend("approval") + assert result == {"approved": True} + + +def test_suspend_validates_payload_against_live_schema(): + pydantic = pytest.importorskip("pydantic") + + class Approval(pydantic.BaseModel): + approved: bool + + ctx = _make_context(resume_signals={"approval": {"approved": True}}) + result = ctx.suspend("approval", schema=Approval) + assert isinstance(result, Approval) + assert result.approved is True + + +def test_suspend_first_call_schema_json_populated(): + pydantic = pytest.importorskip("pydantic") + + class Approval(pydantic.BaseModel): + approved: bool + + ctx = _make_context() + with pytest.raises(_Suspended) as excinfo: + ctx.suspend("approval", schema=Approval) + assert excinfo.value.schema_json == Approval.model_json_schema() + + +# --------------------------------------------------------------------------- +# Integration: suspend signal caught by the sync run loop (Task 2.3) +# --------------------------------------------------------------------------- + + +def _suspending_app(persister): + from burr.core import ApplicationBuilder, State, action + + @action(reads=[], writes=["seen"]) + def start(state): + return state.update(seen=True) + + @action(reads=["seen"], writes=["done"]) + def gate(state, __context): + decision = __context.suspend("approval") + return state.update(done=decision) + + return ( + ApplicationBuilder() + .with_actions(start=start, gate=gate) + .with_transitions(("start", "gate")) + .with_entrypoint("start") + .with_state(State({})) + .with_identifiers(app_id="app1", partition_key="pk1") + .with_state_persister(persister) + .build() + ) + + +def test_run_stops_and_records_suspension(): + from burr.core.persistence import InMemoryPersister + + persister = InMemoryPersister() + app = _suspending_app(persister) + app.run(halt_after=["gate"]) + + assert app.suspended is not None + assert app.suspended.channel == "approval" + assert app.suspended.position == "gate" + record = persister.load_suspension("pk1", "app1", "approval") + assert record is not None + assert record.resolved is False + assert record.state.get("seen") is True + + +# --------------------------------------------------------------------------- +# Integration: suspend signal caught by the async run loop (Task 2.4) +# --------------------------------------------------------------------------- + + +async def test_arun_stops_and_records_suspension(): + from burr.core import ApplicationBuilder, State, action + from burr.core.persistence import InMemoryPersister + + @action(reads=[], writes=["seen"]) + async def astart(state): + return state.update(seen=True) + + @action(reads=["seen"], writes=["done"]) + async def agate(state, __context): + decision = __context.suspend("approval") + return state.update(done=decision) + + persister = InMemoryPersister() + app = ( + ApplicationBuilder() + .with_actions(astart=astart, agate=agate) + .with_transitions(("astart", "agate")) + .with_entrypoint("astart") + .with_state(State({})) + .with_identifiers(app_id="app2", partition_key="pk2") + .with_state_persister(persister) + .build() + ) + await app.arun(halt_after=["agate"]) + + assert app.suspended is not None + assert app.suspended.position == "agate" + record = persister.load_suspension("pk2", "app2", "approval") + assert record is not None + assert record.channel == "approval" + assert record.resolved is False + assert record.state.get("seen") is True + + +# --- ApplicationContext.durable() tests (Task 3.1) ---------------------------- + + +def test_durable_executes_fn_and_journals_on_first_run(): + calls = [] + + def side_effect(x): + calls.append(x) + return x * 2 + + ctx = _make_context() + result = ctx.durable("double", side_effect, 21) + assert result == 42 + assert calls == [21] + # The entry was appended to the journal sink for persistence. + assert len(ctx._journal_sink) == 1 + assert ctx._journal_sink[0].step_key == "double" + assert ctx._journal_sink[0].call_index == 0 + assert ctx._journal_sink[0].result == 42 + + +def test_durable_forwards_positional_and_keyword_args(): + ctx = _make_context() + result = ctx.durable("combine", lambda x, y: (x, y), 1, y=2) + assert result == (1, 2) + + +def test_durable_assigns_increasing_call_index(): + ctx = _make_context() + ctx.durable("a", lambda: 1) + ctx.durable("b", lambda: 2) + assert [e.call_index for e in ctx._journal_sink] == [0, 1] + assert [e.step_key for e in ctx._journal_sink] == ["a", "b"] + + +# --- ApplicationContext.durable() replay tests (Task 3.2) --------------------- + + +def test_durable_replays_from_loaded_journal_without_executing_fn(): + from burr.core.durable import JournalEntry + + recorded = [ + JournalEntry("p", "a", 1, "double", 0, 42), + ] + ctx = _make_context() + ctx._loaded_journal = recorded + + calls = [] + + def side_effect(x): + calls.append(x) + return x * 2 + + result = ctx.durable("double", side_effect, 21) + assert result == 42 + assert calls == [] # fn must NOT run on replay + + +def test_durable_replay_then_execute_for_calls_past_the_journal(): + from burr.core.durable import JournalEntry + + ctx = _make_context() + ctx._loaded_journal = [JournalEntry("p", "a", 1, "first", 0, "cached")] + + first = ctx.durable("first", lambda: "fresh") + second = ctx.durable("second", lambda: "executed") + assert first == "cached" # replayed + assert second == "executed" # past the journal -> executed + + +# --- ApplicationContext.durable() determinism error (Task 3.3) ---------------- + + +def test_durable_raises_determinism_error_on_key_mismatch(): + from burr.core.durable import DeterminismError, JournalEntry + + ctx = _make_context() + ctx._loaded_journal = [JournalEntry("p", "a", 1, "summarize", 0, "x")] + + with pytest.raises(DeterminismError): + # The first durable call on resume used a different key than recorded. + ctx.durable("translate", lambda: "y") + + +def test_journal_sink_flushed_into_state_on_completion_with_fallback(): + from burr.core import ApplicationBuilder, State, action + from burr.core.durable import read_journal_from_state + + @action(reads=[], writes=["v"]) + def compute(state, __context): + value = __context.durable("calc", lambda: 99) + return state.update(v=value) + + persister = NonDurablePersister() + app = ( + ApplicationBuilder() + .with_actions(compute=compute) + .with_entrypoint("compute") + .with_state(State({})) + .with_identifiers(app_id="j1", partition_key="pk") + .with_state_persister(persister) + .build() + ) + app.run(halt_after=["compute"]) + loaded = persister.load("pk", "j1") + journal = read_journal_from_state(loaded["state"]) + assert len(journal) == 1 + assert journal[0].result == 99 + + +def test_journal_accumulates_across_multiple_actions(): + from burr.core import ApplicationBuilder, State, action + from burr.core.durable import read_journal_from_state + + @action(reads=[], writes=["a"]) + def step_a(state, __context): + v = __context.durable("a_calc", lambda: 1) + return state.update(a=v) + + @action(reads=["a"], writes=["b"]) + def step_b(state, __context): + v = __context.durable("b_calc", lambda: 2) + return state.update(b=v) + + persister = NonDurablePersister() + app = ( + ApplicationBuilder() + .with_actions(step_a=step_a, step_b=step_b) + .with_transitions(("step_a", "step_b")) + .with_entrypoint("step_a") + .with_state(State({})) + .with_identifiers(app_id="j2", partition_key="pk") + .with_state_persister(persister) + .build() + ) + app.run(halt_after=["step_b"]) + loaded = persister.load("pk", "j2") + journal = read_journal_from_state(loaded["state"]) + assert len(journal) == 2 + keys = {e.step_key for e in journal} + assert keys == {"a_calc", "b_calc"} + + +def test_journal_no_double_count_via_stream_result(): + """Regression guard: step_a's journal entry must not be double-counted + when stream_result() fast-forwards through it and then executes a + non-streaming step_b. The fix is self._journal_sink = [] at line ~1744 + of application.py, immediately before the direct self._step() call in + the non-streaming branch of stream_result(). Deleting that line causes + this test to observe 3 journal entries instead of 2.""" + from burr.core import ApplicationBuilder, State, action + from burr.core.durable import read_journal_from_state + + @action(reads=[], writes=["a"]) + def step_a(state, __context): + v = __context.durable("a_calc", lambda: 1) + return state.update(a=v) + + @action(reads=["a"], writes=["b"]) + def step_b(state, __context): + v = __context.durable("b_calc", lambda: 2) + return state.update(b=v) + + persister = NonDurablePersister() + app = ( + ApplicationBuilder() + .with_actions(step_a=step_a, step_b=step_b) + .with_transitions(("step_a", "step_b")) + .with_entrypoint("step_a") + .with_state(State({})) + .with_identifiers(app_id="j3", partition_key="pk") + .with_state_persister(persister) + .build() + ) + # step_a is NOT in halt_after, so stream_result fast-forwards through it + # via self.run(), then hits the non-streaming branch for step_b. + # The fix resets _journal_sink before that branch so step_a's entry is + # not accumulated a second time into the persisted state. + action_, container = app.stream_result(halt_after=["step_b"]) + result, final_state = container.get() + + # Verify via the persisted state (the source of truth for the bug). + loaded = persister.load("pk", "j3") + journal = read_journal_from_state(loaded["state"]) + assert len(journal) == 2, ( + f"Expected 2 journal entries (a_calc + b_calc), got {len(journal)}: " + f"{[e.step_key for e in journal]}" + ) + assert {e.step_key for e in journal} == {"a_calc", "b_calc"} + + +# --- ApplicationContext.adurable() tests (Task 3.5) --------------------------- + + +@pytest.mark.asyncio +async def test_adurable_executes_coroutine_and_journals(): + calls = [] + + async def async_side_effect(x): + calls.append(x) + return x + 1 + + ctx = _make_context() + result = await ctx.adurable("inc", async_side_effect, 41) + assert result == 42 + assert calls == [41] + assert ctx._journal_sink[0].step_key == "inc" + + +@pytest.mark.asyncio +async def test_adurable_replays_without_executing(): + from burr.core.durable import JournalEntry + + ctx = _make_context() + ctx._loaded_journal = [JournalEntry("p", "a", 1, "inc", 0, 42)] + + calls = [] + + async def async_side_effect(x): + calls.append(x) + return x + 1 + + result = await ctx.adurable("inc", async_side_effect, 41) + assert result == 42 + assert calls == [] + + +@pytest.mark.asyncio +async def test_adurable_journals_into_durable_persister(): + from burr.core.persistence import InMemoryPersister + + persister = InMemoryPersister() + + async def async_side_effect(): + return "value" + + ctx = _make_context(state_persister=persister) + await ctx.adurable("step", async_side_effect) + + journal = persister.load_journal("p", "a", 1) + assert len(journal) == 1 + assert journal[0].step_key == "step" + assert journal[0].result == "value" + + +@pytest.mark.asyncio +async def test_ahandle_suspension_persists_via_async_durable_persister(): + from burr.core import ApplicationBuilder, GraphBuilder, State, action + from burr.core.persistence import AsyncInMemoryPersister + + @action(reads=[], writes=["seen"]) + async def astart(state): + return state.update(seen=True) + + @action(reads=["seen"], writes=["done"]) + async def agate(state, __context): + decision = __context.suspend("approval") + return state.update(done=decision["approved"]) + + persister = AsyncInMemoryPersister() + graph = ( + GraphBuilder() + .with_actions(astart=astart, agate=agate) + .with_transitions(("astart", "agate")) + .build() + ) + app = await ( + ApplicationBuilder() + .with_graph(graph) + .with_entrypoint("astart") + .with_state(State({})) + .with_identifiers(app_id="arun1", partition_key="pk1") + .with_state_persister(persister) + .abuild() + ) + await app.arun(halt_after=["agate"]) + assert app.suspended is not None + + # The async durable path persists via await persister.save_suspension(...), + # so the record lives in persister._suspensions, NOT inside the State blob. + record = await persister.load_suspension("pk1", "arun1", "approval") + assert record is not None + assert record.channel == "approval" + assert record.resolved is False + + +# --------------------------------------------------------------------------- +# Milestone 5.2: lifecycle hook integration tests +# --------------------------------------------------------------------------- + + +def test_post_action_suspend_hook_fires(): + from burr.core import ApplicationBuilder, State, action + from burr.core.persistence import InMemoryPersister + from burr.lifecycle import PostActionSuspendHook + + fired = [] + + class Recorder(PostActionSuspendHook): + def post_action_suspend(self, *, app_id, partition_key, action, + sequence_id, suspension, **kw): + fired.append((suspension.channel, app_id, partition_key)) + + @action(reads=[], writes=["done"]) + def gate(state, __context): + return state.update(done=__context.suspend("approval")) + + app = ( + ApplicationBuilder() + .with_actions(gate=gate) + .with_entrypoint("gate") + .with_state(State({})) + .with_identifiers(app_id="h1", partition_key="pk") + .with_state_persister(InMemoryPersister()) + .with_hooks(Recorder()) + .build() + ) + app.run(halt_after=["gate"]) + assert fired == [("approval", "h1", "pk")] + + +def test_pre_action_resume_hook_fires(): + from burr.core import ApplicationBuilder, GraphBuilder, State, action + from burr.core.graph import Graph + from burr.core.persistence import InMemoryPersister + from burr.core.resume import resume + from burr.lifecycle import PreActionResumeHook + + fired = [] + + class Recorder(PreActionResumeHook): + def pre_action_resume(self, *, app_id, partition_key, action, + sequence_id, channel, **kw): + fired.append((channel, app_id, action.name)) + + @action(reads=[], writes=["approved"]) + def gate(state, __context): + return state.update(approved=__context.suspend("approval")) + + persister = InMemoryPersister() + graph: Graph = ( + GraphBuilder() + .with_actions(gate=gate) + .with_transitions() + .build() + ) + app = ( + ApplicationBuilder() + .with_graph(graph) + .with_entrypoint("gate") + .with_state(State({})) + .with_identifiers(app_id="r1", partition_key="pk") + .with_state_persister(persister) + .build() + ) + app.run(halt_after=["gate"]) + + resume( + persister=persister, graph=graph, app_id="r1", partition_key="pk", + channel="approval", payload=True, hooks=[Recorder()], + ) + assert fired == [("approval", "r1", "gate")] + + +@pytest.mark.asyncio +async def test_pre_action_resume_hook_fires_async(): + from burr.core import ApplicationBuilder, GraphBuilder, State, action + from burr.core.graph import Graph + from burr.core.persistence import AsyncInMemoryPersister + from burr.core.resume import aresume + from burr.lifecycle import PreActionResumeHookAsync + + fired = [] + + class Recorder(PreActionResumeHookAsync): + async def pre_action_resume(self, *, app_id, partition_key, action, + sequence_id, channel, **kw): + fired.append((channel, app_id, action.name)) + + @action(reads=[], writes=["approved"]) + async def gate(state, __context): + return state.update(approved=__context.suspend("approval")) + + persister = AsyncInMemoryPersister() + graph: Graph = ( + GraphBuilder() + .with_actions(gate=gate) + .with_transitions() + .build() + ) + app = await ( + ApplicationBuilder() + .with_graph(graph) + .with_entrypoint("gate") + .with_state(State({})) + .with_identifiers(app_id="ar1", partition_key="pk") + .with_state_persister(persister) + .abuild() + ) + await app.arun(halt_after=["gate"]) + + await aresume( + persister=persister, graph=graph, app_id="ar1", partition_key="pk", + channel="approval", payload=True, hooks=[Recorder()], + ) + assert fired == [("approval", "ar1", "gate")] diff --git a/tests/core/test_durable_integration.py b/tests/core/test_durable_integration.py new file mode 100644 index 000000000..e326ada31 --- /dev/null +++ b/tests/core/test_durable_integration.py @@ -0,0 +1,491 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime +from collections import defaultdict +from typing import Literal, Optional + +import pytest + +from burr.core import ApplicationBuilder, GraphBuilder, State, action, resume +from burr.core.persistence import AsyncInMemoryPersister, BaseStatePersister, InMemoryPersister +from burr.core.state import State as _State + + +class NonDurablePersister(BaseStatePersister): + """Dict-backed persister that does NOT override any durable-storage methods. + + ``supports_durable_storage(NonDurablePersister())`` returns False because + ``save_suspension`` is inherited unchanged from ``BaseStatePersister``. + The Application therefore stores suspensions and journal entries inside + the State blob (in-state fallback path). + """ + + def __init__(self): + self._storage = defaultdict(lambda: defaultdict(list)) + + def save( + self, + partition_key: Optional[str], + app_id: str, + sequence_id: int, + position: str, + state: "_State", + status: Literal["completed", "failed", "suspended"], + **kwargs, + ): + record = { + "partition_key": partition_key or "", + "app_id": app_id, + "sequence_id": sequence_id, + "position": position, + "state": state, + "created_at": datetime.datetime.now().isoformat(), + "status": status, + } + self._storage[partition_key][app_id].append(record) + + def load( + self, + partition_key: str, + app_id: Optional[str], + sequence_id: Optional[int] = None, + **kwargs, + ): + if app_id is None: + return None + states = self._storage[partition_key][app_id] + if not states: + return None + if sequence_id is None: + return states[-1] + matching = [s for s in states if s["sequence_id"] == sequence_id] + return matching[-1] if matching else None + + def list_app_ids(self, partition_key: str, **kwargs): + return list(self._storage[partition_key].keys()) + + +@action(reads=[], writes=["seen"]) +def start(state): + return state.update(seen=True) + + +@action(reads=["seen"], writes=["done"]) +def gate(state, __context): + decision = __context.suspend("approval") + return state.update(done=decision["approved"]) + + +def _graph(): + return ( + GraphBuilder() + .with_actions(start=start, gate=gate) + .with_transitions(("start", "gate")) + .build() + ) + + +def _build(persister, graph): + return ( + ApplicationBuilder() + .with_graph(graph) + .with_entrypoint("start") + .with_state(State({})) + .with_identifiers(app_id="run1", partition_key="pk1") + .with_state_persister(persister) + .build() + ) + + +def test_suspend_then_resume_completes_the_run(): + persister = InMemoryPersister() + graph = _graph() + + # First process: run, suspend. + app = _build(persister, graph) + app.run(halt_after=["gate"]) + assert app.suspended is not None + + # The process can die here. New process: resume. + final_state = resume( + persister=persister, graph=graph, + app_id="run1", partition_key="pk1", + channel="approval", payload={"approved": True}, + ) + assert final_state["done"] is True + + # The suspension is now resolved. + record = persister.load_suspension("pk1", "run1", "approval") + assert record.resolved is True + + +def test_resume_is_idempotent(): + persister = InMemoryPersister() + graph = _graph() + app = _build(persister, graph) + app.run(halt_after=["gate"]) + + first = resume( + persister=persister, graph=graph, app_id="run1", partition_key="pk1", + channel="approval", payload={"approved": True}, + ) + # Webhook retries are real: a second resume is a no-op. + second = resume( + persister=persister, graph=graph, app_id="run1", partition_key="pk1", + channel="approval", payload={"approved": True}, + ) + assert first["done"] == second["done"] is True + + +def test_resume_unknown_channel_raises(): + persister = InMemoryPersister() + graph = _graph() + app = _build(persister, graph) + app.run(halt_after=["gate"]) + + with pytest.raises(ValueError): + resume( + persister=persister, graph=graph, app_id="run1", partition_key="pk1", + channel="nonexistent", payload={}, + ) + + +async def test_async_suspend_then_aresume_completes(): + from burr.core import aresume + + @action(reads=[], writes=["seen"]) + async def astart(state): + return state.update(seen=True) + + @action(reads=["seen"], writes=["done"]) + async def agate(state, __context): + decision = __context.suspend("approval") + return state.update(done=decision["approved"]) + + graph = ( + GraphBuilder() + .with_actions(astart=astart, agate=agate) + .with_transitions(("astart", "agate")) + .build() + ) + persister = InMemoryPersister() + app = ( + ApplicationBuilder() + .with_graph(graph) + .with_entrypoint("astart") + .with_state(State({})) + .with_identifiers(app_id="arun1", partition_key="pk1") + .with_state_persister(persister) + .build() + ) + await app.arun(halt_after=["agate"]) + assert app.suspended is not None + + final_state = await aresume( + persister=persister, graph=graph, app_id="arun1", partition_key="pk1", + channel="approval", payload={"approved": True}, + ) + assert final_state["done"] is True + + +def test_resume_through_in_state_fallback(): + """Resume uses the in-state fallback path when the persister does not support + dedicated durable storage (supports_durable_storage() is False). NonDurablePersister + does not override save_suspension, so it triggers the fallback path where + suspension data rides inside the State blob.""" + persister = NonDurablePersister() + + graph = _graph() + + # First process: build app, run until it suspends at 'gate'. + app = _build(persister, graph) + app.run(halt_after=["gate"]) + assert app.suspended is not None + + # Same persister instance -- in-memory SQLite is lost if we open a new connection. + final_state = resume( + persister=persister, + graph=graph, + app_id="run1", + partition_key="pk1", + channel="approval", + payload={"approved": True}, + ) + assert final_state["done"] is True + + +def test_resume_in_state_fallback_second_call_raises(): + """A second resume() call on an in-state fallback persister raises ValueError. + + After the first resume() completes, the resumed run's new state row no longer + carries '__burr_durable__', so the suspension record is gone. A second resume() + must raise ValueError with a message that names the in-state fallback as the + reason, distinguishing it from a never-suspended app_id. + """ + persister = NonDurablePersister() + + graph = _graph() + + # Suspend. + app = _build(persister, graph) + app.run(halt_after=["gate"]) + assert app.suspended is not None + + # First resume succeeds. + resume( + persister=persister, + graph=graph, + app_id="run1", + partition_key="pk1", + channel="approval", + payload={"approved": True}, + ) + + # Second resume on in-state fallback must raise ValueError naming the cause. + with pytest.raises(ValueError, match="already resolved on a persister without durable storage"): + resume( + persister=persister, + graph=graph, + app_id="run1", + partition_key="pk1", + channel="approval", + payload={"approved": True}, + ) + + +async def test_async_suspend_resume_with_async_durable_persister(): + """aresume() works end-to-end with an async durable persister. + + Uses AsyncInMemoryPersister (async + durable storage) to exercise the full + async load/journal/rebuild path introduced in Task 4.6. + """ + from burr.core import aresume + + @action(reads=[], writes=["seen"]) + async def astart(state): + return state.update(seen=True) + + @action(reads=["seen"], writes=["done"]) + async def agate(state, __context): + decision = __context.suspend("approval") + return state.update(done=decision["approved"]) + + graph = ( + GraphBuilder() + .with_actions(astart=astart, agate=agate) + .with_transitions(("astart", "agate")) + .build() + ) + persister = AsyncInMemoryPersister() + app = await ( + ApplicationBuilder() + .with_graph(graph) + .with_entrypoint("astart") + .with_state(State({})) + .with_identifiers(app_id="async_durable_run1", partition_key="pk1") + .with_state_persister(persister) + .abuild() + ) + await app.arun(halt_after=["agate"]) + assert app.suspended is not None + + final_state = await aresume( + persister=persister, + graph=graph, + app_id="async_durable_run1", + partition_key="pk1", + channel="approval", + payload={"approved": True}, + ) + assert final_state["done"] is True + + # Confirm async mark_suspension_resolved was called. + record = await persister.load_suspension("pk1", "async_durable_run1", "approval") + assert record.resolved is True + + +# --- Task 3.6: durable side effect runs exactly once across suspend/resume ---- + +# Module-level counter: survives the Application instance, not the process. +_side_effect_calls = [] + + +@action(reads=[], writes=["summary", "approved"]) +def summarize_then_gate(state, __context): + summary = __context.durable("summarize", _expensive_summarize, "draft text") + decision = __context.suspend("approval", metadata={"summary": summary}) + return state.update(summary=summary, approved=decision["approved"]) + + +def _expensive_summarize(text): + _side_effect_calls.append(text) + return f"summary of {text}" + + +def test_durable_side_effect_runs_once_across_suspend_resume(): + _side_effect_calls.clear() + graph = ( + GraphBuilder() + .with_actions(summarize_then_gate=summarize_then_gate) + .with_transitions() + .build() + ) + persister = InMemoryPersister() + app = ( + ApplicationBuilder() + .with_graph(graph) + .with_entrypoint("summarize_then_gate") + .with_state(State({})) + .with_identifiers(app_id="once1", partition_key="pk") + .with_state_persister(persister) + .build() + ) + app.run(halt_after=["summarize_then_gate"]) + assert app.suspended is not None + assert len(_side_effect_calls) == 1 # ran once before suspending + + final_state = resume( + persister=persister, graph=graph, app_id="once1", partition_key="pk", + channel="approval", payload={"approved": True}, + ) + # The action re-ran top-to-bottom on resume, but summarize was replayed. + assert len(_side_effect_calls) == 1 + assert final_state["approved"] is True + assert final_state["summary"] == "summary of draft text" + + +# --- Task 3.7: non-deterministic branch raises DeterminismError --------------- + +_branch_toggle = {"value": True} + + +@action(reads=[], writes=["out"]) +def nondeterministic(state, __context): + # ANTI-PATTERN under test: a durable() call behind a branch that flips + # between the first run and the resume re-run. + if _branch_toggle["value"]: + __context.durable("branch_a", lambda: "a") + else: + __context.durable("branch_b", lambda: "b") + decision = __context.suspend("approval") + return state.update(out=decision["ok"]) + + +def test_nondeterministic_branch_raises_determinism_error(): + from burr.core.durable import DeterminismError + + _branch_toggle["value"] = True + try: + graph = ( + GraphBuilder() + .with_actions(nondeterministic=nondeterministic) + .with_transitions() + .build() + ) + persister = InMemoryPersister() + app = ( + ApplicationBuilder() + .with_graph(graph) + .with_entrypoint("nondeterministic") + .with_state(State({})) + .with_identifiers(app_id="det1", partition_key="pk") + .with_state_persister(persister) + .build() + ) + app.run(halt_after=["nondeterministic"]) + + # Flip the branch before resume: the re-run takes branch_b. + _branch_toggle["value"] = False + with pytest.raises(DeterminismError): + resume( + persister=persister, graph=graph, app_id="det1", partition_key="pk", + channel="approval", payload={"ok": True}, + ) + finally: + _branch_toggle["value"] = True + + +# --- Task 4.2: SQLite end-to-end through dedicated tables -------------------- + + +def test_suspend_resume_with_sqlite_dedicated_storage(tmp_path): + """End-to-end: suspend on a file-backed SQLite persister, close the connection + (simulating process death), reopen with a fresh persister against the same file, + resume. Exercises the dedicated ``burr_suspensions`` + ``burr_journal`` tables + across a true process boundary.""" + from burr.core.persistence import SQLitePersister + + db = str(tmp_path / "durable.db") + + graph = _graph() + p1 = SQLitePersister.from_values(db) + p1.initialize() + app = ( + ApplicationBuilder() + .with_graph(graph) + .with_entrypoint("start") + .with_state(State({})) + .with_identifiers(app_id="sql1", partition_key="pk") + .with_state_persister(p1) + .build() + ) + app.run(halt_after=["gate"]) + assert app.suspended is not None + p1.connection.close() # simulate the process dying + + # New process: brand-new persister against the same DB file. + p2 = SQLitePersister.from_values(db) + p2.initialize() + final_state = resume( + persister=p2, graph=graph, app_id="sql1", partition_key="pk", + channel="approval", payload={"approved": True}, + ) + assert final_state["done"] is True + p2.connection.close() + + +def test_example_application_suspends_and_resumes(tmp_path): + import importlib.util + import pathlib + import sys + + path = ( + pathlib.Path(__file__).parents[2] + / "examples" / "durable-execution" / "application.py" + ) + spec = importlib.util.spec_from_file_location("durable_example", path) + module = importlib.util.module_from_spec(spec) + sys.modules["durable_example"] = module + spec.loader.exec_module(module) + try: + # The example exposes build_application() and the graph it used. + # Use tmp_path for the SQLite DB so the test does not pollute ~/. + app, graph, persister = module.build_application( + app_id="ex1", db_path=str(tmp_path / "durable.db") + ) + app.run(halt_after=["review"]) + assert app.suspended is not None + assert app.suspended.channel == "human_approval" + + final = resume( + persister=persister, graph=graph, app_id="ex1", partition_key=None, + channel="human_approval", payload={"approved": True}, + ) + assert "approved" in final + finally: + sys.modules.pop("durable_example", None) diff --git a/tests/core/test_durable_persisters.py b/tests/core/test_durable_persisters.py new file mode 100644 index 000000000..d3fb53887 --- /dev/null +++ b/tests/core/test_durable_persisters.py @@ -0,0 +1,572 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +import pytest + +from burr.core.durable import JournalEntry, SuspensionRecord, supports_durable_storage +from burr.core.persistence import SQLitePersister + +_pg_integration = pytest.mark.skipif( + os.environ.get("BURR_CI_INTEGRATION_TESTS") != "true", + reason="Skipping integration tests", +) + + +@pytest.fixture +def sqlite_persister(): + persister = SQLitePersister.from_values(":memory:") + persister.initialize() + yield persister + + +def _record(resolved=False): + return SuspensionRecord( + suspension_id="sus-1", partition_key="pk", app_id="app", + sequence_id=4, position="review", channel="approval", + schema_json={"type": "object"}, metadata={"summary": "s"}, + inputs={"x": 1}, state={"draft": "d"}, + created_at="2026-05-22T00:00:00", resolved=resolved, + ) + + +def test_sqlite_supports_durable_storage(sqlite_persister): + assert supports_durable_storage(sqlite_persister) is True + + +def test_sqlite_suspension_round_trip(sqlite_persister): + sqlite_persister.save_suspension(_record()) + loaded = sqlite_persister.load_suspension("pk", "app", "approval") + assert loaded.suspension_id == "sus-1" + assert loaded.state == {"draft": "d"} + assert loaded.inputs == {"x": 1} + assert loaded.schema_json == {"type": "object"} + assert loaded.resolved is False + + +def test_sqlite_load_suspension_returns_resolved_record(sqlite_persister): + # Contract: load_suspension returns the record whether or not it is + # resolved; the caller checks record.resolved for resume-once idempotency. + sqlite_persister.save_suspension(_record()) + sqlite_persister.mark_suspension_resolved("sus-1") + loaded = sqlite_persister.load_suspension("pk", "app", "approval") + assert loaded is not None + assert loaded.resolved is True + + +def test_sqlite_mark_resolved_is_conditional(sqlite_persister): + sqlite_persister.save_suspension(_record()) + first = sqlite_persister.mark_suspension_resolved("sus-1") + second = sqlite_persister.mark_suspension_resolved("sus-1") + # First call resolves a row; second call resolves nothing (resume-once). + assert first is True + assert second is False + + +def test_sqlite_journal_round_trip(sqlite_persister): + sqlite_persister.save_journal_entry( + JournalEntry("pk", "app", 4, "summarize", 0, "result-a") + ) + sqlite_persister.save_journal_entry( + JournalEntry("pk", "app", 4, "translate", 1, "result-b") + ) + journal = sqlite_persister.load_journal("pk", "app", 4) + assert [e.call_index for e in journal] == [0, 1] + assert [e.result for e in journal] == ["result-a", "result-b"] + + +# --------------------------------------------------------------------------- +# PostgreSQL durable storage tests — skipped unless BURR_CI_INTEGRATION_TESTS=true +# --------------------------------------------------------------------------- + + +@pytest.fixture +def pg_persister(): + from burr.integrations.persisters.b_psycopg2 import PostgreSQLPersister + + persister = PostgreSQLPersister.from_values( + db_name=os.environ.get("POSTGRES_DB", "postgres"), + user=os.environ.get("POSTGRES_USER", "postgres"), + password=os.environ.get("POSTGRES_PASSWORD", "postgres"), + host=os.environ.get("POSTGRES_HOST", "localhost"), + port=int(os.environ.get("POSTGRES_PORT", "5432")), + table_name="burr_state_durable_test", + ) + persister.initialize() + yield persister + # Teardown: drop durable + state tables so the next run starts clean. + cursor = persister.connection.cursor() + cursor.execute("DROP TABLE IF EXISTS burr_suspensions") + cursor.execute("DROP TABLE IF EXISTS burr_journal") + cursor.execute("DROP TABLE IF EXISTS burr_state_durable_test") + persister.connection.commit() + persister.cleanup() + + +@_pg_integration +def test_postgres_supports_durable_storage(pg_persister): + assert supports_durable_storage(pg_persister) is True + + +@_pg_integration +def test_postgres_suspension_round_trip(pg_persister): + pg_persister.save_suspension(_record()) + loaded = pg_persister.load_suspension("pk", "app", "approval") + assert loaded.suspension_id == "sus-1" + assert loaded.state == {"draft": "d"} + assert loaded.inputs == {"x": 1} + assert loaded.schema_json == {"type": "object"} + assert loaded.resolved is False + + +@_pg_integration +def test_postgres_load_suspension_returns_resolved_record(pg_persister): + # Contract: load_suspension returns the record whether or not it is + # resolved; the caller checks record.resolved for resume-once idempotency. + pg_persister.save_suspension(_record()) + pg_persister.mark_suspension_resolved("sus-1") + loaded = pg_persister.load_suspension("pk", "app", "approval") + assert loaded is not None + assert loaded.resolved is True + + +@_pg_integration +def test_postgres_mark_resolved_is_conditional(pg_persister): + pg_persister.save_suspension(_record()) + first = pg_persister.mark_suspension_resolved("sus-1") + second = pg_persister.mark_suspension_resolved("sus-1") + # First call resolves a row; second call resolves nothing (resume-once). + assert first is True + assert second is False + + +@_pg_integration +def test_postgres_journal_round_trip(pg_persister): + pg_persister.save_journal_entry( + JournalEntry("pk", "app", 4, "summarize", 0, "result-a") + ) + pg_persister.save_journal_entry( + JournalEntry("pk", "app", 4, "translate", 1, "result-b") + ) + journal = pg_persister.load_journal("pk", "app", 4) + assert [e.call_index for e in journal] == [0, 1] + assert [e.result for e in journal] == ["result-a", "result-b"] + + +# --------------------------------------------------------------------------- +# asyncpg durable storage tests — skipped unless BURR_CI_INTEGRATION_TESTS=true +# --------------------------------------------------------------------------- + +import pytest_asyncio + + +@pytest_asyncio.fixture +async def asyncpg_persister(): + from burr.integrations.persisters.b_asyncpg import AsyncPostgreSQLPersister + + persister = await AsyncPostgreSQLPersister.from_values( + db_name=os.environ.get("POSTGRES_DB", "postgres"), + user=os.environ.get("POSTGRES_USER", "postgres"), + password=os.environ.get("POSTGRES_PASSWORD", "postgres"), + host=os.environ.get("POSTGRES_HOST", "localhost"), + port=int(os.environ.get("POSTGRES_PORT", "5432")), + table_name="burr_state_asyncpg_durable_test", + ) + await persister.initialize() + yield persister + conn, acquired = await persister._get_connection() + try: + await conn.execute("DROP TABLE IF EXISTS burr_suspensions") + await conn.execute("DROP TABLE IF EXISTS burr_journal") + await conn.execute("DROP TABLE IF EXISTS burr_state_asyncpg_durable_test") + finally: + await persister._release_connection(conn, acquired) + await persister.cleanup() + + +@_pg_integration +@pytest.mark.asyncio +async def test_asyncpg_supports_durable_storage(asyncpg_persister): + assert supports_durable_storage(asyncpg_persister) is True + + +@_pg_integration +@pytest.mark.asyncio +async def test_asyncpg_suspension_round_trip(asyncpg_persister): + await asyncpg_persister.save_suspension(_record()) + loaded = await asyncpg_persister.load_suspension("pk", "app", "approval") + assert loaded.suspension_id == "sus-1" + assert loaded.state == {"draft": "d"} + assert loaded.inputs == {"x": 1} + assert loaded.schema_json == {"type": "object"} + assert loaded.resolved is False + + +@_pg_integration +@pytest.mark.asyncio +async def test_asyncpg_load_suspension_returns_resolved_record(asyncpg_persister): + # Contract: load_suspension returns the record whether or not it is + # resolved; the caller checks record.resolved for resume-once idempotency. + await asyncpg_persister.save_suspension(_record()) + await asyncpg_persister.mark_suspension_resolved("sus-1") + loaded = await asyncpg_persister.load_suspension("pk", "app", "approval") + assert loaded is not None + assert loaded.resolved is True + + +@_pg_integration +@pytest.mark.asyncio +async def test_asyncpg_mark_resolved_is_conditional(asyncpg_persister): + await asyncpg_persister.save_suspension(_record()) + first = await asyncpg_persister.mark_suspension_resolved("sus-1") + second = await asyncpg_persister.mark_suspension_resolved("sus-1") + # First call resolves a row; second call resolves nothing (resume-once). + assert first is True + assert second is False + + +@_pg_integration +@pytest.mark.asyncio +async def test_asyncpg_journal_round_trip(asyncpg_persister): + await asyncpg_persister.save_journal_entry( + JournalEntry("pk", "app", 4, "summarize", 0, "result-a") + ) + await asyncpg_persister.save_journal_entry( + JournalEntry("pk", "app", 4, "translate", 1, "result-b") + ) + journal = await asyncpg_persister.load_journal("pk", "app", 4) + assert [e.call_index for e in journal] == [0, 1] + assert [e.result for e in journal] == ["result-a", "result-b"] + + +# --------------------------------------------------------------------------- +# aiosqlite durable storage tests — no integration marker, uses :memory: DB +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def aiosqlite_persister(): + from burr.integrations.persisters.b_aiosqlite import AsyncSQLitePersister + + persister = await AsyncSQLitePersister.from_values(db_path=":memory:") + await persister.initialize() + yield persister + await persister.connection.close() + + +@pytest.mark.asyncio +async def test_aiosqlite_supports_durable_storage(aiosqlite_persister): + from burr.core.durable import supports_durable_storage + + assert supports_durable_storage(aiosqlite_persister) is True + + +@pytest.mark.asyncio +async def test_aiosqlite_suspension_round_trip(aiosqlite_persister): + await aiosqlite_persister.save_suspension(_record()) + loaded = await aiosqlite_persister.load_suspension("pk", "app", "approval") + assert loaded.suspension_id == "sus-1" + assert loaded.state == {"draft": "d"} + assert loaded.inputs == {"x": 1} + assert loaded.schema_json == {"type": "object"} + assert loaded.resolved is False + + +@pytest.mark.asyncio +async def test_aiosqlite_load_suspension_returns_resolved_record(aiosqlite_persister): + # Contract: load_suspension returns the record whether or not it is + # resolved; the caller checks record.resolved for resume-once idempotency. + await aiosqlite_persister.save_suspension(_record()) + await aiosqlite_persister.mark_suspension_resolved("sus-1") + loaded = await aiosqlite_persister.load_suspension("pk", "app", "approval") + assert loaded is not None + assert loaded.resolved is True + + +@pytest.mark.asyncio +async def test_aiosqlite_mark_resolved_is_conditional(aiosqlite_persister): + await aiosqlite_persister.save_suspension(_record()) + first = await aiosqlite_persister.mark_suspension_resolved("sus-1") + second = await aiosqlite_persister.mark_suspension_resolved("sus-1") + # First call resolves a row; second call resolves nothing (resume-once). + assert first is True + assert second is False + + +@pytest.mark.asyncio +async def test_aiosqlite_journal_round_trip(aiosqlite_persister): + await aiosqlite_persister.save_journal_entry( + JournalEntry("pk", "app", 4, "summarize", 0, "result-a") + ) + await aiosqlite_persister.save_journal_entry( + JournalEntry("pk", "app", 4, "translate", 1, "result-b") + ) + journal = await aiosqlite_persister.load_journal("pk", "app", 4) + assert [e.call_index for e in journal] == [0, 1] + assert [e.result for e in journal] == ["result-a", "result-b"] + + +# --------------------------------------------------------------------------- +# Redis durable storage tests — skipped unless BURR_CI_INTEGRATION_TESTS=true +# --------------------------------------------------------------------------- + + +@pytest.fixture +def redis_persister(): + from burr.integrations.persisters.b_redis import RedisBasePersister + + persister = RedisBasePersister.from_values( + host=os.environ.get("REDIS_HOST", "localhost"), + port=int(os.environ.get("REDIS_PORT", "6379")), + db=int(os.environ.get("REDIS_DB", "15")), + ) + persister.connection.flushdb() + yield persister + persister.connection.flushdb() + persister.connection.close() + + +@_pg_integration +def test_redis_supports_durable_storage(redis_persister): + assert supports_durable_storage(redis_persister) is True + + +@_pg_integration +def test_redis_suspension_round_trip(redis_persister): + redis_persister.save_suspension(_record()) + loaded = redis_persister.load_suspension("pk", "app", "approval") + assert loaded.suspension_id == "sus-1" + assert loaded.state == {"draft": "d"} + assert loaded.inputs == {"x": 1} + assert loaded.schema_json == {"type": "object"} + assert loaded.resolved is False + + +@_pg_integration +def test_redis_load_suspension_returns_resolved_record(redis_persister): + # Contract: load_suspension returns the record whether or not it is + # resolved; the caller checks record.resolved for resume-once idempotency. + redis_persister.save_suspension(_record()) + redis_persister.mark_suspension_resolved("sus-1") + loaded = redis_persister.load_suspension("pk", "app", "approval") + assert loaded is not None + assert loaded.resolved is True + + +@_pg_integration +def test_redis_mark_resolved_is_conditional(redis_persister): + redis_persister.save_suspension(_record()) + first = redis_persister.mark_suspension_resolved("sus-1") + second = redis_persister.mark_suspension_resolved("sus-1") + # First call resolves; second call is a no-op (resume-once). + assert first is True + assert second is False + + +@_pg_integration +def test_redis_journal_round_trip(redis_persister): + redis_persister.save_journal_entry( + JournalEntry("pk", "app", 4, "summarize", 0, "result-a") + ) + redis_persister.save_journal_entry( + JournalEntry("pk", "app", 4, "translate", 1, "result-b") + ) + journal = redis_persister.load_journal("pk", "app", 4) + assert [e.call_index for e in journal] == [0, 1] + assert [e.result for e in journal] == ["result-a", "result-b"] + + +# --------------------------------------------------------------------------- +# Async Redis durable storage tests — skipped unless BURR_CI_INTEGRATION_TESTS=true +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def async_redis_persister(): + from burr.integrations.persisters.b_redis import AsyncRedisBasePersister + + persister = AsyncRedisBasePersister.from_values( + host=os.environ.get("REDIS_HOST", "localhost"), + port=int(os.environ.get("REDIS_PORT", "6379")), + db=int(os.environ.get("REDIS_DB", "15")), + ) + await persister.connection.flushdb() + yield persister + await persister.connection.flushdb() + await persister.connection.aclose() + + +@_pg_integration +@pytest.mark.asyncio +async def test_async_redis_supports_durable_storage(async_redis_persister): + assert supports_durable_storage(async_redis_persister) is True + + +@_pg_integration +@pytest.mark.asyncio +async def test_async_redis_suspension_round_trip(async_redis_persister): + await async_redis_persister.save_suspension(_record()) + loaded = await async_redis_persister.load_suspension("pk", "app", "approval") + assert loaded.suspension_id == "sus-1" + assert loaded.state == {"draft": "d"} + assert loaded.inputs == {"x": 1} + assert loaded.schema_json == {"type": "object"} + assert loaded.resolved is False + + +@_pg_integration +@pytest.mark.asyncio +async def test_async_redis_load_suspension_returns_resolved_record(async_redis_persister): + # Contract: load_suspension returns the record whether or not it is + # resolved; the caller checks record.resolved for resume-once idempotency. + await async_redis_persister.save_suspension(_record()) + await async_redis_persister.mark_suspension_resolved("sus-1") + loaded = await async_redis_persister.load_suspension("pk", "app", "approval") + assert loaded is not None + assert loaded.resolved is True + + +@_pg_integration +@pytest.mark.asyncio +async def test_async_redis_mark_resolved_is_conditional(async_redis_persister): + await async_redis_persister.save_suspension(_record()) + first = await async_redis_persister.mark_suspension_resolved("sus-1") + second = await async_redis_persister.mark_suspension_resolved("sus-1") + # First call resolves; second call is a no-op (resume-once). + assert first is True + assert second is False + + +@_pg_integration +@pytest.mark.asyncio +async def test_async_redis_journal_round_trip(async_redis_persister): + await async_redis_persister.save_journal_entry( + JournalEntry("pk", "app", 4, "summarize", 0, "result-a") + ) + await async_redis_persister.save_journal_entry( + JournalEntry("pk", "app", 4, "translate", 1, "result-b") + ) + journal = await async_redis_persister.load_journal("pk", "app", 4) + assert [e.call_index for e in journal] == [0, 1] + assert [e.result for e in journal] == ["result-a", "result-b"] + + +# --------------------------------------------------------------------------- +# MongoDB (pymongo) durable storage tests — skipped unless BURR_CI_INTEGRATION_TESTS=true +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mongo_persister(): + from burr.integrations.persisters.b_pymongo import MongoDBBasePersister + from pymongo import MongoClient + + client = MongoClient(os.environ.get("MONGO_URI", "mongodb://localhost:27017")) + db_name = os.environ.get("MONGO_DB", "burr_durable_test") + # Drop on setup too: absorbs leftover state from a prior run that + # terminated before teardown (OOM, Ctrl-C). Otherwise a stale + # resolved=True row poisons mark_resolved_is_conditional. + client.drop_database(db_name) + persister = MongoDBBasePersister( + client=client, db_name=db_name, collection_name="burr_state_durable_test" + ) + persister.initialize() + yield persister + client.drop_database(db_name) + client.close() + + +@_pg_integration +def test_pymongo_supports_durable_storage(mongo_persister): + assert supports_durable_storage(mongo_persister) is True + + +@_pg_integration +def test_pymongo_suspension_round_trip(mongo_persister): + mongo_persister.save_suspension(_record()) + loaded = mongo_persister.load_suspension("pk", "app", "approval") + assert loaded.suspension_id == "sus-1" + assert loaded.state == {"draft": "d"} + assert loaded.inputs == {"x": 1} + assert loaded.schema_json == {"type": "object"} + assert loaded.resolved is False + + +@_pg_integration +def test_pymongo_load_suspension_returns_resolved_record(mongo_persister): + # Contract: load_suspension returns the record whether or not it is + # resolved; the caller checks record.resolved for resume-once idempotency. + mongo_persister.save_suspension(_record()) + mongo_persister.mark_suspension_resolved("sus-1") + loaded = mongo_persister.load_suspension("pk", "app", "approval") + assert loaded is not None + assert loaded.resolved is True + + +@_pg_integration +def test_pymongo_mark_resolved_is_conditional(mongo_persister): + mongo_persister.save_suspension(_record()) + first = mongo_persister.mark_suspension_resolved("sus-1") + second = mongo_persister.mark_suspension_resolved("sus-1") + # First call resolves a row; second call resolves nothing (resume-once). + assert first is True + assert second is False + + +@_pg_integration +def test_pymongo_journal_round_trip(mongo_persister): + mongo_persister.save_journal_entry( + JournalEntry("pk", "app", 4, "summarize", 0, "result-a") + ) + mongo_persister.save_journal_entry( + JournalEntry("pk", "app", 4, "translate", 1, "result-b") + ) + journal = mongo_persister.load_journal("pk", "app", 4) + assert [e.call_index for e in journal] == [0, 1] + assert [e.result for e in journal] == ["result-a", "result-b"] + + +def test_deprecated_mongodb_shim_inherits_durable_storage(): + """The deprecated ``burr.integrations.persisters.b_mongodb.MongoDBBasePersister`` + is a subclass of the canonical pymongo persister, so it must inherit the + durable-storage overrides without re-declaring them.""" + from unittest.mock import MagicMock + + from burr.integrations.persisters.b_mongodb import ( + MongoDBBasePersister as DeprecatedMongoShim, + ) + + client = MagicMock() + instance = DeprecatedMongoShim(client=client, db_name="x", collection_name="y") + assert supports_durable_storage(instance) is True + + +def test_deprecated_postgresql_shim_inherits_durable_storage(): + """The deprecated ``burr.integrations.persisters.postgresql.PostgreSQLPersister`` + is a subclass of the canonical psycopg2 persister, so it must inherit the + durable-storage overrides without re-declaring them. We don't connect to a + real database here, only confirm ``supports_durable_storage`` is True on a + no-arg instance constructed with a dummy connection.""" + from unittest.mock import MagicMock + + from burr.integrations.persisters.postgresql import ( + PostgreSQLPersister as DeprecatedShim, + ) + + instance = DeprecatedShim(connection=MagicMock(), table_name="burr_state_shim_test") + assert supports_durable_storage(instance) is True diff --git a/tests/lifecycle/test_lifecycle.py b/tests/lifecycle/test_lifecycle.py new file mode 100644 index 000000000..cbb39db7c --- /dev/null +++ b/tests/lifecycle/test_lifecycle.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +def test_suspend_resume_hooks_exist_and_are_exported(): + from burr.lifecycle import ( + PostActionSuspendHook, + PostActionSuspendHookAsync, + PreActionResumeHook, + PreActionResumeHookAsync, + ) + + assert hasattr(PostActionSuspendHook, "post_action_suspend") + assert hasattr(PostActionSuspendHookAsync, "post_action_suspend") + assert hasattr(PreActionResumeHook, "pre_action_resume") + assert hasattr(PreActionResumeHookAsync, "pre_action_resume") diff --git a/tests/tracking/test_local_tracking_client.py b/tests/tracking/test_local_tracking_client.py index 7a8196c0e..8e9b088fd 100644 --- a/tests/tracking/test_local_tracking_client.py +++ b/tests/tracking/test_local_tracking_client.py @@ -494,3 +494,49 @@ def test_local_tracking_client_copy(): assert copy.project_id == tracking_client.project_id assert copy.serde_kwargs == tracking_client.serde_kwargs assert copy.storage_dir == tracking_client.storage_dir + + +def test_application_tracks_suspended_run(tmpdir: str): + """Tests that LocalTrackingClient writes a suspend_entry line when an action suspends.""" + from burr.core.persistence import InMemoryPersister + from burr.tracking.common.models import SuspendEntryModel + + app_id = str(uuid.uuid4()) + log_dir = os.path.join(tmpdir, "tracking") + project_name = "test_application_tracks_suspended_run" + + @action(reads=[], writes=[]) + def suspending_action(state: State, __context) -> State: + __context.suspend("approval", metadata={"reason": "needs review"}) + return state # never reached + + tracker = LocalTrackingClient(project=project_name, storage_dir=log_dir) + app = ( + ApplicationBuilder() + .with_actions(suspending_action) + .with_transitions(("suspending_action", "suspending_action", default)) + .with_entrypoint("suspending_action") + .with_state(State({})) + .with_identifiers(app_id=app_id) + .with_tracker(tracker) + .with_state_persister(InMemoryPersister()) + .build() + ) + app.run(halt_after=["suspending_action"]) + + results_dir = os.path.join(log_dir, project_name, app_id) + log_output = os.path.join(results_dir, LocalTrackingClient.LOG_FILENAME) + assert os.path.exists(log_output) + + with open(log_output) as f: + log_contents = [json.loads(line) for line in f.readlines()] + + suspend_entries = [ + SuspendEntryModel.model_validate(line) + for line in log_contents + if line["type"] == "suspend_entry" + ] + assert len(suspend_entries) >= 1 + entry = suspend_entries[0] + assert entry.channel == "approval" + assert entry.metadata == {"reason": "needs review"}