variable_factory.py 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. from collections.abc import Mapping
  2. from typing import Any
  3. from configs import dify_config
  4. from core.file import File
  5. from core.variables import (
  6. ArrayAnySegment,
  7. ArrayFileSegment,
  8. ArrayNumberSegment,
  9. ArrayNumberVariable,
  10. ArrayObjectSegment,
  11. ArrayObjectVariable,
  12. ArraySegment,
  13. ArrayStringSegment,
  14. ArrayStringVariable,
  15. FileSegment,
  16. FloatSegment,
  17. FloatVariable,
  18. IntegerSegment,
  19. IntegerVariable,
  20. NoneSegment,
  21. ObjectSegment,
  22. ObjectVariable,
  23. SecretVariable,
  24. Segment,
  25. SegmentType,
  26. StringSegment,
  27. StringVariable,
  28. Variable,
  29. )
  30. from core.variables.exc import VariableError
  31. def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
  32. if (value_type := mapping.get("value_type")) is None:
  33. raise VariableError("missing value type")
  34. if not mapping.get("name"):
  35. raise VariableError("missing name")
  36. if (value := mapping.get("value")) is None:
  37. raise VariableError("missing value")
  38. match value_type:
  39. case SegmentType.STRING:
  40. result = StringVariable.model_validate(mapping)
  41. case SegmentType.SECRET:
  42. result = SecretVariable.model_validate(mapping)
  43. case SegmentType.NUMBER if isinstance(value, int):
  44. result = IntegerVariable.model_validate(mapping)
  45. case SegmentType.NUMBER if isinstance(value, float):
  46. result = FloatVariable.model_validate(mapping)
  47. case SegmentType.NUMBER if not isinstance(value, float | int):
  48. raise VariableError(f"invalid number value {value}")
  49. case SegmentType.OBJECT if isinstance(value, dict):
  50. result = ObjectVariable.model_validate(mapping)
  51. case SegmentType.ARRAY_STRING if isinstance(value, list):
  52. result = ArrayStringVariable.model_validate(mapping)
  53. case SegmentType.ARRAY_NUMBER if isinstance(value, list):
  54. result = ArrayNumberVariable.model_validate(mapping)
  55. case SegmentType.ARRAY_OBJECT if isinstance(value, list):
  56. result = ArrayObjectVariable.model_validate(mapping)
  57. case _:
  58. raise VariableError(f"not supported value type {value_type}")
  59. if result.size > dify_config.MAX_VARIABLE_SIZE:
  60. raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
  61. return result
  62. def build_segment(value: Any, /) -> Segment:
  63. if value is None:
  64. return NoneSegment()
  65. if isinstance(value, str):
  66. return StringSegment(value=value)
  67. if isinstance(value, int):
  68. return IntegerSegment(value=value)
  69. if isinstance(value, float):
  70. return FloatSegment(value=value)
  71. if isinstance(value, dict):
  72. return ObjectSegment(value=value)
  73. if isinstance(value, File):
  74. return FileSegment(value=value)
  75. if isinstance(value, list):
  76. items = [build_segment(item) for item in value]
  77. types = {item.value_type for item in items}
  78. if len(types) != 1 or all(isinstance(item, ArraySegment) for item in items):
  79. return ArrayAnySegment(value=value)
  80. match types.pop():
  81. case SegmentType.STRING:
  82. return ArrayStringSegment(value=value)
  83. case SegmentType.NUMBER:
  84. return ArrayNumberSegment(value=value)
  85. case SegmentType.OBJECT:
  86. return ArrayObjectSegment(value=value)
  87. case SegmentType.FILE:
  88. return ArrayFileSegment(value=value)
  89. case SegmentType.NONE:
  90. return ArrayAnySegment(value=value)
  91. case _:
  92. raise ValueError(f"not supported value {value}")
  93. raise ValueError(f"not supported value {value}")