variable_pool.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. from collections import defaultdict
  2. from collections.abc import Mapping, Sequence
  3. from typing import Any, Union
  4. from pydantic import BaseModel, Field, model_validator
  5. from typing_extensions import deprecated
  6. from core.app.segments import Segment, Variable, factory
  7. from core.file.file_obj import FileVar
  8. from core.workflow.enums import SystemVariableKey
  9. VariableValue = Union[str, int, float, dict, list, FileVar]
  10. SYSTEM_VARIABLE_NODE_ID = "sys"
  11. ENVIRONMENT_VARIABLE_NODE_ID = "env"
  12. CONVERSATION_VARIABLE_NODE_ID = "conversation"
  13. class VariablePool(BaseModel):
  14. # Variable dictionary is a dictionary for looking up variables by their selector.
  15. # The first element of the selector is the node id, it's the first-level key in the dictionary.
  16. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
  17. # elements of the selector except the first one.
  18. variable_dictionary: dict[str, dict[int, Segment]] = Field(
  19. description='Variables mapping',
  20. default=defaultdict(dict)
  21. )
  22. # TODO: This user inputs is not used for pool.
  23. user_inputs: Mapping[str, Any] = Field(
  24. description='User inputs',
  25. )
  26. system_variables: Mapping[SystemVariableKey, Any] = Field(
  27. description='System variables',
  28. )
  29. environment_variables: Sequence[Variable] = Field(
  30. description="Environment variables.",
  31. default_factory=list
  32. )
  33. conversation_variables: Sequence[Variable] | None = None
  34. @model_validator(mode="after")
  35. def val_model_after(self):
  36. """
  37. Append system variables
  38. :return:
  39. """
  40. # Add system variables to the variable pool
  41. for key, value in self.system_variables.items():
  42. self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
  43. # Add environment variables to the variable pool
  44. for var in self.environment_variables or []:
  45. self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
  46. # Add conversation variables to the variable pool
  47. for var in self.conversation_variables or []:
  48. self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
  49. return self
  50. def add(self, selector: Sequence[str], value: Any, /) -> None:
  51. """
  52. Adds a variable to the variable pool.
  53. Args:
  54. selector (Sequence[str]): The selector for the variable.
  55. value (VariableValue): The value of the variable.
  56. Raises:
  57. ValueError: If the selector is invalid.
  58. Returns:
  59. None
  60. """
  61. if len(selector) < 2:
  62. raise ValueError("Invalid selector")
  63. if value is None:
  64. return
  65. if isinstance(value, Segment):
  66. v = value
  67. else:
  68. v = factory.build_segment(value)
  69. hash_key = hash(tuple(selector[1:]))
  70. self.variable_dictionary[selector[0]][hash_key] = v
  71. def get(self, selector: Sequence[str], /) -> Segment | None:
  72. """
  73. Retrieves the value from the variable pool based on the given selector.
  74. Args:
  75. selector (Sequence[str]): The selector used to identify the variable.
  76. Returns:
  77. Any: The value associated with the given selector.
  78. Raises:
  79. ValueError: If the selector is invalid.
  80. """
  81. if len(selector) < 2:
  82. raise ValueError("Invalid selector")
  83. hash_key = hash(tuple(selector[1:]))
  84. value = self.variable_dictionary[selector[0]].get(hash_key)
  85. return value
  86. @deprecated("This method is deprecated, use `get` instead.")
  87. def get_any(self, selector: Sequence[str], /) -> Any | None:
  88. """
  89. Retrieves the value from the variable pool based on the given selector.
  90. Args:
  91. selector (Sequence[str]): The selector used to identify the variable.
  92. Returns:
  93. Any: The value associated with the given selector.
  94. Raises:
  95. ValueError: If the selector is invalid.
  96. """
  97. if len(selector) < 2:
  98. raise ValueError("Invalid selector")
  99. hash_key = hash(tuple(selector[1:]))
  100. value = self.variable_dictionary[selector[0]].get(hash_key)
  101. return value.to_object() if value else None
  102. def remove(self, selector: Sequence[str], /):
  103. """
  104. Remove variables from the variable pool based on the given selector.
  105. Args:
  106. selector (Sequence[str]): A sequence of strings representing the selector.
  107. Returns:
  108. None
  109. """
  110. if not selector:
  111. return
  112. if len(selector) == 1:
  113. self.variable_dictionary[selector[0]] = {}
  114. return
  115. hash_key = hash(tuple(selector[1:]))
  116. self.variable_dictionary[selector[0]].pop(hash_key, None)
  117. def remove_node(self, node_id: str, /):
  118. """
  119. Remove all variables associated with a given node id.
  120. Args:
  121. node_id (str): The node id to remove.
  122. Returns:
  123. None
  124. """
  125. self.variable_dictionary.pop(node_id, None)