llm_generator.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import logging
  2. from langchain import PromptTemplate
  3. from langchain.chat_models.base import BaseChatModel
  4. from langchain.schema import HumanMessage, OutputParserException, BaseMessage
  5. from core.constant import llm_constant
  6. from core.llm.llm_builder import LLMBuilder
  7. from core.llm.streamable_open_ai import StreamableOpenAI
  8. from core.llm.token_calculator import TokenCalculator
  9. from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
  10. from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
  11. from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate
  12. from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT
  13. # gpt-3.5-turbo works not well
  14. generate_base_model = 'text-davinci-003'
  15. class LLMGenerator:
  16. @classmethod
  17. def generate_conversation_name(cls, tenant_id: str, query, answer):
  18. prompt = CONVERSATION_TITLE_PROMPT
  19. prompt = prompt.format(query=query)
  20. llm: StreamableOpenAI = LLMBuilder.to_llm(
  21. tenant_id=tenant_id,
  22. model_name='gpt-3.5-turbo',
  23. max_tokens=50
  24. )
  25. if isinstance(llm, BaseChatModel):
  26. prompt = [HumanMessage(content=prompt)]
  27. response = llm.generate([prompt])
  28. answer = response.generations[0][0].text
  29. return answer.strip()
  30. @classmethod
  31. def generate_conversation_summary(cls, tenant_id: str, messages):
  32. max_tokens = 200
  33. model = 'gpt-3.5-turbo'
  34. prompt = CONVERSATION_SUMMARY_PROMPT
  35. prompt_with_empty_context = prompt.format(context='')
  36. prompt_tokens = TokenCalculator.get_num_tokens(model, prompt_with_empty_context)
  37. rest_tokens = llm_constant.max_context_token_length[model] - prompt_tokens - max_tokens - 1
  38. context = ''
  39. for message in messages:
  40. if not message.answer:
  41. continue
  42. message_qa_text = "Human:" + message.query + "\nAI:" + message.answer + "\n"
  43. if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0:
  44. context += message_qa_text
  45. if not context:
  46. return '[message too long, no summary]'
  47. prompt = prompt.format(context=context)
  48. llm: StreamableOpenAI = LLMBuilder.to_llm(
  49. tenant_id=tenant_id,
  50. model_name=model,
  51. max_tokens=max_tokens
  52. )
  53. if isinstance(llm, BaseChatModel):
  54. prompt = [HumanMessage(content=prompt)]
  55. response = llm.generate([prompt])
  56. answer = response.generations[0][0].text
  57. return answer.strip()
  58. @classmethod
  59. def generate_introduction(cls, tenant_id: str, pre_prompt: str):
  60. prompt = INTRODUCTION_GENERATE_PROMPT
  61. prompt = prompt.format(prompt=pre_prompt)
  62. llm: StreamableOpenAI = LLMBuilder.to_llm(
  63. tenant_id=tenant_id,
  64. model_name=generate_base_model,
  65. )
  66. if isinstance(llm, BaseChatModel):
  67. prompt = [HumanMessage(content=prompt)]
  68. response = llm.generate([prompt])
  69. answer = response.generations[0][0].text
  70. return answer.strip()
  71. @classmethod
  72. def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
  73. output_parser = SuggestedQuestionsAfterAnswerOutputParser()
  74. format_instructions = output_parser.get_format_instructions()
  75. prompt = JinjaPromptTemplate(
  76. template="{{histories}}\n{{format_instructions}}\nquestions:\n",
  77. input_variables=["histories"],
  78. partial_variables={"format_instructions": format_instructions}
  79. )
  80. _input = prompt.format_prompt(histories=histories)
  81. llm: StreamableOpenAI = LLMBuilder.to_llm(
  82. tenant_id=tenant_id,
  83. model_name='gpt-3.5-turbo',
  84. temperature=0,
  85. max_tokens=256
  86. )
  87. if isinstance(llm, BaseChatModel):
  88. query = [HumanMessage(content=_input.to_string())]
  89. else:
  90. query = _input.to_string()
  91. try:
  92. output = llm(query)
  93. if isinstance(output, BaseMessage):
  94. output = output.content
  95. questions = output_parser.parse(output)
  96. except Exception:
  97. logging.exception("Error generating suggested questions after answer")
  98. questions = []
  99. return questions
  100. @classmethod
  101. def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
  102. output_parser = RuleConfigGeneratorOutputParser()
  103. prompt = OutLinePromptTemplate(
  104. template=output_parser.get_format_instructions(),
  105. input_variables=["audiences", "hoping_to_solve"],
  106. partial_variables={
  107. "variable": '{variable}',
  108. "lanA": '{lanA}',
  109. "lanB": '{lanB}',
  110. "topic": '{topic}'
  111. },
  112. validate_template=False
  113. )
  114. _input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
  115. llm: StreamableOpenAI = LLMBuilder.to_llm(
  116. tenant_id=tenant_id,
  117. model_name=generate_base_model,
  118. temperature=0,
  119. max_tokens=512
  120. )
  121. if isinstance(llm, BaseChatModel):
  122. query = [HumanMessage(content=_input.to_string())]
  123. else:
  124. query = _input.to_string()
  125. try:
  126. output = llm(query)
  127. rule_config = output_parser.parse(output)
  128. except OutputParserException:
  129. raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
  130. except Exception:
  131. logging.exception("Error generating prompt")
  132. rule_config = {
  133. "prompt": "",
  134. "variables": [],
  135. "opening_statement": ""
  136. }
  137. return rule_config