retrieval_service.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. from typing import Optional
  2. from flask import Flask, current_app
  3. from langchain.embeddings.base import Embeddings
  4. from core.index.vector_index.vector_index import VectorIndex
  5. from core.model_manager import ModelManager
  6. from core.model_runtime.entities.model_entities import ModelType
  7. from core.model_runtime.errors.invoke import InvokeAuthorizationError
  8. from core.rerank.rerank import RerankRunner
  9. from extensions.ext_database import db
  10. from models.dataset import Dataset
  11. default_retrieval_model = {
  12. 'search_method': 'semantic_search',
  13. 'reranking_enable': False,
  14. 'reranking_model': {
  15. 'reranking_provider_name': '',
  16. 'reranking_model_name': ''
  17. },
  18. 'top_k': 2,
  19. 'score_threshold_enabled': False
  20. }
  21. class RetrievalService:
  22. @classmethod
  23. def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
  24. top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
  25. all_documents: list, search_method: str, embeddings: Embeddings):
  26. with flask_app.app_context():
  27. dataset = db.session.query(Dataset).filter(
  28. Dataset.id == dataset_id
  29. ).first()
  30. vector_index = VectorIndex(
  31. dataset=dataset,
  32. config=current_app.config,
  33. embeddings=embeddings
  34. )
  35. documents = vector_index.search(
  36. query,
  37. search_type='similarity_score_threshold',
  38. search_kwargs={
  39. 'k': top_k,
  40. 'score_threshold': score_threshold,
  41. 'filter': {
  42. 'group_id': [dataset.id]
  43. }
  44. }
  45. )
  46. if documents:
  47. if reranking_model and search_method == 'semantic_search':
  48. try:
  49. model_manager = ModelManager()
  50. rerank_model_instance = model_manager.get_model_instance(
  51. tenant_id=dataset.tenant_id,
  52. provider=reranking_model['reranking_provider_name'],
  53. model_type=ModelType.RERANK,
  54. model=reranking_model['reranking_model_name']
  55. )
  56. except InvokeAuthorizationError:
  57. return
  58. rerank_runner = RerankRunner(rerank_model_instance)
  59. all_documents.extend(rerank_runner.run(
  60. query=query,
  61. documents=documents,
  62. score_threshold=score_threshold,
  63. top_n=len(documents)
  64. ))
  65. else:
  66. all_documents.extend(documents)
  67. @classmethod
  68. def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
  69. top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
  70. all_documents: list, search_method: str, embeddings: Embeddings):
  71. with flask_app.app_context():
  72. dataset = db.session.query(Dataset).filter(
  73. Dataset.id == dataset_id
  74. ).first()
  75. vector_index = VectorIndex(
  76. dataset=dataset,
  77. config=current_app.config,
  78. embeddings=embeddings
  79. )
  80. documents = vector_index.search_by_full_text_index(
  81. query,
  82. search_type='similarity_score_threshold',
  83. top_k=top_k
  84. )
  85. if documents:
  86. if reranking_model and search_method == 'full_text_search':
  87. try:
  88. model_manager = ModelManager()
  89. rerank_model_instance = model_manager.get_model_instance(
  90. tenant_id=dataset.tenant_id,
  91. provider=reranking_model['reranking_provider_name'],
  92. model_type=ModelType.RERANK,
  93. model=reranking_model['reranking_model_name']
  94. )
  95. except InvokeAuthorizationError:
  96. return
  97. rerank_runner = RerankRunner(rerank_model_instance)
  98. all_documents.extend(rerank_runner.run(
  99. query=query,
  100. documents=documents,
  101. score_threshold=score_threshold,
  102. top_n=len(documents)
  103. ))
  104. else:
  105. all_documents.extend(documents)