base_node.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. from abc import ABC, abstractmethod
  2. from enum import Enum
  3. from typing import Optional
  4. from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
  5. from core.workflow.entities.base_node_data_entities import BaseNodeData
  6. from core.workflow.entities.node_entities import NodeRunResult, NodeType
  7. from core.workflow.entities.variable_pool import VariablePool
  8. class UserFrom(Enum):
  9. """
  10. User from
  11. """
  12. ACCOUNT = "account"
  13. END_USER = "end-user"
  14. @classmethod
  15. def value_of(cls, value: str) -> "UserFrom":
  16. """
  17. Value of
  18. :param value: value
  19. :return:
  20. """
  21. for item in cls:
  22. if item.value == value:
  23. return item
  24. raise ValueError(f"Invalid value: {value}")
  25. class BaseNode(ABC):
  26. _node_data_cls: type[BaseNodeData]
  27. _node_type: NodeType
  28. tenant_id: str
  29. app_id: str
  30. workflow_id: str
  31. user_id: str
  32. user_from: UserFrom
  33. node_id: str
  34. node_data: BaseNodeData
  35. node_run_result: Optional[NodeRunResult] = None
  36. callbacks: list[BaseWorkflowCallback]
  37. def __init__(self, tenant_id: str,
  38. app_id: str,
  39. workflow_id: str,
  40. user_id: str,
  41. user_from: UserFrom,
  42. config: dict,
  43. callbacks: list[BaseWorkflowCallback] = None) -> None:
  44. self.tenant_id = tenant_id
  45. self.app_id = app_id
  46. self.workflow_id = workflow_id
  47. self.user_id = user_id
  48. self.user_from = user_from
  49. self.node_id = config.get("id")
  50. if not self.node_id:
  51. raise ValueError("Node ID is required.")
  52. self.node_data = self._node_data_cls(**config.get("data", {}))
  53. self.callbacks = callbacks or []
  54. @abstractmethod
  55. def _run(self, variable_pool: VariablePool) -> NodeRunResult:
  56. """
  57. Run node
  58. :param variable_pool: variable pool
  59. :return:
  60. """
  61. raise NotImplementedError
  62. def run(self, variable_pool: VariablePool) -> NodeRunResult:
  63. """
  64. Run node entry
  65. :param variable_pool: variable pool
  66. :return:
  67. """
  68. result = self._run(
  69. variable_pool=variable_pool
  70. )
  71. self.node_run_result = result
  72. return result
  73. def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
  74. """
  75. Publish text chunk
  76. :param text: chunk text
  77. :param value_selector: value selector
  78. :return:
  79. """
  80. if self.callbacks:
  81. for callback in self.callbacks:
  82. callback.on_node_text_chunk(
  83. node_id=self.node_id,
  84. text=text,
  85. metadata={
  86. "node_type": self.node_type,
  87. "value_selector": value_selector
  88. }
  89. )
  90. @classmethod
  91. def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict[str, list[str]]:
  92. """
  93. Extract variable selector to variable mapping
  94. :param config: node config
  95. :return:
  96. """
  97. node_data = cls._node_data_cls(**config.get("data", {}))
  98. return cls._extract_variable_selector_to_variable_mapping(node_data)
  99. @classmethod
  100. @abstractmethod
  101. def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
  102. """
  103. Extract variable selector to variable mapping
  104. :param node_data: node data
  105. :return:
  106. """
  107. raise NotImplementedError
  108. @classmethod
  109. def get_default_config(cls, filters: Optional[dict] = None) -> dict:
  110. """
  111. Get default config of node.
  112. :param filters: filter by node config parameters.
  113. :return:
  114. """
  115. return {}
  116. @property
  117. def node_type(self) -> NodeType:
  118. """
  119. Get node type
  120. :return:
  121. """
  122. return self._node_type