workflow_tool.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import json
  2. import logging
  3. from copy import deepcopy
  4. from typing import Any, Optional, Union
  5. from core.file.file_obj import FileTransferMethod, FileVar
  6. from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
  7. from core.tools.tool.tool import Tool
  8. from extensions.ext_database import db
  9. from models.account import Account
  10. from models.model import App, EndUser
  11. from models.workflow import Workflow
  12. logger = logging.getLogger(__name__)
  13. class WorkflowTool(Tool):
  14. workflow_app_id: str
  15. version: str
  16. workflow_entities: dict[str, Any]
  17. workflow_call_depth: int
  18. thread_pool_id: Optional[str] = None
  19. label: str
  20. """
  21. Workflow tool.
  22. """
  23. def tool_provider_type(self) -> ToolProviderType:
  24. """
  25. get the tool provider type
  26. :return: the tool provider type
  27. """
  28. return ToolProviderType.WORKFLOW
  29. def _invoke(
  30. self, user_id: str, tool_parameters: dict[str, Any]
  31. ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
  32. """
  33. invoke the tool
  34. """
  35. app = self._get_app(app_id=self.workflow_app_id)
  36. workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version)
  37. # transform the tool parameters
  38. tool_parameters, files = self._transform_args(tool_parameters)
  39. from core.app.apps.workflow.app_generator import WorkflowAppGenerator
  40. generator = WorkflowAppGenerator()
  41. result = generator.generate(
  42. app_model=app,
  43. workflow=workflow,
  44. user=self._get_user(user_id),
  45. args={"inputs": tool_parameters, "files": files},
  46. invoke_from=self.runtime.invoke_from,
  47. stream=False,
  48. call_depth=self.workflow_call_depth + 1,
  49. workflow_thread_pool_id=self.thread_pool_id,
  50. )
  51. data = result.get("data", {})
  52. if data.get("error"):
  53. raise Exception(data.get("error"))
  54. result = []
  55. outputs = data.get("outputs")
  56. if outputs == None:
  57. outputs = {}
  58. else:
  59. outputs, files = self._extract_files(outputs)
  60. for file in files:
  61. result.append(self.create_file_var_message(file))
  62. result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
  63. result.append(self.create_json_message(outputs))
  64. return result
  65. def _get_user(self, user_id: str) -> Union[EndUser, Account]:
  66. """
  67. get the user by user id
  68. """
  69. user = db.session.query(EndUser).filter(EndUser.id == user_id).first()
  70. if not user:
  71. user = db.session.query(Account).filter(Account.id == user_id).first()
  72. if not user:
  73. raise ValueError("user not found")
  74. return user
  75. def fork_tool_runtime(self, runtime: dict[str, Any]) -> "WorkflowTool":
  76. """
  77. fork a new tool with meta data
  78. :param meta: the meta data of a tool call processing, tenant_id is required
  79. :return: the new tool
  80. """
  81. return self.__class__(
  82. identity=deepcopy(self.identity),
  83. parameters=deepcopy(self.parameters),
  84. description=deepcopy(self.description),
  85. runtime=Tool.Runtime(**runtime),
  86. workflow_app_id=self.workflow_app_id,
  87. workflow_entities=self.workflow_entities,
  88. workflow_call_depth=self.workflow_call_depth,
  89. version=self.version,
  90. label=self.label,
  91. )
  92. def _get_workflow(self, app_id: str, version: str) -> Workflow:
  93. """
  94. get the workflow by app id and version
  95. """
  96. if not version:
  97. workflow = (
  98. db.session.query(Workflow)
  99. .filter(Workflow.app_id == app_id, Workflow.version != "draft")
  100. .order_by(Workflow.created_at.desc())
  101. .first()
  102. )
  103. else:
  104. workflow = db.session.query(Workflow).filter(Workflow.app_id == app_id, Workflow.version == version).first()
  105. if not workflow:
  106. raise ValueError("workflow not found or not published")
  107. return workflow
  108. def _get_app(self, app_id: str) -> App:
  109. """
  110. get the app by app id
  111. """
  112. app = db.session.query(App).filter(App.id == app_id).first()
  113. if not app:
  114. raise ValueError("app not found")
  115. return app
  116. def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
  117. """
  118. transform the tool parameters
  119. :param tool_parameters: the tool parameters
  120. :return: tool_parameters, files
  121. """
  122. parameter_rules = self.get_all_runtime_parameters()
  123. parameters_result = {}
  124. files = []
  125. for parameter in parameter_rules:
  126. if parameter.type == ToolParameter.ToolParameterType.FILE:
  127. file = tool_parameters.get(parameter.name)
  128. if file:
  129. try:
  130. file_var_list = [FileVar(**f) for f in file]
  131. for file_var in file_var_list:
  132. file_dict = {
  133. "transfer_method": file_var.transfer_method.value,
  134. "type": file_var.type.value,
  135. }
  136. if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
  137. file_dict["tool_file_id"] = file_var.related_id
  138. elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE:
  139. file_dict["upload_file_id"] = file_var.related_id
  140. elif file_var.transfer_method == FileTransferMethod.REMOTE_URL:
  141. file_dict["url"] = file_var.preview_url
  142. files.append(file_dict)
  143. except Exception as e:
  144. logger.exception(e)
  145. else:
  146. parameters_result[parameter.name] = tool_parameters.get(parameter.name)
  147. return parameters_result, files
  148. def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]:
  149. """
  150. extract files from the result
  151. :param result: the result
  152. :return: the result, files
  153. """
  154. files = []
  155. result = {}
  156. for key, value in outputs.items():
  157. if isinstance(value, list):
  158. has_file = False
  159. for item in value:
  160. if isinstance(item, dict) and item.get("__variant") == "FileVar":
  161. try:
  162. files.append(FileVar(**item))
  163. has_file = True
  164. except Exception as e:
  165. pass
  166. if has_file:
  167. continue
  168. result[key] = value
  169. return result, files