|
@@ -0,0 +1,72 @@
|
|
|
+import random
|
|
|
+from typing import Any, Union
|
|
|
+
|
|
|
+from openai import OpenAI
|
|
|
+from yarl import URL
|
|
|
+
|
|
|
+from core.tools.entities.tool_entities import ToolInvokeMessage
|
|
|
+from core.tools.tool.builtin_tool import BuiltinTool
|
|
|
+
|
|
|
+
|
|
|
+class StepfunTool(BuiltinTool):
|
|
|
+ """ Stepfun Image Generation Tool """
|
|
|
+ def _invoke(self,
|
|
|
+ user_id: str,
|
|
|
+ tool_parameters: dict[str, Any],
|
|
|
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
|
|
+ """
|
|
|
+ invoke tools
|
|
|
+ """
|
|
|
+ base_url = self.runtime.credentials.get('stepfun_base_url', None)
|
|
|
+ if not base_url:
|
|
|
+ base_url = None
|
|
|
+ else:
|
|
|
+ base_url = str(URL(base_url) / 'v1')
|
|
|
+
|
|
|
+ client = OpenAI(
|
|
|
+ api_key=self.runtime.credentials['stepfun_api_key'],
|
|
|
+ base_url=base_url,
|
|
|
+ )
|
|
|
+
|
|
|
+ extra_body = {}
|
|
|
+ model = tool_parameters.get('model', 'step-1x-medium')
|
|
|
+ if not model:
|
|
|
+ return self.create_text_message('Please input model name')
|
|
|
+ # prompt
|
|
|
+ prompt = tool_parameters.get('prompt', '')
|
|
|
+ if not prompt:
|
|
|
+ return self.create_text_message('Please input prompt')
|
|
|
+
|
|
|
+ seed = tool_parameters.get('seed', 0)
|
|
|
+ if seed > 0:
|
|
|
+ extra_body['seed'] = seed
|
|
|
+ steps = tool_parameters.get('steps', 0)
|
|
|
+ if steps > 0:
|
|
|
+ extra_body['steps'] = steps
|
|
|
+ negative_prompt = tool_parameters.get('negative_prompt', '')
|
|
|
+ if negative_prompt:
|
|
|
+ extra_body['negative_prompt'] = negative_prompt
|
|
|
+
|
|
|
+ # call openapi stepfun model
|
|
|
+ response = client.images.generate(
|
|
|
+ prompt=prompt,
|
|
|
+ model=model,
|
|
|
+ size=tool_parameters.get('size', '1024x1024'),
|
|
|
+ n=tool_parameters.get('n', 1),
|
|
|
+ extra_body= extra_body
|
|
|
+ )
|
|
|
+ print(response)
|
|
|
+
|
|
|
+ result = []
|
|
|
+ for image in response.data:
|
|
|
+ result.append(self.create_image_message(image=image.url))
|
|
|
+ result.append(self.create_json_message({
|
|
|
+ "url": image.url,
|
|
|
+ }))
|
|
|
+ return result
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _generate_random_id(length=8):
|
|
|
+ characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
|
|
|
+ random_id = ''.join(random.choices(characters, k=length))
|
|
|
+ return random_id
|