workflow_tool.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. import json
  2. import logging
  3. from copy import deepcopy
  4. from typing import Any, Union
  5. from core.file.file_obj import FileTransferMethod, FileVar
  6. from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
  7. from core.tools.tool.tool import Tool
  8. from extensions.ext_database import db
  9. from models.account import Account
  10. from models.model import App, EndUser
  11. from models.workflow import Workflow
  12. logger = logging.getLogger(__name__)
  13. class WorkflowTool(Tool):
  14. workflow_app_id: str
  15. version: str
  16. workflow_entities: dict[str, Any]
  17. workflow_call_depth: int
  18. label: str
  19. """
  20. Workflow tool.
  21. """
  22. def tool_provider_type(self) -> ToolProviderType:
  23. """
  24. get the tool provider type
  25. :return: the tool provider type
  26. """
  27. return ToolProviderType.WORKFLOW
  28. def _invoke(
  29. self, user_id: str, tool_parameters: dict[str, Any]
  30. ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
  31. """
  32. invoke the tool
  33. """
  34. app = self._get_app(app_id=self.workflow_app_id)
  35. workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version)
  36. # transform the tool parameters
  37. tool_parameters, files = self._transform_args(tool_parameters)
  38. from core.app.apps.workflow.app_generator import WorkflowAppGenerator
  39. generator = WorkflowAppGenerator()
  40. result = generator.generate(
  41. app_model=app,
  42. workflow=workflow,
  43. user=self._get_user(user_id),
  44. args={
  45. 'inputs': tool_parameters,
  46. 'files': files
  47. },
  48. invoke_from=self.runtime.invoke_from,
  49. stream=False,
  50. call_depth=self.workflow_call_depth + 1,
  51. )
  52. data = result.get('data', {})
  53. if data.get('error'):
  54. raise Exception(data.get('error'))
  55. result = []
  56. outputs = data.get('outputs', {})
  57. outputs, files = self._extract_files(outputs)
  58. for file in files:
  59. result.append(self.create_file_var_message(file))
  60. result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
  61. return result
  62. def _get_user(self, user_id: str) -> Union[EndUser, Account]:
  63. """
  64. get the user by user id
  65. """
  66. user = db.session.query(EndUser).filter(EndUser.id == user_id).first()
  67. if not user:
  68. user = db.session.query(Account).filter(Account.id == user_id).first()
  69. if not user:
  70. raise ValueError('user not found')
  71. return user
  72. def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'WorkflowTool':
  73. """
  74. fork a new tool with meta data
  75. :param meta: the meta data of a tool call processing, tenant_id is required
  76. :return: the new tool
  77. """
  78. return self.__class__(
  79. identity=deepcopy(self.identity),
  80. parameters=deepcopy(self.parameters),
  81. description=deepcopy(self.description),
  82. runtime=Tool.Runtime(**runtime),
  83. workflow_app_id=self.workflow_app_id,
  84. workflow_entities=self.workflow_entities,
  85. workflow_call_depth=self.workflow_call_depth,
  86. version=self.version,
  87. label=self.label
  88. )
  89. def _get_workflow(self, app_id: str, version: str) -> Workflow:
  90. """
  91. get the workflow by app id and version
  92. """
  93. if not version:
  94. workflow = db.session.query(Workflow).filter(
  95. Workflow.app_id == app_id,
  96. Workflow.version != 'draft'
  97. ).order_by(Workflow.created_at.desc()).first()
  98. else:
  99. workflow = db.session.query(Workflow).filter(
  100. Workflow.app_id == app_id,
  101. Workflow.version == version
  102. ).first()
  103. if not workflow:
  104. raise ValueError('workflow not found or not published')
  105. return workflow
  106. def _get_app(self, app_id: str) -> App:
  107. """
  108. get the app by app id
  109. """
  110. app = db.session.query(App).filter(App.id == app_id).first()
  111. if not app:
  112. raise ValueError('app not found')
  113. return app
  114. def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
  115. """
  116. transform the tool parameters
  117. :param tool_parameters: the tool parameters
  118. :return: tool_parameters, files
  119. """
  120. parameter_rules = self.get_all_runtime_parameters()
  121. parameters_result = {}
  122. files = []
  123. for parameter in parameter_rules:
  124. if parameter.type == ToolParameter.ToolParameterType.FILE:
  125. file = tool_parameters.get(parameter.name)
  126. if file:
  127. try:
  128. file_var_list = [FileVar(**f) for f in file]
  129. for file_var in file_var_list:
  130. file_dict = {
  131. 'transfer_method': file_var.transfer_method.value,
  132. 'type': file_var.type.value,
  133. }
  134. if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
  135. file_dict['tool_file_id'] = file_var.related_id
  136. elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE:
  137. file_dict['upload_file_id'] = file_var.related_id
  138. elif file_var.transfer_method == FileTransferMethod.REMOTE_URL:
  139. file_dict['url'] = file_var.preview_url
  140. files.append(file_dict)
  141. except Exception as e:
  142. logger.exception(e)
  143. else:
  144. parameters_result[parameter.name] = tool_parameters.get(parameter.name)
  145. return parameters_result, files
  146. def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]:
  147. """
  148. extract files from the result
  149. :param result: the result
  150. :return: the result, files
  151. """
  152. files = []
  153. result = {}
  154. for key, value in outputs.items():
  155. if isinstance(value, list):
  156. has_file = False
  157. for item in value:
  158. if isinstance(item, dict) and item.get('__variant') == 'FileVar':
  159. try:
  160. files.append(FileVar(**item))
  161. has_file = True
  162. except Exception as e:
  163. pass
  164. if has_file:
  165. continue
  166. result[key] = value
  167. return result, files