1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162 |
- import os
- from langchain_community.chat_models import ChatOllama
- from langchain.prompts import ChatPromptTemplate, PromptTemplate
- from langchain_core.output_parsers import StrOutputParser
- from langchain_core.runnables import RunnablePassthrough
- from langchain.retrievers.multi_query import MultiQueryRetriever
- from llm_model.get_vector_db import get_vector_db
- LLM_MODEL = os.getenv('LLM_MODEL', 'qwen2:7b')
- # Function to get the prompt templates for generating alternative questions and answering based on context
- def get_prompt():
- QUERY_PROMPT = PromptTemplate(
- input_variables=["question"],
- template="""你是一名AI语言模型助理。你的任务是生成三个
- 从中检索相关文档的给定用户问题的不同版本
- 矢量数据库。通过对用户问题生成多个视角
- 目标是帮助用户克服基于距离的一些局限性
- 相似性搜索。请提供这些用换行符分隔的备选问题。
- Original question: {question}""",
- )
- template = """仅根据以下上下文用中文回答问题:
- {context},请严格以markdown格式输出并保障寄送格式正确无误,
- Question: {question}
- """
- # Question: {question}
- prompt = ChatPromptTemplate.from_template(template)
- return QUERY_PROMPT, prompt
- # Main function to handle the query process
- def query(input):
- if input:
- # Initialize the language model with the specified model name
- llm = ChatOllama(model=LLM_MODEL,keep_alive=-1,num_gpu=0)
- # Get the vector database instance
- db = get_vector_db()
- # Get the prompt templates
- QUERY_PROMPT, prompt = get_prompt()
- # Set up the retriever to generate multiple queries using the language model and the query prompt
- retriever = MultiQueryRetriever.from_llm(
- db.as_retriever(),
- llm,
- prompt=QUERY_PROMPT
- )
- # Define the processing chain to retrieve context, generate the answer, and parse the output
- chain = (
- {"context": retriever, "question": RunnablePassthrough()}
- | prompt
- | llm
- | StrOutputParser()
- )
- response = chain.invoke(input)
- return response
- return None
|