base_agent_runner.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546
  1. import json
  2. import logging
  3. import uuid
  4. from datetime import datetime, timezone
  5. from typing import Optional, Union, cast
  6. from core.agent.entities import AgentEntity, AgentToolEntity
  7. from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
  8. from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
  9. from core.app.apps.base_app_queue_manager import AppQueueManager
  10. from core.app.apps.base_app_runner import AppRunner
  11. from core.app.entities.app_invoke_entities import (
  12. AgentChatAppGenerateEntity,
  13. ModelConfigWithCredentialsEntity,
  14. )
  15. from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
  16. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  17. from core.file.message_file_parser import MessageFileParser
  18. from core.memory.token_buffer_memory import TokenBufferMemory
  19. from core.model_manager import ModelInstance
  20. from core.model_runtime.entities.llm_entities import LLMUsage
  21. from core.model_runtime.entities.message_entities import (
  22. AssistantPromptMessage,
  23. PromptMessage,
  24. PromptMessageTool,
  25. SystemPromptMessage,
  26. TextPromptMessageContent,
  27. ToolPromptMessage,
  28. UserPromptMessage,
  29. )
  30. from core.model_runtime.entities.model_entities import ModelFeature
  31. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  32. from core.model_runtime.utils.encoders import jsonable_encoder
  33. from core.tools.entities.tool_entities import (
  34. ToolInvokeMessage,
  35. ToolParameter,
  36. ToolRuntimeVariablePool,
  37. )
  38. from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
  39. from core.tools.tool.tool import Tool
  40. from core.tools.tool_manager import ToolManager
  41. from extensions.ext_database import db
  42. from models.model import Conversation, Message, MessageAgentThought
  43. from models.tools import ToolConversationVariables
  44. logger = logging.getLogger(__name__)
  45. class BaseAgentRunner(AppRunner):
  46. def __init__(self, tenant_id: str,
  47. application_generate_entity: AgentChatAppGenerateEntity,
  48. conversation: Conversation,
  49. app_config: AgentChatAppConfig,
  50. model_config: ModelConfigWithCredentialsEntity,
  51. config: AgentEntity,
  52. queue_manager: AppQueueManager,
  53. message: Message,
  54. user_id: str,
  55. memory: Optional[TokenBufferMemory] = None,
  56. prompt_messages: Optional[list[PromptMessage]] = None,
  57. variables_pool: Optional[ToolRuntimeVariablePool] = None,
  58. db_variables: Optional[ToolConversationVariables] = None,
  59. model_instance: ModelInstance = None
  60. ) -> None:
  61. """
  62. Agent runner
  63. :param tenant_id: tenant id
  64. :param app_config: app generate entity
  65. :param model_config: model config
  66. :param config: dataset config
  67. :param queue_manager: queue manager
  68. :param message: message
  69. :param user_id: user id
  70. :param agent_llm_callback: agent llm callback
  71. :param callback: callback
  72. :param memory: memory
  73. """
  74. self.tenant_id = tenant_id
  75. self.application_generate_entity = application_generate_entity
  76. self.conversation = conversation
  77. self.app_config = app_config
  78. self.model_config = model_config
  79. self.config = config
  80. self.queue_manager = queue_manager
  81. self.message = message
  82. self.user_id = user_id
  83. self.memory = memory
  84. self.history_prompt_messages = self.organize_agent_history(
  85. prompt_messages=prompt_messages or []
  86. )
  87. self.variables_pool = variables_pool
  88. self.db_variables_pool = db_variables
  89. self.model_instance = model_instance
  90. # init callback
  91. self.agent_callback = DifyAgentCallbackHandler()
  92. # init dataset tools
  93. hit_callback = DatasetIndexToolCallbackHandler(
  94. queue_manager=queue_manager,
  95. app_id=self.app_config.app_id,
  96. message_id=message.id,
  97. user_id=user_id,
  98. invoke_from=self.application_generate_entity.invoke_from,
  99. )
  100. self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
  101. tenant_id=tenant_id,
  102. dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
  103. retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
  104. return_resource=app_config.additional_features.show_retrieve_source,
  105. invoke_from=application_generate_entity.invoke_from,
  106. hit_callback=hit_callback
  107. )
  108. # get how many agent thoughts have been created
  109. self.agent_thought_count = db.session.query(MessageAgentThought).filter(
  110. MessageAgentThought.message_id == self.message.id,
  111. ).count()
  112. db.session.close()
  113. # check if model supports stream tool call
  114. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  115. model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
  116. if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
  117. self.stream_tool_call = True
  118. else:
  119. self.stream_tool_call = False
  120. # check if model supports vision
  121. if model_schema and ModelFeature.VISION in (model_schema.features or []):
  122. self.files = application_generate_entity.files
  123. else:
  124. self.files = []
  125. def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
  126. -> AgentChatAppGenerateEntity:
  127. """
  128. Repack app generate entity
  129. """
  130. if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
  131. app_generate_entity.app_config.prompt_template.simple_prompt_template = ''
  132. return app_generate_entity
  133. def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
  134. """
  135. Handle tool response
  136. """
  137. result = ''
  138. for response in tool_response:
  139. if response.type == ToolInvokeMessage.MessageType.TEXT:
  140. result += response.message
  141. elif response.type == ToolInvokeMessage.MessageType.LINK:
  142. result += f"result link: {response.message}. please tell user to check it."
  143. elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
  144. response.type == ToolInvokeMessage.MessageType.IMAGE:
  145. result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
  146. else:
  147. result += f"tool response: {response.message}."
  148. return result
  149. def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
  150. """
  151. convert tool to prompt message tool
  152. """
  153. tool_entity = ToolManager.get_agent_tool_runtime(
  154. tenant_id=self.tenant_id,
  155. app_id=self.app_config.app_id,
  156. agent_tool=tool,
  157. )
  158. tool_entity.load_variables(self.variables_pool)
  159. message_tool = PromptMessageTool(
  160. name=tool.tool_name,
  161. description=tool_entity.description.llm,
  162. parameters={
  163. "type": "object",
  164. "properties": {},
  165. "required": [],
  166. }
  167. )
  168. parameters = tool_entity.get_all_runtime_parameters()
  169. for parameter in parameters:
  170. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  171. continue
  172. parameter_type = 'string'
  173. enum = []
  174. if parameter.type == ToolParameter.ToolParameterType.STRING:
  175. parameter_type = 'string'
  176. elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
  177. parameter_type = 'boolean'
  178. elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
  179. parameter_type = 'number'
  180. elif parameter.type == ToolParameter.ToolParameterType.SELECT:
  181. for option in parameter.options:
  182. enum.append(option.value)
  183. parameter_type = 'string'
  184. else:
  185. raise ValueError(f"parameter type {parameter.type} is not supported")
  186. message_tool.parameters['properties'][parameter.name] = {
  187. "type": parameter_type,
  188. "description": parameter.llm_description or '',
  189. }
  190. if len(enum) > 0:
  191. message_tool.parameters['properties'][parameter.name]['enum'] = enum
  192. if parameter.required:
  193. message_tool.parameters['required'].append(parameter.name)
  194. return message_tool, tool_entity
  195. def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
  196. """
  197. convert dataset retriever tool to prompt message tool
  198. """
  199. prompt_tool = PromptMessageTool(
  200. name=tool.identity.name,
  201. description=tool.description.llm,
  202. parameters={
  203. "type": "object",
  204. "properties": {},
  205. "required": [],
  206. }
  207. )
  208. for parameter in tool.get_runtime_parameters():
  209. parameter_type = 'string'
  210. prompt_tool.parameters['properties'][parameter.name] = {
  211. "type": parameter_type,
  212. "description": parameter.llm_description or '',
  213. }
  214. if parameter.required:
  215. if parameter.name not in prompt_tool.parameters['required']:
  216. prompt_tool.parameters['required'].append(parameter.name)
  217. return prompt_tool
  218. def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
  219. """
  220. Init tools
  221. """
  222. tool_instances = {}
  223. prompt_messages_tools = []
  224. for tool in self.app_config.agent.tools if self.app_config.agent else []:
  225. try:
  226. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  227. except Exception:
  228. # api tool may be deleted
  229. continue
  230. # save tool entity
  231. tool_instances[tool.tool_name] = tool_entity
  232. # save prompt tool
  233. prompt_messages_tools.append(prompt_tool)
  234. # convert dataset tools into ModelRuntime Tool format
  235. for dataset_tool in self.dataset_tools:
  236. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  237. # save prompt tool
  238. prompt_messages_tools.append(prompt_tool)
  239. # save tool entity
  240. tool_instances[dataset_tool.identity.name] = dataset_tool
  241. return tool_instances, prompt_messages_tools
  242. def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
  243. """
  244. update prompt message tool
  245. """
  246. # try to get tool runtime parameters
  247. tool_runtime_parameters = tool.get_runtime_parameters() or []
  248. for parameter in tool_runtime_parameters:
  249. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  250. continue
  251. parameter_type = 'string'
  252. enum = []
  253. if parameter.type == ToolParameter.ToolParameterType.STRING:
  254. parameter_type = 'string'
  255. elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
  256. parameter_type = 'boolean'
  257. elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
  258. parameter_type = 'number'
  259. elif parameter.type == ToolParameter.ToolParameterType.SELECT:
  260. for option in parameter.options:
  261. enum.append(option.value)
  262. parameter_type = 'string'
  263. else:
  264. raise ValueError(f"parameter type {parameter.type} is not supported")
  265. prompt_tool.parameters['properties'][parameter.name] = {
  266. "type": parameter_type,
  267. "description": parameter.llm_description or '',
  268. }
  269. if len(enum) > 0:
  270. prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
  271. if parameter.required:
  272. if parameter.name not in prompt_tool.parameters['required']:
  273. prompt_tool.parameters['required'].append(parameter.name)
  274. return prompt_tool
  275. def create_agent_thought(self, message_id: str, message: str,
  276. tool_name: str, tool_input: str, messages_ids: list[str]
  277. ) -> MessageAgentThought:
  278. """
  279. Create agent thought
  280. """
  281. thought = MessageAgentThought(
  282. message_id=message_id,
  283. message_chain_id=None,
  284. thought='',
  285. tool=tool_name,
  286. tool_labels_str='{}',
  287. tool_meta_str='{}',
  288. tool_input=tool_input,
  289. message=message,
  290. message_token=0,
  291. message_unit_price=0,
  292. message_price_unit=0,
  293. message_files=json.dumps(messages_ids) if messages_ids else '',
  294. answer='',
  295. observation='',
  296. answer_token=0,
  297. answer_unit_price=0,
  298. answer_price_unit=0,
  299. tokens=0,
  300. total_price=0,
  301. position=self.agent_thought_count + 1,
  302. currency='USD',
  303. latency=0,
  304. created_by_role='account',
  305. created_by=self.user_id,
  306. )
  307. db.session.add(thought)
  308. db.session.commit()
  309. db.session.refresh(thought)
  310. db.session.close()
  311. self.agent_thought_count += 1
  312. return thought
  313. def save_agent_thought(self,
  314. agent_thought: MessageAgentThought,
  315. tool_name: str,
  316. tool_input: Union[str, dict],
  317. thought: str,
  318. observation: Union[str, dict],
  319. tool_invoke_meta: Union[str, dict],
  320. answer: str,
  321. messages_ids: list[str],
  322. llm_usage: LLMUsage = None) -> MessageAgentThought:
  323. """
  324. Save agent thought
  325. """
  326. agent_thought = db.session.query(MessageAgentThought).filter(
  327. MessageAgentThought.id == agent_thought.id
  328. ).first()
  329. if thought is not None:
  330. agent_thought.thought = thought
  331. if tool_name is not None:
  332. agent_thought.tool = tool_name
  333. if tool_input is not None:
  334. if isinstance(tool_input, dict):
  335. try:
  336. tool_input = json.dumps(tool_input, ensure_ascii=False)
  337. except Exception as e:
  338. tool_input = json.dumps(tool_input)
  339. agent_thought.tool_input = tool_input
  340. if observation is not None:
  341. if isinstance(observation, dict):
  342. try:
  343. observation = json.dumps(observation, ensure_ascii=False)
  344. except Exception as e:
  345. observation = json.dumps(observation)
  346. agent_thought.observation = observation
  347. if answer is not None:
  348. agent_thought.answer = answer
  349. if messages_ids is not None and len(messages_ids) > 0:
  350. agent_thought.message_files = json.dumps(messages_ids)
  351. if llm_usage:
  352. agent_thought.message_token = llm_usage.prompt_tokens
  353. agent_thought.message_price_unit = llm_usage.prompt_price_unit
  354. agent_thought.message_unit_price = llm_usage.prompt_unit_price
  355. agent_thought.answer_token = llm_usage.completion_tokens
  356. agent_thought.answer_price_unit = llm_usage.completion_price_unit
  357. agent_thought.answer_unit_price = llm_usage.completion_unit_price
  358. agent_thought.tokens = llm_usage.total_tokens
  359. agent_thought.total_price = llm_usage.total_price
  360. # check if tool labels is not empty
  361. labels = agent_thought.tool_labels or {}
  362. tools = agent_thought.tool.split(';') if agent_thought.tool else []
  363. for tool in tools:
  364. if not tool:
  365. continue
  366. if tool not in labels:
  367. tool_label = ToolManager.get_tool_label(tool)
  368. if tool_label:
  369. labels[tool] = tool_label.to_dict()
  370. else:
  371. labels[tool] = {'en_US': tool, 'zh_Hans': tool}
  372. agent_thought.tool_labels_str = json.dumps(labels)
  373. if tool_invoke_meta is not None:
  374. if isinstance(tool_invoke_meta, dict):
  375. try:
  376. tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
  377. except Exception as e:
  378. tool_invoke_meta = json.dumps(tool_invoke_meta)
  379. agent_thought.tool_meta_str = tool_invoke_meta
  380. db.session.commit()
  381. db.session.close()
  382. def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
  383. """
  384. convert tool variables to db variables
  385. """
  386. db_variables = db.session.query(ToolConversationVariables).filter(
  387. ToolConversationVariables.conversation_id == self.message.conversation_id,
  388. ).first()
  389. db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
  390. db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
  391. db.session.commit()
  392. db.session.close()
  393. def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  394. """
  395. Organize agent history
  396. """
  397. result = []
  398. # check if there is a system message in the beginning of the conversation
  399. for prompt_message in prompt_messages:
  400. if isinstance(prompt_message, SystemPromptMessage):
  401. result.append(prompt_message)
  402. messages: list[Message] = db.session.query(Message).filter(
  403. Message.conversation_id == self.message.conversation_id,
  404. ).order_by(Message.created_at.asc()).all()
  405. for message in messages:
  406. if message.id == self.message.id:
  407. continue
  408. result.append(self.organize_agent_user_prompt(message))
  409. agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
  410. if agent_thoughts:
  411. for agent_thought in agent_thoughts:
  412. tools = agent_thought.tool
  413. if tools:
  414. tools = tools.split(';')
  415. tool_calls: list[AssistantPromptMessage.ToolCall] = []
  416. tool_call_response: list[ToolPromptMessage] = []
  417. try:
  418. tool_inputs = json.loads(agent_thought.tool_input)
  419. except Exception as e:
  420. tool_inputs = { tool: {} for tool in tools }
  421. try:
  422. tool_responses = json.loads(agent_thought.observation)
  423. except Exception as e:
  424. tool_responses = { tool: agent_thought.observation for tool in tools }
  425. for tool in tools:
  426. # generate a uuid for tool call
  427. tool_call_id = str(uuid.uuid4())
  428. tool_calls.append(AssistantPromptMessage.ToolCall(
  429. id=tool_call_id,
  430. type='function',
  431. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  432. name=tool,
  433. arguments=json.dumps(tool_inputs.get(tool, {})),
  434. )
  435. ))
  436. tool_call_response.append(ToolPromptMessage(
  437. content=tool_responses.get(tool, agent_thought.observation),
  438. name=tool,
  439. tool_call_id=tool_call_id,
  440. ))
  441. result.extend([
  442. AssistantPromptMessage(
  443. content=agent_thought.thought,
  444. tool_calls=tool_calls,
  445. ),
  446. *tool_call_response
  447. ])
  448. if not tools:
  449. result.append(AssistantPromptMessage(content=agent_thought.thought))
  450. else:
  451. if message.answer:
  452. result.append(AssistantPromptMessage(content=message.answer))
  453. db.session.close()
  454. return result
  455. def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
  456. message_file_parser = MessageFileParser(
  457. tenant_id=self.tenant_id,
  458. app_id=self.app_config.app_id,
  459. )
  460. files = message.message_files
  461. if files:
  462. file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
  463. if file_extra_config:
  464. file_objs = message_file_parser.transform_message_files(
  465. files,
  466. file_extra_config
  467. )
  468. else:
  469. file_objs = []
  470. if not file_objs:
  471. return UserPromptMessage(content=message.query)
  472. else:
  473. prompt_message_contents = [TextPromptMessageContent(data=message.query)]
  474. for file_obj in file_objs:
  475. prompt_message_contents.append(file_obj.prompt_message_content)
  476. return UserPromptMessage(content=prompt_message_contents)
  477. else:
  478. return UserPromptMessage(content=message.query)