variable_pool.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import re
  2. from collections import defaultdict
  3. from collections.abc import Mapping, Sequence
  4. from typing import Any, Union
  5. from pydantic import BaseModel, Field
  6. from typing_extensions import deprecated
  7. from core.file import File, FileAttribute, file_manager
  8. from core.variables import Segment, SegmentGroup, Variable
  9. from core.variables.segments import FileSegment
  10. from factories import variable_factory
  11. from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
  12. from ..enums import SystemVariableKey
  13. VariableValue = Union[str, int, float, dict, list, File]
  14. VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
  15. class VariablePool(BaseModel):
  16. # Variable dictionary is a dictionary for looking up variables by their selector.
  17. # The first element of the selector is the node id, it's the first-level key in the dictionary.
  18. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
  19. # elements of the selector except the first one.
  20. variable_dictionary: dict[str, dict[int, Segment]] = Field(
  21. description="Variables mapping",
  22. default=defaultdict(dict),
  23. )
  24. # TODO: This user inputs is not used for pool.
  25. user_inputs: Mapping[str, Any] = Field(
  26. description="User inputs",
  27. )
  28. system_variables: Mapping[SystemVariableKey, Any] = Field(
  29. description="System variables",
  30. )
  31. environment_variables: Sequence[Variable] = Field(
  32. description="Environment variables.",
  33. default_factory=list,
  34. )
  35. conversation_variables: Sequence[Variable] = Field(
  36. description="Conversation variables.",
  37. default_factory=list,
  38. )
  39. def __init__(
  40. self,
  41. *,
  42. system_variables: Mapping[SystemVariableKey, Any] | None = None,
  43. user_inputs: Mapping[str, Any] | None = None,
  44. environment_variables: Sequence[Variable] | None = None,
  45. conversation_variables: Sequence[Variable] | None = None,
  46. **kwargs,
  47. ):
  48. environment_variables = environment_variables or []
  49. conversation_variables = conversation_variables or []
  50. user_inputs = user_inputs or {}
  51. system_variables = system_variables or {}
  52. super().__init__(
  53. system_variables=system_variables,
  54. user_inputs=user_inputs,
  55. environment_variables=environment_variables,
  56. conversation_variables=conversation_variables,
  57. **kwargs,
  58. )
  59. for key, value in self.system_variables.items():
  60. self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
  61. # Add environment variables to the variable pool
  62. for var in self.environment_variables:
  63. self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
  64. # Add conversation variables to the variable pool
  65. for var in self.conversation_variables:
  66. self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
  67. def add(self, selector: Sequence[str], value: Any, /) -> None:
  68. """
  69. Adds a variable to the variable pool.
  70. NOTE: You should not add a non-Segment value to the variable pool
  71. even if it is allowed now.
  72. Args:
  73. selector (Sequence[str]): The selector for the variable.
  74. value (VariableValue): The value of the variable.
  75. Raises:
  76. ValueError: If the selector is invalid.
  77. Returns:
  78. None
  79. """
  80. if len(selector) < 2:
  81. raise ValueError("Invalid selector")
  82. if value is None:
  83. return
  84. if isinstance(value, Segment):
  85. v = value
  86. else:
  87. v = variable_factory.build_segment(value)
  88. hash_key = hash(tuple(selector[1:]))
  89. self.variable_dictionary[selector[0]][hash_key] = v
  90. def get(self, selector: Sequence[str], /) -> Segment | None:
  91. """
  92. Retrieves the value from the variable pool based on the given selector.
  93. Args:
  94. selector (Sequence[str]): The selector used to identify the variable.
  95. Returns:
  96. Any: The value associated with the given selector.
  97. Raises:
  98. ValueError: If the selector is invalid.
  99. """
  100. if len(selector) < 2:
  101. return None
  102. hash_key = hash(tuple(selector[1:]))
  103. value = self.variable_dictionary[selector[0]].get(hash_key)
  104. if value is None:
  105. selector, attr = selector[:-1], selector[-1]
  106. value = self.get(selector)
  107. if isinstance(value, FileSegment):
  108. attr = FileAttribute(attr)
  109. attr_value = file_manager.get_attr(file=value.value, attr=attr)
  110. return variable_factory.build_segment(attr_value)
  111. return value
  112. @deprecated("This method is deprecated, use `get` instead.")
  113. def get_any(self, selector: Sequence[str], /) -> Any | None:
  114. """
  115. Retrieves the value from the variable pool based on the given selector.
  116. Args:
  117. selector (Sequence[str]): The selector used to identify the variable.
  118. Returns:
  119. Any: The value associated with the given selector.
  120. Raises:
  121. ValueError: If the selector is invalid.
  122. """
  123. if len(selector) < 2:
  124. raise ValueError("Invalid selector")
  125. hash_key = hash(tuple(selector[1:]))
  126. value = self.variable_dictionary[selector[0]].get(hash_key)
  127. return value.to_object() if value else None
  128. def remove(self, selector: Sequence[str], /):
  129. """
  130. Remove variables from the variable pool based on the given selector.
  131. Args:
  132. selector (Sequence[str]): A sequence of strings representing the selector.
  133. Returns:
  134. None
  135. """
  136. if not selector:
  137. return
  138. if len(selector) == 1:
  139. self.variable_dictionary[selector[0]] = {}
  140. return
  141. hash_key = hash(tuple(selector[1:]))
  142. self.variable_dictionary[selector[0]].pop(hash_key, None)
  143. def convert_template(self, template: str, /):
  144. parts = VARIABLE_PATTERN.split(template)
  145. segments = []
  146. for part in filter(lambda x: x, parts):
  147. if "." in part and (variable := self.get(part.split("."))):
  148. segments.append(variable)
  149. else:
  150. segments.append(variable_factory.build_segment(part))
  151. return SegmentGroup(value=segments)
  152. def get_file(self, selector: Sequence[str], /) -> FileSegment | None:
  153. segment = self.get(selector)
  154. if isinstance(segment, FileSegment):
  155. return segment
  156. return None