Browse Source

refactor(variables): replace deprecated 'get_any' with 'get' method (#9584)

-LAN- 5 months ago
parent
commit
8f670f31b8

+ 0 - 21
api/core/workflow/entities/variable_pool.py

@@ -4,7 +4,6 @@ from collections.abc import Mapping, Sequence
 from typing import Any, Union
 
 from pydantic import BaseModel, Field
-from typing_extensions import deprecated
 
 from core.file import File, FileAttribute, file_manager
 from core.variables import Segment, SegmentGroup, Variable
@@ -133,26 +132,6 @@ class VariablePool(BaseModel):
 
         return value
 
-    @deprecated("This method is deprecated, use `get` instead.")
-    def get_any(self, selector: Sequence[str], /) -> Any | None:
-        """
-        Retrieves the value from the variable pool based on the given selector.
-
-        Args:
-            selector (Sequence[str]): The selector used to identify the variable.
-
-        Returns:
-            Any: The value associated with the given selector.
-
-        Raises:
-            ValueError: If the selector is invalid.
-        """
-        if len(selector) < 2:
-            raise ValueError("Invalid selector")
-        hash_key = hash(tuple(selector[1:]))
-        value = self.variable_dictionary[selector[0]].get(hash_key)
-        return value.to_object() if value else None
-
     def remove(self, selector: Sequence[str], /):
         """
         Remove variables from the variable pool based on the given selector.

+ 9 - 4
api/core/workflow/nodes/code/code_node.py

@@ -41,10 +41,15 @@ class CodeNode(BaseNode[CodeNodeData]):
         # Get variables
         variables = {}
         for variable_selector in self.node_data.variables:
-            variable = variable_selector.variable
-            value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
-
-            variables[variable] = value
+            variable_name = variable_selector.variable
+            variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
+            if variable is None:
+                return NodeRunResult(
+                    status=WorkflowNodeExecutionStatus.FAILED,
+                    inputs=variables,
+                    error=f"Variable `{variable_selector.value_selector}` not found",
+                )
+            variables[variable_name] = variable.to_object()
         # Run code
         try:
             result = CodeExecutor.execute_workflow_code_template(

+ 25 - 10
api/core/workflow/nodes/iteration/iteration_node.py

@@ -5,6 +5,7 @@ from typing import Any, cast
 
 from configs import dify_config
 from core.model_runtime.utils.encoders import jsonable_encoder
+from core.variables import IntegerSegment
 from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
 from core.workflow.graph_engine.entities.event import (
     BaseGraphEvent,
@@ -147,9 +148,16 @@ class IterationNode(BaseNode[IterationNodeData]):
 
                             if NodeRunMetadataKey.ITERATION_ID not in metadata:
                                 metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
-                                metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any(
-                                    [self.node_id, "index"]
-                                )
+                                index_variable = variable_pool.get([self.node_id, "index"])
+                                if not isinstance(index_variable, IntegerSegment):
+                                    yield RunCompletedEvent(
+                                        run_result=NodeRunResult(
+                                            status=WorkflowNodeExecutionStatus.FAILED,
+                                            error=f"Invalid index variable type: {type(index_variable)}",
+                                        )
+                                    )
+                                    return
+                                metadata[NodeRunMetadataKey.ITERATION_INDEX] = index_variable.value
                                 event.route_node_state.node_run_result.metadata = metadata
 
                         yield event
@@ -181,7 +189,16 @@ class IterationNode(BaseNode[IterationNodeData]):
                         yield event
 
                 # append to iteration output variable list
-                current_iteration_output = variable_pool.get_any(self.node_data.output_selector)
+                current_iteration_output_variable = variable_pool.get(self.node_data.output_selector)
+                if current_iteration_output_variable is None:
+                    yield RunCompletedEvent(
+                        run_result=NodeRunResult(
+                            status=WorkflowNodeExecutionStatus.FAILED,
+                            error=f"Iteration output variable {self.node_data.output_selector} not found",
+                        )
+                    )
+                    return
+                current_iteration_output = current_iteration_output_variable.to_object()
                 outputs.append(current_iteration_output)
 
                 # remove all nodes outputs from variable pool
@@ -189,11 +206,11 @@ class IterationNode(BaseNode[IterationNodeData]):
                     variable_pool.remove([node_id])
 
                 # move to next iteration
-                current_index = variable_pool.get([self.node_id, "index"])
-                if current_index is None:
+                current_index_variable = variable_pool.get([self.node_id, "index"])
+                if not isinstance(current_index_variable, IntegerSegment):
                     raise ValueError(f"iteration {self.node_id} current index not found")
 
-                next_index = int(current_index.to_object()) + 1
+                next_index = current_index_variable.value + 1
                 variable_pool.add([self.node_id, "index"], next_index)
 
                 if next_index < len(iterator_list_value):
@@ -205,9 +222,7 @@ class IterationNode(BaseNode[IterationNodeData]):
                     iteration_node_type=self.node_type,
                     iteration_node_data=self.node_data,
                     index=next_index,
-                    pre_iteration_output=jsonable_encoder(current_iteration_output)
-                    if current_iteration_output
-                    else None,
+                    pre_iteration_output=jsonable_encoder(current_iteration_output),
                 )
 
             yield IterationRunSucceededEvent(

+ 9 - 2
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py

@@ -14,6 +14,7 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
 from core.rag.retrieval.retrieval_methods import RetrievalMethod
+from core.variables import StringSegment
 from core.workflow.entities.node_entities import NodeRunResult
 from core.workflow.nodes.base import BaseNode
 from core.workflow.nodes.enums import NodeType
@@ -39,8 +40,14 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
 
     def _run(self) -> NodeRunResult:
         # extract variables
-        variable = self.graph_runtime_state.variable_pool.get_any(self.node_data.query_variable_selector)
-        query = variable
+        variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector)
+        if not isinstance(variable, StringSegment):
+            return NodeRunResult(
+                status=WorkflowNodeExecutionStatus.FAILED,
+                inputs={},
+                error="Query variable is not string type.",
+            )
+        query = variable.value
         variables = {"query": query}
         if not query:
             return NodeRunResult(

+ 34 - 31
api/core/workflow/nodes/llm/node.py

@@ -22,7 +22,15 @@ from core.model_runtime.utils.encoders import jsonable_encoder
 from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
 from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
 from core.prompt.utils.prompt_message_util import PromptMessageUtil
-from core.variables import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment
+from core.variables import (
+    ArrayAnySegment,
+    ArrayFileSegment,
+    ArraySegment,
+    FileSegment,
+    NoneSegment,
+    ObjectSegment,
+    StringSegment,
+)
 from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
 from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
 from core.workflow.enums import SystemVariableKey
@@ -263,50 +271,44 @@ class LLMNode(BaseNode[LLMNodeData]):
             return variables
 
         for variable_selector in node_data.prompt_config.jinja2_variables or []:
-            variable = variable_selector.variable
-            value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
+            variable_name = variable_selector.variable
+            variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
+            if variable is None:
+                raise ValueError(f"Variable {variable_selector.variable} not found")
 
-            def parse_dict(d: dict) -> str:
+            def parse_dict(input_dict: Mapping[str, Any]) -> str:
                 """
                 Parse dict into string
                 """
                 # check if it's a context structure
-                if "metadata" in d and "_source" in d["metadata"] and "content" in d:
-                    return d["content"]
+                if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict:
+                    return input_dict["content"]
 
                 # else, parse the dict
                 try:
-                    return json.dumps(d, ensure_ascii=False)
+                    return json.dumps(input_dict, ensure_ascii=False)
                 except Exception:
-                    return str(d)
+                    return str(input_dict)
 
-            if isinstance(value, str):
-                value = value
-            elif isinstance(value, list):
+            if isinstance(variable, ArraySegment):
                 result = ""
-                for item in value:
+                for item in variable.value:
                     if isinstance(item, dict):
                         result += parse_dict(item)
-                    elif isinstance(item, str):
-                        result += item
-                    elif isinstance(item, int | float):
-                        result += str(item)
                     else:
                         result += str(item)
                     result += "\n"
                 value = result.strip()
-            elif isinstance(value, dict):
-                value = parse_dict(value)
-            elif isinstance(value, int | float):
-                value = str(value)
+            elif isinstance(variable, ObjectSegment):
+                value = parse_dict(variable.value)
             else:
-                value = str(value)
+                value = variable.text
 
-            variables[variable] = value
+            variables[variable_name] = value
 
         return variables
 
-    def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
+    def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]:
         inputs = {}
         prompt_template = node_data.prompt_template
 
@@ -363,14 +365,14 @@ class LLMNode(BaseNode[LLMNodeData]):
         if not node_data.context.variable_selector:
             return
 
-        context_value = self.graph_runtime_state.variable_pool.get_any(node_data.context.variable_selector)
-        if context_value:
-            if isinstance(context_value, str):
-                yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value)
-            elif isinstance(context_value, list):
+        context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector)
+        if context_value_variable:
+            if isinstance(context_value_variable, StringSegment):
+                yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
+            elif isinstance(context_value_variable, ArraySegment):
                 context_str = ""
                 original_retriever_resource = []
-                for item in context_value:
+                for item in context_value_variable.value:
                     if isinstance(item, str):
                         context_str += item + "\n"
                     else:
@@ -484,11 +486,12 @@ class LLMNode(BaseNode[LLMNodeData]):
             return None
 
         # get conversation id
-        conversation_id = self.graph_runtime_state.variable_pool.get_any(
+        conversation_id_variable = self.graph_runtime_state.variable_pool.get(
             ["sys", SystemVariableKey.CONVERSATION_ID.value]
         )
-        if conversation_id is None:
+        if not isinstance(conversation_id_variable, StringSegment):
             return None
+        conversation_id = conversation_id_variable.value
 
         # get conversation
         conversation = (

+ 7 - 2
api/core/workflow/nodes/template_transform/template_transform_node.py

@@ -33,8 +33,13 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
         variables = {}
         for variable_selector in self.node_data.variables:
             variable_name = variable_selector.variable
-            value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
-            variables[variable_name] = value
+            value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
+            if value is None:
+                return NodeRunResult(
+                    status=WorkflowNodeExecutionStatus.FAILED,
+                    error=f"Variable {variable_name} not found in variable pool",
+                )
+            variables[variable_name] = value.to_object()
         # Run code
         try:
             result = CodeExecutor.execute_workflow_code_template(

+ 7 - 7
api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py

@@ -19,27 +19,27 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
 
         if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
             for selector in self.node_data.variables:
-                variable = self.graph_runtime_state.variable_pool.get_any(selector)
+                variable = self.graph_runtime_state.variable_pool.get(selector)
                 if variable is not None:
-                    outputs = {"output": variable}
+                    outputs = {"output": variable.to_object()}
 
-                    inputs = {".".join(selector[1:]): variable}
+                    inputs = {".".join(selector[1:]): variable.to_object()}
                     break
         else:
             for group in self.node_data.advanced_settings.groups:
                 for selector in group.variables:
-                    variable = self.graph_runtime_state.variable_pool.get_any(selector)
+                    variable = self.graph_runtime_state.variable_pool.get(selector)
 
                     if variable is not None:
-                        outputs[group.group_name] = {"output": variable}
-                        inputs[".".join(selector[1:])] = variable
+                        outputs[group.group_name] = {"output": variable.to_object()}
+                        inputs[".".join(selector[1:])] = variable.to_object()
                         break
 
         return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, inputs=inputs)
 
     @classmethod
     def _extract_variable_selector_to_variable_mapping(
-        cls, graph_config: Mapping[str, Any], node_id: str, node_data: VariableAssignerNodeData
+        cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: VariableAssignerNodeData
     ) -> Mapping[str, Sequence[str]]:
         """
         Extract variable selector to variable mapping

+ 4 - 0
api/tests/integration_tests/workflow/nodes/test_code.py

@@ -102,6 +102,8 @@ def test_execute_code(setup_code_executor_mock):
     }
 
     node = init_code_node(code_config)
+    node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], 1)
+    node.graph_runtime_state.variable_pool.add(["1", "123", "args2"], 2)
 
     # execute node
     result = node._run()
@@ -146,6 +148,8 @@ def test_execute_code_output_validator(setup_code_executor_mock):
     }
 
     node = init_code_node(code_config)
+    node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], 1)
+    node.graph_runtime_state.variable_pool.add(["1", "123", "args2"], 2)
 
     # execute node
     result = node._run()