base_node.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. from abc import ABC, abstractmethod
  2. from collections.abc import Mapping, Sequence
  3. from enum import Enum
  4. from typing import Any, Optional
  5. from core.app.entities.app_invoke_entities import InvokeFrom
  6. from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
  7. from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
  8. from core.workflow.entities.node_entities import NodeRunResult, NodeType
  9. from core.workflow.entities.variable_pool import VariablePool
  10. class UserFrom(Enum):
  11. """
  12. User from
  13. """
  14. ACCOUNT = "account"
  15. END_USER = "end-user"
  16. @classmethod
  17. def value_of(cls, value: str) -> "UserFrom":
  18. """
  19. Value of
  20. :param value: value
  21. :return:
  22. """
  23. for item in cls:
  24. if item.value == value:
  25. return item
  26. raise ValueError(f"Invalid value: {value}")
  27. class BaseNode(ABC):
  28. _node_data_cls: type[BaseNodeData]
  29. _node_type: NodeType
  30. tenant_id: str
  31. app_id: str
  32. workflow_id: str
  33. user_id: str
  34. user_from: UserFrom
  35. invoke_from: InvokeFrom
  36. workflow_call_depth: int
  37. node_id: str
  38. node_data: BaseNodeData
  39. node_run_result: Optional[NodeRunResult] = None
  40. callbacks: Sequence[WorkflowCallback]
  41. is_answer_previous_node: bool = False
  42. def __init__(self, tenant_id: str,
  43. app_id: str,
  44. workflow_id: str,
  45. user_id: str,
  46. user_from: UserFrom,
  47. invoke_from: InvokeFrom,
  48. config: Mapping[str, Any],
  49. callbacks: Sequence[WorkflowCallback] | None = None,
  50. workflow_call_depth: int = 0) -> None:
  51. self.tenant_id = tenant_id
  52. self.app_id = app_id
  53. self.workflow_id = workflow_id
  54. self.user_id = user_id
  55. self.user_from = user_from
  56. self.invoke_from = invoke_from
  57. self.workflow_call_depth = workflow_call_depth
  58. # TODO: May need to check if key exists.
  59. self.node_id = config["id"]
  60. if not self.node_id:
  61. raise ValueError("Node ID is required.")
  62. self.node_data = self._node_data_cls(**config.get("data", {}))
  63. self.callbacks = callbacks or []
  64. @abstractmethod
  65. def _run(self, variable_pool: VariablePool) -> NodeRunResult:
  66. """
  67. Run node
  68. :param variable_pool: variable pool
  69. :return:
  70. """
  71. raise NotImplementedError
  72. def run(self, variable_pool: VariablePool) -> NodeRunResult:
  73. """
  74. Run node entry
  75. :param variable_pool: variable pool
  76. :return:
  77. """
  78. result = self._run(
  79. variable_pool=variable_pool
  80. )
  81. self.node_run_result = result
  82. return result
  83. def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
  84. """
  85. Publish text chunk
  86. :param text: chunk text
  87. :param value_selector: value selector
  88. :return:
  89. """
  90. if self.callbacks:
  91. for callback in self.callbacks:
  92. callback.on_node_text_chunk(
  93. node_id=self.node_id,
  94. text=text,
  95. metadata={
  96. "node_type": self.node_type,
  97. "is_answer_previous_node": self.is_answer_previous_node,
  98. "value_selector": value_selector
  99. }
  100. )
  101. @classmethod
  102. def extract_variable_selector_to_variable_mapping(cls, config: dict):
  103. """
  104. Extract variable selector to variable mapping
  105. :param config: node config
  106. :return:
  107. """
  108. node_data = cls._node_data_cls(**config.get("data", {}))
  109. return cls._extract_variable_selector_to_variable_mapping(node_data)
  110. @classmethod
  111. def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> Mapping[str, Sequence[str]]:
  112. """
  113. Extract variable selector to variable mapping
  114. :param node_data: node data
  115. :return:
  116. """
  117. return {}
  118. @classmethod
  119. def get_default_config(cls, filters: Optional[dict] = None) -> dict:
  120. """
  121. Get default config of node.
  122. :param filters: filter by node config parameters.
  123. :return:
  124. """
  125. return {}
  126. @property
  127. def node_type(self) -> NodeType:
  128. """
  129. Get node type
  130. :return:
  131. """
  132. return self._node_type
  133. class BaseIterationNode(BaseNode):
  134. @abstractmethod
  135. def _run(self, variable_pool: VariablePool) -> BaseIterationState:
  136. """
  137. Run node
  138. :param variable_pool: variable pool
  139. :return:
  140. """
  141. raise NotImplementedError
  142. def run(self, variable_pool: VariablePool) -> BaseIterationState:
  143. """
  144. Run node entry
  145. :param variable_pool: variable pool
  146. :return:
  147. """
  148. return self._run(variable_pool=variable_pool)
  149. def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
  150. """
  151. Get next iteration start node id based on the graph.
  152. :param graph: graph
  153. :return: next node id
  154. """
  155. return self._get_next_iteration(variable_pool, state)
  156. @abstractmethod
  157. def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
  158. """
  159. Get next iteration start node id based on the graph.
  160. :param graph: graph
  161. :return: next node id
  162. """
  163. raise NotImplementedError