|
@@ -22,7 +22,15 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|
|
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
|
|
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
|
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
|
|
-from core.variables import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment
|
|
|
+from core.variables import (
|
|
|
+ ArrayAnySegment,
|
|
|
+ ArrayFileSegment,
|
|
|
+ ArraySegment,
|
|
|
+ FileSegment,
|
|
|
+ NoneSegment,
|
|
|
+ ObjectSegment,
|
|
|
+ StringSegment,
|
|
|
+)
|
|
|
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
|
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
|
|
from core.workflow.enums import SystemVariableKey
|
|
@@ -263,50 +271,44 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|
|
return variables
|
|
|
|
|
|
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
|
|
- variable = variable_selector.variable
|
|
|
- value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
|
|
|
+ variable_name = variable_selector.variable
|
|
|
+ variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
|
|
+ if variable is None:
|
|
|
+ raise ValueError(f"Variable {variable_selector.variable} not found")
|
|
|
|
|
|
- def parse_dict(d: dict) -> str:
|
|
|
+ def parse_dict(input_dict: Mapping[str, Any]) -> str:
|
|
|
"""
|
|
|
Parse dict into string
|
|
|
"""
|
|
|
# check if it's a context structure
|
|
|
- if "metadata" in d and "_source" in d["metadata"] and "content" in d:
|
|
|
- return d["content"]
|
|
|
+ if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict:
|
|
|
+ return input_dict["content"]
|
|
|
|
|
|
# else, parse the dict
|
|
|
try:
|
|
|
- return json.dumps(d, ensure_ascii=False)
|
|
|
+ return json.dumps(input_dict, ensure_ascii=False)
|
|
|
except Exception:
|
|
|
- return str(d)
|
|
|
+ return str(input_dict)
|
|
|
|
|
|
- if isinstance(value, str):
|
|
|
- value = value
|
|
|
- elif isinstance(value, list):
|
|
|
+ if isinstance(variable, ArraySegment):
|
|
|
result = ""
|
|
|
- for item in value:
|
|
|
+ for item in variable.value:
|
|
|
if isinstance(item, dict):
|
|
|
result += parse_dict(item)
|
|
|
- elif isinstance(item, str):
|
|
|
- result += item
|
|
|
- elif isinstance(item, int | float):
|
|
|
- result += str(item)
|
|
|
else:
|
|
|
result += str(item)
|
|
|
result += "\n"
|
|
|
value = result.strip()
|
|
|
- elif isinstance(value, dict):
|
|
|
- value = parse_dict(value)
|
|
|
- elif isinstance(value, int | float):
|
|
|
- value = str(value)
|
|
|
+ elif isinstance(variable, ObjectSegment):
|
|
|
+ value = parse_dict(variable.value)
|
|
|
else:
|
|
|
- value = str(value)
|
|
|
+ value = variable.text
|
|
|
|
|
|
- variables[variable] = value
|
|
|
+ variables[variable_name] = value
|
|
|
|
|
|
return variables
|
|
|
|
|
|
- def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
|
|
|
+ def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]:
|
|
|
inputs = {}
|
|
|
prompt_template = node_data.prompt_template
|
|
|
|
|
@@ -363,14 +365,14 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|
|
if not node_data.context.variable_selector:
|
|
|
return
|
|
|
|
|
|
- context_value = self.graph_runtime_state.variable_pool.get_any(node_data.context.variable_selector)
|
|
|
- if context_value:
|
|
|
- if isinstance(context_value, str):
|
|
|
- yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value)
|
|
|
- elif isinstance(context_value, list):
|
|
|
+ context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector)
|
|
|
+ if context_value_variable:
|
|
|
+ if isinstance(context_value_variable, StringSegment):
|
|
|
+ yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
|
|
|
+ elif isinstance(context_value_variable, ArraySegment):
|
|
|
context_str = ""
|
|
|
original_retriever_resource = []
|
|
|
- for item in context_value:
|
|
|
+ for item in context_value_variable.value:
|
|
|
if isinstance(item, str):
|
|
|
context_str += item + "\n"
|
|
|
else:
|
|
@@ -484,11 +486,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|
|
return None
|
|
|
|
|
|
# get conversation id
|
|
|
- conversation_id = self.graph_runtime_state.variable_pool.get_any(
|
|
|
+ conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
|
|
["sys", SystemVariableKey.CONVERSATION_ID.value]
|
|
|
)
|
|
|
- if conversation_id is None:
|
|
|
+ if not isinstance(conversation_id_variable, StringSegment):
|
|
|
return None
|
|
|
+ conversation_id = conversation_id_variable.value
|
|
|
|
|
|
# get conversation
|
|
|
conversation = (
|