llm_generator.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import logging
  2. from langchain.schema import OutputParserException
  3. from core.model_providers.error import LLMError, ProviderTokenNotInitError
  4. from core.model_providers.model_factory import ModelFactory
  5. from core.model_providers.models.entity.message import PromptMessage, MessageType
  6. from core.model_providers.models.entity.model_params import ModelKwargs
  7. from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
  8. from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
  9. from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate
  10. from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
  11. GENERATOR_QA_PROMPT
  12. class LLMGenerator:
  13. @classmethod
  14. def generate_conversation_name(cls, tenant_id: str, query, answer):
  15. prompt = CONVERSATION_TITLE_PROMPT
  16. if len(query) > 2000:
  17. query = query[:300] + "...[TRUNCATED]..." + query[-300:]
  18. prompt = prompt.format(query=query)
  19. model_instance = ModelFactory.get_text_generation_model(
  20. tenant_id=tenant_id,
  21. model_kwargs=ModelKwargs(
  22. max_tokens=50
  23. )
  24. )
  25. prompts = [PromptMessage(content=prompt)]
  26. response = model_instance.run(prompts)
  27. answer = response.content
  28. return answer.strip()
  29. @classmethod
  30. def generate_conversation_summary(cls, tenant_id: str, messages):
  31. max_tokens = 200
  32. model_instance = ModelFactory.get_text_generation_model(
  33. tenant_id=tenant_id,
  34. model_kwargs=ModelKwargs(
  35. max_tokens=max_tokens
  36. )
  37. )
  38. prompt = CONVERSATION_SUMMARY_PROMPT
  39. prompt_with_empty_context = prompt.format(context='')
  40. prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
  41. max_context_token_length = model_instance.model_rules.max_tokens.max
  42. max_context_token_length = max_context_token_length if max_context_token_length else 1500
  43. rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
  44. context = ''
  45. for message in messages:
  46. if not message.answer:
  47. continue
  48. if len(message.query) > 2000:
  49. query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
  50. else:
  51. query = message.query
  52. if len(message.answer) > 2000:
  53. answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
  54. else:
  55. answer = message.answer
  56. message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
  57. if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0:
  58. context += message_qa_text
  59. if not context:
  60. return '[message too long, no summary]'
  61. prompt = prompt.format(context=context)
  62. prompts = [PromptMessage(content=prompt)]
  63. response = model_instance.run(prompts)
  64. answer = response.content
  65. return answer.strip()
  66. @classmethod
  67. def generate_introduction(cls, tenant_id: str, pre_prompt: str):
  68. prompt = INTRODUCTION_GENERATE_PROMPT
  69. prompt = prompt.format(prompt=pre_prompt)
  70. model_instance = ModelFactory.get_text_generation_model(
  71. tenant_id=tenant_id
  72. )
  73. prompts = [PromptMessage(content=prompt)]
  74. response = model_instance.run(prompts)
  75. answer = response.content
  76. return answer.strip()
  77. @classmethod
  78. def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
  79. output_parser = SuggestedQuestionsAfterAnswerOutputParser()
  80. format_instructions = output_parser.get_format_instructions()
  81. prompt = JinjaPromptTemplate(
  82. template="{{histories}}\n{{format_instructions}}\nquestions:\n",
  83. input_variables=["histories"],
  84. partial_variables={"format_instructions": format_instructions}
  85. )
  86. _input = prompt.format_prompt(histories=histories)
  87. try:
  88. model_instance = ModelFactory.get_text_generation_model(
  89. tenant_id=tenant_id,
  90. model_kwargs=ModelKwargs(
  91. max_tokens=256,
  92. temperature=0
  93. )
  94. )
  95. except ProviderTokenNotInitError:
  96. return []
  97. prompts = [PromptMessage(content=_input.to_string())]
  98. try:
  99. output = model_instance.run(prompts)
  100. questions = output_parser.parse(output.content)
  101. except LLMError:
  102. questions = []
  103. except Exception as e:
  104. logging.exception(e)
  105. questions = []
  106. return questions
  107. @classmethod
  108. def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
  109. output_parser = RuleConfigGeneratorOutputParser()
  110. prompt = OutLinePromptTemplate(
  111. template=output_parser.get_format_instructions(),
  112. input_variables=["audiences", "hoping_to_solve"],
  113. partial_variables={
  114. "variable": '{variable}',
  115. "lanA": '{lanA}',
  116. "lanB": '{lanB}',
  117. "topic": '{topic}'
  118. },
  119. validate_template=False
  120. )
  121. _input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
  122. model_instance = ModelFactory.get_text_generation_model(
  123. tenant_id=tenant_id,
  124. model_kwargs=ModelKwargs(
  125. max_tokens=512,
  126. temperature=0
  127. )
  128. )
  129. prompts = [PromptMessage(content=_input.to_string())]
  130. try:
  131. output = model_instance.run(prompts)
  132. rule_config = output_parser.parse(output.content)
  133. except LLMError as e:
  134. raise e
  135. except OutputParserException:
  136. raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
  137. except Exception as e:
  138. logging.exception(e)
  139. rule_config = {
  140. "prompt": "",
  141. "variables": [],
  142. "opening_statement": ""
  143. }
  144. return rule_config
  145. @classmethod
  146. def generate_qa_document(cls, tenant_id: str, query, document_language: str):
  147. prompt = GENERATOR_QA_PROMPT.format(language=document_language)
  148. model_instance = ModelFactory.get_text_generation_model(
  149. tenant_id=tenant_id,
  150. model_kwargs=ModelKwargs(
  151. max_tokens=2000
  152. )
  153. )
  154. prompts = [
  155. PromptMessage(content=prompt, type=MessageType.SYSTEM),
  156. PromptMessage(content=query)
  157. ]
  158. response = model_instance.run(prompts)
  159. answer = response.content
  160. return answer.strip()