node_entities.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. from collections.abc import Mapping
  2. from enum import Enum
  3. from typing import Any, Optional
  4. from pydantic import BaseModel
  5. from models.workflow import WorkflowNodeExecutionStatus
  6. class NodeType(Enum):
  7. """
  8. Node Types.
  9. """
  10. START = 'start'
  11. END = 'end'
  12. ANSWER = 'answer'
  13. LLM = 'llm'
  14. KNOWLEDGE_RETRIEVAL = 'knowledge-retrieval'
  15. IF_ELSE = 'if-else'
  16. CODE = 'code'
  17. TEMPLATE_TRANSFORM = 'template-transform'
  18. QUESTION_CLASSIFIER = 'question-classifier'
  19. HTTP_REQUEST = 'http-request'
  20. TOOL = 'tool'
  21. VARIABLE_AGGREGATOR = 'variable-aggregator'
  22. # TODO: merge this into VARIABLE_AGGREGATOR
  23. VARIABLE_ASSIGNER = 'variable-assigner'
  24. LOOP = 'loop'
  25. ITERATION = 'iteration'
  26. PARAMETER_EXTRACTOR = 'parameter-extractor'
  27. CONVERSATION_VARIABLE_ASSIGNER = 'assigner'
  28. @classmethod
  29. def value_of(cls, value: str) -> 'NodeType':
  30. """
  31. Get value of given node type.
  32. :param value: node type value
  33. :return: node type
  34. """
  35. for node_type in cls:
  36. if node_type.value == value:
  37. return node_type
  38. raise ValueError(f'invalid node type value {value}')
  39. class SystemVariable(Enum):
  40. """
  41. System Variables.
  42. """
  43. QUERY = 'query'
  44. FILES = 'files'
  45. CONVERSATION_ID = 'conversation_id'
  46. USER_ID = 'user_id'
  47. @classmethod
  48. def value_of(cls, value: str) -> 'SystemVariable':
  49. """
  50. Get value of given system variable.
  51. :param value: system variable value
  52. :return: system variable
  53. """
  54. for system_variable in cls:
  55. if system_variable.value == value:
  56. return system_variable
  57. raise ValueError(f'invalid system variable value {value}')
  58. class NodeRunMetadataKey(Enum):
  59. """
  60. Node Run Metadata Key.
  61. """
  62. TOTAL_TOKENS = 'total_tokens'
  63. TOTAL_PRICE = 'total_price'
  64. CURRENCY = 'currency'
  65. TOOL_INFO = 'tool_info'
  66. ITERATION_ID = 'iteration_id'
  67. ITERATION_INDEX = 'iteration_index'
  68. class NodeRunResult(BaseModel):
  69. """
  70. Node Run Result.
  71. """
  72. status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
  73. inputs: Optional[Mapping[str, Any]] = None # node inputs
  74. process_data: Optional[dict] = None # process data
  75. outputs: Optional[Mapping[str, Any]] = None # node outputs
  76. metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
  77. edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
  78. error: Optional[str] = None # error message if status is failed