variable_pool.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. from enum import Enum
  2. from typing import Any, Optional, Union
  3. from core.file.file_obj import FileVar
  4. from core.workflow.entities.node_entities import SystemVariable
  5. VariableValue = Union[str, int, float, dict, list, FileVar]
  6. class ValueType(Enum):
  7. """
  8. Value Type Enum
  9. """
  10. STRING = "string"
  11. NUMBER = "number"
  12. OBJECT = "object"
  13. ARRAY_STRING = "array[string]"
  14. ARRAY_NUMBER = "array[number]"
  15. ARRAY_OBJECT = "array[object]"
  16. ARRAY_FILE = "array[file]"
  17. FILE = "file"
  18. class VariablePool:
  19. def __init__(self, system_variables: dict[SystemVariable, Any],
  20. user_inputs: dict) -> None:
  21. # system variables
  22. # for example:
  23. # {
  24. # 'query': 'abc',
  25. # 'files': []
  26. # }
  27. self.variables_mapping = {}
  28. self.user_inputs = user_inputs
  29. self.system_variables = system_variables
  30. for system_variable, value in system_variables.items():
  31. self.append_variable('sys', [system_variable.value], value)
  32. def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None:
  33. """
  34. Append variable
  35. :param node_id: node id
  36. :param variable_key_list: variable key list, like: ['result', 'text']
  37. :param value: value
  38. :return:
  39. """
  40. if node_id not in self.variables_mapping:
  41. self.variables_mapping[node_id] = {}
  42. variable_key_list_hash = hash(tuple(variable_key_list))
  43. self.variables_mapping[node_id][variable_key_list_hash] = value
  44. def get_variable_value(self, variable_selector: list[str],
  45. target_value_type: Optional[ValueType] = None) -> Optional[VariableValue]:
  46. """
  47. Get variable
  48. :param variable_selector: include node_id and variables
  49. :param target_value_type: target value type
  50. :return:
  51. """
  52. if len(variable_selector) < 2:
  53. raise ValueError('Invalid value selector')
  54. node_id = variable_selector[0]
  55. if node_id not in self.variables_mapping:
  56. return None
  57. # fetch variable keys, pop node_id
  58. variable_key_list = variable_selector[1:]
  59. variable_key_list_hash = hash(tuple(variable_key_list))
  60. value = self.variables_mapping[node_id].get(variable_key_list_hash)
  61. if target_value_type:
  62. if target_value_type == ValueType.STRING:
  63. return str(value)
  64. elif target_value_type == ValueType.NUMBER:
  65. return int(value)
  66. elif target_value_type == ValueType.OBJECT:
  67. if not isinstance(value, dict):
  68. raise ValueError('Invalid value type: object')
  69. elif target_value_type in [ValueType.ARRAY_STRING,
  70. ValueType.ARRAY_NUMBER,
  71. ValueType.ARRAY_OBJECT,
  72. ValueType.ARRAY_FILE]:
  73. if not isinstance(value, list):
  74. raise ValueError(f'Invalid value type: {target_value_type.value}')
  75. return value
  76. def clear_node_variables(self, node_id: str) -> None:
  77. """
  78. Clear node variables
  79. :param node_id: node id
  80. :return:
  81. """
  82. if node_id in self.variables_mapping:
  83. self.variables_mapping.pop(node_id)