rerank.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from typing import List, Optional
  2. from langchain.schema import Document
  3. from core.model_manager import ModelInstance
  4. class RerankRunner:
  5. def __init__(self, rerank_model_instance: ModelInstance) -> None:
  6. self.rerank_model_instance = rerank_model_instance
  7. def run(self, query: str, documents: List[Document], score_threshold: Optional[float] = None,
  8. top_n: Optional[int] = None, user: Optional[str] = None) -> List[Document]:
  9. """
  10. Run rerank model
  11. :param query: search query
  12. :param documents: documents for reranking
  13. :param score_threshold: score threshold
  14. :param top_n: top n
  15. :param user: unique user id if needed
  16. :return:
  17. """
  18. docs = []
  19. doc_id = []
  20. unique_documents = []
  21. for document in documents:
  22. if document.metadata['doc_id'] not in doc_id:
  23. doc_id.append(document.metadata['doc_id'])
  24. docs.append(document.page_content)
  25. unique_documents.append(document)
  26. documents = unique_documents
  27. rerank_result = self.rerank_model_instance.invoke_rerank(
  28. query=query,
  29. docs=docs,
  30. score_threshold=score_threshold,
  31. top_n=top_n,
  32. user=user
  33. )
  34. rerank_documents = []
  35. for result in rerank_result.docs:
  36. # format document
  37. rerank_document = Document(
  38. page_content=result.text,
  39. metadata={
  40. "doc_id": documents[result.index].metadata['doc_id'],
  41. "doc_hash": documents[result.index].metadata['doc_hash'],
  42. "document_id": documents[result.index].metadata['document_id'],
  43. "dataset_id": documents[result.index].metadata['dataset_id'],
  44. 'score': result.score
  45. }
  46. )
  47. rerank_documents.append(rerank_document)
  48. return rerank_documents