query.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import os
  2. from langchain_community.chat_models import ChatOllama
  3. from langchain.prompts import ChatPromptTemplate, PromptTemplate
  4. from langchain_core.output_parsers import StrOutputParser
  5. from langchain_core.runnables import RunnablePassthrough
  6. from langchain.retrievers.multi_query import MultiQueryRetriever
  7. from llm_model.get_vector_db import get_vector_db
  8. LLM_MODEL = os.getenv('LLM_MODEL', 'qwen2:7b')
  9. # Function to get the prompt templates for generating alternative questions and answering based on context
  10. def get_prompt():
  11. QUERY_PROMPT = PromptTemplate(
  12. input_variables=["question"],
  13. template="""你是一名AI语言模型助理。你的任务是生成三个
  14. 从中检索相关文档的给定用户问题的不同版本
  15. 矢量数据库。通过对用户问题生成多个视角
  16. 目标是帮助用户克服基于距离的一些局限性
  17. 相似性搜索。请提供这些用换行符分隔的备选问题。
  18. Original question: {question}""",
  19. )
  20. template = """仅根据以下上下文用中文回答问题:
  21. {context},请严格以markdown格式输出并保障寄送格式正确无误,
  22. Question: {question}
  23. """
  24. # Question: {question}
  25. prompt = ChatPromptTemplate.from_template(template)
  26. return QUERY_PROMPT, prompt
  27. # Main function to handle the query process
  28. def query(input):
  29. if input:
  30. # Initialize the language model with the specified model name
  31. llm = ChatOllama(model=LLM_MODEL,keep_alive=-1,num_gpu=0)
  32. # Get the vector database instance
  33. db = get_vector_db()
  34. # Get the prompt templates
  35. QUERY_PROMPT, prompt = get_prompt()
  36. # Set up the retriever to generate multiple queries using the language model and the query prompt
  37. retriever = MultiQueryRetriever.from_llm(
  38. db.as_retriever(),
  39. llm,
  40. prompt=QUERY_PROMPT
  41. )
  42. # Define the processing chain to retrieve context, generate the answer, and parse the output
  43. chain = (
  44. {"context": retriever, "question": RunnablePassthrough()}
  45. | prompt
  46. | llm
  47. | StrOutputParser()
  48. )
  49. response = chain.invoke(input)
  50. return response
  51. return None