Browse Source

Fix/agent external knowledge retrieval (#9241)

Jyong 6 months ago
parent
commit
42b02b3a5f

+ 17 - 0
api/configs/middleware/__init__.py

@@ -191,6 +191,22 @@ class CeleryConfig(DatabaseConfig):
         return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False
 
 
+class InternalTestConfig(BaseSettings):
+    """
+    Configuration settings for Internal Test
+    """
+
+    AWS_SECRET_ACCESS_KEY: Optional[str] = Field(
+        description="Internal test AWS secret access key",
+        default=None,
+    )
+
+    AWS_ACCESS_KEY_ID: Optional[str] = Field(
+        description="Internal test AWS access key ID",
+        default=None,
+    )
+
+
 class MiddlewareConfig(
     # place the configs in alphabet order
     CeleryConfig,
@@ -224,5 +240,6 @@ class MiddlewareConfig(
     TiDBVectorConfig,
     WeaviateConfig,
     ElasticsearchConfig,
+    InternalTestConfig,
 ):
     pass

+ 24 - 0
api/controllers/console/datasets/external.py

@@ -13,6 +13,7 @@ from libs.login import login_required
 from services.dataset_service import DatasetService
 from services.external_knowledge_service import ExternalDatasetService
 from services.hit_testing_service import HitTestingService
+from services.knowledge_service import ExternalDatasetTestService
 
 
 def _validate_name(name):
@@ -232,8 +233,31 @@ class ExternalKnowledgeHitTestingApi(Resource):
             raise InternalServerError(str(e))
 
 
+class BedrockRetrievalApi(Resource):
+    # this api is only for internal testing
+    def post(self):
+        parser = reqparse.RequestParser()
+        parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
+        parser.add_argument(
+            "query",
+            nullable=False,
+            required=True,
+            type=str,
+        )
+        parser.add_argument("knowledge_id", nullable=False, required=True, type=str)
+        args = parser.parse_args()
+
+        # Call the knowledge retrieval service
+        result = ExternalDatasetTestService.knowledge_retrieval(
+            args["retrieval_setting"], args["query"], args["knowledge_id"]
+        )
+        return result, 200
+
+
 api.add_resource(ExternalKnowledgeHitTestingApi, "/datasets/<uuid:dataset_id>/external-hit-testing")
 api.add_resource(ExternalDatasetCreateApi, "/datasets/external")
 api.add_resource(ExternalApiTemplateListApi, "/datasets/external-knowledge-api")
 api.add_resource(ExternalApiTemplateApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>")
 api.add_resource(ExternalApiUseCheckApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check")
+# this api is only for internal test
+api.add_resource(BedrockRetrievalApi, "/test/retrieval")

+ 1 - 1
api/core/rag/retrieval/dataset_retrieval.py

@@ -539,7 +539,7 @@ class DatasetRetrieval:
                 continue
 
             # pass if dataset is not available
-            if dataset and dataset.available_document_count == 0:
+            if dataset and dataset.provider != "external" and dataset.available_document_count == 0:
                 continue
 
             available_datasets.append(dataset)

+ 130 - 88
api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py

@@ -1,10 +1,12 @@
 from pydantic import BaseModel, Field
 
 from core.rag.datasource.retrieval_service import RetrievalService
+from core.rag.models.document import Document as RetrievalDocument
 from core.rag.retrieval.retrieval_methods import RetrievalMethod
 from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
 from extensions.ext_database import db
 from models.dataset import Dataset, Document, DocumentSegment
+from services.external_knowledge_service import ExternalDatasetService
 
 default_retrieval_model = {
     "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
@@ -53,97 +55,137 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
 
         for hit_callback in self.hit_callbacks:
             hit_callback.on_query(query, dataset.id)
-
-        # get retrieval model , if the model is not setting , using default
-        retrieval_model = dataset.retrieval_model or default_retrieval_model
-        if dataset.indexing_technique == "economy":
-            # use keyword table query
-            documents = RetrievalService.retrieve(
-                retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k
+        if dataset.provider == "external":
+            results = []
+            external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
+                tenant_id=dataset.tenant_id,
+                dataset_id=dataset.id,
+                query=query,
+                external_retrieval_parameters=dataset.retrieval_model,
             )
-            return str("\n".join([document.page_content for document in documents]))
+            for external_document in external_documents:
+                document = RetrievalDocument(
+                    page_content=external_document.get("content"),
+                    metadata=external_document.get("metadata"),
+                    provider="external",
+                )
+                document.metadata["score"] = external_document.get("score")
+                document.metadata["title"] = external_document.get("title")
+                document.metadata["dataset_id"] = dataset.id
+                document.metadata["dataset_name"] = dataset.name
+                results.append(document)
+            # deal with external documents
+            context_list = []
+            for position, item in enumerate(results, start=1):
+                source = {
+                    "position": position,
+                    "dataset_id": item.metadata.get("dataset_id"),
+                    "dataset_name": item.metadata.get("dataset_name"),
+                    "document_name": item.metadata.get("title"),
+                    "data_source_type": "external",
+                    "retriever_from": self.retriever_from,
+                    "score": item.metadata.get("score"),
+                    "title": item.metadata.get("title"),
+                    "content": item.page_content,
+                }
+                context_list.append(source)
+            for hit_callback in self.hit_callbacks:
+                hit_callback.return_retriever_resource_info(context_list)
+
+            return str("\n".join([item.page_content for item in results]))
         else:
-            if self.top_k > 0:
-                # retrieval source
+            # get retrieval model , if the model is not setting , using default
+            retrieval_model = dataset.retrieval_model or default_retrieval_model
+            if dataset.indexing_technique == "economy":
+                # use keyword table query
                 documents = RetrievalService.retrieve(
-                    retrieval_method=retrieval_model.get("search_method", "semantic_search"),
-                    dataset_id=dataset.id,
-                    query=query,
-                    top_k=self.top_k,
-                    score_threshold=retrieval_model.get("score_threshold", 0.0)
-                    if retrieval_model["score_threshold_enabled"]
-                    else 0.0,
-                    reranking_model=retrieval_model.get("reranking_model", None)
-                    if retrieval_model["reranking_enable"]
-                    else None,
-                    reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
-                    weights=retrieval_model.get("weights", None),
+                    retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k
                 )
+                return str("\n".join([document.page_content for document in documents]))
             else:
-                documents = []
-
-            for hit_callback in self.hit_callbacks:
-                hit_callback.on_tool_end(documents)
-            document_score_list = {}
-            if dataset.indexing_technique != "economy":
-                for item in documents:
-                    if item.metadata.get("score"):
-                        document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
-            document_context_list = []
-            index_node_ids = [document.metadata["doc_id"] for document in documents]
-            segments = DocumentSegment.query.filter(
-                DocumentSegment.dataset_id == self.dataset_id,
-                DocumentSegment.completed_at.isnot(None),
-                DocumentSegment.status == "completed",
-                DocumentSegment.enabled == True,
-                DocumentSegment.index_node_id.in_(index_node_ids),
-            ).all()
-
-            if segments:
-                index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
-                sorted_segments = sorted(
-                    segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
-                )
-                for segment in sorted_segments:
-                    if segment.answer:
-                        document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}")
-                    else:
-                        document_context_list.append(segment.get_sign_content())
-                if self.return_resource:
-                    context_list = []
-                    resource_number = 1
+                if self.top_k > 0:
+                    # retrieval source
+                    documents = RetrievalService.retrieve(
+                        retrieval_method=retrieval_model.get("search_method", "semantic_search"),
+                        dataset_id=dataset.id,
+                        query=query,
+                        top_k=self.top_k,
+                        score_threshold=retrieval_model.get("score_threshold", 0.0)
+                        if retrieval_model["score_threshold_enabled"]
+                        else 0.0,
+                        reranking_model=retrieval_model.get("reranking_model", None)
+                        if retrieval_model["reranking_enable"]
+                        else None,
+                        reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
+                        weights=retrieval_model.get("weights", None),
+                    )
+                else:
+                    documents = []
+
+                for hit_callback in self.hit_callbacks:
+                    hit_callback.on_tool_end(documents)
+                document_score_list = {}
+                if dataset.indexing_technique != "economy":
+                    for item in documents:
+                        if item.metadata.get("score"):
+                            document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
+                document_context_list = []
+                index_node_ids = [document.metadata["doc_id"] for document in documents]
+                segments = DocumentSegment.query.filter(
+                    DocumentSegment.dataset_id == self.dataset_id,
+                    DocumentSegment.completed_at.isnot(None),
+                    DocumentSegment.status == "completed",
+                    DocumentSegment.enabled == True,
+                    DocumentSegment.index_node_id.in_(index_node_ids),
+                ).all()
+
+                if segments:
+                    index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
+                    sorted_segments = sorted(
+                        segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
+                    )
                     for segment in sorted_segments:
-                        context = {}
-                        document = Document.query.filter(
-                            Document.id == segment.document_id,
-                            Document.enabled == True,
-                            Document.archived == False,
-                        ).first()
-                        if dataset and document:
-                            source = {
-                                "position": resource_number,
-                                "dataset_id": dataset.id,
-                                "dataset_name": dataset.name,
-                                "document_id": document.id,
-                                "document_name": document.name,
-                                "data_source_type": document.data_source_type,
-                                "segment_id": segment.id,
-                                "retriever_from": self.retriever_from,
-                                "score": document_score_list.get(segment.index_node_id, None),
-                            }
-                            if self.retriever_from == "dev":
-                                source["hit_count"] = segment.hit_count
-                                source["word_count"] = segment.word_count
-                                source["segment_position"] = segment.position
-                                source["index_node_hash"] = segment.index_node_hash
-                            if segment.answer:
-                                source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
-                            else:
-                                source["content"] = segment.content
-                            context_list.append(source)
-                        resource_number += 1
-
-                    for hit_callback in self.hit_callbacks:
-                        hit_callback.return_retriever_resource_info(context_list)
-
-            return str("\n".join(document_context_list))
+                        if segment.answer:
+                            document_context_list.append(
+                                f"question:{segment.get_sign_content()} answer:{segment.answer}"
+                            )
+                        else:
+                            document_context_list.append(segment.get_sign_content())
+                    if self.return_resource:
+                        context_list = []
+                        resource_number = 1
+                        for segment in sorted_segments:
+                            context = {}
+                            document = Document.query.filter(
+                                Document.id == segment.document_id,
+                                Document.enabled == True,
+                                Document.archived == False,
+                            ).first()
+                            if dataset and document:
+                                source = {
+                                    "position": resource_number,
+                                    "dataset_id": dataset.id,
+                                    "dataset_name": dataset.name,
+                                    "document_id": document.id,
+                                    "document_name": document.name,
+                                    "data_source_type": document.data_source_type,
+                                    "segment_id": segment.id,
+                                    "retriever_from": self.retriever_from,
+                                    "score": document_score_list.get(segment.index_node_id, None),
+                                }
+                                if self.retriever_from == "dev":
+                                    source["hit_count"] = segment.hit_count
+                                    source["word_count"] = segment.word_count
+                                    source["segment_position"] = segment.position
+                                    source["index_node_hash"] = segment.index_node_hash
+                                if segment.answer:
+                                    source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
+                                else:
+                                    source["content"] = segment.content
+                                context_list.append(source)
+                            resource_number += 1
+
+                        for hit_callback in self.hit_callbacks:
+                            hit_callback.return_retriever_resource_info(context_list)
+
+                return str("\n".join(document_context_list))

+ 9 - 5
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py

@@ -79,8 +79,9 @@ class KnowledgeRetrievalNode(BaseNode):
 
         results = (
             db.session.query(Dataset)
-            .join(subquery, Dataset.id == subquery.c.dataset_id)
+            .outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
             .filter(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
+            .filter((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
             .all()
         )
 
@@ -121,10 +122,13 @@ class KnowledgeRetrievalNode(BaseNode):
                 )
         elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
             if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
-                reranking_model = {
-                    "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
-                    "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
-                }
+                if node_data.multiple_retrieval_config.reranking_model:
+                    reranking_model = {
+                        "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
+                        "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
+                    }
+                else:
+                    reranking_model = None
                 weights = None
             elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score":
                 reranking_model = None

+ 1 - 0
api/services/dataset_service.py

@@ -234,6 +234,7 @@ class DatasetService:
             dataset.name = data.get("name", dataset.name)
             dataset.description = data.get("description", "")
             external_knowledge_id = data.get("external_knowledge_id", None)
+            dataset.permission = data.get("permission")
             db.session.add(dataset)
             if not external_knowledge_id:
                 raise ValueError("External knowledge id is required.")

+ 45 - 0
api/services/knowledge_service.py

@@ -0,0 +1,45 @@
+import boto3
+
+from configs import dify_config
+
+
+class ExternalDatasetTestService:
+    # this service is only for internal testing
+    @staticmethod
+    def knowledge_retrieval(retrieval_setting: dict, query: str, knowledge_id: str):
+        # get bedrock client
+        client = boto3.client(
+            "bedrock-agent-runtime",
+            aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY,
+            aws_access_key_id=dify_config.AWS_ACCESS_KEY_ID,
+            # example: us-east-1
+            region_name="us-east-1",
+        )
+        # fetch external knowledge retrieval
+        response = client.retrieve(
+            knowledgeBaseId=knowledge_id,
+            retrievalConfiguration={
+                "vectorSearchConfiguration": {
+                    "numberOfResults": retrieval_setting.get("top_k"),
+                    "overrideSearchType": "HYBRID",
+                }
+            },
+            retrievalQuery={"text": query},
+        )
+        # parse response
+        results = []
+        if response.get("ResponseMetadata") and response.get("ResponseMetadata").get("HTTPStatusCode") == 200:
+            if response.get("retrievalResults"):
+                retrieval_results = response.get("retrievalResults")
+                for retrieval_result in retrieval_results:
+                    # filter out results with score less than threshold
+                    if retrieval_result.get("score") < retrieval_setting.get("score_threshold", 0.0):
+                        continue
+                    result = {
+                        "metadata": retrieval_result.get("metadata"),
+                        "score": retrieval_result.get("score"),
+                        "title": retrieval_result.get("metadata").get("x-amz-bedrock-kb-source-uri"),
+                        "content": retrieval_result.get("content").get("text"),
+                    }
+                    results.append(result)
+        return {"records": results}