base_node.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. from abc import ABC, abstractmethod
  2. from collections.abc import Mapping, Sequence
  3. from enum import Enum
  4. from typing import Any, Optional
  5. from core.app.entities.app_invoke_entities import InvokeFrom
  6. from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
  7. from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
  8. from core.workflow.entities.node_entities import NodeRunResult, NodeType
  9. from core.workflow.entities.variable_pool import VariablePool
  10. from models import WorkflowNodeExecutionStatus
  11. class UserFrom(Enum):
  12. """
  13. User from
  14. """
  15. ACCOUNT = "account"
  16. END_USER = "end-user"
  17. @classmethod
  18. def value_of(cls, value: str) -> "UserFrom":
  19. """
  20. Value of
  21. :param value: value
  22. :return:
  23. """
  24. for item in cls:
  25. if item.value == value:
  26. return item
  27. raise ValueError(f"Invalid value: {value}")
  28. class BaseNode(ABC):
  29. _node_data_cls: type[BaseNodeData]
  30. _node_type: NodeType
  31. tenant_id: str
  32. app_id: str
  33. workflow_id: str
  34. user_id: str
  35. user_from: UserFrom
  36. invoke_from: InvokeFrom
  37. workflow_call_depth: int
  38. node_id: str
  39. node_data: BaseNodeData
  40. node_run_result: Optional[NodeRunResult] = None
  41. callbacks: Sequence[WorkflowCallback]
  42. is_answer_previous_node: bool = False
  43. def __init__(self, tenant_id: str,
  44. app_id: str,
  45. workflow_id: str,
  46. user_id: str,
  47. user_from: UserFrom,
  48. invoke_from: InvokeFrom,
  49. config: Mapping[str, Any],
  50. callbacks: Sequence[WorkflowCallback] | None = None,
  51. workflow_call_depth: int = 0) -> None:
  52. self.tenant_id = tenant_id
  53. self.app_id = app_id
  54. self.workflow_id = workflow_id
  55. self.user_id = user_id
  56. self.user_from = user_from
  57. self.invoke_from = invoke_from
  58. self.workflow_call_depth = workflow_call_depth
  59. # TODO: May need to check if key exists.
  60. self.node_id = config["id"]
  61. if not self.node_id:
  62. raise ValueError("Node ID is required.")
  63. self.node_data = self._node_data_cls(**config.get("data", {}))
  64. self.callbacks = callbacks or []
  65. @abstractmethod
  66. def _run(self, variable_pool: VariablePool) -> NodeRunResult:
  67. """
  68. Run node
  69. :param variable_pool: variable pool
  70. :return:
  71. """
  72. raise NotImplementedError
  73. def run(self, variable_pool: VariablePool) -> NodeRunResult:
  74. """
  75. Run node entry
  76. :param variable_pool: variable pool
  77. :return:
  78. """
  79. try:
  80. result = self._run(
  81. variable_pool=variable_pool
  82. )
  83. self.node_run_result = result
  84. return result
  85. except Exception as e:
  86. return NodeRunResult(
  87. status=WorkflowNodeExecutionStatus.FAILED,
  88. error=str(e),
  89. )
  90. def publish_text_chunk(self, text: str, value_selector: list[str] | None = None) -> None:
  91. """
  92. Publish text chunk
  93. :param text: chunk text
  94. :param value_selector: value selector
  95. :return:
  96. """
  97. if self.callbacks:
  98. for callback in self.callbacks:
  99. callback.on_node_text_chunk(
  100. node_id=self.node_id,
  101. text=text,
  102. metadata={
  103. "node_type": self.node_type,
  104. "is_answer_previous_node": self.is_answer_previous_node,
  105. "value_selector": value_selector
  106. }
  107. )
  108. @classmethod
  109. def extract_variable_selector_to_variable_mapping(cls, config: dict):
  110. """
  111. Extract variable selector to variable mapping
  112. :param config: node config
  113. :return:
  114. """
  115. node_data = cls._node_data_cls(**config.get("data", {}))
  116. return cls._extract_variable_selector_to_variable_mapping(node_data)
  117. @classmethod
  118. def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> Mapping[str, Sequence[str]]:
  119. """
  120. Extract variable selector to variable mapping
  121. :param node_data: node data
  122. :return:
  123. """
  124. return {}
  125. @classmethod
  126. def get_default_config(cls, filters: Optional[dict] = None) -> dict:
  127. """
  128. Get default config of node.
  129. :param filters: filter by node config parameters.
  130. :return:
  131. """
  132. return {}
  133. @property
  134. def node_type(self) -> NodeType:
  135. """
  136. Get node type
  137. :return:
  138. """
  139. return self._node_type
  140. class BaseIterationNode(BaseNode):
  141. @abstractmethod
  142. def _run(self, variable_pool: VariablePool) -> BaseIterationState:
  143. """
  144. Run node
  145. :param variable_pool: variable pool
  146. :return:
  147. """
  148. raise NotImplementedError
  149. def run(self, variable_pool: VariablePool) -> BaseIterationState:
  150. """
  151. Run node entry
  152. :param variable_pool: variable pool
  153. :return:
  154. """
  155. return self._run(variable_pool=variable_pool)
  156. def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
  157. """
  158. Get next iteration start node id based on the graph.
  159. :param graph: graph
  160. :return: next node id
  161. """
  162. return self._get_next_iteration(variable_pool, state)
  163. @abstractmethod
  164. def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
  165. """
  166. Get next iteration start node id based on the graph.
  167. :param graph: graph
  168. :return: next node id
  169. """
  170. raise NotImplementedError