query.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import os
  2. from langchain_community.chat_models import ChatOllama
  3. from langchain.prompts import ChatPromptTemplate, PromptTemplate
  4. from langchain_core.output_parsers import StrOutputParser
  5. from langchain_core.runnables import RunnablePassthrough
  6. from langchain.retrievers.multi_query import MultiQueryRetriever
  7. from get_vector_db import get_vector_db
  8. LLM_MODEL = os.getenv('LLM_MODEL', 'qwen2:7b')
  9. # Function to get the prompt templates for generating alternative questions and answering based on context
  10. def get_prompt():
  11. QUERY_PROMPT = PromptTemplate(
  12. input_variables=["question"],
  13. template="""You are an AI language model assistant. Your task is to generate five
  14. different versions of the given user question to retrieve relevant documents from
  15. a vector database. By generating multiple perspectives on the user question, your
  16. goal is to help the user overcome some of the limitations of the distance-based
  17. similarity search. Provide these alternative questions separated by newlines.
  18. Original question: {question}""",
  19. )
  20. template = """Answer the question in Chinese based ONLY on the following context:
  21. {context}
  22. Question: {question}
  23. """
  24. prompt = ChatPromptTemplate.from_template(template)
  25. return QUERY_PROMPT, prompt
  26. # Main function to handle the query process
  27. def query(input):
  28. if input:
  29. # Initialize the language model with the specified model name
  30. llm = ChatOllama(model=LLM_MODEL)
  31. # Get the vector database instance
  32. db = get_vector_db()
  33. # Get the prompt templates
  34. QUERY_PROMPT, prompt = get_prompt()
  35. # Set up the retriever to generate multiple queries using the language model and the query prompt
  36. retriever = MultiQueryRetriever.from_llm(
  37. db.as_retriever(),
  38. llm,
  39. prompt=QUERY_PROMPT
  40. )
  41. # Define the processing chain to retrieve context, generate the answer, and parse the output
  42. chain = (
  43. {"context": retriever, "question": RunnablePassthrough()}
  44. | prompt
  45. | llm
  46. | StrOutputParser()
  47. )
  48. response = chain.invoke(input)
  49. return response
  50. return None