|
@@ -28,11 +28,13 @@ from core.app.entities.task_entities import (
|
|
|
WorkflowAppBlockingResponse,
|
|
|
WorkflowAppStreamResponse,
|
|
|
WorkflowFinishStreamResponse,
|
|
|
+ WorkflowStreamGenerateNodes,
|
|
|
WorkflowTaskState,
|
|
|
)
|
|
|
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
|
|
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
|
|
-from core.workflow.entities.node_entities import SystemVariable
|
|
|
+from core.workflow.entities.node_entities import NodeType, SystemVariable
|
|
|
+from core.workflow.nodes.end.end_node import EndNode
|
|
|
from extensions.ext_database import db
|
|
|
from models.account import Account
|
|
|
from models.model import EndUser
|
|
@@ -40,6 +42,7 @@ from models.workflow import (
|
|
|
Workflow,
|
|
|
WorkflowAppLog,
|
|
|
WorkflowAppLogCreatedFrom,
|
|
|
+ WorkflowNodeExecution,
|
|
|
WorkflowRun,
|
|
|
)
|
|
|
|
|
@@ -83,6 +86,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|
|
}
|
|
|
|
|
|
self._task_state = WorkflowTaskState()
|
|
|
+ self._stream_generate_nodes = self._get_stream_generate_nodes()
|
|
|
|
|
|
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
|
|
"""
|
|
@@ -167,6 +171,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|
|
)
|
|
|
elif isinstance(event, QueueNodeStartedEvent):
|
|
|
workflow_node_execution = self._handle_node_start(event)
|
|
|
+
|
|
|
+ # search stream_generate_routes if node id is answer start at node
|
|
|
+ if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_nodes:
|
|
|
+ self._task_state.current_stream_generate_state = self._stream_generate_nodes[event.node_id]
|
|
|
+
|
|
|
+ # generate stream outputs when node started
|
|
|
+ yield from self._generate_stream_outputs_when_node_started()
|
|
|
+
|
|
|
yield self._workflow_node_start_to_stream_response(
|
|
|
event=event,
|
|
|
task_id=self._application_generate_entity.task_id,
|
|
@@ -174,6 +186,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|
|
)
|
|
|
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
|
|
|
workflow_node_execution = self._handle_node_finished(event)
|
|
|
+
|
|
|
yield self._workflow_node_finish_to_stream_response(
|
|
|
task_id=self._application_generate_entity.task_id,
|
|
|
workflow_node_execution=workflow_node_execution
|
|
@@ -193,6 +206,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|
|
if delta_text is None:
|
|
|
continue
|
|
|
|
|
|
+ if not self._is_stream_out_support(
|
|
|
+ event=event
|
|
|
+ ):
|
|
|
+ continue
|
|
|
+
|
|
|
self._task_state.answer += delta_text
|
|
|
yield self._text_chunk_to_stream_response(delta_text)
|
|
|
elif isinstance(event, QueueMessageReplaceEvent):
|
|
@@ -254,3 +272,140 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|
|
task_id=self._application_generate_entity.task_id,
|
|
|
text=TextReplaceStreamResponse.Data(text=text)
|
|
|
)
|
|
|
+
|
|
|
+ def _get_stream_generate_nodes(self) -> dict[str, WorkflowStreamGenerateNodes]:
|
|
|
+ """
|
|
|
+ Get stream generate nodes.
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ # find all answer nodes
|
|
|
+ graph = self._workflow.graph_dict
|
|
|
+ end_node_configs = [
|
|
|
+ node for node in graph['nodes']
|
|
|
+ if node.get('data', {}).get('type') == NodeType.END.value
|
|
|
+ ]
|
|
|
+
|
|
|
+ # parse stream output node value selectors of end nodes
|
|
|
+ stream_generate_routes = {}
|
|
|
+ for node_config in end_node_configs:
|
|
|
+ # get generate route for stream output
|
|
|
+ end_node_id = node_config['id']
|
|
|
+ generate_nodes = EndNode.extract_generate_nodes(graph, node_config)
|
|
|
+ start_node_ids = self._get_end_start_at_node_ids(graph, end_node_id)
|
|
|
+ if not start_node_ids:
|
|
|
+ continue
|
|
|
+
|
|
|
+ for start_node_id in start_node_ids:
|
|
|
+ stream_generate_routes[start_node_id] = WorkflowStreamGenerateNodes(
|
|
|
+ end_node_id=end_node_id,
|
|
|
+ stream_node_ids=generate_nodes
|
|
|
+ )
|
|
|
+
|
|
|
+ return stream_generate_routes
|
|
|
+
|
|
|
+ def _get_end_start_at_node_ids(self, graph: dict, target_node_id: str) \
|
|
|
+ -> list[str]:
|
|
|
+ """
|
|
|
+ Get end start at node id.
|
|
|
+ :param graph: graph
|
|
|
+ :param target_node_id: target node ID
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ nodes = graph.get('nodes')
|
|
|
+ edges = graph.get('edges')
|
|
|
+
|
|
|
+ # fetch all ingoing edges from source node
|
|
|
+ ingoing_edges = []
|
|
|
+ for edge in edges:
|
|
|
+ if edge.get('target') == target_node_id:
|
|
|
+ ingoing_edges.append(edge)
|
|
|
+
|
|
|
+ if not ingoing_edges:
|
|
|
+ return []
|
|
|
+
|
|
|
+ start_node_ids = []
|
|
|
+ for ingoing_edge in ingoing_edges:
|
|
|
+ source_node_id = ingoing_edge.get('source')
|
|
|
+ source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
|
|
|
+ if not source_node:
|
|
|
+ continue
|
|
|
+
|
|
|
+ node_type = source_node.get('data', {}).get('type')
|
|
|
+ if node_type in [
|
|
|
+ NodeType.IF_ELSE.value,
|
|
|
+ NodeType.QUESTION_CLASSIFIER.value
|
|
|
+ ]:
|
|
|
+ start_node_id = target_node_id
|
|
|
+ start_node_ids.append(start_node_id)
|
|
|
+ elif node_type == NodeType.START.value:
|
|
|
+ start_node_id = source_node_id
|
|
|
+ start_node_ids.append(start_node_id)
|
|
|
+ else:
|
|
|
+ sub_start_node_ids = self._get_end_start_at_node_ids(graph, source_node_id)
|
|
|
+ if sub_start_node_ids:
|
|
|
+ start_node_ids.extend(sub_start_node_ids)
|
|
|
+
|
|
|
+ return start_node_ids
|
|
|
+
|
|
|
+ def _generate_stream_outputs_when_node_started(self) -> Generator:
|
|
|
+ """
|
|
|
+ Generate stream outputs.
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ if self._task_state.current_stream_generate_state:
|
|
|
+ stream_node_ids = self._task_state.current_stream_generate_state.stream_node_ids
|
|
|
+
|
|
|
+ for node_id, node_execution_info in self._task_state.ran_node_execution_infos.items():
|
|
|
+ if node_id not in stream_node_ids:
|
|
|
+ continue
|
|
|
+
|
|
|
+ node_execution_info = self._task_state.ran_node_execution_infos[node_id]
|
|
|
+
|
|
|
+ # get chunk node execution
|
|
|
+ route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
|
|
+ WorkflowNodeExecution.id == node_execution_info.workflow_node_execution_id).first()
|
|
|
+
|
|
|
+ if not route_chunk_node_execution:
|
|
|
+ continue
|
|
|
+
|
|
|
+ outputs = route_chunk_node_execution.outputs_dict
|
|
|
+
|
|
|
+ if not outputs:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # get value from outputs
|
|
|
+ text = outputs.get('text')
|
|
|
+
|
|
|
+ if text:
|
|
|
+ self._task_state.answer += text
|
|
|
+ yield self._text_chunk_to_stream_response(text)
|
|
|
+
|
|
|
+ def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
|
|
|
+ """
|
|
|
+ Is stream out support
|
|
|
+ :param event: queue text chunk event
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ if not event.metadata:
|
|
|
+ return False
|
|
|
+
|
|
|
+ if 'node_id' not in event.metadata:
|
|
|
+ return False
|
|
|
+
|
|
|
+ node_id = event.metadata.get('node_id')
|
|
|
+ node_type = event.metadata.get('node_type')
|
|
|
+ stream_output_value_selector = event.metadata.get('value_selector')
|
|
|
+ if not stream_output_value_selector:
|
|
|
+ return False
|
|
|
+
|
|
|
+ if not self._task_state.current_stream_generate_state:
|
|
|
+ return False
|
|
|
+
|
|
|
+ if node_id not in self._task_state.current_stream_generate_state.stream_node_ids:
|
|
|
+ return False
|
|
|
+
|
|
|
+ if node_type != NodeType.LLM:
|
|
|
+ # only LLM support chunk stream output
|
|
|
+ return False
|
|
|
+
|
|
|
+ return True
|