base_node.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. from abc import ABC, abstractmethod
  2. from enum import Enum
  3. from typing import Optional
  4. from core.app.entities.app_invoke_entities import InvokeFrom
  5. from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
  6. from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
  7. from core.workflow.entities.node_entities import NodeRunResult, NodeType
  8. from core.workflow.entities.variable_pool import VariablePool
  9. class UserFrom(Enum):
  10. """
  11. User from
  12. """
  13. ACCOUNT = "account"
  14. END_USER = "end-user"
  15. @classmethod
  16. def value_of(cls, value: str) -> "UserFrom":
  17. """
  18. Value of
  19. :param value: value
  20. :return:
  21. """
  22. for item in cls:
  23. if item.value == value:
  24. return item
  25. raise ValueError(f"Invalid value: {value}")
  26. class BaseNode(ABC):
  27. _node_data_cls: type[BaseNodeData]
  28. _node_type: NodeType
  29. tenant_id: str
  30. app_id: str
  31. workflow_id: str
  32. user_id: str
  33. user_from: UserFrom
  34. invoke_from: InvokeFrom
  35. workflow_call_depth: int
  36. node_id: str
  37. node_data: BaseNodeData
  38. node_run_result: Optional[NodeRunResult] = None
  39. callbacks: list[BaseWorkflowCallback]
  40. def __init__(self, tenant_id: str,
  41. app_id: str,
  42. workflow_id: str,
  43. user_id: str,
  44. user_from: UserFrom,
  45. invoke_from: InvokeFrom,
  46. config: dict,
  47. callbacks: list[BaseWorkflowCallback] = None,
  48. workflow_call_depth: int = 0) -> None:
  49. self.tenant_id = tenant_id
  50. self.app_id = app_id
  51. self.workflow_id = workflow_id
  52. self.user_id = user_id
  53. self.user_from = user_from
  54. self.invoke_from = invoke_from
  55. self.workflow_call_depth = workflow_call_depth
  56. self.node_id = config.get("id")
  57. if not self.node_id:
  58. raise ValueError("Node ID is required.")
  59. self.node_data = self._node_data_cls(**config.get("data", {}))
  60. self.callbacks = callbacks or []
  61. @abstractmethod
  62. def _run(self, variable_pool: VariablePool) -> NodeRunResult:
  63. """
  64. Run node
  65. :param variable_pool: variable pool
  66. :return:
  67. """
  68. raise NotImplementedError
  69. def run(self, variable_pool: VariablePool) -> NodeRunResult:
  70. """
  71. Run node entry
  72. :param variable_pool: variable pool
  73. :return:
  74. """
  75. result = self._run(
  76. variable_pool=variable_pool
  77. )
  78. self.node_run_result = result
  79. return result
  80. def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
  81. """
  82. Publish text chunk
  83. :param text: chunk text
  84. :param value_selector: value selector
  85. :return:
  86. """
  87. if self.callbacks:
  88. for callback in self.callbacks:
  89. callback.on_node_text_chunk(
  90. node_id=self.node_id,
  91. text=text,
  92. metadata={
  93. "node_type": self.node_type,
  94. "value_selector": value_selector
  95. }
  96. )
  97. @classmethod
  98. def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict[str, list[str]]:
  99. """
  100. Extract variable selector to variable mapping
  101. :param config: node config
  102. :return:
  103. """
  104. node_data = cls._node_data_cls(**config.get("data", {}))
  105. return cls._extract_variable_selector_to_variable_mapping(node_data)
  106. @classmethod
  107. @abstractmethod
  108. def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
  109. """
  110. Extract variable selector to variable mapping
  111. :param node_data: node data
  112. :return:
  113. """
  114. raise NotImplementedError
  115. @classmethod
  116. def get_default_config(cls, filters: Optional[dict] = None) -> dict:
  117. """
  118. Get default config of node.
  119. :param filters: filter by node config parameters.
  120. :return:
  121. """
  122. return {}
  123. @property
  124. def node_type(self) -> NodeType:
  125. """
  126. Get node type
  127. :return:
  128. """
  129. return self._node_type
  130. class BaseIterationNode(BaseNode):
  131. @abstractmethod
  132. def _run(self, variable_pool: VariablePool) -> BaseIterationState:
  133. """
  134. Run node
  135. :param variable_pool: variable pool
  136. :return:
  137. """
  138. raise NotImplementedError
  139. def run(self, variable_pool: VariablePool) -> BaseIterationState:
  140. """
  141. Run node entry
  142. :param variable_pool: variable pool
  143. :return:
  144. """
  145. return self._run(variable_pool=variable_pool)
  146. def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
  147. """
  148. Get next iteration start node id based on the graph.
  149. :param graph: graph
  150. :return: next node id
  151. """
  152. return self._get_next_iteration(variable_pool, state)
  153. @abstractmethod
  154. def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
  155. """
  156. Get next iteration start node id based on the graph.
  157. :param graph: graph
  158. :return: next node id
  159. """
  160. raise NotImplementedError