dataset_retriever_tool.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import re
  2. from typing import Type
  3. from flask import current_app
  4. from langchain.tools import BaseTool
  5. from pydantic import Field, BaseModel
  6. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  7. from core.embedding.cached_embedding import CacheEmbedding
  8. from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
  9. from core.index.vector_index.vector_index import VectorIndex
  10. from core.model_providers.model_factory import ModelFactory
  11. from extensions.ext_database import db
  12. from models.dataset import Dataset, DocumentSegment
  13. class DatasetRetrieverToolInput(BaseModel):
  14. dataset_id: str = Field(..., description="ID of dataset to be queried. MUST be UUID format.")
  15. query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
  16. class DatasetRetrieverTool(BaseTool):
  17. """Tool for querying a Dataset."""
  18. name: str = "dataset"
  19. args_schema: Type[BaseModel] = DatasetRetrieverToolInput
  20. description: str = "use this to retrieve a dataset. "
  21. tenant_id: str
  22. dataset_id: str
  23. k: int = 3
  24. @classmethod
  25. def from_dataset(cls, dataset: Dataset, **kwargs):
  26. description = dataset.description
  27. if not description:
  28. description = 'useful for when you want to answer queries about the ' + dataset.name
  29. description = description.replace('\n', '').replace('\r', '')
  30. description += '\nID of dataset MUST be ' + dataset.id
  31. return cls(
  32. tenant_id=dataset.tenant_id,
  33. dataset_id=dataset.id,
  34. description=description,
  35. **kwargs
  36. )
  37. def _run(self, dataset_id: str, query: str) -> str:
  38. pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
  39. match = re.search(pattern, dataset_id, re.IGNORECASE)
  40. if match:
  41. dataset_id = match.group()
  42. dataset = db.session.query(Dataset).filter(
  43. Dataset.tenant_id == self.tenant_id,
  44. Dataset.id == dataset_id
  45. ).first()
  46. if not dataset:
  47. return f'[{self.name} failed to find dataset with id {dataset_id}.]'
  48. if dataset.indexing_technique == "economy":
  49. # use keyword table query
  50. kw_table_index = KeywordTableIndex(
  51. dataset=dataset,
  52. config=KeywordTableConfig(
  53. max_keywords_per_chunk=5
  54. )
  55. )
  56. documents = kw_table_index.search(query, search_kwargs={'k': self.k})
  57. return str("\n".join([document.page_content for document in documents]))
  58. else:
  59. embedding_model = ModelFactory.get_embedding_model(
  60. tenant_id=dataset.tenant_id
  61. )
  62. embeddings = CacheEmbedding(embedding_model)
  63. vector_index = VectorIndex(
  64. dataset=dataset,
  65. config=current_app.config,
  66. embeddings=embeddings
  67. )
  68. if self.k > 0:
  69. documents = vector_index.search(
  70. query,
  71. search_type='similarity',
  72. search_kwargs={
  73. 'k': self.k
  74. }
  75. )
  76. else:
  77. documents = []
  78. hit_callback = DatasetIndexToolCallbackHandler(dataset.id)
  79. hit_callback.on_tool_end(documents)
  80. document_context_list = []
  81. index_node_ids = [document.metadata['doc_id'] for document in documents]
  82. segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
  83. DocumentSegment.status == 'completed',
  84. DocumentSegment.enabled == True,
  85. DocumentSegment.index_node_id.in_(index_node_ids)
  86. ).all()
  87. if segments:
  88. index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
  89. sorted_segments = sorted(segments,
  90. key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
  91. float('inf')))
  92. for segment in sorted_segments:
  93. if segment.answer:
  94. document_context_list.append(f'question:{segment.content} \nanswer:{segment.answer}')
  95. else:
  96. document_context_list.append(segment.content)
  97. return str("\n".join(document_context_list))
  98. async def _arun(self, tool_input: str) -> str:
  99. raise NotImplementedError()