|
@@ -1,12 +1,20 @@
|
|
|
import logging
|
|
|
+import uuid
|
|
|
from collections.abc import Generator, Mapping, Sequence
|
|
|
+from concurrent.futures import Future, wait
|
|
|
from datetime import datetime, timezone
|
|
|
-from typing import Any, cast
|
|
|
+from queue import Empty, Queue
|
|
|
+from typing import TYPE_CHECKING, Any, Optional, cast
|
|
|
+
|
|
|
+from flask import Flask, current_app
|
|
|
|
|
|
from configs import dify_config
|
|
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
|
|
-from core.variables import IntegerSegment
|
|
|
-from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
|
|
+from core.workflow.entities.node_entities import (
|
|
|
+ NodeRunMetadataKey,
|
|
|
+ NodeRunResult,
|
|
|
+)
|
|
|
+from core.workflow.entities.variable_pool import VariablePool
|
|
|
from core.workflow.graph_engine.entities.event import (
|
|
|
BaseGraphEvent,
|
|
|
BaseNodeEvent,
|
|
@@ -17,6 +25,9 @@ from core.workflow.graph_engine.entities.event import (
|
|
|
IterationRunNextEvent,
|
|
|
IterationRunStartedEvent,
|
|
|
IterationRunSucceededEvent,
|
|
|
+ NodeInIterationFailedEvent,
|
|
|
+ NodeRunFailedEvent,
|
|
|
+ NodeRunStartedEvent,
|
|
|
NodeRunStreamChunkEvent,
|
|
|
NodeRunSucceededEvent,
|
|
|
)
|
|
@@ -24,9 +35,11 @@ from core.workflow.graph_engine.entities.graph import Graph
|
|
|
from core.workflow.nodes.base import BaseNode
|
|
|
from core.workflow.nodes.enums import NodeType
|
|
|
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
|
|
-from core.workflow.nodes.iteration.entities import IterationNodeData
|
|
|
+from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
|
|
from models.workflow import WorkflowNodeExecutionStatus
|
|
|
|
|
|
+if TYPE_CHECKING:
|
|
|
+ from core.workflow.graph_engine.graph_engine import GraphEngine
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
@@ -38,6 +51,17 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|
|
_node_data_cls = IterationNodeData
|
|
|
_node_type = NodeType.ITERATION
|
|
|
|
|
|
+ @classmethod
|
|
|
+ def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
|
|
+ return {
|
|
|
+ "type": "iteration",
|
|
|
+ "config": {
|
|
|
+ "is_parallel": False,
|
|
|
+ "parallel_nums": 10,
|
|
|
+ "error_handle_mode": ErrorHandleMode.TERMINATED.value,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
|
|
"""
|
|
|
Run the node.
|
|
@@ -83,7 +107,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|
|
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
|
|
|
|
|
|
# init graph engine
|
|
|
- from core.workflow.graph_engine.graph_engine import GraphEngine
|
|
|
+ from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
|
|
|
|
|
|
graph_engine = GraphEngine(
|
|
|
tenant_id=self.tenant_id,
|
|
@@ -123,108 +147,64 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|
|
index=0,
|
|
|
pre_iteration_output=None,
|
|
|
)
|
|
|
-
|
|
|
outputs: list[Any] = []
|
|
|
try:
|
|
|
- for _ in range(len(iterator_list_value)):
|
|
|
- # run workflow
|
|
|
- rst = graph_engine.run()
|
|
|
- for event in rst:
|
|
|
- if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
|
|
- event.in_iteration_id = self.node_id
|
|
|
-
|
|
|
- if (
|
|
|
- isinstance(event, BaseNodeEvent)
|
|
|
- and event.node_type == NodeType.ITERATION_START
|
|
|
- and not isinstance(event, NodeRunStreamChunkEvent)
|
|
|
- ):
|
|
|
- continue
|
|
|
-
|
|
|
- if isinstance(event, NodeRunSucceededEvent):
|
|
|
- if event.route_node_state.node_run_result:
|
|
|
- metadata = event.route_node_state.node_run_result.metadata
|
|
|
- if not metadata:
|
|
|
- metadata = {}
|
|
|
-
|
|
|
- if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
|
|
- metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
|
|
- index_variable = variable_pool.get([self.node_id, "index"])
|
|
|
- if not isinstance(index_variable, IntegerSegment):
|
|
|
- yield RunCompletedEvent(
|
|
|
- run_result=NodeRunResult(
|
|
|
- status=WorkflowNodeExecutionStatus.FAILED,
|
|
|
- error=f"Invalid index variable type: {type(index_variable)}",
|
|
|
- )
|
|
|
- )
|
|
|
- return
|
|
|
- metadata[NodeRunMetadataKey.ITERATION_INDEX] = index_variable.value
|
|
|
- event.route_node_state.node_run_result.metadata = metadata
|
|
|
-
|
|
|
- yield event
|
|
|
- elif isinstance(event, BaseGraphEvent):
|
|
|
- if isinstance(event, GraphRunFailedEvent):
|
|
|
- # iteration run failed
|
|
|
- yield IterationRunFailedEvent(
|
|
|
- iteration_id=self.id,
|
|
|
- iteration_node_id=self.node_id,
|
|
|
- iteration_node_type=self.node_type,
|
|
|
- iteration_node_data=self.node_data,
|
|
|
- start_at=start_at,
|
|
|
- inputs=inputs,
|
|
|
- outputs={"output": jsonable_encoder(outputs)},
|
|
|
- steps=len(iterator_list_value),
|
|
|
- metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
|
|
- error=event.error,
|
|
|
- )
|
|
|
-
|
|
|
- yield RunCompletedEvent(
|
|
|
- run_result=NodeRunResult(
|
|
|
- status=WorkflowNodeExecutionStatus.FAILED,
|
|
|
- error=event.error,
|
|
|
- )
|
|
|
- )
|
|
|
- return
|
|
|
- else:
|
|
|
- event = cast(InNodeEvent, event)
|
|
|
+ if self.node_data.is_parallel:
|
|
|
+ futures: list[Future] = []
|
|
|
+ q = Queue()
|
|
|
+ thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100)
|
|
|
+ for index, item in enumerate(iterator_list_value):
|
|
|
+ future: Future = thread_pool.submit(
|
|
|
+ self._run_single_iter_parallel,
|
|
|
+ current_app._get_current_object(),
|
|
|
+ q,
|
|
|
+ iterator_list_value,
|
|
|
+ inputs,
|
|
|
+ outputs,
|
|
|
+ start_at,
|
|
|
+ graph_engine,
|
|
|
+ iteration_graph,
|
|
|
+ index,
|
|
|
+ item,
|
|
|
+ )
|
|
|
+ future.add_done_callback(thread_pool.task_done_callback)
|
|
|
+ futures.append(future)
|
|
|
+ succeeded_count = 0
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ event = q.get(timeout=1)
|
|
|
+ if event is None:
|
|
|
+ break
|
|
|
+ if isinstance(event, IterationRunNextEvent):
|
|
|
+ succeeded_count += 1
|
|
|
+ if succeeded_count == len(futures):
|
|
|
+ q.put(None)
|
|
|
yield event
|
|
|
+ if isinstance(event, RunCompletedEvent):
|
|
|
+ q.put(None)
|
|
|
+ for f in futures:
|
|
|
+ if not f.done():
|
|
|
+ f.cancel()
|
|
|
+ yield event
|
|
|
+ if isinstance(event, IterationRunFailedEvent):
|
|
|
+ q.put(None)
|
|
|
+ yield event
|
|
|
+ except Empty:
|
|
|
+ continue
|
|
|
|
|
|
- # append to iteration output variable list
|
|
|
- current_iteration_output_variable = variable_pool.get(self.node_data.output_selector)
|
|
|
- if current_iteration_output_variable is None:
|
|
|
- yield RunCompletedEvent(
|
|
|
- run_result=NodeRunResult(
|
|
|
- status=WorkflowNodeExecutionStatus.FAILED,
|
|
|
- error=f"Iteration output variable {self.node_data.output_selector} not found",
|
|
|
- )
|
|
|
+ # wait all threads
|
|
|
+ wait(futures)
|
|
|
+ else:
|
|
|
+ for _ in range(len(iterator_list_value)):
|
|
|
+ yield from self._run_single_iter(
|
|
|
+ iterator_list_value,
|
|
|
+ variable_pool,
|
|
|
+ inputs,
|
|
|
+ outputs,
|
|
|
+ start_at,
|
|
|
+ graph_engine,
|
|
|
+ iteration_graph,
|
|
|
)
|
|
|
- return
|
|
|
- current_iteration_output = current_iteration_output_variable.to_object()
|
|
|
- outputs.append(current_iteration_output)
|
|
|
-
|
|
|
- # remove all nodes outputs from variable pool
|
|
|
- for node_id in iteration_graph.node_ids:
|
|
|
- variable_pool.remove([node_id])
|
|
|
-
|
|
|
- # move to next iteration
|
|
|
- current_index_variable = variable_pool.get([self.node_id, "index"])
|
|
|
- if not isinstance(current_index_variable, IntegerSegment):
|
|
|
- raise ValueError(f"iteration {self.node_id} current index not found")
|
|
|
-
|
|
|
- next_index = current_index_variable.value + 1
|
|
|
- variable_pool.add([self.node_id, "index"], next_index)
|
|
|
-
|
|
|
- if next_index < len(iterator_list_value):
|
|
|
- variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
|
|
-
|
|
|
- yield IterationRunNextEvent(
|
|
|
- iteration_id=self.id,
|
|
|
- iteration_node_id=self.node_id,
|
|
|
- iteration_node_type=self.node_type,
|
|
|
- iteration_node_data=self.node_data,
|
|
|
- index=next_index,
|
|
|
- pre_iteration_output=jsonable_encoder(current_iteration_output),
|
|
|
- )
|
|
|
-
|
|
|
yield IterationRunSucceededEvent(
|
|
|
iteration_id=self.id,
|
|
|
iteration_node_id=self.node_id,
|
|
@@ -330,3 +310,231 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|
|
}
|
|
|
|
|
|
return variable_mapping
|
|
|
+
|
|
|
+ def _handle_event_metadata(
|
|
|
+ self, event: BaseNodeEvent, iter_run_index: str, parallel_mode_run_id: str
|
|
|
+ ) -> NodeRunStartedEvent | BaseNodeEvent:
|
|
|
+ """
|
|
|
+ add iteration metadata to event.
|
|
|
+ """
|
|
|
+ if not isinstance(event, BaseNodeEvent):
|
|
|
+ return event
|
|
|
+ if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
|
|
|
+ event.parallel_mode_run_id = parallel_mode_run_id
|
|
|
+ return event
|
|
|
+ if event.route_node_state.node_run_result:
|
|
|
+ metadata = event.route_node_state.node_run_result.metadata
|
|
|
+ if not metadata:
|
|
|
+ metadata = {}
|
|
|
+
|
|
|
+ if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
|
|
+ metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
|
|
+ if self.node_data.is_parallel:
|
|
|
+ metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
|
|
|
+ else:
|
|
|
+ metadata[NodeRunMetadataKey.ITERATION_INDEX] = iter_run_index
|
|
|
+ event.route_node_state.node_run_result.metadata = metadata
|
|
|
+ return event
|
|
|
+
|
|
|
+ def _run_single_iter(
|
|
|
+ self,
|
|
|
+ iterator_list_value: list[str],
|
|
|
+ variable_pool: VariablePool,
|
|
|
+ inputs: dict[str, list],
|
|
|
+ outputs: list,
|
|
|
+ start_at: datetime,
|
|
|
+ graph_engine: "GraphEngine",
|
|
|
+ iteration_graph: Graph,
|
|
|
+ parallel_mode_run_id: Optional[str] = None,
|
|
|
+ ) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
|
|
+ """
|
|
|
+ run single iteration
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ rst = graph_engine.run()
|
|
|
+ # get current iteration index
|
|
|
+ current_index = variable_pool.get([self.node_id, "index"]).value
|
|
|
+ next_index = int(current_index) + 1
|
|
|
+
|
|
|
+ if current_index is None:
|
|
|
+ raise ValueError(f"iteration {self.node_id} current index not found")
|
|
|
+ for event in rst:
|
|
|
+ if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
|
|
+ event.in_iteration_id = self.node_id
|
|
|
+
|
|
|
+ if (
|
|
|
+ isinstance(event, BaseNodeEvent)
|
|
|
+ and event.node_type == NodeType.ITERATION_START
|
|
|
+ and not isinstance(event, NodeRunStreamChunkEvent)
|
|
|
+ ):
|
|
|
+ continue
|
|
|
+
|
|
|
+ if isinstance(event, NodeRunSucceededEvent):
|
|
|
+ yield self._handle_event_metadata(event, current_index, parallel_mode_run_id)
|
|
|
+ elif isinstance(event, BaseGraphEvent):
|
|
|
+ if isinstance(event, GraphRunFailedEvent):
|
|
|
+ # iteration run failed
|
|
|
+ if self.node_data.is_parallel:
|
|
|
+ yield IterationRunFailedEvent(
|
|
|
+ iteration_id=self.id,
|
|
|
+ iteration_node_id=self.node_id,
|
|
|
+ iteration_node_type=self.node_type,
|
|
|
+ iteration_node_data=self.node_data,
|
|
|
+ parallel_mode_run_id=parallel_mode_run_id,
|
|
|
+ start_at=start_at,
|
|
|
+ inputs=inputs,
|
|
|
+ outputs={"output": jsonable_encoder(outputs)},
|
|
|
+ steps=len(iterator_list_value),
|
|
|
+ metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
|
|
+ error=event.error,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ yield IterationRunFailedEvent(
|
|
|
+ iteration_id=self.id,
|
|
|
+ iteration_node_id=self.node_id,
|
|
|
+ iteration_node_type=self.node_type,
|
|
|
+ iteration_node_data=self.node_data,
|
|
|
+ start_at=start_at,
|
|
|
+ inputs=inputs,
|
|
|
+ outputs={"output": jsonable_encoder(outputs)},
|
|
|
+ steps=len(iterator_list_value),
|
|
|
+ metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
|
|
+ error=event.error,
|
|
|
+ )
|
|
|
+ yield RunCompletedEvent(
|
|
|
+ run_result=NodeRunResult(
|
|
|
+ status=WorkflowNodeExecutionStatus.FAILED,
|
|
|
+ error=event.error,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ return
|
|
|
+ else:
|
|
|
+ event = cast(InNodeEvent, event)
|
|
|
+ metadata_event = self._handle_event_metadata(event, current_index, parallel_mode_run_id)
|
|
|
+ if isinstance(event, NodeRunFailedEvent):
|
|
|
+ if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
|
|
|
+ yield NodeInIterationFailedEvent(
|
|
|
+ **metadata_event.model_dump(),
|
|
|
+ )
|
|
|
+ outputs.insert(current_index, None)
|
|
|
+ variable_pool.add([self.node_id, "index"], next_index)
|
|
|
+ if next_index < len(iterator_list_value):
|
|
|
+ variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
|
|
+ yield IterationRunNextEvent(
|
|
|
+ iteration_id=self.id,
|
|
|
+ iteration_node_id=self.node_id,
|
|
|
+ iteration_node_type=self.node_type,
|
|
|
+ iteration_node_data=self.node_data,
|
|
|
+ index=next_index,
|
|
|
+ parallel_mode_run_id=parallel_mode_run_id,
|
|
|
+ pre_iteration_output=None,
|
|
|
+ )
|
|
|
+ return
|
|
|
+ elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
|
|
+ yield NodeInIterationFailedEvent(
|
|
|
+ **metadata_event.model_dump(),
|
|
|
+ )
|
|
|
+ variable_pool.add([self.node_id, "index"], next_index)
|
|
|
+
|
|
|
+ if next_index < len(iterator_list_value):
|
|
|
+ variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
|
|
+ yield IterationRunNextEvent(
|
|
|
+ iteration_id=self.id,
|
|
|
+ iteration_node_id=self.node_id,
|
|
|
+ iteration_node_type=self.node_type,
|
|
|
+ iteration_node_data=self.node_data,
|
|
|
+ index=next_index,
|
|
|
+ parallel_mode_run_id=parallel_mode_run_id,
|
|
|
+ pre_iteration_output=None,
|
|
|
+ )
|
|
|
+ return
|
|
|
+ elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
|
|
|
+ yield IterationRunFailedEvent(
|
|
|
+ iteration_id=self.id,
|
|
|
+ iteration_node_id=self.node_id,
|
|
|
+ iteration_node_type=self.node_type,
|
|
|
+ iteration_node_data=self.node_data,
|
|
|
+ start_at=start_at,
|
|
|
+ inputs=inputs,
|
|
|
+ outputs={"output": None},
|
|
|
+ steps=len(iterator_list_value),
|
|
|
+ metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
|
|
+ error=event.error,
|
|
|
+ )
|
|
|
+ yield metadata_event
|
|
|
+
|
|
|
+ current_iteration_output = variable_pool.get(self.node_data.output_selector).value
|
|
|
+ outputs.insert(current_index, current_iteration_output)
|
|
|
+ # remove all nodes outputs from variable pool
|
|
|
+ for node_id in iteration_graph.node_ids:
|
|
|
+ variable_pool.remove([node_id])
|
|
|
+
|
|
|
+ # move to next iteration
|
|
|
+ variable_pool.add([self.node_id, "index"], next_index)
|
|
|
+
|
|
|
+ if next_index < len(iterator_list_value):
|
|
|
+ variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
|
|
+ yield IterationRunNextEvent(
|
|
|
+ iteration_id=self.id,
|
|
|
+ iteration_node_id=self.node_id,
|
|
|
+ iteration_node_type=self.node_type,
|
|
|
+ iteration_node_data=self.node_data,
|
|
|
+ index=next_index,
|
|
|
+ parallel_mode_run_id=parallel_mode_run_id,
|
|
|
+ pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None,
|
|
|
+ )
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception(f"Iteration run failed:{str(e)}")
|
|
|
+ yield IterationRunFailedEvent(
|
|
|
+ iteration_id=self.id,
|
|
|
+ iteration_node_id=self.node_id,
|
|
|
+ iteration_node_type=self.node_type,
|
|
|
+ iteration_node_data=self.node_data,
|
|
|
+ start_at=start_at,
|
|
|
+ inputs=inputs,
|
|
|
+ outputs={"output": None},
|
|
|
+ steps=len(iterator_list_value),
|
|
|
+ metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
|
|
+ error=str(e),
|
|
|
+ )
|
|
|
+ yield RunCompletedEvent(
|
|
|
+ run_result=NodeRunResult(
|
|
|
+ status=WorkflowNodeExecutionStatus.FAILED,
|
|
|
+ error=str(e),
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ def _run_single_iter_parallel(
|
|
|
+ self,
|
|
|
+ flask_app: Flask,
|
|
|
+ q: Queue,
|
|
|
+ iterator_list_value: list[str],
|
|
|
+ inputs: dict[str, list],
|
|
|
+ outputs: list,
|
|
|
+ start_at: datetime,
|
|
|
+ graph_engine: "GraphEngine",
|
|
|
+ iteration_graph: Graph,
|
|
|
+ index: int,
|
|
|
+ item: Any,
|
|
|
+ ) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
|
|
+ """
|
|
|
+ run single iteration in parallel mode
|
|
|
+ """
|
|
|
+ with flask_app.app_context():
|
|
|
+ parallel_mode_run_id = uuid.uuid4().hex
|
|
|
+ graph_engine_copy = graph_engine.create_copy()
|
|
|
+ variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool
|
|
|
+ variable_pool_copy.add([self.node_id, "index"], index)
|
|
|
+ variable_pool_copy.add([self.node_id, "item"], item)
|
|
|
+ for event in self._run_single_iter(
|
|
|
+ iterator_list_value=iterator_list_value,
|
|
|
+ variable_pool=variable_pool_copy,
|
|
|
+ inputs=inputs,
|
|
|
+ outputs=outputs,
|
|
|
+ start_at=start_at,
|
|
|
+ graph_engine=graph_engine_copy,
|
|
|
+ iteration_graph=iteration_graph,
|
|
|
+ parallel_mode_run_id=parallel_mode_run_id,
|
|
|
+ ):
|
|
|
+ q.put(event)
|