variable_pool.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. from enum import Enum
  2. from typing import Any, Optional, Union
  3. from core.file.file_obj import FileVar
  4. from core.workflow.entities.node_entities import SystemVariable
  5. VariableValue = Union[str, int, float, dict, list, FileVar]
  6. class ValueType(Enum):
  7. """
  8. Value Type Enum
  9. """
  10. STRING = "string"
  11. NUMBER = "number"
  12. OBJECT = "object"
  13. ARRAY_STRING = "array[string]"
  14. ARRAY_NUMBER = "array[number]"
  15. ARRAY_OBJECT = "array[object]"
  16. ARRAY_FILE = "array[file]"
  17. FILE = "file"
  18. class VariablePool:
  19. variables_mapping = {}
  20. user_inputs: dict
  21. system_variables: dict[SystemVariable, Any]
  22. def __init__(self, system_variables: dict[SystemVariable, Any],
  23. user_inputs: dict) -> None:
  24. # system variables
  25. # for example:
  26. # {
  27. # 'query': 'abc',
  28. # 'files': []
  29. # }
  30. self.user_inputs = user_inputs
  31. self.system_variables = system_variables
  32. for system_variable, value in system_variables.items():
  33. self.append_variable('sys', [system_variable.value], value)
  34. def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None:
  35. """
  36. Append variable
  37. :param node_id: node id
  38. :param variable_key_list: variable key list, like: ['result', 'text']
  39. :param value: value
  40. :return:
  41. """
  42. if node_id not in self.variables_mapping:
  43. self.variables_mapping[node_id] = {}
  44. variable_key_list_hash = hash(tuple(variable_key_list))
  45. self.variables_mapping[node_id][variable_key_list_hash] = value
  46. def get_variable_value(self, variable_selector: list[str],
  47. target_value_type: Optional[ValueType] = None) -> Optional[VariableValue]:
  48. """
  49. Get variable
  50. :param variable_selector: include node_id and variables
  51. :param target_value_type: target value type
  52. :return:
  53. """
  54. if len(variable_selector) < 2:
  55. raise ValueError('Invalid value selector')
  56. node_id = variable_selector[0]
  57. if node_id not in self.variables_mapping:
  58. return None
  59. # fetch variable keys, pop node_id
  60. variable_key_list = variable_selector[1:]
  61. variable_key_list_hash = hash(tuple(variable_key_list))
  62. value = self.variables_mapping[node_id].get(variable_key_list_hash)
  63. if target_value_type:
  64. if target_value_type == ValueType.STRING:
  65. return str(value)
  66. elif target_value_type == ValueType.NUMBER:
  67. return int(value)
  68. elif target_value_type == ValueType.OBJECT:
  69. if not isinstance(value, dict):
  70. raise ValueError('Invalid value type: object')
  71. elif target_value_type in [ValueType.ARRAY_STRING,
  72. ValueType.ARRAY_NUMBER,
  73. ValueType.ARRAY_OBJECT,
  74. ValueType.ARRAY_FILE]:
  75. if not isinstance(value, list):
  76. raise ValueError(f'Invalid value type: {target_value_type.value}')
  77. return value