Sfoglia il codice sorgente

feat: add the workflow tool of comfyUI (#9447)

非法操作 1 anno fa
parent
commit
d3c06a3f76

+ 11 - 7
api/core/tools/provider/builtin/comfyui/comfyui.py

@@ -1,17 +1,21 @@
 from typing import Any
 
+import websocket
+from yarl import URL
+
 from core.tools.errors import ToolProviderCredentialValidationError
-from core.tools.provider.builtin.comfyui.tools.comfyui_stable_diffusion import ComfyuiStableDiffusionTool
 from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
 
 
 class ComfyUIProvider(BuiltinToolProviderController):
     def _validate_credentials(self, credentials: dict[str, Any]) -> None:
+        ws = websocket.WebSocket()
+        base_url = URL(credentials.get("base_url"))
+        ws_address = f"ws://{base_url.authority}/ws?clientId=test123"
+
         try:
-            ComfyuiStableDiffusionTool().fork_tool_runtime(
-                runtime={
-                    "credentials": credentials,
-                }
-            ).validate_models()
+            ws.connect(ws_address)
         except Exception as e:
-            raise ToolProviderCredentialValidationError(str(e))
+            raise ToolProviderCredentialValidationError(f"can not connect to {ws_address}")
+        finally:
+            ws.close()

+ 3 - 22
api/core/tools/provider/builtin/comfyui/comfyui.yaml

@@ -4,11 +4,9 @@ identity:
   label:
     en_US: ComfyUI
     zh_Hans: ComfyUI
-    pt_BR: ComfyUI
   description:
     en_US: ComfyUI is a tool for generating images which can be deployed locally.
     zh_Hans: ComfyUI 是一个可以在本地部署的图片生成的工具。
-    pt_BR: ComfyUI is a tool for generating images which can be deployed locally.
   icon: icon.png
   tags:
     - image
@@ -17,26 +15,9 @@ credentials_for_provider:
     type: text-input
     required: true
     label:
-      en_US: Base URL
-      zh_Hans: ComfyUI服务器的Base URL
-      pt_BR: Base URL
+      en_US: The URL of ComfyUI Server
+      zh_Hans: ComfyUI服务器的URL
     placeholder:
       en_US: Please input your ComfyUI server's Base URL
       zh_Hans: 请输入你的 ComfyUI 服务器的 Base URL
-      pt_BR: Please input your ComfyUI server's Base URL
-  model:
-    type: text-input
-    required: true
-    label:
-      en_US: Model with suffix
-      zh_Hans: 模型, 需要带后缀
-      pt_BR: Model with suffix
-    placeholder:
-      en_US: Please input your model
-      zh_Hans: 请输入你的模型名称
-      pt_BR: Please input your model
-    help:
-      en_US: The checkpoint name of the ComfyUI server, e.g. xxx.safetensors
-      zh_Hans: ComfyUI服务器的模型名称, 比如 xxx.safetensors
-      pt_BR: The checkpoint name of the ComfyUI server, e.g. xxx.safetensors
-    url: https://github.com/comfyanonymous/ComfyUI#installing
+    url: https://docs.dify.ai/guides/tools/tool-configuration/comfyui

+ 105 - 0
api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py

@@ -0,0 +1,105 @@
+import json
+import random
+import uuid
+
+import httpx
+from websocket import WebSocket
+from yarl import URL
+
+
+class ComfyUiClient:
+    def __init__(self, base_url: str):
+        self.base_url = URL(base_url)
+
+    def get_history(self, prompt_id: str):
+        res = httpx.get(str(self.base_url / "history"), params={"prompt_id": prompt_id})
+        history = res.json()[prompt_id]
+        return history
+
+    def get_image(self, filename: str, subfolder: str, folder_type: str):
+        response = httpx.get(
+            str(self.base_url / "view"),
+            params={"filename": filename, "subfolder": subfolder, "type": folder_type},
+        )
+        return response.content
+
+    def upload_image(self, input_path: str, name: str, image_type: str = "input", overwrite: bool = False):
+        # plan to support img2img in dify 0.10.0
+        with open(input_path, "rb") as file:
+            files = {"image": (name, file, "image/png")}
+            data = {"type": image_type, "overwrite": str(overwrite).lower()}
+
+        res = httpx.post(str(self.base_url / "upload/image"), data=data, files=files)
+        return res
+
+    def queue_prompt(self, client_id: str, prompt: dict):
+        res = httpx.post(str(self.base_url / "prompt"), json={"client_id": client_id, "prompt": prompt})
+        prompt_id = res.json()["prompt_id"]
+        return prompt_id
+
+    def open_websocket_connection(self):
+        client_id = str(uuid.uuid4())
+        ws = WebSocket()
+        ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}"
+        ws.connect(ws_address)
+        return ws, client_id
+
+    def set_prompt(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = ""):
+        """
+        find the first KSampler, then can find the prompt node through it.
+        """
+        prompt = origin_prompt.copy()
+        id_to_class_type = {id: details["class_type"] for id, details in prompt.items()}
+        k_sampler = [key for key, value in id_to_class_type.items() if value == "KSampler"][0]
+        prompt.get(k_sampler)["inputs"]["seed"] = random.randint(10**14, 10**15 - 1)
+        positive_input_id = prompt.get(k_sampler)["inputs"]["positive"][0]
+        prompt.get(positive_input_id)["inputs"]["text"] = positive_prompt
+
+        if negative_prompt != "":
+            negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0]
+            prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt
+        return prompt
+
+    def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str):
+        node_ids = list(prompt.keys())
+        finished_nodes = []
+
+        while True:
+            out = ws.recv()
+            if isinstance(out, str):
+                message = json.loads(out)
+                if message["type"] == "progress":
+                    data = message["data"]
+                    current_step = data["value"]
+                    print("In K-Sampler -> Step: ", current_step, " of: ", data["max"])
+                if message["type"] == "execution_cached":
+                    data = message["data"]
+                    for itm in data["nodes"]:
+                        if itm not in finished_nodes:
+                            finished_nodes.append(itm)
+                            print("Progress: ", len(finished_nodes), "/", len(node_ids), " Tasks done")
+                if message["type"] == "executing":
+                    data = message["data"]
+                    if data["node"] not in finished_nodes:
+                        finished_nodes.append(data["node"])
+                        print("Progress: ", len(finished_nodes), "/", len(node_ids), " Tasks done")
+
+                    if data["node"] is None and data["prompt_id"] == prompt_id:
+                        break  # Execution is done
+            else:
+                continue
+
+    def generate_image_by_prompt(self, prompt: dict):
+        try:
+            ws, client_id = self.open_websocket_connection()
+            prompt_id = self.queue_prompt(client_id, prompt)
+            self.track_progress(prompt, ws, prompt_id)
+            history = self.get_history(prompt_id)
+            images = []
+            for output in history["outputs"].values():
+                for img in output.get("images", []):
+                    image_data = self.get_image(img["filename"], img["subfolder"], img["type"])
+                    images.append(image_data)
+            return images
+        finally:
+            ws.close()

