123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201 |
- import json
- import logging
- from copy import deepcopy
- from typing import Any, Union
- from core.file.file_obj import FileTransferMethod, FileVar
- from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
- from core.tools.tool.tool import Tool
- from extensions.ext_database import db
- from models.account import Account
- from models.model import App, EndUser
- from models.workflow import Workflow
- logger = logging.getLogger(__name__)
- class WorkflowTool(Tool):
- workflow_app_id: str
- version: str
- workflow_entities: dict[str, Any]
- workflow_call_depth: int
- label: str
- """
- Workflow tool.
- """
- def tool_provider_type(self) -> ToolProviderType:
- """
- get the tool provider type
- :return: the tool provider type
- """
- return ToolProviderType.WORKFLOW
- def _invoke(
- self, user_id: str, tool_parameters: dict[str, Any]
- ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
- """
- invoke the tool
- """
- app = self._get_app(app_id=self.workflow_app_id)
- workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version)
- # transform the tool parameters
- tool_parameters, files = self._transform_args(tool_parameters)
- from core.app.apps.workflow.app_generator import WorkflowAppGenerator
- generator = WorkflowAppGenerator()
- result = generator.generate(
- app_model=app,
- workflow=workflow,
- user=self._get_user(user_id),
- args={
- 'inputs': tool_parameters,
- 'files': files
- },
- invoke_from=self.runtime.invoke_from,
- stream=False,
- call_depth=self.workflow_call_depth + 1,
- )
- data = result.get('data', {})
- if data.get('error'):
- raise Exception(data.get('error'))
-
- result = []
- outputs = data.get('outputs', {})
- outputs, files = self._extract_files(outputs)
- for file in files:
- result.append(self.create_file_var_message(file))
-
- result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
- return result
- def _get_user(self, user_id: str) -> Union[EndUser, Account]:
- """
- get the user by user id
- """
- user = db.session.query(EndUser).filter(EndUser.id == user_id).first()
- if not user:
- user = db.session.query(Account).filter(Account.id == user_id).first()
- if not user:
- raise ValueError('user not found')
- return user
- def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'WorkflowTool':
- """
- fork a new tool with meta data
- :param meta: the meta data of a tool call processing, tenant_id is required
- :return: the new tool
- """
- return self.__class__(
- identity=deepcopy(self.identity),
- parameters=deepcopy(self.parameters),
- description=deepcopy(self.description),
- runtime=Tool.Runtime(**runtime),
- workflow_app_id=self.workflow_app_id,
- workflow_entities=self.workflow_entities,
- workflow_call_depth=self.workflow_call_depth,
- version=self.version,
- label=self.label
- )
-
- def _get_workflow(self, app_id: str, version: str) -> Workflow:
- """
- get the workflow by app id and version
- """
- if not version:
- workflow = db.session.query(Workflow).filter(
- Workflow.app_id == app_id,
- Workflow.version != 'draft'
- ).order_by(Workflow.created_at.desc()).first()
- else:
- workflow = db.session.query(Workflow).filter(
- Workflow.app_id == app_id,
- Workflow.version == version
- ).first()
- if not workflow:
- raise ValueError('workflow not found or not published')
- return workflow
-
- def _get_app(self, app_id: str) -> App:
- """
- get the app by app id
- """
- app = db.session.query(App).filter(App.id == app_id).first()
- if not app:
- raise ValueError('app not found')
- return app
-
- def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
- """
- transform the tool parameters
- :param tool_parameters: the tool parameters
- :return: tool_parameters, files
- """
- parameter_rules = self.get_all_runtime_parameters()
- parameters_result = {}
- files = []
- for parameter in parameter_rules:
- if parameter.type == ToolParameter.ToolParameterType.FILE:
- file = tool_parameters.get(parameter.name)
- if file:
- try:
- file_var_list = [FileVar(**f) for f in file]
- for file_var in file_var_list:
- file_dict = {
- 'transfer_method': file_var.transfer_method.value,
- 'type': file_var.type.value,
- }
- if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
- file_dict['tool_file_id'] = file_var.related_id
- elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE:
- file_dict['upload_file_id'] = file_var.related_id
- elif file_var.transfer_method == FileTransferMethod.REMOTE_URL:
- file_dict['url'] = file_var.preview_url
- files.append(file_dict)
- except Exception as e:
- logger.exception(e)
- else:
- parameters_result[parameter.name] = tool_parameters.get(parameter.name)
- return parameters_result, files
-
- def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]:
- """
- extract files from the result
- :param result: the result
- :return: the result, files
- """
- files = []
- result = {}
- for key, value in outputs.items():
- if isinstance(value, list):
- has_file = False
- for item in value:
- if isinstance(item, dict) and item.get('__variant') == 'FileVar':
- try:
- files.append(FileVar(**item))
- has_file = True
- except Exception as e:
- pass
- if has_file:
- continue
- result[key] = value
- return result, files
|