entities.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. from typing import Any, Literal, Union
  2. from pydantic import BaseModel, validator
  3. from core.workflow.entities.base_node_data_entities import BaseNodeData
  4. class ToolEntity(BaseModel):
  5. provider_id: str
  6. provider_type: Literal['builtin', 'api']
  7. provider_name: str # redundancy
  8. tool_name: str
  9. tool_label: str # redundancy
  10. tool_configurations: dict[str, Any]
  11. @validator('tool_configurations', pre=True, always=True)
  12. def validate_tool_configurations(cls, value, values):
  13. if not isinstance(value, dict):
  14. raise ValueError('tool_configurations must be a dictionary')
  15. for key in values.get('tool_configurations', {}).keys():
  16. value = values.get('tool_configurations', {}).get(key)
  17. if not isinstance(value, str | int | float | bool):
  18. raise ValueError(f'{key} must be a string')
  19. return value
  20. class ToolNodeData(BaseNodeData, ToolEntity):
  21. class ToolInput(BaseModel):
  22. value: Union[Any, list[str]]
  23. type: Literal['mixed', 'variable', 'constant']
  24. @validator('type', pre=True, always=True)
  25. def check_type(cls, value, values):
  26. typ = value
  27. value = values.get('value')
  28. if typ == 'mixed' and not isinstance(value, str):
  29. raise ValueError('value must be a string')
  30. elif typ == 'variable':
  31. if not isinstance(value, list):
  32. raise ValueError('value must be a list')
  33. for val in value:
  34. if not isinstance(val, str):
  35. raise ValueError('value must be a list of strings')
  36. elif typ == 'constant' and not isinstance(value, str | int | float | bool):
  37. raise ValueError('value must be a string, int, float, or bool')
  38. return typ
  39. """
  40. Tool Node Schema
  41. """
  42. tool_parameters: dict[str, ToolInput]