| 
					
				 | 
			
			
				@@ -0,0 +1,125 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import pytest 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.app.entities.app_invoke_entities import InvokeFrom 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.file import File, FileTransferMethod, FileType 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.model_runtime.entities.message_entities import ImagePromptMessageContent 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.workflow.entities.variable_pool import VariablePool 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.workflow.nodes.answer import AnswerStreamGenerateRoute 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.workflow.nodes.end import EndStreamParam 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.workflow.nodes.llm.node import LLMNode 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from models.enums import UserFrom 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from models.workflow import WorkflowType 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+class TestLLMNode: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    @pytest.fixture 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def llm_node(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        data = LLMNodeData( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            title="Test LLM", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            prompt_template=[], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            memory=None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            context=ContextConfig(enabled=False), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            vision=VisionConfig( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                enabled=True, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                configs=VisionConfigOptions( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    variable_selector=["sys", "files"], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    detail=ImagePromptMessageContent.DETAIL.HIGH, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        variable_pool = VariablePool( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            system_variables={}, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            user_inputs={}, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        node = LLMNode( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            id="1", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            config={ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "id": "1", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "data": data.model_dump(), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            }, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            graph_init_params=GraphInitParams( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                tenant_id="1", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                app_id="1", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                workflow_type=WorkflowType.WORKFLOW, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                workflow_id="1", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                graph_config={}, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                user_id="1", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                user_from=UserFrom.ACCOUNT, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                invoke_from=InvokeFrom.SERVICE_API, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                call_depth=0, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            graph=Graph( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                root_node_id="1", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                answer_stream_generate_routes=AnswerStreamGenerateRoute( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    answer_dependencies={}, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    answer_generate_route={}, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                end_stream_param=EndStreamParam( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    end_dependencies={}, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    end_stream_variable_selector_mapping={}, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            graph_runtime_state=GraphRuntimeState( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                variable_pool=variable_pool, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                start_at=0, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return node 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def test_fetch_files_with_file_segment(self, llm_node): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        file = File( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            id="1", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            tenant_id="test", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            type=FileType.IMAGE, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            filename="test.jpg", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            transfer_method=FileTransferMethod.LOCAL_FILE, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            related_id="1", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        result = llm_node._fetch_files(selector=["sys", "files"]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        assert result == [file] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def test_fetch_files_with_array_file_segment(self, llm_node): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        files = [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            File( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                id="1", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                tenant_id="test", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                type=FileType.IMAGE, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                filename="test1.jpg", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                transfer_method=FileTransferMethod.LOCAL_FILE, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                related_id="1", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            File( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                id="2", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                tenant_id="test", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                type=FileType.IMAGE, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                filename="test2.jpg", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                transfer_method=FileTransferMethod.LOCAL_FILE, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                related_id="2", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        result = llm_node._fetch_files(selector=["sys", "files"]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        assert result == files 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def test_fetch_files_with_none_segment(self, llm_node): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        result = llm_node._fetch_files(selector=["sys", "files"]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        assert result == [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def test_fetch_files_with_array_any_segment(self, llm_node): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        result = llm_node._fetch_files(selector=["sys", "files"]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        assert result == [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def test_fetch_files_with_non_existent_variable(self, llm_node): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        result = llm_node._fetch_files(selector=["sys", "files"]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        assert result == [] 
			 |