import json import logging from langchain.schema import OutputParserException from core.model_providers.error import LLMError, ProviderTokenNotInitError from core.model_providers.model_factory import ModelFactory from core.model_providers.models.entity.message import PromptMessage, MessageType from core.model_providers.models.entity.model_params import ModelKwargs from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \ GENERATOR_QA_PROMPT class LLMGenerator: @classmethod def generate_conversation_name(cls, tenant_id: str, query, answer): prompt = CONVERSATION_TITLE_PROMPT if len(query) > 2000: query = query[:300] + "...[TRUNCATED]..." + query[-300:] query = query.replace("\n", " ") prompt += query + "\n" model_instance = ModelFactory.get_text_generation_model( tenant_id=tenant_id, model_kwargs=ModelKwargs( temperature=1, max_tokens=100 ) ) prompts = [PromptMessage(content=prompt)] response = model_instance.run(prompts) answer = response.content result_dict = json.loads(answer) answer = result_dict['Your Output'] return answer.strip() @classmethod def generate_conversation_summary(cls, tenant_id: str, messages): max_tokens = 200 model_instance = ModelFactory.get_text_generation_model( tenant_id=tenant_id, model_kwargs=ModelKwargs( max_tokens=max_tokens ) ) prompt = CONVERSATION_SUMMARY_PROMPT prompt_with_empty_context = prompt.format(context='') prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)]) max_context_token_length = model_instance.model_rules.max_tokens.max max_context_token_length = max_context_token_length if max_context_token_length else 1500 rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1 context = '' for message in messages: if not message.answer: continue if len(message.query) > 2000: query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:] else: query = message.query if len(message.answer) > 2000: answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:] else: answer = message.answer message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0: context += message_qa_text if not context: return '[message too long, no summary]' prompt = prompt.format(context=context) prompts = [PromptMessage(content=prompt)] response = model_instance.run(prompts) answer = response.content return answer.strip() @classmethod def generate_introduction(cls, tenant_id: str, pre_prompt: str): prompt = INTRODUCTION_GENERATE_PROMPT prompt = prompt.format(prompt=pre_prompt) model_instance = ModelFactory.get_text_generation_model( tenant_id=tenant_id ) prompts = [PromptMessage(content=prompt)] response = model_instance.run(prompts) answer = response.content return answer.strip() @classmethod def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str): output_parser = SuggestedQuestionsAfterAnswerOutputParser() format_instructions = output_parser.get_format_instructions() prompt = JinjaPromptTemplate( template="{{histories}}\n{{format_instructions}}\nquestions:\n", input_variables=["histories"], partial_variables={"format_instructions": format_instructions} ) _input = prompt.format_prompt(histories=histories) try: model_instance = ModelFactory.get_text_generation_model( tenant_id=tenant_id, model_kwargs=ModelKwargs( max_tokens=256, temperature=0 ) ) except ProviderTokenNotInitError: return [] prompts = [PromptMessage(content=_input.to_string())] try: output = model_instance.run(prompts) questions = output_parser.parse(output.content) except LLMError: questions = [] except Exception as e: logging.exception(e) questions = [] return questions @classmethod def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict: output_parser = RuleConfigGeneratorOutputParser() prompt = OutLinePromptTemplate( template=output_parser.get_format_instructions(), input_variables=["audiences", "hoping_to_solve"], partial_variables={ "variable": '{variable}', "lanA": '{lanA}', "lanB": '{lanB}', "topic": '{topic}' }, validate_template=False ) _input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve) model_instance = ModelFactory.get_text_generation_model( tenant_id=tenant_id, model_kwargs=ModelKwargs( max_tokens=512, temperature=0 ) ) prompts = [PromptMessage(content=_input.to_string())] try: output = model_instance.run(prompts) rule_config = output_parser.parse(output.content) except LLMError as e: raise e except OutputParserException: raise ValueError('Please give a valid input for intended audience or hoping to solve problems.') except Exception as e: logging.exception(e) rule_config = { "prompt": "", "variables": [], "opening_statement": "" } return rule_config @classmethod def generate_qa_document(cls, tenant_id: str, query, document_language: str): prompt = GENERATOR_QA_PROMPT.format(language=document_language) model_instance = ModelFactory.get_text_generation_model( tenant_id=tenant_id, model_kwargs=ModelKwargs( max_tokens=2000 ) ) prompts = [ PromptMessage(content=prompt, type=MessageType.SYSTEM), PromptMessage(content=query) ] response = model_instance.run(prompts) answer = response.content return answer.strip()