| 
					
				 | 
			
			
				@@ -1,5 +1,3 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import base64 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import io 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import json 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import random 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import uuid 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -8,7 +6,7 @@ import httpx 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from websocket import WebSocket 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from yarl import URL 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from core.file.file_manager import _get_encoded_string 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.file.file_manager import download 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.file.models import File 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -29,8 +27,7 @@ class ComfyUiClient: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return response.content 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def upload_image(self, image_file: File) -> dict: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        image_content = base64.b64decode(_get_encoded_string(image_file)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        file = io.BytesIO(image_content) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        file = download(image_file) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         files = {"image": (image_file.filename, file, image_file.mime_type), "overwrite": "true"} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         res = httpx.post(str(self.base_url / "upload/image"), files=files) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return res.json() 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -47,12 +44,7 @@ class ComfyUiClient: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ws.connect(ws_address) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return ws, client_id 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def set_prompt( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "", image_name: str = "" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    ) -> dict: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        find the first KSampler, then can find the prompt node through it. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def set_prompt_by_ksampler(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "") -> dict: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         prompt = origin_prompt.copy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         id_to_class_type = {id: details["class_type"] for id, details in prompt.items()} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         k_sampler = [key for key, value in id_to_class_type.items() if value == "KSampler"][0] 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -64,9 +56,20 @@ class ComfyUiClient: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if image_name != "": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            image_loader = [key for key, value in id_to_class_type.items() if value == "LoadImage"][0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            prompt.get(image_loader)["inputs"]["image"] = image_name 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return prompt 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def set_prompt_images_by_ids(self, origin_prompt: dict, image_names: list[str], image_ids: list[str]) -> dict: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        prompt = origin_prompt.copy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for index, image_node_id in enumerate(image_ids): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            prompt[image_node_id]["inputs"]["image"] = image_names[index] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return prompt 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def set_prompt_images_by_default(self, origin_prompt: dict, image_names: list[str]) -> dict: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        prompt = origin_prompt.copy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        id_to_class_type = {id: details["class_type"] for id, details in prompt.items()} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        load_image_nodes = [key for key, value in id_to_class_type.items() if value == "LoadImage"] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for load_image, image_name in zip(load_image_nodes, image_names): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            prompt.get(load_image)["inputs"]["image"] = image_name 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return prompt 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str): 
			 |