llm_generator.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. import logging
  2. from langchain import PromptTemplate
  3. from langchain.chat_models.base import BaseChatModel
  4. from langchain.schema import HumanMessage, OutputParserException, BaseMessage, SystemMessage
  5. from core.constant import llm_constant
  6. from core.llm.llm_builder import LLMBuilder
  7. from core.llm.streamable_open_ai import StreamableOpenAI
  8. from core.llm.token_calculator import TokenCalculator
  9. from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
  10. from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
  11. from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate
  12. from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
  13. GENERATOR_QA_PROMPT
  14. # gpt-3.5-turbo works not well
  15. generate_base_model = 'text-davinci-003'
  16. class LLMGenerator:
  17. @classmethod
  18. def generate_conversation_name(cls, tenant_id: str, query, answer):
  19. prompt = CONVERSATION_TITLE_PROMPT
  20. if len(query) > 2000:
  21. query = query[:300] + "...[TRUNCATED]..." + query[-300:]
  22. prompt = prompt.format(query=query)
  23. llm: StreamableOpenAI = LLMBuilder.to_llm(
  24. tenant_id=tenant_id,
  25. model_name='gpt-3.5-turbo',
  26. max_tokens=50,
  27. timeout=600
  28. )
  29. if isinstance(llm, BaseChatModel):
  30. prompt = [HumanMessage(content=prompt)]
  31. response = llm.generate([prompt])
  32. answer = response.generations[0][0].text
  33. return answer.strip()
  34. @classmethod
  35. def generate_conversation_summary(cls, tenant_id: str, messages):
  36. max_tokens = 200
  37. model = 'gpt-3.5-turbo'
  38. prompt = CONVERSATION_SUMMARY_PROMPT
  39. prompt_with_empty_context = prompt.format(context='')
  40. prompt_tokens = TokenCalculator.get_num_tokens(model, prompt_with_empty_context)
  41. rest_tokens = llm_constant.max_context_token_length[model] - prompt_tokens - max_tokens - 1
  42. context = ''
  43. for message in messages:
  44. if not message.answer:
  45. continue
  46. if len(message.query) > 2000:
  47. query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
  48. else:
  49. query = message.query
  50. if len(message.answer) > 2000:
  51. answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
  52. else:
  53. answer = message.answer
  54. message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
  55. if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0:
  56. context += message_qa_text
  57. if not context:
  58. return '[message too long, no summary]'
  59. prompt = prompt.format(context=context)
  60. llm: StreamableOpenAI = LLMBuilder.to_llm(
  61. tenant_id=tenant_id,
  62. model_name=model,
  63. max_tokens=max_tokens
  64. )
  65. if isinstance(llm, BaseChatModel):
  66. prompt = [HumanMessage(content=prompt)]
  67. response = llm.generate([prompt])
  68. answer = response.generations[0][0].text
  69. return answer.strip()
  70. @classmethod
  71. def generate_introduction(cls, tenant_id: str, pre_prompt: str):
  72. prompt = INTRODUCTION_GENERATE_PROMPT
  73. prompt = prompt.format(prompt=pre_prompt)
  74. llm: StreamableOpenAI = LLMBuilder.to_llm(
  75. tenant_id=tenant_id,
  76. model_name=generate_base_model,
  77. )
  78. if isinstance(llm, BaseChatModel):
  79. prompt = [HumanMessage(content=prompt)]
  80. response = llm.generate([prompt])
  81. answer = response.generations[0][0].text
  82. return answer.strip()
  83. @classmethod
  84. def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
  85. output_parser = SuggestedQuestionsAfterAnswerOutputParser()
  86. format_instructions = output_parser.get_format_instructions()
  87. prompt = JinjaPromptTemplate(
  88. template="{{histories}}\n{{format_instructions}}\nquestions:\n",
  89. input_variables=["histories"],
  90. partial_variables={"format_instructions": format_instructions}
  91. )
  92. _input = prompt.format_prompt(histories=histories)
  93. llm: StreamableOpenAI = LLMBuilder.to_llm(
  94. tenant_id=tenant_id,
  95. model_name='gpt-3.5-turbo',
  96. temperature=0,
  97. max_tokens=256
  98. )
  99. if isinstance(llm, BaseChatModel):
  100. query = [HumanMessage(content=_input.to_string())]
  101. else:
  102. query = _input.to_string()
  103. try:
  104. output = llm(query)
  105. if isinstance(output, BaseMessage):
  106. output = output.content
  107. questions = output_parser.parse(output)
  108. except Exception:
  109. logging.exception("Error generating suggested questions after answer")
  110. questions = []
  111. return questions
  112. @classmethod
  113. def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
  114. output_parser = RuleConfigGeneratorOutputParser()
  115. prompt = OutLinePromptTemplate(
  116. template=output_parser.get_format_instructions(),
  117. input_variables=["audiences", "hoping_to_solve"],
  118. partial_variables={
  119. "variable": '{variable}',
  120. "lanA": '{lanA}',
  121. "lanB": '{lanB}',
  122. "topic": '{topic}'
  123. },
  124. validate_template=False
  125. )
  126. _input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
  127. llm: StreamableOpenAI = LLMBuilder.to_llm(
  128. tenant_id=tenant_id,
  129. model_name=generate_base_model,
  130. temperature=0,
  131. max_tokens=512
  132. )
  133. if isinstance(llm, BaseChatModel):
  134. query = [HumanMessage(content=_input.to_string())]
  135. else:
  136. query = _input.to_string()
  137. try:
  138. output = llm(query)
  139. rule_config = output_parser.parse(output)
  140. except OutputParserException:
  141. raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
  142. except Exception:
  143. logging.exception("Error generating prompt")
  144. rule_config = {
  145. "prompt": "",
  146. "variables": [],
  147. "opening_statement": ""
  148. }
  149. return rule_config
  150. @classmethod
  151. async def generate_qa_document(cls, llm: StreamableOpenAI, query):
  152. prompt = GENERATOR_QA_PROMPT
  153. if isinstance(llm, BaseChatModel):
  154. prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
  155. response = llm.generate([prompt])
  156. answer = response.generations[0][0].text
  157. return answer.strip()
  158. @classmethod
  159. def generate_qa_document_sync(cls, llm: StreamableOpenAI, query):
  160. prompt = GENERATOR_QA_PROMPT
  161. if isinstance(llm, BaseChatModel):
  162. prompt = [SystemMessage(content=prompt), HumanMessage(content=query)]
  163. response = llm.generate([prompt])
  164. answer = response.generations[0][0].text
  165. return answer.strip()