completion.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. import logging
  2. from typing import Optional, List, Union, Tuple
  3. from langchain.callbacks import CallbackManager
  4. from langchain.chat_models.base import BaseChatModel
  5. from langchain.llms import BaseLLM
  6. from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage
  7. from requests.exceptions import ChunkedEncodingError
  8. from core.constant import llm_constant
  9. from core.callback_handler.llm_callback_handler import LLMCallbackHandler
  10. from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
  11. DifyStdOutCallbackHandler
  12. from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, PubHandler
  13. from core.llm.error import LLMBadRequestError
  14. from core.llm.llm_builder import LLMBuilder
  15. from core.chain.main_chain_builder import MainChainBuilder
  16. from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
  17. from core.llm.streamable_open_ai import StreamableOpenAI
  18. from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
  19. ReadOnlyConversationTokenDBBufferSharedMemory
  20. from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
  21. ReadOnlyConversationTokenDBStringBufferSharedMemory
  22. from core.prompt.prompt_builder import PromptBuilder
  23. from core.prompt.prompt_template import OutLinePromptTemplate
  24. from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
  25. from models.model import App, AppModelConfig, Account, Conversation, Message
  26. class Completion:
  27. @classmethod
  28. def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
  29. user: Account, conversation: Optional[Conversation], streaming: bool, is_override: bool = False):
  30. """
  31. errors: ProviderTokenNotInitError
  32. """
  33. cls.validate_query_tokens(app.tenant_id, app_model_config, query)
  34. memory = None
  35. if conversation:
  36. # get memory of conversation (read-only)
  37. memory = cls.get_memory_from_conversation(
  38. tenant_id=app.tenant_id,
  39. app_model_config=app_model_config,
  40. conversation=conversation,
  41. return_messages=False
  42. )
  43. inputs = conversation.inputs
  44. conversation_message_task = ConversationMessageTask(
  45. task_id=task_id,
  46. app=app,
  47. app_model_config=app_model_config,
  48. user=user,
  49. conversation=conversation,
  50. is_override=is_override,
  51. inputs=inputs,
  52. query=query,
  53. streaming=streaming
  54. )
  55. # build main chain include agent
  56. main_chain = MainChainBuilder.to_langchain_components(
  57. tenant_id=app.tenant_id,
  58. agent_mode=app_model_config.agent_mode_dict,
  59. memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
  60. conversation_message_task=conversation_message_task
  61. )
  62. chain_output = ''
  63. if main_chain:
  64. chain_output = main_chain.run(query)
  65. # run the final llm
  66. try:
  67. cls.run_final_llm(
  68. tenant_id=app.tenant_id,
  69. mode=app.mode,
  70. app_model_config=app_model_config,
  71. query=query,
  72. inputs=inputs,
  73. chain_output=chain_output,
  74. conversation_message_task=conversation_message_task,
  75. memory=memory,
  76. streaming=streaming
  77. )
  78. except ConversationTaskStoppedException:
  79. return
  80. except ChunkedEncodingError as e:
  81. # Interrupt by LLM (like OpenAI), handle it.
  82. logging.warning(f'ChunkedEncodingError: {e}')
  83. conversation_message_task.end()
  84. return
  85. @classmethod
  86. def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
  87. chain_output: str,
  88. conversation_message_task: ConversationMessageTask,
  89. memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
  90. final_llm = LLMBuilder.to_llm_from_model(
  91. tenant_id=tenant_id,
  92. model=app_model_config.model_dict,
  93. streaming=streaming
  94. )
  95. # get llm prompt
  96. prompt, stop_words = cls.get_main_llm_prompt(
  97. mode=mode,
  98. llm=final_llm,
  99. pre_prompt=app_model_config.pre_prompt,
  100. query=query,
  101. inputs=inputs,
  102. chain_output=chain_output,
  103. memory=memory
  104. )
  105. final_llm.callback_manager = cls.get_llm_callback_manager(final_llm, streaming, conversation_message_task)
  106. cls.recale_llm_max_tokens(
  107. final_llm=final_llm,
  108. prompt=prompt,
  109. mode=mode
  110. )
  111. response = final_llm.generate([prompt], stop_words)
  112. return response
  113. @classmethod
  114. def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
  115. chain_output: Optional[str],
  116. memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
  117. Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
  118. # disable template string in query
  119. query_params = OutLinePromptTemplate.from_template(template=query).input_variables
  120. if query_params:
  121. for query_param in query_params:
  122. if query_param not in inputs:
  123. inputs[query_param] = '{' + query_param + '}'
  124. pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
  125. if mode == 'completion':
  126. prompt_template = OutLinePromptTemplate.from_template(
  127. template=("""Use the following CONTEXT as your learned knowledge:
  128. [CONTEXT]
  129. {context}
  130. [END CONTEXT]
  131. When answer to user:
  132. - If you don't know, just say that you don't know.
  133. - If you don't know when you are not sure, ask for clarification.
  134. Avoid mentioning that you obtained the information from the context.
  135. And answer according to the language of the user's question.
  136. """ if chain_output else "")
  137. + (pre_prompt + "\n" if pre_prompt else "")
  138. + "{query}\n"
  139. )
  140. if chain_output:
  141. inputs['context'] = chain_output
  142. context_params = OutLinePromptTemplate.from_template(template=chain_output).input_variables
  143. if context_params:
  144. for context_param in context_params:
  145. if context_param not in inputs:
  146. inputs[context_param] = '{' + context_param + '}'
  147. prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
  148. prompt_content = prompt_template.format(
  149. query=query,
  150. **prompt_inputs
  151. )
  152. if isinstance(llm, BaseChatModel):
  153. # use chat llm as completion model
  154. return [HumanMessage(content=prompt_content)], None
  155. else:
  156. return prompt_content, None
  157. else:
  158. messages: List[BaseMessage] = []
  159. human_inputs = {
  160. "query": query
  161. }
  162. human_message_prompt = ""
  163. if pre_prompt:
  164. pre_prompt_inputs = {k: inputs[k] for k in
  165. OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
  166. if k in inputs}
  167. if pre_prompt_inputs:
  168. human_inputs.update(pre_prompt_inputs)
  169. if chain_output:
  170. human_inputs['context'] = chain_output
  171. human_message_prompt += """Use the following CONTEXT as your learned knowledge.
  172. [CONTEXT]
  173. {context}
  174. [END CONTEXT]
  175. When answer to user:
  176. - If you don't know, just say that you don't know.
  177. - If you don't know when you are not sure, ask for clarification.
  178. Avoid mentioning that you obtained the information from the context.
  179. And answer according to the language of the user's question.
  180. """
  181. if pre_prompt:
  182. human_message_prompt += pre_prompt
  183. query_prompt = "\nHuman: {query}\nAI: "
  184. if memory:
  185. # append chat histories
  186. tmp_human_message = PromptBuilder.to_human_message(
  187. prompt_content=human_message_prompt + query_prompt,
  188. inputs=human_inputs
  189. )
  190. curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message])
  191. rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \
  192. - memory.llm.max_tokens - curr_message_tokens
  193. rest_tokens = max(rest_tokens, 0)
  194. histories = cls.get_history_messages_from_memory(memory, rest_tokens)
  195. # disable template string in query
  196. histories_params = OutLinePromptTemplate.from_template(template=histories).input_variables
  197. if histories_params:
  198. for histories_param in histories_params:
  199. if histories_param not in human_inputs:
  200. human_inputs[histories_param] = '{' + histories_param + '}'
  201. human_message_prompt += "\n\n" + histories
  202. human_message_prompt += query_prompt
  203. # construct main prompt
  204. human_message = PromptBuilder.to_human_message(
  205. prompt_content=human_message_prompt,
  206. inputs=human_inputs
  207. )
  208. messages.append(human_message)
  209. return messages, ['\nHuman:']
  210. @classmethod
  211. def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
  212. streaming: bool,
  213. conversation_message_task: ConversationMessageTask) -> CallbackManager:
  214. llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
  215. if streaming:
  216. callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
  217. else:
  218. callback_handlers = [llm_callback_handler, DifyStdOutCallbackHandler()]
  219. return CallbackManager(callback_handlers)
  220. @classmethod
  221. def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
  222. max_token_limit: int) -> \
  223. str:
  224. """Get memory messages."""
  225. memory.max_token_limit = max_token_limit
  226. memory_key = memory.memory_variables[0]
  227. external_context = memory.load_memory_variables({})
  228. return external_context[memory_key]
  229. @classmethod
  230. def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
  231. conversation: Conversation,
  232. **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
  233. # only for calc token in memory
  234. memory_llm = LLMBuilder.to_llm_from_model(
  235. tenant_id=tenant_id,
  236. model=app_model_config.model_dict
  237. )
  238. # use llm config from conversation
  239. memory = ReadOnlyConversationTokenDBBufferSharedMemory(
  240. conversation=conversation,
  241. llm=memory_llm,
  242. max_token_limit=kwargs.get("max_token_limit", 2048),
  243. memory_key=kwargs.get("memory_key", "chat_history"),
  244. return_messages=kwargs.get("return_messages", True),
  245. input_key=kwargs.get("input_key", "input"),
  246. output_key=kwargs.get("output_key", "output"),
  247. message_limit=kwargs.get("message_limit", 10),
  248. )
  249. return memory
  250. @classmethod
  251. def validate_query_tokens(cls, tenant_id: str, app_model_config: AppModelConfig, query: str):
  252. llm = LLMBuilder.to_llm_from_model(
  253. tenant_id=tenant_id,
  254. model=app_model_config.model_dict
  255. )
  256. model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
  257. max_tokens = llm.max_tokens
  258. if model_limited_tokens - max_tokens - llm.get_num_tokens(query) < 0:
  259. raise LLMBadRequestError("Query is too long")
  260. @classmethod
  261. def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
  262. prompt: Union[str, List[BaseMessage]], mode: str):
  263. # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
  264. model_limited_tokens = llm_constant.max_context_token_length[final_llm.model_name]
  265. max_tokens = final_llm.max_tokens
  266. if mode == 'completion' and isinstance(final_llm, BaseLLM):
  267. prompt_tokens = final_llm.get_num_tokens(prompt)
  268. else:
  269. prompt_tokens = final_llm.get_messages_tokens(prompt)
  270. if prompt_tokens + max_tokens > model_limited_tokens:
  271. max_tokens = max(model_limited_tokens - prompt_tokens, 16)
  272. final_llm.max_tokens = max_tokens
  273. @classmethod
  274. def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
  275. app_model_config: AppModelConfig, user: Account, streaming: bool):
  276. llm: StreamableOpenAI = LLMBuilder.to_llm(
  277. tenant_id=app.tenant_id,
  278. model_name='gpt-3.5-turbo',
  279. streaming=streaming
  280. )
  281. # get llm prompt
  282. original_prompt, _ = cls.get_main_llm_prompt(
  283. mode="completion",
  284. llm=llm,
  285. pre_prompt=pre_prompt,
  286. query=message.query,
  287. inputs=message.inputs,
  288. chain_output=None,
  289. memory=None
  290. )
  291. original_completion = message.answer.strip()
  292. prompt = MORE_LIKE_THIS_GENERATE_PROMPT
  293. prompt = prompt.format(prompt=original_prompt, original_completion=original_completion)
  294. if isinstance(llm, BaseChatModel):
  295. prompt = [HumanMessage(content=prompt)]
  296. conversation_message_task = ConversationMessageTask(
  297. task_id=task_id,
  298. app=app,
  299. app_model_config=app_model_config,
  300. user=user,
  301. inputs=message.inputs,
  302. query=message.query,
  303. is_override=True if message.override_model_configs else False,
  304. streaming=streaming
  305. )
  306. llm.callback_manager = cls.get_llm_callback_manager(llm, streaming, conversation_message_task)
  307. cls.recale_llm_max_tokens(
  308. final_llm=llm,
  309. prompt=prompt,
  310. mode='completion'
  311. )
  312. llm.generate([prompt])