node_entities.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from enum import Enum
  2. from typing import Any, Optional
  3. from pydantic import BaseModel
  4. from core.model_runtime.entities.llm_entities import LLMUsage
  5. from models import WorkflowNodeExecutionStatus
  6. class NodeType(Enum):
  7. """
  8. Node Types.
  9. """
  10. START = "start"
  11. END = "end"
  12. ANSWER = "answer"
  13. LLM = "llm"
  14. KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
  15. IF_ELSE = "if-else"
  16. CODE = "code"
  17. TEMPLATE_TRANSFORM = "template-transform"
  18. QUESTION_CLASSIFIER = "question-classifier"
  19. HTTP_REQUEST = "http-request"
  20. TOOL = "tool"
  21. VARIABLE_AGGREGATOR = "variable-aggregator"
  22. # TODO: merge this into VARIABLE_AGGREGATOR
  23. VARIABLE_ASSIGNER = "variable-assigner"
  24. LOOP = "loop"
  25. ITERATION = "iteration"
  26. ITERATION_START = "iteration-start" # fake start node for iteration
  27. PARAMETER_EXTRACTOR = "parameter-extractor"
  28. CONVERSATION_VARIABLE_ASSIGNER = "assigner"
  29. @classmethod
  30. def value_of(cls, value: str) -> "NodeType":
  31. """
  32. Get value of given node type.
  33. :param value: node type value
  34. :return: node type
  35. """
  36. for node_type in cls:
  37. if node_type.value == value:
  38. return node_type
  39. raise ValueError(f"invalid node type value {value}")
  40. class NodeRunMetadataKey(Enum):
  41. """
  42. Node Run Metadata Key.
  43. """
  44. TOTAL_TOKENS = "total_tokens"
  45. TOTAL_PRICE = "total_price"
  46. CURRENCY = "currency"
  47. TOOL_INFO = "tool_info"
  48. ITERATION_ID = "iteration_id"
  49. ITERATION_INDEX = "iteration_index"
  50. PARALLEL_ID = "parallel_id"
  51. PARALLEL_START_NODE_ID = "parallel_start_node_id"
  52. PARENT_PARALLEL_ID = "parent_parallel_id"
  53. PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
  54. class NodeRunResult(BaseModel):
  55. """
  56. Node Run Result.
  57. """
  58. status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
  59. inputs: Optional[dict[str, Any]] = None # node inputs
  60. process_data: Optional[dict[str, Any]] = None # process data
  61. outputs: Optional[dict[str, Any]] = None # node outputs
  62. metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
  63. llm_usage: Optional[LLMUsage] = None # llm usage
  64. edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
  65. error: Optional[str] = None # error message if status is failed
  66. class UserFrom(Enum):
  67. """
  68. User from
  69. """
  70. ACCOUNT = "account"
  71. END_USER = "end-user"
  72. @classmethod
  73. def value_of(cls, value: str) -> "UserFrom":
  74. """
  75. Value of
  76. :param value: value
  77. :return:
  78. """
  79. for item in cls:
  80. if item.value == value:
  81. return item
  82. raise ValueError(f"Invalid value: {value}")