dataset_retriever_tool.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import re
  2. from typing import Type
  3. from flask import current_app
  4. from langchain.embeddings import OpenAIEmbeddings
  5. from langchain.tools import BaseTool
  6. from pydantic import Field, BaseModel
  7. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  8. from core.embedding.cached_embedding import CacheEmbedding
  9. from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
  10. from core.index.vector_index.vector_index import VectorIndex
  11. from core.llm.llm_builder import LLMBuilder
  12. from extensions.ext_database import db
  13. from models.dataset import Dataset
  14. class DatasetRetrieverToolInput(BaseModel):
  15. dataset_id: str = Field(..., description="ID of dataset to be queried. MUST be UUID format.")
  16. query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
  17. class DatasetRetrieverTool(BaseTool):
  18. """Tool for querying a Dataset."""
  19. name: str = "dataset"
  20. args_schema: Type[BaseModel] = DatasetRetrieverToolInput
  21. description: str = "use this to retrieve a dataset. "
  22. tenant_id: str
  23. dataset_id: str
  24. k: int = 3
  25. @classmethod
  26. def from_dataset(cls, dataset: Dataset, **kwargs):
  27. description = dataset.description
  28. if not description:
  29. description = 'useful for when you want to answer queries about the ' + dataset.name
  30. description = description.replace('\n', '').replace('\r', '')
  31. description += '\nID of dataset MUST be ' + dataset.id
  32. return cls(
  33. tenant_id=dataset.tenant_id,
  34. dataset_id=dataset.id,
  35. description=description,
  36. **kwargs
  37. )
  38. def _run(self, dataset_id: str, query: str) -> str:
  39. pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
  40. match = re.search(pattern, dataset_id, re.IGNORECASE)
  41. if match:
  42. dataset_id = match.group()
  43. dataset = db.session.query(Dataset).filter(
  44. Dataset.tenant_id == self.tenant_id,
  45. Dataset.id == dataset_id
  46. ).first()
  47. if not dataset:
  48. return f'[{self.name} failed to find dataset with id {dataset_id}.]'
  49. if dataset.indexing_technique == "economy":
  50. # use keyword table query
  51. kw_table_index = KeywordTableIndex(
  52. dataset=dataset,
  53. config=KeywordTableConfig(
  54. max_keywords_per_chunk=5
  55. )
  56. )
  57. documents = kw_table_index.search(query, search_kwargs={'k': self.k})
  58. else:
  59. model_credentials = LLMBuilder.get_model_credentials(
  60. tenant_id=dataset.tenant_id,
  61. model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
  62. model_name='text-embedding-ada-002'
  63. )
  64. embeddings = CacheEmbedding(OpenAIEmbeddings(
  65. **model_credentials
  66. ))
  67. vector_index = VectorIndex(
  68. dataset=dataset,
  69. config=current_app.config,
  70. embeddings=embeddings
  71. )
  72. if self.k > 0:
  73. documents = vector_index.search(
  74. query,
  75. search_type='similarity',
  76. search_kwargs={
  77. 'k': self.k
  78. }
  79. )
  80. else:
  81. documents = []
  82. hit_callback = DatasetIndexToolCallbackHandler(dataset.id)
  83. hit_callback.on_tool_end(documents)
  84. return str("\n".join([document.page_content for document in documents]))
  85. async def _arun(self, tool_input: str) -> str:
  86. raise NotImplementedError()