llm_generator.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import json
  2. import logging
  3. from langchain.schema import OutputParserException
  4. from core.model_manager import ModelManager
  5. from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
  6. from core.model_runtime.entities.model_entities import ModelType
  7. from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
  8. from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
  9. from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
  10. from core.prompt.prompt_template import PromptTemplateParser
  11. from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT
  12. class LLMGenerator:
  13. @classmethod
  14. def generate_conversation_name(cls, tenant_id: str, query):
  15. prompt = CONVERSATION_TITLE_PROMPT
  16. if len(query) > 2000:
  17. query = query[:300] + "...[TRUNCATED]..." + query[-300:]
  18. query = query.replace("\n", " ")
  19. prompt += query + "\n"
  20. model_manager = ModelManager()
  21. model_instance = model_manager.get_default_model_instance(
  22. tenant_id=tenant_id,
  23. model_type=ModelType.LLM,
  24. )
  25. prompts = [UserPromptMessage(content=prompt)]
  26. response = model_instance.invoke_llm(
  27. prompt_messages=prompts,
  28. model_parameters={
  29. "max_tokens": 100,
  30. "temperature": 1
  31. },
  32. stream=False
  33. )
  34. answer = response.message.content
  35. result_dict = json.loads(answer)
  36. answer = result_dict['Your Output']
  37. name = answer.strip()
  38. if len(name) > 75:
  39. name = name[:75] + '...'
  40. return name
  41. @classmethod
  42. def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
  43. output_parser = SuggestedQuestionsAfterAnswerOutputParser()
  44. format_instructions = output_parser.get_format_instructions()
  45. prompt_template = PromptTemplateParser(
  46. template="{{histories}}\n{{format_instructions}}\nquestions:\n"
  47. )
  48. prompt = prompt_template.format({
  49. "histories": histories,
  50. "format_instructions": format_instructions
  51. })
  52. try:
  53. model_manager = ModelManager()
  54. model_instance = model_manager.get_default_model_instance(
  55. tenant_id=tenant_id,
  56. model_type=ModelType.LLM,
  57. )
  58. except InvokeAuthorizationError:
  59. return []
  60. prompt_messages = [UserPromptMessage(content=prompt)]
  61. try:
  62. response = model_instance.invoke_llm(
  63. prompt_messages=prompt_messages,
  64. model_parameters={
  65. "max_tokens": 256,
  66. "temperature": 0
  67. },
  68. stream=False
  69. )
  70. questions = output_parser.parse(response.message.content)
  71. except InvokeError:
  72. questions = []
  73. except Exception as e:
  74. logging.exception(e)
  75. questions = []
  76. return questions
  77. @classmethod
  78. def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
  79. output_parser = RuleConfigGeneratorOutputParser()
  80. prompt_template = PromptTemplateParser(
  81. template=output_parser.get_format_instructions()
  82. )
  83. prompt = prompt_template.format(
  84. inputs={
  85. "audiences": audiences,
  86. "hoping_to_solve": hoping_to_solve,
  87. "variable": "{{variable}}",
  88. "lanA": "{{lanA}}",
  89. "lanB": "{{lanB}}",
  90. "topic": "{{topic}}"
  91. },
  92. remove_template_variables=False
  93. )
  94. model_manager = ModelManager()
  95. model_instance = model_manager.get_default_model_instance(
  96. tenant_id=tenant_id,
  97. model_type=ModelType.LLM,
  98. )
  99. prompt_messages = [UserPromptMessage(content=prompt)]
  100. try:
  101. response = model_instance.invoke_llm(
  102. prompt_messages=prompt_messages,
  103. model_parameters={
  104. "max_tokens": 512,
  105. "temperature": 0
  106. },
  107. stream=False
  108. )
  109. rule_config = output_parser.parse(response.message.content)
  110. except InvokeError as e:
  111. raise e
  112. except OutputParserException:
  113. raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
  114. except Exception as e:
  115. logging.exception(e)
  116. rule_config = {
  117. "prompt": "",
  118. "variables": [],
  119. "opening_statement": ""
  120. }
  121. return rule_config
  122. @classmethod
  123. def generate_qa_document(cls, tenant_id: str, query, document_language: str):
  124. prompt = GENERATOR_QA_PROMPT.format(language=document_language)
  125. model_manager = ModelManager()
  126. model_instance = model_manager.get_default_model_instance(
  127. tenant_id=tenant_id,
  128. model_type=ModelType.LLM,
  129. )
  130. prompt_messages = [
  131. SystemPromptMessage(content=prompt),
  132. UserPromptMessage(content=query)
  133. ]
  134. response = model_instance.invoke_llm(
  135. prompt_messages=prompt_messages,
  136. model_parameters={
  137. "max_tokens": 2000
  138. },
  139. stream=False
  140. )
  141. answer = response.message.content
  142. return answer.strip()