|  | @@ -1,5 +1,3 @@
 | 
	
		
			
				|  |  | -import base64
 | 
	
		
			
				|  |  | -import io
 | 
	
		
			
				|  |  |  import json
 | 
	
		
			
				|  |  |  import random
 | 
	
		
			
				|  |  |  import uuid
 | 
	
	
		
			
				|  | @@ -8,7 +6,7 @@ import httpx
 | 
	
		
			
				|  |  |  from websocket import WebSocket
 | 
	
		
			
				|  |  |  from yarl import URL
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -from core.file.file_manager import _get_encoded_string
 | 
	
		
			
				|  |  | +from core.file.file_manager import download
 | 
	
		
			
				|  |  |  from core.file.models import File
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -29,8 +27,7 @@ class ComfyUiClient:
 | 
	
		
			
				|  |  |          return response.content
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def upload_image(self, image_file: File) -> dict:
 | 
	
		
			
				|  |  | -        image_content = base64.b64decode(_get_encoded_string(image_file))
 | 
	
		
			
				|  |  | -        file = io.BytesIO(image_content)
 | 
	
		
			
				|  |  | +        file = download(image_file)
 | 
	
		
			
				|  |  |          files = {"image": (image_file.filename, file, image_file.mime_type), "overwrite": "true"}
 | 
	
		
			
				|  |  |          res = httpx.post(str(self.base_url / "upload/image"), files=files)
 | 
	
		
			
				|  |  |          return res.json()
 | 
	
	
		
			
				|  | @@ -47,12 +44,7 @@ class ComfyUiClient:
 | 
	
		
			
				|  |  |          ws.connect(ws_address)
 | 
	
		
			
				|  |  |          return ws, client_id
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def set_prompt(
 | 
	
		
			
				|  |  | -        self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "", image_name: str = ""
 | 
	
		
			
				|  |  | -    ) -> dict:
 | 
	
		
			
				|  |  | -        """
 | 
	
		
			
				|  |  | -        find the first KSampler, then can find the prompt node through it.
 | 
	
		
			
				|  |  | -        """
 | 
	
		
			
				|  |  | +    def set_prompt_by_ksampler(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "") -> dict:
 | 
	
		
			
				|  |  |          prompt = origin_prompt.copy()
 | 
	
		
			
				|  |  |          id_to_class_type = {id: details["class_type"] for id, details in prompt.items()}
 | 
	
		
			
				|  |  |          k_sampler = [key for key, value in id_to_class_type.items() if value == "KSampler"][0]
 | 
	
	
		
			
				|  | @@ -64,9 +56,20 @@ class ComfyUiClient:
 | 
	
		
			
				|  |  |              negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0]
 | 
	
		
			
				|  |  |              prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        if image_name != "":
 | 
	
		
			
				|  |  | -            image_loader = [key for key, value in id_to_class_type.items() if value == "LoadImage"][0]
 | 
	
		
			
				|  |  | -            prompt.get(image_loader)["inputs"]["image"] = image_name
 | 
	
		
			
				|  |  | +        return prompt
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def set_prompt_images_by_ids(self, origin_prompt: dict, image_names: list[str], image_ids: list[str]) -> dict:
 | 
	
		
			
				|  |  | +        prompt = origin_prompt.copy()
 | 
	
		
			
				|  |  | +        for index, image_node_id in enumerate(image_ids):
 | 
	
		
			
				|  |  | +            prompt[image_node_id]["inputs"]["image"] = image_names[index]
 | 
	
		
			
				|  |  | +        return prompt
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def set_prompt_images_by_default(self, origin_prompt: dict, image_names: list[str]) -> dict:
 | 
	
		
			
				|  |  | +        prompt = origin_prompt.copy()
 | 
	
		
			
				|  |  | +        id_to_class_type = {id: details["class_type"] for id, details in prompt.items()}
 | 
	
		
			
				|  |  | +        load_image_nodes = [key for key, value in id_to_class_type.items() if value == "LoadImage"]
 | 
	
		
			
				|  |  | +        for load_image, image_name in zip(load_image_nodes, image_names):
 | 
	
		
			
				|  |  | +            prompt.get(load_image)["inputs"]["image"] = image_name
 | 
	
		
			
				|  |  |          return prompt
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str):
 |