Explorar o código

feat: update the xinf tool's API key to optional (#9073)

zhuhao hai 6 meses
pai
achega
5213650fed

+ 9 - 6
api/core/tools/provider/builtin/xinference/tools/stable_diffusion.py

@@ -104,14 +104,15 @@ class StableDiffusionTool(BuiltinTool):
         model = self.runtime.credentials.get("model", None)
         if not model:
             return self.create_text_message("Please input model")
-
+        api_key = self.runtime.credentials.get("api_key") or "abc"
+        headers = {"Authorization": f"Bearer {api_key}"}
         # set model
         try:
             url = str(URL(base_url) / "sdapi" / "v1" / "options")
             response = post(
                 url,
                 json={"sd_model_checkpoint": model},
-                headers={"Authorization": f"Bearer {self.runtime.credentials['api_key']}"},
+                headers=headers,
             )
             if response.status_code != 200:
                 raise ToolProviderCredentialValidationError("Failed to set model, please tell user to set model")
@@ -257,14 +258,15 @@ class StableDiffusionTool(BuiltinTool):
             draw_options["prompt"] = f"{lora},{prompt}"
         else:
             draw_options["prompt"] = prompt
-
+        api_key = self.runtime.credentials.get("api_key") or "abc"
+        headers = {"Authorization": f"Bearer {api_key}"}
         try:
             url = str(URL(base_url) / "sdapi" / "v1" / "img2img")
             response = post(
                 url,
                 json=draw_options,
                 timeout=120,
-                headers={"Authorization": f"Bearer {self.runtime.credentials['api_key']}"},
+                headers=headers,
             )
             if response.status_code != 200:
                 return self.create_text_message("Failed to generate image")
@@ -298,14 +300,15 @@ class StableDiffusionTool(BuiltinTool):
         else:
             draw_options["prompt"] = prompt
         draw_options["override_settings"]["sd_model_checkpoint"] = model
-
+        api_key = self.runtime.credentials.get("api_key") or "abc"
+        headers = {"Authorization": f"Bearer {api_key}"}
         try:
             url = str(URL(base_url) / "sdapi" / "v1" / "txt2img")
             response = post(
                 url,
                 json=draw_options,
                 timeout=120,
-                headers={"Authorization": f"Bearer {self.runtime.credentials['api_key']}"},
+                headers=headers,
             )
             if response.status_code != 200:
                 return self.create_text_message("Failed to generate image")

+ 10 - 4
api/core/tools/provider/builtin/xinference/xinference.py

@@ -6,12 +6,18 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
 
 class XinferenceProvider(BuiltinToolProviderController):
     def _validate_credentials(self, credentials: dict) -> None:
-        base_url = credentials.get("base_url")
-        api_key = credentials.get("api_key")
-        model = credentials.get("model")
+        base_url = credentials.get("base_url", "").removesuffix("/")
+        api_key = credentials.get("api_key", "")
+        if not api_key:
+            api_key = "abc"
+            credentials["api_key"] = api_key
+        model = credentials.get("model", "")
+        if not base_url or not model:
+            raise ToolProviderCredentialValidationError("Xinference base_url and model is required")
+        headers = {"Authorization": f"Bearer {api_key}"}
         res = requests.post(
             f"{base_url}/sdapi/v1/options",
-            headers={"Authorization": f"Bearer {api_key}"},
+            headers=headers,
             json={"sd_model_checkpoint": model},
         )
         if res.status_code != 200:

+ 1 - 1
api/core/tools/provider/builtin/xinference/xinference.yaml

@@ -31,7 +31,7 @@ credentials_for_provider:
       zh_Hans: 请输入你的模型名称
   api_key:
     type: secret-input
-    required: true
+    required: false
     label:
       en_US: API Key
       zh_Hans: Xinference 服务器的 API Key