|
@@ -0,0 +1,60 @@
|
|
|
+from typing import Any
|
|
|
+
|
|
|
+from httpx import post
|
|
|
+
|
|
|
+from core.tools.entities.tool_entities import ToolInvokeMessage
|
|
|
+from core.tools.provider.builtin.stability.tools.base import BaseStabilityAuthorization
|
|
|
+from core.tools.tool.builtin_tool import BuiltinTool
|
|
|
+
|
|
|
+
|
|
|
+class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization):
|
|
|
+ """
|
|
|
+ This class is responsible for providing the stable diffusion tool.
|
|
|
+ """
|
|
|
+ model_endpoint_map = {
|
|
|
+ 'sd3': 'https://api.stability.ai/v2beta/stable-image/generate/sd3',
|
|
|
+ 'sd3-turbo': 'https://api.stability.ai/v2beta/stable-image/generate/sd3',
|
|
|
+ 'core': 'https://api.stability.ai/v2beta/stable-image/generate/core',
|
|
|
+ }
|
|
|
+
|
|
|
+ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
|
|
+ """
|
|
|
+ Invoke the tool.
|
|
|
+ """
|
|
|
+ payload = {
|
|
|
+ 'prompt': tool_parameters.get('prompt', ''),
|
|
|
+ 'aspect_radio': tool_parameters.get('aspect_radio', '16:9'),
|
|
|
+ 'mode': 'text-to-image',
|
|
|
+ 'seed': tool_parameters.get('seed', 0),
|
|
|
+ 'output_format': 'png',
|
|
|
+ }
|
|
|
+
|
|
|
+ model = tool_parameters.get('model', 'core')
|
|
|
+
|
|
|
+ if model in ['sd3', 'sd3-turbo']:
|
|
|
+ payload['model'] = tool_parameters.get('model')
|
|
|
+
|
|
|
+ if not model == 'sd3-turbo':
|
|
|
+ payload['negative_prompt'] = tool_parameters.get('negative_prompt', '')
|
|
|
+
|
|
|
+ response = post(
|
|
|
+ self.model_endpoint_map[tool_parameters.get('model', 'core')],
|
|
|
+ headers={
|
|
|
+ 'accept': 'image/*',
|
|
|
+ **self.generate_authorization_headers(self.runtime.credentials),
|
|
|
+ },
|
|
|
+ files={
|
|
|
+ key: (None, str(value)) for key, value in payload.items()
|
|
|
+ },
|
|
|
+ timeout=(5, 30)
|
|
|
+ )
|
|
|
+
|
|
|
+ if not response.status_code == 200:
|
|
|
+ raise Exception(response.text)
|
|
|
+
|
|
|
+ return self.create_blob_message(
|
|
|
+ blob=response.content, meta={
|
|
|
+ 'mime_type': 'image/png'
|
|
|
+ },
|
|
|
+ save_as=self.VARIABLE_KEY.IMAGE.value
|
|
|
+ )
|