llm_generator.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. import json
  2. import logging
  3. from langchain.schema import OutputParserException
  4. from core.model_providers.error import LLMError, ProviderTokenNotInitError
  5. from core.model_providers.model_factory import ModelFactory
  6. from core.model_providers.models.entity.message import PromptMessage, MessageType
  7. from core.model_providers.models.entity.model_params import ModelKwargs
  8. from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
  9. from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
  10. from core.prompt.prompt_template import PromptTemplateParser
  11. from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT
  12. class LLMGenerator:
  13. @classmethod
  14. def generate_conversation_name(cls, tenant_id: str, query, answer):
  15. prompt = CONVERSATION_TITLE_PROMPT
  16. if len(query) > 2000:
  17. query = query[:300] + "...[TRUNCATED]..." + query[-300:]
  18. query = query.replace("\n", " ")
  19. prompt += query + "\n"
  20. model_instance = ModelFactory.get_text_generation_model(
  21. tenant_id=tenant_id,
  22. model_kwargs=ModelKwargs(
  23. temperature=1,
  24. max_tokens=100
  25. )
  26. )
  27. prompts = [PromptMessage(content=prompt)]
  28. response = model_instance.run(prompts)
  29. answer = response.content
  30. result_dict = json.loads(answer)
  31. answer = result_dict['Your Output']
  32. return answer.strip()
  33. @classmethod
  34. def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
  35. output_parser = SuggestedQuestionsAfterAnswerOutputParser()
  36. format_instructions = output_parser.get_format_instructions()
  37. prompt_template = PromptTemplateParser(
  38. template="{{histories}}\n{{format_instructions}}\nquestions:\n"
  39. )
  40. prompt = prompt_template.format({
  41. "histories": histories,
  42. "format_instructions": format_instructions
  43. })
  44. try:
  45. model_instance = ModelFactory.get_text_generation_model(
  46. tenant_id=tenant_id,
  47. model_kwargs=ModelKwargs(
  48. max_tokens=256,
  49. temperature=0
  50. )
  51. )
  52. except ProviderTokenNotInitError:
  53. return []
  54. prompt_messages = [PromptMessage(content=prompt)]
  55. try:
  56. output = model_instance.run(prompt_messages)
  57. questions = output_parser.parse(output.content)
  58. except LLMError:
  59. questions = []
  60. except Exception as e:
  61. logging.exception(e)
  62. questions = []
  63. return questions
  64. @classmethod
  65. def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
  66. output_parser = RuleConfigGeneratorOutputParser()
  67. prompt_template = PromptTemplateParser(
  68. template=output_parser.get_format_instructions()
  69. )
  70. prompt = prompt_template.format(
  71. inputs={
  72. "audiences": audiences,
  73. "hoping_to_solve": hoping_to_solve,
  74. "variable": "{{variable}}",
  75. "lanA": "{{lanA}}",
  76. "lanB": "{{lanB}}",
  77. "topic": "{{topic}}"
  78. },
  79. remove_template_variables=False
  80. )
  81. model_instance = ModelFactory.get_text_generation_model(
  82. tenant_id=tenant_id,
  83. model_kwargs=ModelKwargs(
  84. max_tokens=512,
  85. temperature=0
  86. )
  87. )
  88. prompt_messages = [PromptMessage(content=prompt)]
  89. try:
  90. output = model_instance.run(prompt_messages)
  91. rule_config = output_parser.parse(output.content)
  92. except LLMError as e:
  93. raise e
  94. except OutputParserException:
  95. raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
  96. except Exception as e:
  97. logging.exception(e)
  98. rule_config = {
  99. "prompt": "",
  100. "variables": [],
  101. "opening_statement": ""
  102. }
  103. return rule_config
  104. @classmethod
  105. def generate_qa_document(cls, tenant_id: str, query, document_language: str):
  106. prompt = GENERATOR_QA_PROMPT.format(language=document_language)
  107. model_instance = ModelFactory.get_text_generation_model(
  108. tenant_id=tenant_id,
  109. model_kwargs=ModelKwargs(
  110. max_tokens=2000
  111. )
  112. )
  113. prompts = [
  114. PromptMessage(content=prompt, type=MessageType.SYSTEM),
  115. PromptMessage(content=query)
  116. ]
  117. response = model_instance.run(prompts)
  118. answer = response.content
  119. return answer.strip()