| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220 | import jsonimport threadingfrom typing import Type, Optional, Listfrom flask import current_appfrom langchain.tools import BaseToolfrom pydantic import Field, BaseModelfrom core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandlerfrom core.conversation_message_task import ConversationMessageTaskfrom core.embedding.cached_embedding import CacheEmbeddingfrom core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfigfrom core.index.vector_index.vector_index import VectorIndexfrom core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitErrorfrom core.model_providers.model_factory import ModelFactoryfrom extensions.ext_database import dbfrom models.dataset import Dataset, DocumentSegment, Documentfrom services.retrieval_service import RetrievalServicedefault_retrieval_model = {    'search_method': 'semantic_search',    'reranking_enable': False,    'reranking_model': {        'reranking_provider_name': '',        'reranking_model_name': ''    },    'top_k': 2,    'score_threshold_enabled': False}class DatasetRetrieverToolInput(BaseModel):    query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")class DatasetRetrieverTool(BaseTool):    """Tool for querying a Dataset."""    name: str = "dataset"    args_schema: Type[BaseModel] = DatasetRetrieverToolInput    description: str = "use this to retrieve a dataset. "    tenant_id: str    dataset_id: str    top_k: int = 2    score_threshold: Optional[float] = None    conversation_message_task: ConversationMessageTask    return_resource: bool    retriever_from: str    @classmethod    def from_dataset(cls, dataset: Dataset, **kwargs):        description = dataset.description        if not description:            description = 'useful for when you want to answer queries about the ' + dataset.name        description = description.replace('\n', '').replace('\r', '')        return cls(            name=f'dataset-{dataset.id}',            tenant_id=dataset.tenant_id,            dataset_id=dataset.id,            description=description,            **kwargs        )    def _run(self, query: str) -> str:        dataset = db.session.query(Dataset).filter(            Dataset.tenant_id == self.tenant_id,            Dataset.id == self.dataset_id        ).first()        if not dataset:            return ''        # get retrieval model , if the model is not setting , using default        retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model        if dataset.indexing_technique == "economy":            # use keyword table query            kw_table_index = KeywordTableIndex(                dataset=dataset,                config=KeywordTableConfig(                    max_keywords_per_chunk=5                )            )            documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})            return str("\n".join([document.page_content for document in documents]))        else:            try:                embedding_model = ModelFactory.get_embedding_model(                    tenant_id=dataset.tenant_id,                    model_provider_name=dataset.embedding_model_provider,                    model_name=dataset.embedding_model                )            except LLMBadRequestError:                return ''            except ProviderTokenNotInitError:                return ''            embeddings = CacheEmbedding(embedding_model)            documents = []            threads = []            if self.top_k > 0:                # retrieval source with semantic                if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':                    embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={                        'flask_app': current_app._get_current_object(),                        'dataset_id': str(dataset.id),                        'query': query,                        'top_k': self.top_k,                        'score_threshold': retrieval_model['score_threshold'] if retrieval_model[                            'score_threshold_enabled'] else None,                        'reranking_model': retrieval_model['reranking_model'] if retrieval_model[                            'reranking_enable'] else None,                        'all_documents': documents,                        'search_method': retrieval_model['search_method'],                        'embeddings': embeddings                    })                    threads.append(embedding_thread)                    embedding_thread.start()                # retrieval_model source with full text                if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':                    full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={                        'flask_app': current_app._get_current_object(),                        'dataset_id': str(dataset.id),                        'query': query,                        'search_method': retrieval_model['search_method'],                        'embeddings': embeddings,                        'score_threshold': retrieval_model['score_threshold'] if retrieval_model[                            'score_threshold_enabled'] else None,                        'top_k': self.top_k,                        'reranking_model': retrieval_model['reranking_model'] if retrieval_model[                            'reranking_enable'] else None,                        'all_documents': documents                    })                    threads.append(full_text_index_thread)                    full_text_index_thread.start()                for thread in threads:                    thread.join()                # hybrid search: rerank after all documents have been searched                if retrieval_model['search_method'] == 'hybrid_search':                    hybrid_rerank = ModelFactory.get_reranking_model(                        tenant_id=dataset.tenant_id,                        model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'],                        model_name=retrieval_model['reranking_model']['reranking_model_name']                    )                    documents = hybrid_rerank.rerank(query, documents,                                                     retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,                                                     self.top_k)            else:                documents = []            hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task)            hit_callback.on_tool_end(documents)            document_score_list = {}            if dataset.indexing_technique != "economy":                for item in documents:                    if 'score' in item.metadata and item.metadata['score']:                        document_score_list[item.metadata['doc_id']] = item.metadata['score']            document_context_list = []            index_node_ids = [document.metadata['doc_id'] for document in documents]            segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id,                                                    DocumentSegment.completed_at.isnot(None),                                                    DocumentSegment.status == 'completed',                                                    DocumentSegment.enabled == True,                                                    DocumentSegment.index_node_id.in_(index_node_ids)                                                    ).all()            if segments:                index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}                sorted_segments = sorted(segments,                                         key=lambda segment: index_node_id_to_position.get(segment.index_node_id,                                                                                           float('inf')))                for segment in sorted_segments:                    if segment.answer:                        document_context_list.append(f'question:{segment.content} answer:{segment.answer}')                    else:                        document_context_list.append(segment.content)                if self.return_resource:                    context_list = []                    resource_number = 1                    for segment in sorted_segments:                        context = {}                        document = Document.query.filter(Document.id == segment.document_id,                                                         Document.enabled == True,                                                         Document.archived == False,                                                         ).first()                        if dataset and document:                            source = {                                'position': resource_number,                                'dataset_id': dataset.id,                                'dataset_name': dataset.name,                                'document_id': document.id,                                'document_name': document.name,                                'data_source_type': document.data_source_type,                                'segment_id': segment.id,                                'retriever_from': self.retriever_from,                                'score': document_score_list.get(segment.index_node_id, None)                            }                            if self.retriever_from == 'dev':                                source['hit_count'] = segment.hit_count                                source['word_count'] = segment.word_count                                source['segment_position'] = segment.position                                source['index_node_hash'] = segment.index_node_hash                            if segment.answer:                                source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'                            else:                                source['content'] = segment.content                            context_list.append(source)                        resource_number += 1                    hit_callback.return_retriever_resource_info(context_list)            return str("\n".join(document_context_list))    async def _arun(self, tool_input: str) -> str:        raise NotImplementedError()
 |