+ 4 - 4
api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.yaml

@@ -1,10 +1,10 @@
 identity:
-  name: txt2img workflow
+  name: txt2img
   author: Qun
   label:
-    en_US: Txt2Img Workflow
-    zh_Hans: Txt2Img Workflow
-    pt_BR: Txt2Img Workflow
+    en_US: Txt2Img
+    zh_Hans: Txt2Img
+    pt_BR: Txt2Img
 description:
   human:
     en_US: a pre-defined comfyui workflow that can use one model and up to 3 loras to generate images. Support SD1.5, SDXL, SD3 and FLUX which contain text encoders/clip, but does not support models that requires a triple clip loader.

+ 32 - 0
api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py

@@ -0,0 +1,32 @@
+import json
+from typing import Any
+
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.tool.builtin_tool import BuiltinTool
+
+from .comfyui_client import ComfyUiClient
+
+
+class ComfyUIWorkflowTool(BuiltinTool):
+    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
+        comfyui = ComfyUiClient(self.runtime.credentials["base_url"])
+
+        positive_prompt = tool_parameters.get("positive_prompt")
+        negative_prompt = tool_parameters.get("negative_prompt")
+        workflow = tool_parameters.get("workflow_json")
+
+        try:
+            origin_prompt = json.loads(workflow)
+        except:
+            return self.create_text_message("the Workflow JSON is not correct")
+
+        prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt)
+        images = comfyui.generate_image_by_prompt(prompt)
+        result = []
+        for img in images:
+            result.append(
+                self.create_blob_message(
+                    blob=img, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value
+                )
+            )
+        return result

+ 35 - 0
api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml

@@ -0,0 +1,35 @@
+identity:
+  name: workflow
+  author: hjlarry
+  label:
+    en_US: workflow
+    zh_Hans: 工作流
+description:
+  human:
+    en_US: Run ComfyUI workflow.
+    zh_Hans: 运行ComfyUI工作流。
+  llm: Run ComfyUI workflow.
+parameters:
+  - name: positive_prompt
+    type: string
+    label:
+      en_US: Prompt
+      zh_Hans: 提示词
+    llm_description: Image prompt, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.
+    form: llm
+  - name: negative_prompt
+    type: string
+    label:
+      en_US: Negative Prompt
+      zh_Hans: 负面提示词
+    llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English.
+    form: llm
+  - name: workflow_json
+    type: string
+    required: true
+    label:
+      en_US: Workflow JSON
+    human_description:
+      en_US: exported from ComfyUI workflow
+      zh_Hans: 从ComfyUI的工作流中导出
+    form: form