Browse Source

Optimization stable diffusion verify (#2322)

Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Charlie.Wei 1 year ago
parent
commit
5929e84036

+ 2 - 10
api/core/tools/provider/builtin/stablediffusion/stablediffusion.py

@@ -5,6 +5,7 @@ from core.tools.provider.builtin.stablediffusion.tools.stable_diffusion import S
 
 
 from typing import Any, Dict
 from typing import Any, Dict
 
 
+
 class StableDiffusionProvider(BuiltinToolProviderController):
 class StableDiffusionProvider(BuiltinToolProviderController):
     def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
     def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
         try:
         try:
@@ -12,15 +13,6 @@ class StableDiffusionProvider(BuiltinToolProviderController):
                 meta={
                 meta={
                     "credentials": credentials,
                     "credentials": credentials,
                 }
                 }
-            ).invoke(
-                user_id='',
-                tool_parameters={
-                    "prompt": "cat",
-                    "lora": "",
-                    "steps": 1,
-                    "width": 512,
-                    "height": 512,
-                },
-            )
+            ).validate_models()
         except Exception as e:
         except Exception as e:
             raise ToolProviderCredentialValidationError(str(e))
             raise ToolProviderCredentialValidationError(str(e))

+ 29 - 5
api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py

@@ -4,7 +4,7 @@ from core.tools.entities.common_entities import I18nObject
 from core.tools.errors import ToolProviderCredentialValidationError
 from core.tools.errors import ToolProviderCredentialValidationError
 
 
 from typing import Any, Dict, List, Union
 from typing import Any, Dict, List, Union
-from httpx import post
+from httpx import post, get
 from os.path import join
 from os.path import join
 from base64 import b64decode, b64encode
 from base64 import b64decode, b64encode
 from PIL import Image
 from PIL import Image
@@ -59,6 +59,7 @@ DRAW_TEXT_OPTIONS = {
     "alwayson_scripts": {}
     "alwayson_scripts": {}
 }
 }
 
 
+
 class StableDiffusionTool(BuiltinTool):
 class StableDiffusionTool(BuiltinTool):
     def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
     def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
         -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
         -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
@@ -136,7 +137,31 @@ class StableDiffusionTool(BuiltinTool):
                              width=width,
                              width=width,
                              height=height,
                              height=height,
                              steps=steps)
                              steps=steps)
-        
+
+    def validate_models(self) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
+        """
+            validate models
+        """
+        try:
+            base_url = self.runtime.credentials.get('base_url', None)
+            if not base_url:
+                raise ToolProviderCredentialValidationError('Please input base_url')
+            model = self.runtime.credentials.get('model', None)
+            if not model:
+                raise ToolProviderCredentialValidationError('Please input model')
+
+            response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=120)
+            if response.status_code != 200:
+                raise ToolProviderCredentialValidationError('Failed to get models')
+            else:
+                models = [d['model_name'] for d in response.json()]
+                if len([d for d in models if d == model]) > 0:
+                    return self.create_text_message(json.dumps(models))
+                else:
+                    raise ToolProviderCredentialValidationError(f'model {model} does not exist')
+        except Exception as e:
+            raise ToolProviderCredentialValidationError(f'Failed to get models, {e}')
+
     def img2img(self, base_url: str, lora: str, image_binary: bytes, 
     def img2img(self, base_url: str, lora: str, image_binary: bytes, 
                 prompt: str, negative_prompt: str,
                 prompt: str, negative_prompt: str,
                 width: int, height: int, steps: int) \
                 width: int, height: int, steps: int) \
@@ -211,10 +236,9 @@ class StableDiffusionTool(BuiltinTool):
         except Exception as e:
         except Exception as e:
             return self.create_text_message('Failed to generate image')
             return self.create_text_message('Failed to generate image')
 
 
-
     def get_runtime_parameters(self) -> List[ToolParameter]:
     def get_runtime_parameters(self) -> List[ToolParameter]:
         parameters = [
         parameters = [
-            ToolParameter(name='prompt', 
+            ToolParameter(name='prompt',
                          label=I18nObject(en_US='Prompt', zh_Hans='Prompt'),
                          label=I18nObject(en_US='Prompt', zh_Hans='Prompt'),
                          human_description=I18nObject(
                          human_description=I18nObject(
                              en_US='Image prompt, you can check the official documentation of Stable Diffusion',
                              en_US='Image prompt, you can check the official documentation of Stable Diffusion',
@@ -227,7 +251,7 @@ class StableDiffusionTool(BuiltinTool):
         ]
         ]
         if len(self.list_default_image_variables()) != 0:
         if len(self.list_default_image_variables()) != 0:
             parameters.append(
             parameters.append(
-                ToolParameter(name='image_id', 
+                ToolParameter(name='image_id',
                              label=I18nObject(en_US='image_id', zh_Hans='image_id'),
                              label=I18nObject(en_US='image_id', zh_Hans='image_id'),
                              human_description=I18nObject(
                              human_description=I18nObject(
                                 en_US='Image id of the image you want to generate based on, if you want to generate image based on the default image, you can leave this field empty.',
                                 en_US='Image id of the image you want to generate based on, if you want to generate image based on the default image, you can leave this field empty.',