Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions core/workflow/engine/callbacks/callback_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,16 @@ class ChatCallBackStreamResult:
node_answer_content: LLMGenerate
"""Generated content from the node execution."""

ordered_stream_key: str = ""
"""Internal queue key used to isolate ordered streaming sessions."""

finish_reason: str = ""
"""Reason for node completion. 'stop' indicates normal completion, empty string otherwise."""

def __post_init__(self) -> None:
if not self.ordered_stream_key:
self.ordered_stream_key = self.node_id


class ChatCallBacks:
"""
Expand All @@ -74,6 +81,7 @@ def __init__(
chains: Chains,
event_id: str,
flow_id: str,
ordered_stream_namespace: str = "",
):
"""
Initialize the chat callback handler.
Expand All @@ -97,6 +105,7 @@ def __init__(
self.chains = chains
self.event_id = event_id
self.flow_id = flow_id
self.ordered_stream_namespace = ordered_stream_namespace

if chains:
self.all_simple_paths_node_cnt = chains.get_all_simple_paths_node_cnt()
Expand Down Expand Up @@ -411,13 +420,22 @@ async def _put_frame_into_queue(
await self.order_stream_result_q.put(
ChatCallBackStreamResult(
node_id=node_id,
ordered_stream_key=self._build_ordered_stream_key(node_id),
node_answer_content=resp,
finish_reason=finish_reason,
)
)
else:
await self.stream_queue.put(resp)

def _build_ordered_stream_key(self, node_id: str) -> str:
"""
Build the internal ordered-stream queue key for one node execution session.
"""
if not self.ordered_stream_namespace:
return node_id
return f"{self.ordered_stream_namespace}::{node_id}"


class ChatCallBackConsumer:
"""
Expand Down Expand Up @@ -464,11 +482,12 @@ async def consume(self) -> None:
result: ChatCallBackStreamResult = (
await self.need_order_stream_result_q.get()
)
if result.node_id not in self.support_stream_node_id_set:
await self._add_node_in_q(result.node_id)
if result.node_id not in self.structured_data:
self.structured_data[result.node_id] = Queue()
await self.structured_data[result.node_id].put(result)
stream_key = result.ordered_stream_key
if stream_key not in self.support_stream_node_id_set:
await self._add_node_in_q(stream_key)
if stream_key not in self.structured_data:
self.structured_data[stream_key] = Queue()
await self.structured_data[stream_key].put(result)
# Workflow execution completed
if (
result.node_id.split("::")[0] == NodeType.END.value
Expand Down Expand Up @@ -577,7 +596,7 @@ async def order_stream_output(self, node_id: str) -> None:
if isinstance(result, ChatCallBackStreamResult):
await self.stream_queue.put(result.node_answer_content)
if result.finish_reason == ChatStatus.FINISH_REASON.value:
self.support_stream_node_id_set.remove(node_id)
self.support_stream_node_id_set.discard(node_id)
self.structured_data.pop(node_id)
break
else:
Expand Down
133 changes: 114 additions & 19 deletions core/workflow/engine/dsl_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,13 +931,14 @@ async def _execute_single_node(
# Node start callback
await self._handle_node_start_callback(node)

# Execute node
run_result, fail_branch = await self._execute_node_with_retry(
node, span_context
)

# Mark node as complete
self.engine_ctx.node_run_status[node.node_id].complete.set()
try:
# Execute node
run_result, fail_branch = await self._execute_node_with_retry(
node, span_context
)
finally:
# Message/end nodes may be waiting on this event even when execution fails.
self.engine_ctx.node_run_status[node.node_id].complete.set()

# Get next batch of active nodes
next_active_nodes, next_inactive_nodes = await self._get_next_nodes(
Expand Down Expand Up @@ -1427,17 +1428,52 @@ async def _wait_all_tasks_completion(self, span: Span) -> None:
if not self.engine_ctx.dfs_tasks:
return

done, pending = await asyncio.wait(
self.engine_ctx.dfs_tasks, return_when=asyncio.FIRST_EXCEPTION
)
exceptions: List[Exception] = []
handled_tasks: Set[Task] = set()

# Cancel all pending tasks and ensure they are awaited
await self._cancel_pending_task(pending)
await self._process_dfs_tasks(exceptions, handled_tasks)
self._validate_responses(exceptions)
self._raise_exceptions_if_any(exceptions, span)

exceptions: List[Exception] = []
async def _process_dfs_tasks(
self, exceptions: List[Exception], handled_tasks: Set[Task]
) -> None:
"""
Process all DFS tasks until completion or first exception.
"""
while True:
current_tasks = {
task for task in self.engine_ctx.dfs_tasks if task not in handled_tasks
}
if not current_tasks:
break

# Check if completed tasks have exceptions
for task in done:
done, pending = await asyncio.wait(
current_tasks, return_when=asyncio.FIRST_EXCEPTION
)

handled_tasks.update(done)
await self._cancel_pending_task(pending)
handled_tasks.update(pending)

await self._collect_task_results(done, exceptions)

if exceptions:
await self._cleanup_remaining_tasks(handled_tasks)
handled_tasks.update(
task
for task in self.engine_ctx.dfs_tasks
if task not in handled_tasks and not task.done()
)
break

async def _collect_task_results(
self, done_tasks: Set[Task], exceptions: List[Exception]
) -> None:
"""
Collect results from completed tasks and handle end nodes.
"""
for task in done_tasks:
try:
if task.cancelled():
continue
Expand All @@ -1446,18 +1482,36 @@ async def _wait_all_tasks_completion(self, span: Span) -> None:
except Exception as e:
exceptions.append(e)

async def _cleanup_remaining_tasks(self, handled_tasks: Set[Task]) -> None:
"""
Cancel and await all remaining unhandled tasks.
"""
remaining_tasks = {
task
for task in self.engine_ctx.dfs_tasks
if task not in handled_tasks and not task.done()
}
await self._cancel_pending_task(remaining_tasks)

def _validate_responses(self, exceptions: List[Exception]) -> None:
"""
Validate that at least one response was collected.
"""
if not self.engine_ctx.responses:
exceptions.append(
CustomException(
CodeEnum.ENG_RUN_ERROR, err_msg="End node did not return result"
)
)

def _raise_exceptions_if_any(self, exceptions: List[Exception], span: Span) -> None:
"""
Record and raise the first exception if any occurred.
"""
if exceptions:
for exception in exceptions:
span.record_exception(exception)
raise exceptions[0]
return None

def _validate_start_node(self) -> None:
"""
Expand Down Expand Up @@ -1630,6 +1684,23 @@ async def _execute_message_node(
node = self.engine_ctx.built_nodes[msg_node_id]
strategy = self.strategy_manager.get_strategy(node.node_id.split("::")[0])
return await strategy.execute_node(node, self.engine_ctx, span_context)
except asyncio.CancelledError:
node = self.engine_ctx.built_nodes[msg_node_id]
result = NodeRunResult(
node_id=node.node_id,
alias_name=node.node_alias_name,
node_type=node.node_id.split("::")[0],
status=WorkflowNodeExecutionStatus.CANCELLED,
inputs={},
outputs={},
node_answer_content="",
)
await self.engine_ctx.callback.on_node_end(
node_id=node.node_id,
alias_name=node.node_alias_name,
message=result,
)
return result
finally:
self.engine_ctx.node_run_status[msg_node_id].complete.set()

Expand Down Expand Up @@ -1838,6 +1909,25 @@ def create_engine(

return builder.build()

@staticmethod
def create_iteration_engine(
sparkflow_dsl: WorkflowDSL,
iteration_node_id: str,
span: Span,
sub_dsl: WorkflowDSL | None = None,
) -> WorkflowEngine:
"""
Create an isolated workflow engine for a single iteration subgraph.

:param sparkflow_dsl: Full workflow DSL definition
:param iteration_node_id: Iteration container node ID
:param span: Tracing span for observability
:return: WorkflowEngine instance for the iteration subgraph
"""
if sub_dsl is None:
sub_dsl = sparkflow_dsl.extract_iteration_sub_dsl(iteration_node_id)
return WorkflowEngineFactory.create_engine(sub_dsl, span)

@staticmethod
def create_debug_node(
sparkflow_dsl: WorkflowDSL,
Expand Down Expand Up @@ -2099,6 +2189,8 @@ def _handle_special_node_types(

if node_type == NodeType.START.value:
self.start_node_id = node.id
elif node_type == NodeType.ITERATION_START.value and not self.start_node_id:
self.start_node_id = node.id
elif node_type == NodeType.DECISION_MAKING.value:
self._handle_decision_making_node(node.id, node)
elif node_type == NodeType.LLM.value:
Expand Down Expand Up @@ -2229,17 +2321,20 @@ def _build_data_dependencies(self) -> None:

inputs = node.data.inputs
for input_item in inputs:
ref_node_id = None
var_type = input_item.input_schema.value.type
if var_type == ValueType.LITERAL.value:
continue

content = input_item.input_schema.value.content
if isinstance(content, NodeRef):
ref_node_id = content.nodeId

if ref_node_id:
if (
ref_node_id
and ref_node_id != node.id
and node.id in self.msg_or_end_node_deps
):
self.msg_or_end_node_deps[node.id].data_dep.add(ref_node_id)

# Check if normal path exists
if self._has_normal_path(ref_node_id, node.id):
self.msg_or_end_node_deps[node.id].data_dep_path_info[
Expand Down
Loading
Loading