variable_factory.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. from collections.abc import Mapping
  2. from typing import Any
  3. from configs import dify_config
  4. from core.file import File
  5. from core.variables import (
  6. ArrayAnySegment,
  7. ArrayFileSegment,
  8. ArrayNumberSegment,
  9. ArrayNumberVariable,
  10. ArrayObjectSegment,
  11. ArrayObjectVariable,
  12. ArrayStringSegment,
  13. ArrayStringVariable,
  14. FileSegment,
  15. FloatSegment,
  16. FloatVariable,
  17. IntegerSegment,
  18. IntegerVariable,
  19. NoneSegment,
  20. ObjectSegment,
  21. ObjectVariable,
  22. SecretVariable,
  23. Segment,
  24. SegmentType,
  25. StringSegment,
  26. StringVariable,
  27. Variable,
  28. )
  29. from core.variables.exc import VariableError
  30. def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
  31. if (value_type := mapping.get("value_type")) is None:
  32. raise VariableError("missing value type")
  33. if not mapping.get("name"):
  34. raise VariableError("missing name")
  35. if (value := mapping.get("value")) is None:
  36. raise VariableError("missing value")
  37. match value_type:
  38. case SegmentType.STRING:
  39. result = StringVariable.model_validate(mapping)
  40. case SegmentType.SECRET:
  41. result = SecretVariable.model_validate(mapping)
  42. case SegmentType.NUMBER if isinstance(value, int):
  43. result = IntegerVariable.model_validate(mapping)
  44. case SegmentType.NUMBER if isinstance(value, float):
  45. result = FloatVariable.model_validate(mapping)
  46. case SegmentType.NUMBER if not isinstance(value, float | int):
  47. raise VariableError(f"invalid number value {value}")
  48. case SegmentType.OBJECT if isinstance(value, dict):
  49. result = ObjectVariable.model_validate(mapping)
  50. case SegmentType.ARRAY_STRING if isinstance(value, list):
  51. result = ArrayStringVariable.model_validate(mapping)
  52. case SegmentType.ARRAY_NUMBER if isinstance(value, list):
  53. result = ArrayNumberVariable.model_validate(mapping)
  54. case SegmentType.ARRAY_OBJECT if isinstance(value, list):
  55. result = ArrayObjectVariable.model_validate(mapping)
  56. case _:
  57. raise VariableError(f"not supported value type {value_type}")
  58. if result.size > dify_config.MAX_VARIABLE_SIZE:
  59. raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
  60. return result
  61. def build_segment(value: Any, /) -> Segment:
  62. if value is None:
  63. return NoneSegment()
  64. if isinstance(value, str):
  65. return StringSegment(value=value)
  66. if isinstance(value, int):
  67. return IntegerSegment(value=value)
  68. if isinstance(value, float):
  69. return FloatSegment(value=value)
  70. if isinstance(value, dict):
  71. return ObjectSegment(value=value)
  72. if isinstance(value, File):
  73. return FileSegment(value=value)
  74. if isinstance(value, list):
  75. items = [build_segment(item) for item in value]
  76. types = {item.value_type for item in items}
  77. if len(types) != 1:
  78. return ArrayAnySegment(value=value)
  79. match types.pop():
  80. case SegmentType.STRING:
  81. return ArrayStringSegment(value=value)
  82. case SegmentType.NUMBER:
  83. return ArrayNumberSegment(value=value)
  84. case SegmentType.OBJECT:
  85. return ArrayObjectSegment(value=value)
  86. case SegmentType.FILE:
  87. return ArrayFileSegment(value=value)
  88. case _:
  89. raise ValueError(f"not supported value {value}")
  90. raise ValueError(f"not supported value {value}")