llm_generator.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. import json
  2. import logging
  3. from langchain.schema import OutputParserException
  4. from core.model_providers.error import LLMError, ProviderTokenNotInitError
  5. from core.model_providers.model_factory import ModelFactory
  6. from core.model_providers.models.entity.message import PromptMessage, MessageType
  7. from core.model_providers.models.entity.model_params import ModelKwargs
  8. from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
  9. from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
  10. from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate
  11. from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
  12. GENERATOR_QA_PROMPT
  13. class LLMGenerator:
  14. @classmethod
  15. def generate_conversation_name(cls, tenant_id: str, query, answer):
  16. prompt = CONVERSATION_TITLE_PROMPT
  17. if len(query) > 2000:
  18. query = query[:300] + "...[TRUNCATED]..." + query[-300:]
  19. query = query.replace("\n", " ")
  20. prompt += query + "\n"
  21. model_instance = ModelFactory.get_text_generation_model(
  22. tenant_id=tenant_id,
  23. model_kwargs=ModelKwargs(
  24. temperature=1,
  25. max_tokens=100
  26. )
  27. )
  28. prompts = [PromptMessage(content=prompt)]
  29. response = model_instance.run(prompts)
  30. answer = response.content
  31. result_dict = json.loads(answer)
  32. answer = result_dict['Your Output']
  33. return answer.strip()
  34. @classmethod
  35. def generate_conversation_summary(cls, tenant_id: str, messages):
  36. max_tokens = 200
  37. model_instance = ModelFactory.get_text_generation_model(
  38. tenant_id=tenant_id,
  39. model_kwargs=ModelKwargs(
  40. max_tokens=max_tokens
  41. )
  42. )
  43. prompt = CONVERSATION_SUMMARY_PROMPT
  44. prompt_with_empty_context = prompt.format(context='')
  45. prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
  46. max_context_token_length = model_instance.model_rules.max_tokens.max
  47. max_context_token_length = max_context_token_length if max_context_token_length else 1500
  48. rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
  49. context = ''
  50. for message in messages:
  51. if not message.answer:
  52. continue
  53. if len(message.query) > 2000:
  54. query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
  55. else:
  56. query = message.query
  57. if len(message.answer) > 2000:
  58. answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
  59. else:
  60. answer = message.answer
  61. message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
  62. if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0:
  63. context += message_qa_text
  64. if not context:
  65. return '[message too long, no summary]'
  66. prompt = prompt.format(context=context)
  67. prompts = [PromptMessage(content=prompt)]
  68. response = model_instance.run(prompts)
  69. answer = response.content
  70. return answer.strip()
  71. @classmethod
  72. def generate_introduction(cls, tenant_id: str, pre_prompt: str):
  73. prompt = INTRODUCTION_GENERATE_PROMPT
  74. prompt = prompt.format(prompt=pre_prompt)
  75. model_instance = ModelFactory.get_text_generation_model(
  76. tenant_id=tenant_id
  77. )
  78. prompts = [PromptMessage(content=prompt)]
  79. response = model_instance.run(prompts)
  80. answer = response.content
  81. return answer.strip()
  82. @classmethod
  83. def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
  84. output_parser = SuggestedQuestionsAfterAnswerOutputParser()
  85. format_instructions = output_parser.get_format_instructions()
  86. prompt = JinjaPromptTemplate(
  87. template="{{histories}}\n{{format_instructions}}\nquestions:\n",
  88. input_variables=["histories"],
  89. partial_variables={"format_instructions": format_instructions}
  90. )
  91. _input = prompt.format_prompt(histories=histories)
  92. try:
  93. model_instance = ModelFactory.get_text_generation_model(
  94. tenant_id=tenant_id,
  95. model_kwargs=ModelKwargs(
  96. max_tokens=256,
  97. temperature=0
  98. )
  99. )
  100. except ProviderTokenNotInitError:
  101. return []
  102. prompts = [PromptMessage(content=_input.to_string())]
  103. try:
  104. output = model_instance.run(prompts)
  105. questions = output_parser.parse(output.content)
  106. except LLMError:
  107. questions = []
  108. except Exception as e:
  109. logging.exception(e)
  110. questions = []
  111. return questions
  112. @classmethod
  113. def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
  114. output_parser = RuleConfigGeneratorOutputParser()
  115. prompt = OutLinePromptTemplate(
  116. template=output_parser.get_format_instructions(),
  117. input_variables=["audiences", "hoping_to_solve"],
  118. partial_variables={
  119. "variable": '{variable}',
  120. "lanA": '{lanA}',
  121. "lanB": '{lanB}',
  122. "topic": '{topic}'
  123. },
  124. validate_template=False
  125. )
  126. _input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
  127. model_instance = ModelFactory.get_text_generation_model(
  128. tenant_id=tenant_id,
  129. model_kwargs=ModelKwargs(
  130. max_tokens=512,
  131. temperature=0
  132. )
  133. )
  134. prompts = [PromptMessage(content=_input.to_string())]
  135. try:
  136. output = model_instance.run(prompts)
  137. rule_config = output_parser.parse(output.content)
  138. except LLMError as e:
  139. raise e
  140. except OutputParserException:
  141. raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
  142. except Exception as e:
  143. logging.exception(e)
  144. rule_config = {
  145. "prompt": "",
  146. "variables": [],
  147. "opening_statement": ""
  148. }
  149. return rule_config
  150. @classmethod
  151. def generate_qa_document(cls, tenant_id: str, query, document_language: str):
  152. prompt = GENERATOR_QA_PROMPT.format(language=document_language)
  153. model_instance = ModelFactory.get_text_generation_model(
  154. tenant_id=tenant_id,
  155. model_kwargs=ModelKwargs(
  156. max_tokens=2000
  157. )
  158. )
  159. prompts = [
  160. PromptMessage(content=prompt, type=MessageType.SYSTEM),
  161. PromptMessage(content=query)
  162. ]
  163. response = model_instance.run(prompts)
  164. answer = response.content
  165. return answer.strip()