base_agent_runner.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. import json
  2. import logging
  3. import uuid
  4. from collections.abc import Mapping, Sequence
  5. from datetime import datetime, timezone
  6. from typing import Optional, Union, cast
  7. from core.agent.entities import AgentEntity, AgentToolEntity
  8. from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
  9. from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
  10. from core.app.apps.base_app_queue_manager import AppQueueManager
  11. from core.app.apps.base_app_runner import AppRunner
  12. from core.app.entities.app_invoke_entities import (
  13. AgentChatAppGenerateEntity,
  14. ModelConfigWithCredentialsEntity,
  15. )
  16. from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
  17. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  18. from core.file.message_file_parser import MessageFileParser
  19. from core.memory.token_buffer_memory import TokenBufferMemory
  20. from core.model_manager import ModelInstance
  21. from core.model_runtime.entities.llm_entities import LLMUsage
  22. from core.model_runtime.entities.message_entities import (
  23. AssistantPromptMessage,
  24. PromptMessage,
  25. PromptMessageTool,
  26. SystemPromptMessage,
  27. TextPromptMessageContent,
  28. ToolPromptMessage,
  29. UserPromptMessage,
  30. )
  31. from core.model_runtime.entities.model_entities import ModelFeature
  32. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  33. from core.model_runtime.utils.encoders import jsonable_encoder
  34. from core.prompt.utils.extract_thread_messages import extract_thread_messages
  35. from core.tools.entities.tool_entities import (
  36. ToolParameter,
  37. ToolRuntimeVariablePool,
  38. )
  39. from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
  40. from core.tools.tool.tool import Tool
  41. from core.tools.tool_manager import ToolManager
  42. from core.tools.utils.tool_parameter_converter import ToolParameterConverter
  43. from extensions.ext_database import db
  44. from models.model import Conversation, Message, MessageAgentThought
  45. from models.tools import ToolConversationVariables
  46. logger = logging.getLogger(__name__)
  47. class BaseAgentRunner(AppRunner):
  48. def __init__(
  49. self,
  50. tenant_id: str,
  51. application_generate_entity: AgentChatAppGenerateEntity,
  52. conversation: Conversation,
  53. app_config: AgentChatAppConfig,
  54. model_config: ModelConfigWithCredentialsEntity,
  55. config: AgentEntity,
  56. queue_manager: AppQueueManager,
  57. message: Message,
  58. user_id: str,
  59. memory: Optional[TokenBufferMemory] = None,
  60. prompt_messages: Optional[list[PromptMessage]] = None,
  61. variables_pool: Optional[ToolRuntimeVariablePool] = None,
  62. db_variables: Optional[ToolConversationVariables] = None,
  63. model_instance: ModelInstance = None,
  64. ) -> None:
  65. """
  66. Agent runner
  67. :param tenant_id: tenant id
  68. :param application_generate_entity: application generate entity
  69. :param conversation: conversation
  70. :param app_config: app generate entity
  71. :param model_config: model config
  72. :param config: dataset config
  73. :param queue_manager: queue manager
  74. :param message: message
  75. :param user_id: user id
  76. :param memory: memory
  77. :param prompt_messages: prompt messages
  78. :param variables_pool: variables pool
  79. :param db_variables: db variables
  80. :param model_instance: model instance
  81. """
  82. self.tenant_id = tenant_id
  83. self.application_generate_entity = application_generate_entity
  84. self.conversation = conversation
  85. self.app_config = app_config
  86. self.model_config = model_config
  87. self.config = config
  88. self.queue_manager = queue_manager
  89. self.message = message
  90. self.user_id = user_id
  91. self.memory = memory
  92. self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
  93. self.variables_pool = variables_pool
  94. self.db_variables_pool = db_variables
  95. self.model_instance = model_instance
  96. # init callback
  97. self.agent_callback = DifyAgentCallbackHandler()
  98. # init dataset tools
  99. hit_callback = DatasetIndexToolCallbackHandler(
  100. queue_manager=queue_manager,
  101. app_id=self.app_config.app_id,
  102. message_id=message.id,
  103. user_id=user_id,
  104. invoke_from=self.application_generate_entity.invoke_from,
  105. )
  106. self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
  107. tenant_id=tenant_id,
  108. dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
  109. retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
  110. return_resource=app_config.additional_features.show_retrieve_source,
  111. invoke_from=application_generate_entity.invoke_from,
  112. hit_callback=hit_callback,
  113. )
  114. # get how many agent thoughts have been created
  115. self.agent_thought_count = (
  116. db.session.query(MessageAgentThought)
  117. .filter(
  118. MessageAgentThought.message_id == self.message.id,
  119. )
  120. .count()
  121. )
  122. db.session.close()
  123. # check if model supports stream tool call
  124. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  125. model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
  126. if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
  127. self.stream_tool_call = True
  128. else:
  129. self.stream_tool_call = False
  130. # check if model supports vision
  131. if model_schema and ModelFeature.VISION in (model_schema.features or []):
  132. self.files = application_generate_entity.files
  133. else:
  134. self.files = []
  135. self.query = None
  136. self._current_thoughts: list[PromptMessage] = []
  137. def _repack_app_generate_entity(
  138. self, app_generate_entity: AgentChatAppGenerateEntity
  139. ) -> AgentChatAppGenerateEntity:
  140. """
  141. Repack app generate entity
  142. """
  143. if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
  144. app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
  145. return app_generate_entity
  146. def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
  147. """
  148. convert tool to prompt message tool
  149. """
  150. tool_entity = ToolManager.get_agent_tool_runtime(
  151. tenant_id=self.tenant_id,
  152. app_id=self.app_config.app_id,
  153. agent_tool=tool,
  154. invoke_from=self.application_generate_entity.invoke_from,
  155. )
  156. tool_entity.load_variables(self.variables_pool)
  157. message_tool = PromptMessageTool(
  158. name=tool.tool_name,
  159. description=tool_entity.description.llm,
  160. parameters={
  161. "type": "object",
  162. "properties": {},
  163. "required": [],
  164. },
  165. )
  166. parameters = tool_entity.get_all_runtime_parameters()
  167. for parameter in parameters:
  168. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  169. continue
  170. parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
  171. enum = []
  172. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  173. enum = [option.value for option in parameter.options]
  174. message_tool.parameters["properties"][parameter.name] = {
  175. "type": parameter_type,
  176. "description": parameter.llm_description or "",
  177. }
  178. if len(enum) > 0:
  179. message_tool.parameters["properties"][parameter.name]["enum"] = enum
  180. if parameter.required:
  181. message_tool.parameters["required"].append(parameter.name)
  182. return message_tool, tool_entity
  183. def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
  184. """
  185. convert dataset retriever tool to prompt message tool
  186. """
  187. prompt_tool = PromptMessageTool(
  188. name=tool.identity.name,
  189. description=tool.description.llm,
  190. parameters={
  191. "type": "object",
  192. "properties": {},
  193. "required": [],
  194. },
  195. )
  196. for parameter in tool.get_runtime_parameters():
  197. parameter_type = "string"
  198. prompt_tool.parameters["properties"][parameter.name] = {
  199. "type": parameter_type,
  200. "description": parameter.llm_description or "",
  201. }
  202. if parameter.required:
  203. if parameter.name not in prompt_tool.parameters["required"]:
  204. prompt_tool.parameters["required"].append(parameter.name)
  205. return prompt_tool
  206. def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]:
  207. """
  208. Init tools
  209. """
  210. tool_instances = {}
  211. prompt_messages_tools = []
  212. for tool in self.app_config.agent.tools if self.app_config.agent else []:
  213. try:
  214. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  215. except Exception:
  216. # api tool may be deleted
  217. continue
  218. # save tool entity
  219. tool_instances[tool.tool_name] = tool_entity
  220. # save prompt tool
  221. prompt_messages_tools.append(prompt_tool)
  222. # convert dataset tools into ModelRuntime Tool format
  223. for dataset_tool in self.dataset_tools:
  224. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  225. # save prompt tool
  226. prompt_messages_tools.append(prompt_tool)
  227. # save tool entity
  228. tool_instances[dataset_tool.identity.name] = dataset_tool
  229. return tool_instances, prompt_messages_tools
  230. def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
  231. """
  232. update prompt message tool
  233. """
  234. # try to get tool runtime parameters
  235. tool_runtime_parameters = tool.get_runtime_parameters() or []
  236. for parameter in tool_runtime_parameters:
  237. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  238. continue
  239. parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
  240. enum = []
  241. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  242. enum = [option.value for option in parameter.options]
  243. prompt_tool.parameters["properties"][parameter.name] = {
  244. "type": parameter_type,
  245. "description": parameter.llm_description or "",
  246. }
  247. if len(enum) > 0:
  248. prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
  249. if parameter.required:
  250. if parameter.name not in prompt_tool.parameters["required"]:
  251. prompt_tool.parameters["required"].append(parameter.name)
  252. return prompt_tool
  253. def create_agent_thought(
  254. self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
  255. ) -> MessageAgentThought:
  256. """
  257. Create agent thought
  258. """
  259. thought = MessageAgentThought(
  260. message_id=message_id,
  261. message_chain_id=None,
  262. thought="",
  263. tool=tool_name,
  264. tool_labels_str="{}",
  265. tool_meta_str="{}",
  266. tool_input=tool_input,
  267. message=message,
  268. message_token=0,
  269. message_unit_price=0,
  270. message_price_unit=0,
  271. message_files=json.dumps(messages_ids) if messages_ids else "",
  272. answer="",
  273. observation="",
  274. answer_token=0,
  275. answer_unit_price=0,
  276. answer_price_unit=0,
  277. tokens=0,
  278. total_price=0,
  279. position=self.agent_thought_count + 1,
  280. currency="USD",
  281. latency=0,
  282. created_by_role="account",
  283. created_by=self.user_id,
  284. )
  285. db.session.add(thought)
  286. db.session.commit()
  287. db.session.refresh(thought)
  288. db.session.close()
  289. self.agent_thought_count += 1
  290. return thought
  291. def save_agent_thought(
  292. self,
  293. agent_thought: MessageAgentThought,
  294. tool_name: str,
  295. tool_input: Union[str, dict],
  296. thought: str,
  297. observation: Union[str, dict],
  298. tool_invoke_meta: Union[str, dict],
  299. answer: str,
  300. messages_ids: list[str],
  301. llm_usage: LLMUsage = None,
  302. ) -> MessageAgentThought:
  303. """
  304. Save agent thought
  305. """
  306. agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
  307. if thought is not None:
  308. agent_thought.thought = thought
  309. if tool_name is not None:
  310. agent_thought.tool = tool_name
  311. if tool_input is not None:
  312. if isinstance(tool_input, dict):
  313. try:
  314. tool_input = json.dumps(tool_input, ensure_ascii=False)
  315. except Exception as e:
  316. tool_input = json.dumps(tool_input)
  317. agent_thought.tool_input = tool_input
  318. if observation is not None:
  319. if isinstance(observation, dict):
  320. try:
  321. observation = json.dumps(observation, ensure_ascii=False)
  322. except Exception as e:
  323. observation = json.dumps(observation)
  324. agent_thought.observation = observation
  325. if answer is not None:
  326. agent_thought.answer = answer
  327. if messages_ids is not None and len(messages_ids) > 0:
  328. agent_thought.message_files = json.dumps(messages_ids)
  329. if llm_usage:
  330. agent_thought.message_token = llm_usage.prompt_tokens
  331. agent_thought.message_price_unit = llm_usage.prompt_price_unit
  332. agent_thought.message_unit_price = llm_usage.prompt_unit_price
  333. agent_thought.answer_token = llm_usage.completion_tokens
  334. agent_thought.answer_price_unit = llm_usage.completion_price_unit
  335. agent_thought.answer_unit_price = llm_usage.completion_unit_price
  336. agent_thought.tokens = llm_usage.total_tokens
  337. agent_thought.total_price = llm_usage.total_price
  338. # check if tool labels is not empty
  339. labels = agent_thought.tool_labels or {}
  340. tools = agent_thought.tool.split(";") if agent_thought.tool else []
  341. for tool in tools:
  342. if not tool:
  343. continue
  344. if tool not in labels:
  345. tool_label = ToolManager.get_tool_label(tool)
  346. if tool_label:
  347. labels[tool] = tool_label.to_dict()
  348. else:
  349. labels[tool] = {"en_US": tool, "zh_Hans": tool}
  350. agent_thought.tool_labels_str = json.dumps(labels)
  351. if tool_invoke_meta is not None:
  352. if isinstance(tool_invoke_meta, dict):
  353. try:
  354. tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
  355. except Exception as e:
  356. tool_invoke_meta = json.dumps(tool_invoke_meta)
  357. agent_thought.tool_meta_str = tool_invoke_meta
  358. db.session.commit()
  359. db.session.close()
  360. def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
  361. """
  362. convert tool variables to db variables
  363. """
  364. db_variables = (
  365. db.session.query(ToolConversationVariables)
  366. .filter(
  367. ToolConversationVariables.conversation_id == self.message.conversation_id,
  368. )
  369. .first()
  370. )
  371. db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
  372. db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
  373. db.session.commit()
  374. db.session.close()
  375. def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  376. """
  377. Organize agent history
  378. """
  379. result = []
  380. # check if there is a system message in the beginning of the conversation
  381. for prompt_message in prompt_messages:
  382. if isinstance(prompt_message, SystemPromptMessage):
  383. result.append(prompt_message)
  384. messages: list[Message] = (
  385. db.session.query(Message)
  386. .filter(
  387. Message.conversation_id == self.message.conversation_id,
  388. )
  389. .order_by(Message.created_at.desc())
  390. .all()
  391. )
  392. messages = list(reversed(extract_thread_messages(messages)))
  393. for message in messages:
  394. if message.id == self.message.id:
  395. continue
  396. result.append(self.organize_agent_user_prompt(message))
  397. agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
  398. if agent_thoughts:
  399. for agent_thought in agent_thoughts:
  400. tools = agent_thought.tool
  401. if tools:
  402. tools = tools.split(";")
  403. tool_calls: list[AssistantPromptMessage.ToolCall] = []
  404. tool_call_response: list[ToolPromptMessage] = []
  405. try:
  406. tool_inputs = json.loads(agent_thought.tool_input)
  407. except Exception as e:
  408. tool_inputs = {tool: {} for tool in tools}
  409. try:
  410. tool_responses = json.loads(agent_thought.observation)
  411. except Exception as e:
  412. tool_responses = dict.fromkeys(tools, agent_thought.observation)
  413. for tool in tools:
  414. # generate a uuid for tool call
  415. tool_call_id = str(uuid.uuid4())
  416. tool_calls.append(
  417. AssistantPromptMessage.ToolCall(
  418. id=tool_call_id,
  419. type="function",
  420. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  421. name=tool,
  422. arguments=json.dumps(tool_inputs.get(tool, {})),
  423. ),
  424. )
  425. )
  426. tool_call_response.append(
  427. ToolPromptMessage(
  428. content=tool_responses.get(tool, agent_thought.observation),
  429. name=tool,
  430. tool_call_id=tool_call_id,
  431. )
  432. )
  433. result.extend(
  434. [
  435. AssistantPromptMessage(
  436. content=agent_thought.thought,
  437. tool_calls=tool_calls,
  438. ),
  439. *tool_call_response,
  440. ]
  441. )
  442. if not tools:
  443. result.append(AssistantPromptMessage(content=agent_thought.thought))
  444. else:
  445. if message.answer:
  446. result.append(AssistantPromptMessage(content=message.answer))
  447. db.session.close()
  448. return result
  449. def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
  450. message_file_parser = MessageFileParser(
  451. tenant_id=self.tenant_id,
  452. app_id=self.app_config.app_id,
  453. )
  454. files = message.message_files
  455. if files:
  456. file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
  457. if file_extra_config:
  458. file_objs = message_file_parser.transform_message_files(files, file_extra_config)
  459. else:
  460. file_objs = []
  461. if not file_objs:
  462. return UserPromptMessage(content=message.query)
  463. else:
  464. prompt_message_contents = [TextPromptMessageContent(data=message.query)]
  465. for file_obj in file_objs:
  466. prompt_message_contents.append(file_obj.prompt_message_content)
  467. return UserPromptMessage(content=prompt_message_contents)
  468. else:
  469. return UserPromptMessage(content=message.query)