completion.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. import json
  2. import logging
  3. from typing import Optional, List, Union
  4. from requests.exceptions import ChunkedEncodingError
  5. from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
  6. from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
  7. from core.callback_handler.llm_callback_handler import LLMCallbackHandler
  8. from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceError
  9. from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
  10. from core.model_providers.error import LLMBadRequestError
  11. from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
  12. ReadOnlyConversationTokenDBBufferSharedMemory
  13. from core.model_providers.model_factory import ModelFactory
  14. from core.model_providers.models.entity.message import PromptMessage
  15. from core.model_providers.models.llm.base import BaseLLM
  16. from core.orchestrator_rule_parser import OrchestratorRuleParser
  17. from core.prompt.prompt_builder import PromptBuilder
  18. from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
  19. from models.dataset import DocumentSegment, Dataset, Document
  20. from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
  21. class Completion:
  22. @classmethod
  23. def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
  24. user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool,
  25. is_override: bool = False, retriever_from: str = 'dev'):
  26. """
  27. errors: ProviderTokenNotInitError
  28. """
  29. query = PromptBuilder.process_template(query)
  30. memory = None
  31. if conversation:
  32. # get memory of conversation (read-only)
  33. memory = cls.get_memory_from_conversation(
  34. tenant_id=app.tenant_id,
  35. app_model_config=app_model_config,
  36. conversation=conversation,
  37. return_messages=False
  38. )
  39. inputs = conversation.inputs
  40. final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
  41. tenant_id=app.tenant_id,
  42. model_config=app_model_config.model_dict,
  43. streaming=streaming
  44. )
  45. conversation_message_task = ConversationMessageTask(
  46. task_id=task_id,
  47. app=app,
  48. app_model_config=app_model_config,
  49. user=user,
  50. conversation=conversation,
  51. is_override=is_override,
  52. inputs=inputs,
  53. query=query,
  54. streaming=streaming,
  55. model_instance=final_model_instance
  56. )
  57. rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
  58. mode=app.mode,
  59. model_instance=final_model_instance,
  60. app_model_config=app_model_config,
  61. query=query,
  62. inputs=inputs
  63. )
  64. # init orchestrator rule parser
  65. orchestrator_rule_parser = OrchestratorRuleParser(
  66. tenant_id=app.tenant_id,
  67. app_model_config=app_model_config
  68. )
  69. try:
  70. # parse sensitive_word_avoidance_chain
  71. chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
  72. sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(
  73. final_model_instance, [chain_callback])
  74. if sensitive_word_avoidance_chain:
  75. try:
  76. query = sensitive_word_avoidance_chain.run(query)
  77. except SensitiveWordAvoidanceError as ex:
  78. cls.run_final_llm(
  79. model_instance=final_model_instance,
  80. mode=app.mode,
  81. app_model_config=app_model_config,
  82. query=query,
  83. inputs=inputs,
  84. agent_execute_result=None,
  85. conversation_message_task=conversation_message_task,
  86. memory=memory,
  87. fake_response=ex.message
  88. )
  89. return
  90. # get agent executor
  91. agent_executor = orchestrator_rule_parser.to_agent_executor(
  92. conversation_message_task=conversation_message_task,
  93. memory=memory,
  94. rest_tokens=rest_tokens_for_context_and_memory,
  95. chain_callback=chain_callback,
  96. retriever_from=retriever_from
  97. )
  98. query_for_agent = cls.get_query_for_agent(app, app_model_config, query, inputs)
  99. # run agent executor
  100. agent_execute_result = None
  101. if query_for_agent and agent_executor:
  102. should_use_agent = agent_executor.should_use_agent(query_for_agent)
  103. if should_use_agent:
  104. agent_execute_result = agent_executor.run(query_for_agent)
  105. # When no extra pre prompt is specified,
  106. # the output of the agent can be used directly as the main output content without calling LLM again
  107. fake_response = None
  108. if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
  109. and agent_execute_result.strategy not in [PlanningStrategy.ROUTER,
  110. PlanningStrategy.REACT_ROUTER]:
  111. fake_response = agent_execute_result.output
  112. # run the final llm
  113. cls.run_final_llm(
  114. model_instance=final_model_instance,
  115. mode=app.mode,
  116. app_model_config=app_model_config,
  117. query=query,
  118. inputs=inputs,
  119. agent_execute_result=agent_execute_result,
  120. conversation_message_task=conversation_message_task,
  121. memory=memory,
  122. fake_response=fake_response
  123. )
  124. except ConversationTaskStoppedException:
  125. return
  126. except ChunkedEncodingError as e:
  127. # Interrupt by LLM (like OpenAI), handle it.
  128. logging.warning(f'ChunkedEncodingError: {e}')
  129. conversation_message_task.end()
  130. return
  131. @classmethod
  132. def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
  133. if app.mode != 'completion':
  134. return query
  135. return inputs.get(app_model_config.dataset_query_variable, "")
  136. @classmethod
  137. def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
  138. inputs: dict,
  139. agent_execute_result: Optional[AgentExecuteResult],
  140. conversation_message_task: ConversationMessageTask,
  141. memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
  142. fake_response: Optional[str]):
  143. # get llm prompt
  144. prompt_messages, stop_words = model_instance.get_prompt(
  145. mode=mode,
  146. pre_prompt=app_model_config.pre_prompt,
  147. inputs=inputs,
  148. query=query,
  149. context=agent_execute_result.output if agent_execute_result else None,
  150. memory=memory
  151. )
  152. cls.recale_llm_max_tokens(
  153. model_instance=model_instance,
  154. prompt_messages=prompt_messages,
  155. )
  156. response = model_instance.run(
  157. messages=prompt_messages,
  158. stop=stop_words,
  159. callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
  160. fake_response=fake_response
  161. )
  162. return response
  163. @classmethod
  164. def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
  165. max_token_limit: int) -> str:
  166. """Get memory messages."""
  167. memory.max_token_limit = max_token_limit
  168. memory_key = memory.memory_variables[0]
  169. external_context = memory.load_memory_variables({})
  170. return external_context[memory_key]
  171. @classmethod
  172. def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
  173. conversation: Conversation,
  174. **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
  175. # only for calc token in memory
  176. memory_model_instance = ModelFactory.get_text_generation_model_from_model_config(
  177. tenant_id=tenant_id,
  178. model_config=app_model_config.model_dict
  179. )
  180. # use llm config from conversation
  181. memory = ReadOnlyConversationTokenDBBufferSharedMemory(
  182. conversation=conversation,
  183. model_instance=memory_model_instance,
  184. max_token_limit=kwargs.get("max_token_limit", 2048),
  185. memory_key=kwargs.get("memory_key", "chat_history"),
  186. return_messages=kwargs.get("return_messages", True),
  187. input_key=kwargs.get("input_key", "input"),
  188. output_key=kwargs.get("output_key", "output"),
  189. message_limit=kwargs.get("message_limit", 10),
  190. )
  191. return memory
  192. @classmethod
  193. def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
  194. query: str, inputs: dict) -> int:
  195. model_limited_tokens = model_instance.model_rules.max_tokens.max
  196. max_tokens = model_instance.get_model_kwargs().max_tokens
  197. if model_limited_tokens is None:
  198. return -1
  199. if max_tokens is None:
  200. max_tokens = 0
  201. # get prompt without memory and context
  202. prompt_messages, _ = model_instance.get_prompt(
  203. mode=mode,
  204. pre_prompt=app_model_config.pre_prompt,
  205. inputs=inputs,
  206. query=query,
  207. context=None,
  208. memory=None
  209. )
  210. prompt_tokens = model_instance.get_num_tokens(prompt_messages)
  211. rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
  212. if rest_tokens < 0:
  213. raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
  214. "or shrink the max token, or switch to a llm with a larger token limit size.")
  215. return rest_tokens
  216. @classmethod
  217. def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]):
  218. # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
  219. model_limited_tokens = model_instance.model_rules.max_tokens.max
  220. max_tokens = model_instance.get_model_kwargs().max_tokens
  221. if model_limited_tokens is None:
  222. return
  223. if max_tokens is None:
  224. max_tokens = 0
  225. prompt_tokens = model_instance.get_num_tokens(prompt_messages)
  226. if prompt_tokens + max_tokens > model_limited_tokens:
  227. max_tokens = max(model_limited_tokens - prompt_tokens, 16)
  228. # update model instance max tokens
  229. model_kwargs = model_instance.get_model_kwargs()
  230. model_kwargs.max_tokens = max_tokens
  231. model_instance.set_model_kwargs(model_kwargs)
  232. @classmethod
  233. def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
  234. app_model_config: AppModelConfig, user: Account, streaming: bool):
  235. final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
  236. tenant_id=app.tenant_id,
  237. model_config=app_model_config.model_dict,
  238. streaming=streaming
  239. )
  240. # get llm prompt
  241. old_prompt_messages, _ = final_model_instance.get_prompt(
  242. mode='completion',
  243. pre_prompt=pre_prompt,
  244. inputs=message.inputs,
  245. query=message.query,
  246. context=None,
  247. memory=None
  248. )
  249. original_completion = message.answer.strip()
  250. prompt = MORE_LIKE_THIS_GENERATE_PROMPT
  251. prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion)
  252. prompt_messages = [PromptMessage(content=prompt)]
  253. conversation_message_task = ConversationMessageTask(
  254. task_id=task_id,
  255. app=app,
  256. app_model_config=app_model_config,
  257. user=user,
  258. inputs=message.inputs,
  259. query=message.query,
  260. is_override=True if message.override_model_configs else False,
  261. streaming=streaming,
  262. model_instance=final_model_instance
  263. )
  264. cls.recale_llm_max_tokens(
  265. model_instance=final_model_instance,
  266. prompt_messages=prompt_messages
  267. )
  268. final_model_instance.run(
  269. messages=prompt_messages,
  270. callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)]
  271. )