llm_generator.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import logging
  2. from langchain import PromptTemplate
  3. from langchain.chat_models.base import BaseChatModel
  4. from langchain.schema import HumanMessage, OutputParserException, BaseMessage
  5. from core.constant import llm_constant
  6. from core.llm.llm_builder import LLMBuilder
  7. from core.llm.streamable_open_ai import StreamableOpenAI
  8. from core.llm.token_calculator import TokenCalculator
  9. from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
  10. from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
  11. from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate
  12. from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT
  13. # gpt-3.5-turbo works not well
  14. generate_base_model = 'text-davinci-003'
  15. class LLMGenerator:
  16. @classmethod
  17. def generate_conversation_name(cls, tenant_id: str, query, answer):
  18. prompt = CONVERSATION_TITLE_PROMPT
  19. if len(query) > 2000:
  20. query = query[:300] + "...[TRUNCATED]..." + query[-300:]
  21. prompt = prompt.format(query=query)
  22. llm: StreamableOpenAI = LLMBuilder.to_llm(
  23. tenant_id=tenant_id,
  24. model_name='gpt-3.5-turbo',
  25. max_tokens=50
  26. )
  27. if isinstance(llm, BaseChatModel):
  28. prompt = [HumanMessage(content=prompt)]
  29. response = llm.generate([prompt])
  30. answer = response.generations[0][0].text
  31. return answer.strip()
  32. @classmethod
  33. def generate_conversation_summary(cls, tenant_id: str, messages):
  34. max_tokens = 200
  35. model = 'gpt-3.5-turbo'
  36. prompt = CONVERSATION_SUMMARY_PROMPT
  37. prompt_with_empty_context = prompt.format(context='')
  38. prompt_tokens = TokenCalculator.get_num_tokens(model, prompt_with_empty_context)
  39. rest_tokens = llm_constant.max_context_token_length[model] - prompt_tokens - max_tokens - 1
  40. context = ''
  41. for message in messages:
  42. if not message.answer:
  43. continue
  44. if len(message.query) > 2000:
  45. query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
  46. else:
  47. query = message.query
  48. if len(message.answer) > 2000:
  49. answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
  50. else:
  51. answer = message.answer
  52. message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
  53. if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0:
  54. context += message_qa_text
  55. if not context:
  56. return '[message too long, no summary]'
  57. prompt = prompt.format(context=context)
  58. llm: StreamableOpenAI = LLMBuilder.to_llm(
  59. tenant_id=tenant_id,
  60. model_name=model,
  61. max_tokens=max_tokens
  62. )
  63. if isinstance(llm, BaseChatModel):
  64. prompt = [HumanMessage(content=prompt)]
  65. response = llm.generate([prompt])
  66. answer = response.generations[0][0].text
  67. return answer.strip()
  68. @classmethod
  69. def generate_introduction(cls, tenant_id: str, pre_prompt: str):
  70. prompt = INTRODUCTION_GENERATE_PROMPT
  71. prompt = prompt.format(prompt=pre_prompt)
  72. llm: StreamableOpenAI = LLMBuilder.to_llm(
  73. tenant_id=tenant_id,
  74. model_name=generate_base_model,
  75. )
  76. if isinstance(llm, BaseChatModel):
  77. prompt = [HumanMessage(content=prompt)]
  78. response = llm.generate([prompt])
  79. answer = response.generations[0][0].text
  80. return answer.strip()
  81. @classmethod
  82. def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
  83. output_parser = SuggestedQuestionsAfterAnswerOutputParser()
  84. format_instructions = output_parser.get_format_instructions()
  85. prompt = JinjaPromptTemplate(
  86. template="{{histories}}\n{{format_instructions}}\nquestions:\n",
  87. input_variables=["histories"],
  88. partial_variables={"format_instructions": format_instructions}
  89. )
  90. _input = prompt.format_prompt(histories=histories)
  91. llm: StreamableOpenAI = LLMBuilder.to_llm(
  92. tenant_id=tenant_id,
  93. model_name='gpt-3.5-turbo',
  94. temperature=0,
  95. max_tokens=256
  96. )
  97. if isinstance(llm, BaseChatModel):
  98. query = [HumanMessage(content=_input.to_string())]
  99. else:
  100. query = _input.to_string()
  101. try:
  102. output = llm(query)
  103. if isinstance(output, BaseMessage):
  104. output = output.content
  105. questions = output_parser.parse(output)
  106. except Exception:
  107. logging.exception("Error generating suggested questions after answer")
  108. questions = []
  109. return questions
  110. @classmethod
  111. def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
  112. output_parser = RuleConfigGeneratorOutputParser()
  113. prompt = OutLinePromptTemplate(
  114. template=output_parser.get_format_instructions(),
  115. input_variables=["audiences", "hoping_to_solve"],
  116. partial_variables={
  117. "variable": '{variable}',
  118. "lanA": '{lanA}',
  119. "lanB": '{lanB}',
  120. "topic": '{topic}'
  121. },
  122. validate_template=False
  123. )
  124. _input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
  125. llm: StreamableOpenAI = LLMBuilder.to_llm(
  126. tenant_id=tenant_id,
  127. model_name=generate_base_model,
  128. temperature=0,
  129. max_tokens=512
  130. )
  131. if isinstance(llm, BaseChatModel):
  132. query = [HumanMessage(content=_input.to_string())]
  133. else:
  134. query = _input.to_string()
  135. try:
  136. output = llm(query)
  137. rule_config = output_parser.parse(output)
  138. except OutputParserException:
  139. raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
  140. except Exception:
  141. logging.exception("Error generating prompt")
  142. rule_config = {
  143. "prompt": "",
  144. "variables": [],
  145. "opening_statement": ""
  146. }
  147. return rule_config