| 
					
				 | 
			
			
				@@ -1,3 +1,5 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import base64 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import io 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import json 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import random 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import uuid 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -6,45 +8,48 @@ import httpx 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from websocket import WebSocket 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from yarl import URL 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.file.file_manager import _get_encoded_string 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.file.models import File 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 class ComfyUiClient: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def __init__(self, base_url: str): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.base_url = URL(base_url) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def get_history(self, prompt_id: str): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def get_history(self, prompt_id: str) -> dict: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         res = httpx.get(str(self.base_url / "history"), params={"prompt_id": prompt_id}) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         history = res.json()[prompt_id] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return history 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def get_image(self, filename: str, subfolder: str, folder_type: str): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def get_image(self, filename: str, subfolder: str, folder_type: str) -> bytes: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         response = httpx.get( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             str(self.base_url / "view"), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             params={"filename": filename, "subfolder": subfolder, "type": folder_type}, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return response.content 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def upload_image(self, input_path: str, name: str, image_type: str = "input", overwrite: bool = False): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # plan to support img2img in dify 0.10.0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        with open(input_path, "rb") as file: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            files = {"image": (name, file, "image/png")} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            data = {"type": image_type, "overwrite": str(overwrite).lower()} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        res = httpx.post(str(self.base_url / "upload/image"), data=data, files=files) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return res 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def upload_image(self, image_file: File) -> dict: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        image_content = base64.b64decode(_get_encoded_string(image_file)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        file = io.BytesIO(image_content) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        files = {"image": (image_file.filename, file, image_file.mime_type), "overwrite": "true"} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        res = httpx.post(str(self.base_url / "upload/image"), files=files) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return res.json() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def queue_prompt(self, client_id: str, prompt: dict): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def queue_prompt(self, client_id: str, prompt: dict) -> str: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         res = httpx.post(str(self.base_url / "prompt"), json={"client_id": client_id, "prompt": prompt}) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         prompt_id = res.json()["prompt_id"] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return prompt_id 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def open_websocket_connection(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def open_websocket_connection(self) -> tuple[WebSocket, str]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         client_id = str(uuid.uuid4()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ws = WebSocket() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ws.connect(ws_address) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return ws, client_id 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def set_prompt(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = ""): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def set_prompt( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "", image_name: str = "" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) -> dict: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         find the first KSampler, then can find the prompt node through it. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -58,6 +63,10 @@ class ComfyUiClient: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if negative_prompt != "": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if image_name != "": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            image_loader = [key for key, value in id_to_class_type.items() if value == "LoadImage"][0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            prompt.get(image_loader)["inputs"]["image"] = image_name 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return prompt 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str): 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -89,7 +98,7 @@ class ComfyUiClient: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def generate_image_by_prompt(self, prompt: dict): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def generate_image_by_prompt(self, prompt: dict) -> list[bytes]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         try: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             ws, client_id = self.open_websocket_connection() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             prompt_id = self.queue_prompt(client_id, prompt) 
			 |