completion.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. import concurrent
  2. import json
  3. import logging
  4. from concurrent.futures import ThreadPoolExecutor
  5. from typing import Optional, List, Union, Tuple
  6. from flask import current_app, Flask
  7. from requests.exceptions import ChunkedEncodingError
  8. from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
  9. from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
  10. from core.callback_handler.llm_callback_handler import LLMCallbackHandler
  11. from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
  12. ConversationTaskInterruptException
  13. from core.external_data_tool.factory import ExternalDataToolFactory
  14. from core.file.file_obj import FileObj
  15. from core.model_providers.error import LLMBadRequestError
  16. from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
  17. ReadOnlyConversationTokenDBBufferSharedMemory
  18. from core.model_providers.model_factory import ModelFactory
  19. from core.model_providers.models.entity.message import PromptMessage, PromptMessageFile
  20. from core.model_providers.models.llm.base import BaseLLM
  21. from core.orchestrator_rule_parser import OrchestratorRuleParser
  22. from core.prompt.prompt_template import PromptTemplateParser
  23. from core.prompt.prompt_transform import PromptTransform
  24. from models.model import App, AppModelConfig, Account, Conversation, EndUser
  25. from core.moderation.base import ModerationException, ModerationAction
  26. from core.moderation.factory import ModerationFactory
  27. class Completion:
  28. @classmethod
  29. def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
  30. files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
  31. streaming: bool, is_override: bool = False, retriever_from: str = 'dev',
  32. auto_generate_name: bool = True):
  33. """
  34. errors: ProviderTokenNotInitError
  35. """
  36. query = PromptTemplateParser.remove_template_variables(query)
  37. memory = None
  38. if conversation:
  39. # get memory of conversation (read-only)
  40. memory = cls.get_memory_from_conversation(
  41. tenant_id=app.tenant_id,
  42. app_model_config=app_model_config,
  43. conversation=conversation,
  44. return_messages=False
  45. )
  46. inputs = conversation.inputs
  47. final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
  48. tenant_id=app.tenant_id,
  49. model_config=app_model_config.model_dict,
  50. streaming=streaming
  51. )
  52. conversation_message_task = ConversationMessageTask(
  53. task_id=task_id,
  54. app=app,
  55. app_model_config=app_model_config,
  56. user=user,
  57. conversation=conversation,
  58. is_override=is_override,
  59. inputs=inputs,
  60. query=query,
  61. files=files,
  62. streaming=streaming,
  63. model_instance=final_model_instance,
  64. auto_generate_name=auto_generate_name
  65. )
  66. prompt_message_files = [file.prompt_message_file for file in files]
  67. rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
  68. mode=app.mode,
  69. model_instance=final_model_instance,
  70. app_model_config=app_model_config,
  71. query=query,
  72. inputs=inputs,
  73. files=prompt_message_files
  74. )
  75. # init orchestrator rule parser
  76. orchestrator_rule_parser = OrchestratorRuleParser(
  77. tenant_id=app.tenant_id,
  78. app_model_config=app_model_config
  79. )
  80. try:
  81. chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
  82. try:
  83. # process sensitive_word_avoidance
  84. inputs, query = cls.moderation_for_inputs(app.id, app.tenant_id, app_model_config, inputs, query)
  85. except ModerationException as e:
  86. cls.run_final_llm(
  87. model_instance=final_model_instance,
  88. mode=app.mode,
  89. app_model_config=app_model_config,
  90. query=query,
  91. inputs=inputs,
  92. files=prompt_message_files,
  93. agent_execute_result=None,
  94. conversation_message_task=conversation_message_task,
  95. memory=memory,
  96. fake_response=str(e)
  97. )
  98. return
  99. # fill in variable inputs from external data tools if exists
  100. external_data_tools = app_model_config.external_data_tools_list
  101. if external_data_tools:
  102. inputs = cls.fill_in_inputs_from_external_data_tools(
  103. tenant_id=app.tenant_id,
  104. app_id=app.id,
  105. external_data_tools=external_data_tools,
  106. inputs=inputs,
  107. query=query
  108. )
  109. # get agent executor
  110. agent_executor = orchestrator_rule_parser.to_agent_executor(
  111. conversation_message_task=conversation_message_task,
  112. memory=memory,
  113. rest_tokens=rest_tokens_for_context_and_memory,
  114. chain_callback=chain_callback,
  115. tenant_id=app.tenant_id,
  116. retriever_from=retriever_from
  117. )
  118. query_for_agent = cls.get_query_for_agent(app, app_model_config, query, inputs)
  119. # run agent executor
  120. agent_execute_result = None
  121. if query_for_agent and agent_executor:
  122. should_use_agent = agent_executor.should_use_agent(query_for_agent)
  123. if should_use_agent:
  124. agent_execute_result = agent_executor.run(query_for_agent)
  125. # When no extra pre prompt is specified,
  126. # the output of the agent can be used directly as the main output content without calling LLM again
  127. fake_response = None
  128. if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
  129. and agent_execute_result.strategy not in [PlanningStrategy.ROUTER,
  130. PlanningStrategy.REACT_ROUTER]:
  131. fake_response = agent_execute_result.output
  132. # run the final llm
  133. cls.run_final_llm(
  134. model_instance=final_model_instance,
  135. mode=app.mode,
  136. app_model_config=app_model_config,
  137. query=query,
  138. inputs=inputs,
  139. files=prompt_message_files,
  140. agent_execute_result=agent_execute_result,
  141. conversation_message_task=conversation_message_task,
  142. memory=memory,
  143. fake_response=fake_response
  144. )
  145. except (ConversationTaskInterruptException, ConversationTaskStoppedException):
  146. return
  147. except ChunkedEncodingError as e:
  148. # Interrupt by LLM (like OpenAI), handle it.
  149. logging.warning(f'ChunkedEncodingError: {e}')
  150. conversation_message_task.end()
  151. return
  152. @classmethod
  153. def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str):
  154. if not app_model_config.sensitive_word_avoidance_dict['enabled']:
  155. return inputs, query
  156. type = app_model_config.sensitive_word_avoidance_dict['type']
  157. moderation = ModerationFactory(type, app_id, tenant_id, app_model_config.sensitive_word_avoidance_dict['config'])
  158. moderation_result = moderation.moderation_for_inputs(inputs, query)
  159. if not moderation_result.flagged:
  160. return inputs, query
  161. if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
  162. raise ModerationException(moderation_result.preset_response)
  163. elif moderation_result.action == ModerationAction.OVERRIDED:
  164. inputs = moderation_result.inputs
  165. query = moderation_result.query
  166. return inputs, query
  167. @classmethod
  168. def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, external_data_tools: list[dict],
  169. inputs: dict, query: str) -> dict:
  170. """
  171. Fill in variable inputs from external data tools if exists.
  172. :param tenant_id: workspace id
  173. :param app_id: app id
  174. :param external_data_tools: external data tools configs
  175. :param inputs: the inputs
  176. :param query: the query
  177. :return: the filled inputs
  178. """
  179. # Group tools by type and config
  180. grouped_tools = {}
  181. for tool in external_data_tools:
  182. if not tool.get("enabled"):
  183. continue
  184. tool_key = (tool.get("type"), json.dumps(tool.get("config"), sort_keys=True))
  185. grouped_tools.setdefault(tool_key, []).append(tool)
  186. results = {}
  187. with ThreadPoolExecutor() as executor:
  188. futures = {}
  189. for tool in external_data_tools:
  190. if not tool.get("enabled"):
  191. continue
  192. future = executor.submit(
  193. cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, tool,
  194. inputs, query
  195. )
  196. futures[future] = tool
  197. for future in concurrent.futures.as_completed(futures):
  198. tool_variable, result = future.result()
  199. results[tool_variable] = result
  200. inputs.update(results)
  201. return inputs
  202. @classmethod
  203. def query_external_data_tool(cls, flask_app: Flask, tenant_id: str, app_id: str, external_data_tool: dict,
  204. inputs: dict, query: str) -> Tuple[Optional[str], Optional[str]]:
  205. with flask_app.app_context():
  206. tool_variable = external_data_tool.get("variable")
  207. tool_type = external_data_tool.get("type")
  208. tool_config = external_data_tool.get("config")
  209. external_data_tool_factory = ExternalDataToolFactory(
  210. name=tool_type,
  211. tenant_id=tenant_id,
  212. app_id=app_id,
  213. variable=tool_variable,
  214. config=tool_config
  215. )
  216. # query external data tool
  217. result = external_data_tool_factory.query(
  218. inputs=inputs,
  219. query=query
  220. )
  221. return tool_variable, result
  222. @classmethod
  223. def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
  224. if app.mode != 'completion':
  225. return query
  226. return inputs.get(app_model_config.dataset_query_variable, "")
  227. @classmethod
  228. def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
  229. inputs: dict,
  230. files: List[PromptMessageFile],
  231. agent_execute_result: Optional[AgentExecuteResult],
  232. conversation_message_task: ConversationMessageTask,
  233. memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
  234. fake_response: Optional[str]):
  235. prompt_transform = PromptTransform()
  236. # get llm prompt
  237. if app_model_config.prompt_type == 'simple':
  238. prompt_messages, stop_words = prompt_transform.get_prompt(
  239. app_mode=mode,
  240. pre_prompt=app_model_config.pre_prompt,
  241. inputs=inputs,
  242. query=query,
  243. files=files,
  244. context=agent_execute_result.output if agent_execute_result else None,
  245. memory=memory,
  246. model_instance=model_instance
  247. )
  248. else:
  249. prompt_messages = prompt_transform.get_advanced_prompt(
  250. app_mode=mode,
  251. app_model_config=app_model_config,
  252. inputs=inputs,
  253. query=query,
  254. files=files,
  255. context=agent_execute_result.output if agent_execute_result else None,
  256. memory=memory,
  257. model_instance=model_instance
  258. )
  259. model_config = app_model_config.model_dict
  260. completion_params = model_config.get("completion_params", {})
  261. stop_words = completion_params.get("stop", [])
  262. cls.recale_llm_max_tokens(
  263. model_instance=model_instance,
  264. prompt_messages=prompt_messages,
  265. )
  266. response = model_instance.run(
  267. messages=prompt_messages,
  268. stop=stop_words if stop_words else None,
  269. callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
  270. fake_response=fake_response
  271. )
  272. return response
  273. @classmethod
  274. def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
  275. max_token_limit: int) -> str:
  276. """Get memory messages."""
  277. memory.max_token_limit = max_token_limit
  278. memory_key = memory.memory_variables[0]
  279. external_context = memory.load_memory_variables({})
  280. return external_context[memory_key]
  281. @classmethod
  282. def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
  283. conversation: Conversation,
  284. **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
  285. # only for calc token in memory
  286. memory_model_instance = ModelFactory.get_text_generation_model_from_model_config(
  287. tenant_id=tenant_id,
  288. model_config=app_model_config.model_dict
  289. )
  290. # use llm config from conversation
  291. memory = ReadOnlyConversationTokenDBBufferSharedMemory(
  292. conversation=conversation,
  293. model_instance=memory_model_instance,
  294. max_token_limit=kwargs.get("max_token_limit", 2048),
  295. memory_key=kwargs.get("memory_key", "chat_history"),
  296. return_messages=kwargs.get("return_messages", True),
  297. input_key=kwargs.get("input_key", "input"),
  298. output_key=kwargs.get("output_key", "output"),
  299. message_limit=kwargs.get("message_limit", 10),
  300. )
  301. return memory
  302. @classmethod
  303. def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
  304. query: str, inputs: dict, files: List[PromptMessageFile]) -> int:
  305. model_limited_tokens = model_instance.model_rules.max_tokens.max
  306. max_tokens = model_instance.get_model_kwargs().max_tokens
  307. if model_limited_tokens is None:
  308. return -1
  309. if max_tokens is None:
  310. max_tokens = 0
  311. prompt_transform = PromptTransform()
  312. # get prompt without memory and context
  313. if app_model_config.prompt_type == 'simple':
  314. prompt_messages, _ = prompt_transform.get_prompt(
  315. app_mode=mode,
  316. pre_prompt=app_model_config.pre_prompt,
  317. inputs=inputs,
  318. query=query,
  319. files=files,
  320. context=None,
  321. memory=None,
  322. model_instance=model_instance
  323. )
  324. else:
  325. prompt_messages = prompt_transform.get_advanced_prompt(
  326. app_mode=mode,
  327. app_model_config=app_model_config,
  328. inputs=inputs,
  329. query=query,
  330. files=files,
  331. context=None,
  332. memory=None,
  333. model_instance=model_instance
  334. )
  335. prompt_tokens = model_instance.get_num_tokens(prompt_messages)
  336. rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
  337. if rest_tokens < 0:
  338. raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
  339. "or shrink the max token, or switch to a llm with a larger token limit size.")
  340. return rest_tokens
  341. @classmethod
  342. def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]):
  343. # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
  344. model_limited_tokens = model_instance.model_rules.max_tokens.max
  345. max_tokens = model_instance.get_model_kwargs().max_tokens
  346. if model_limited_tokens is None:
  347. return
  348. if max_tokens is None:
  349. max_tokens = 0
  350. prompt_tokens = model_instance.get_num_tokens(prompt_messages)
  351. if prompt_tokens + max_tokens > model_limited_tokens:
  352. max_tokens = max(model_limited_tokens - prompt_tokens, 16)
  353. # update model instance max tokens
  354. model_kwargs = model_instance.get_model_kwargs()
  355. model_kwargs.max_tokens = max_tokens
  356. model_instance.set_model_kwargs(model_kwargs)