entities.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from typing import Any, Literal, Union
  2. from pydantic import BaseModel, field_validator
  3. from pydantic_core.core_schema import ValidationInfo
  4. from core.workflow.entities.base_node_data_entities import BaseNodeData
  5. class ToolEntity(BaseModel):
  6. provider_id: str
  7. provider_type: Literal['builtin', 'api', 'workflow']
  8. provider_name: str # redundancy
  9. tool_name: str
  10. tool_label: str # redundancy
  11. tool_configurations: dict[str, Any]
  12. @field_validator('tool_configurations', mode='before')
  13. @classmethod
  14. def validate_tool_configurations(cls, value, values: ValidationInfo):
  15. if not isinstance(value, dict):
  16. raise ValueError('tool_configurations must be a dictionary')
  17. for key in values.data.get('tool_configurations', {}).keys():
  18. value = values.data.get('tool_configurations', {}).get(key)
  19. if not isinstance(value, str | int | float | bool):
  20. raise ValueError(f'{key} must be a string')
  21. return value
  22. class ToolNodeData(BaseNodeData, ToolEntity):
  23. class ToolInput(BaseModel):
  24. value: Union[Any, list[str]]
  25. type: Literal['mixed', 'variable', 'constant']
  26. @field_validator('type', mode='before')
  27. @classmethod
  28. def check_type(cls, value, validation_info: ValidationInfo):
  29. typ = value
  30. value = validation_info.data.get('value')
  31. if typ == 'mixed' and not isinstance(value, str):
  32. raise ValueError('value must be a string')
  33. elif typ == 'variable':
  34. if not isinstance(value, list):
  35. raise ValueError('value must be a list')
  36. for val in value:
  37. if not isinstance(val, str):
  38. raise ValueError('value must be a list of strings')
  39. elif typ == 'constant' and not isinstance(value, str | int | float | bool):
  40. raise ValueError('value must be a string, int, float, or bool')
  41. return typ
  42. """
  43. Tool Node Schema
  44. """
  45. tool_parameters: dict[str, ToolInput]