llm_generator.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import logging
  2. from langchain.chat_models.base import BaseChatModel
  3. from langchain.schema import HumanMessage
  4. from core.constant import llm_constant
  5. from core.llm.llm_builder import LLMBuilder
  6. from core.llm.streamable_open_ai import StreamableOpenAI
  7. from core.llm.token_calculator import TokenCalculator
  8. from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
  9. from core.prompt.prompt_template import OutLinePromptTemplate
  10. from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT
  11. # gpt-3.5-turbo works not well
  12. generate_base_model = 'text-davinci-003'
  13. class LLMGenerator:
  14. @classmethod
  15. def generate_conversation_name(cls, tenant_id: str, query, answer):
  16. prompt = CONVERSATION_TITLE_PROMPT
  17. prompt = prompt.format(query=query, answer=answer)
  18. llm: StreamableOpenAI = LLMBuilder.to_llm(
  19. tenant_id=tenant_id,
  20. model_name=generate_base_model,
  21. max_tokens=50
  22. )
  23. if isinstance(llm, BaseChatModel):
  24. prompt = [HumanMessage(content=prompt)]
  25. response = llm.generate([prompt])
  26. answer = response.generations[0][0].text
  27. return answer.strip()
  28. @classmethod
  29. def generate_conversation_summary(cls, tenant_id: str, messages):
  30. max_tokens = 200
  31. prompt = CONVERSATION_SUMMARY_PROMPT
  32. prompt_with_empty_context = prompt.format(context='')
  33. prompt_tokens = TokenCalculator.get_num_tokens(generate_base_model, prompt_with_empty_context)
  34. rest_tokens = llm_constant.max_context_token_length[generate_base_model] - prompt_tokens - max_tokens
  35. context = ''
  36. for message in messages:
  37. if not message.answer:
  38. continue
  39. message_qa_text = "Human:" + message.query + "\nAI:" + message.answer + "\n"
  40. if rest_tokens - TokenCalculator.get_num_tokens(generate_base_model, context + message_qa_text) > 0:
  41. context += message_qa_text
  42. prompt = prompt.format(context=context)
  43. llm: StreamableOpenAI = LLMBuilder.to_llm(
  44. tenant_id=tenant_id,
  45. model_name=generate_base_model,
  46. max_tokens=max_tokens
  47. )
  48. if isinstance(llm, BaseChatModel):
  49. prompt = [HumanMessage(content=prompt)]
  50. response = llm.generate([prompt])
  51. answer = response.generations[0][0].text
  52. return answer.strip()
  53. @classmethod
  54. def generate_introduction(cls, tenant_id: str, pre_prompt: str):
  55. prompt = INTRODUCTION_GENERATE_PROMPT
  56. prompt = prompt.format(prompt=pre_prompt)
  57. llm: StreamableOpenAI = LLMBuilder.to_llm(
  58. tenant_id=tenant_id,
  59. model_name=generate_base_model,
  60. )
  61. if isinstance(llm, BaseChatModel):
  62. prompt = [HumanMessage(content=prompt)]
  63. response = llm.generate([prompt])
  64. answer = response.generations[0][0].text
  65. return answer.strip()
  66. @classmethod
  67. def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
  68. output_parser = SuggestedQuestionsAfterAnswerOutputParser()
  69. format_instructions = output_parser.get_format_instructions()
  70. prompt = OutLinePromptTemplate(
  71. template="{histories}\n{format_instructions}\nquestions:\n",
  72. input_variables=["histories"],
  73. partial_variables={"format_instructions": format_instructions}
  74. )
  75. _input = prompt.format_prompt(histories=histories)
  76. llm: StreamableOpenAI = LLMBuilder.to_llm(
  77. tenant_id=tenant_id,
  78. model_name=generate_base_model,
  79. temperature=0,
  80. max_tokens=256
  81. )
  82. if isinstance(llm, BaseChatModel):
  83. query = [HumanMessage(content=_input.to_string())]
  84. else:
  85. query = _input.to_string()
  86. try:
  87. output = llm(query)
  88. questions = output_parser.parse(output)
  89. except Exception:
  90. logging.exception("Error generating suggested questions after answer")
  91. questions = []
  92. return questions