Browse Source

feat: support tencent vector db (#3568)

quicksand 10 months ago
parent
commit
4080f7b8ad

+ 9 - 0
api/.env.example

@@ -99,6 +99,15 @@ RELYT_USER=postgres
 RELYT_PASSWORD=postgres
 RELYT_DATABASE=postgres
 
+# Tencent configuration
+TENCENT_VECTOR_DB_URL=http://127.0.0.1
+TENCENT_VECTOR_DB_API_KEY=dify
+TENCENT_VECTOR_DB_TIMEOUT=30
+TENCENT_VECTOR_DB_USERNAME=dify
+TENCENT_VECTOR_DB_DATABASE=dify
+TENCENT_VECTOR_DB_SHARD=1
+TENCENT_VECTOR_DB_REPLICAS=2
+
 # PGVECTO_RS configuration
 PGVECTO_RS_HOST=localhost
 PGVECTO_RS_PORT=5431

+ 8 - 0
api/commands.py

@@ -309,6 +309,14 @@ def migrate_knowledge_vector_database():
                         "vector_store": {"class_prefix": collection_name}
                     }
                     dataset.index_struct = json.dumps(index_struct_dict)
+                elif vector_type == VectorType.TENCENT:
+                    dataset_id = dataset.id
+                    collection_name = Dataset.gen_collection_name_by_id(dataset_id)
+                    index_struct_dict = {
+                        "type": VectorType.TENCENT,
+                        "vector_store": {"class_prefix": collection_name}
+                    }
+                    dataset.index_struct = json.dumps(index_struct_dict)
                 elif vector_type == VectorType.PGVECTOR:
                     dataset_id = dataset.id
                     collection_name = Dataset.gen_collection_name_by_id(dataset_id)

+ 10 - 0
api/config.py

@@ -288,6 +288,16 @@ class Config:
         self.RELYT_PASSWORD = get_env('RELYT_PASSWORD')
         self.RELYT_DATABASE = get_env('RELYT_DATABASE')
 
+
+        # tencent settings
+        self.TENCENT_VECTOR_DB_URL = get_env('TENCENT_VECTOR_DB_URL')
+        self.TENCENT_VECTOR_DB_API_KEY = get_env('TENCENT_VECTOR_DB_API_KEY')
+        self.TENCENT_VECTOR_DB_TIMEOUT = get_env('TENCENT_VECTOR_DB_TIMEOUT')
+        self.TENCENT_VECTOR_DB_USERNAME = get_env('TENCENT_VECTOR_DB_USERNAME')
+        self.TENCENT_VECTOR_DB_DATABASE = get_env('TENCENT_VECTOR_DB_DATABASE')
+        self.TENCENT_VECTOR_DB_SHARD = get_env('TENCENT_VECTOR_DB_SHARD')
+        self.TENCENT_VECTOR_DB_REPLICAS = get_env('TENCENT_VECTOR_DB_REPLICAS')
+
         # pgvecto rs settings
         self.PGVECTO_RS_HOST = get_env('PGVECTO_RS_HOST')
         self.PGVECTO_RS_PORT = get_env('PGVECTO_RS_PORT')

+ 2 - 3
api/controllers/console/datasets/datasets.py

@@ -480,9 +480,8 @@ class DatasetRetrievalSettingApi(Resource):
     @account_initialization_required
     def get(self):
         vector_type = current_app.config['VECTOR_STORE']
-
         match vector_type:
-            case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA:
+            case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT:
                 return {
                     'retrieval_method': [
                         'semantic_search'
@@ -504,7 +503,7 @@ class DatasetRetrievalSettingMockApi(Resource):
     @account_initialization_required
     def get(self, vector_type):
         match vector_type:
-            case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA:
+            case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCEN:
                 return {
                     'retrieval_method': [
                         'semantic_search'

+ 0 - 0
api/core/rag/datasource/vdb/tencent/__init__.py


+ 227 - 0
api/core/rag/datasource/vdb/tencent/tencent_vector.py

@@ -0,0 +1,227 @@
+import json
+from typing import Any, Optional
+
+from flask import current_app
+from pydantic import BaseModel
+from tcvectordb import VectorDBClient
+from tcvectordb.model import document, enum
+from tcvectordb.model import index as vdb_index
+from tcvectordb.model.document import Filter
+
+from core.rag.datasource.entity.embedding import Embeddings
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
+from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.models.document import Document
+from extensions.ext_redis import redis_client
+from models.dataset import Dataset
+
+
+class TencentConfig(BaseModel):
+    url: str
+    api_key: Optional[str]
+    timeout: float = 30
+    username: Optional[str]
+    database: Optional[str]
+    index_type: str = "HNSW"
+    metric_type: str = "L2"
+    shard: int = 1,
+    replicas: int = 2,
+
+    def to_tencent_params(self):
+        return {
+            'url': self.url,
+            'username': self.username,
+            'key': self.api_key,
+            'timeout': self.timeout
+        }
+
+
+class TencentVector(BaseVector):
+    field_id: str = "id"
+    field_vector: str = "vector"
+    field_text: str = "text"
+    field_metadata: str = "metadata"
+
+    def __init__(self, collection_name: str, config: TencentConfig):
+        super().__init__(collection_name)
+        self._client_config = config
+        self._client = VectorDBClient(**self._client_config.to_tencent_params())
+        self._db = self._init_database()
+
+    def _init_database(self):
+        exists = False
+        for db in self._client.list_databases():
+            if db.database_name == self._client_config.database:
+                exists = True
+                break
+        if exists:
+            return self._client.database(self._client_config.database)
+        else:
+            return self._client.create_database(database_name=self._client_config.database)
+
+    def get_type(self) -> str:
+        return 'tencent'
+
+    def to_index_struct(self) -> dict:
+        return {
+            "type": self.get_type(),
+            "vector_store": {"class_prefix": self._collection_name}
+        }
+
+    def _has_collection(self) -> bool:
+        collections = self._db.list_collections()
+        for collection in collections:
+            if collection.collection_name == self._collection_name:
+                return True
+        return False
+
+    def _create_collection(self, dimension: int) -> None:
+        lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
+        with redis_client.lock(lock_name, timeout=20):
+            collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
+            if redis_client.get(collection_exist_cache_key):
+                return
+
+            if self._has_collection():
+                return
+
+            self.delete()
+            index_type = None
+            for k, v in enum.IndexType.__members__.items():
+                if k == self._client_config.index_type:
+                    index_type = v
+            if index_type is None:
+                raise ValueError("unsupported index_type")
+            metric_type = None
+            for k, v in enum.MetricType.__members__.items():
+                if k == self._client_config.metric_type:
+                    metric_type = v
+            if metric_type is None:
+                raise ValueError("unsupported metric_type")
+            params = vdb_index.HNSWParams(m=16, efconstruction=200)
+            index = vdb_index.Index(
+                vdb_index.FilterIndex(
+                    self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY
+                ),
+                vdb_index.VectorIndex(
+                    self.field_vector,
+                    dimension,
+                    index_type,
+                    metric_type,
+                    params,
+                ),
+                vdb_index.FilterIndex(
+                    self.field_text, enum.FieldType.String, enum.IndexType.FILTER
+                ),
+                vdb_index.FilterIndex(
+                    self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER
+                ),
+            )
+
+            self._db.create_collection(
+                name=self._collection_name,
+                shard=self._client_config.shard,
+                replicas=self._client_config.replicas,
+                description="Collection for Dify",
+                index=index,
+            )
+            redis_client.set(collection_exist_cache_key, 1, ex=3600)
+
+    def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+        self._create_collection(len(embeddings[0]))
+        self.add_texts(texts, embeddings)
+
+    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+        texts = [doc.page_content for doc in documents]
+        metadatas = [doc.metadata for doc in documents]
+        total_count = len(embeddings)
+        docs = []
+        for id in range(0, total_count):
+            if metadatas is None:
+                continue
+            metadata = json.dumps(metadatas[id])
+            doc = document.Document(
+                id=metadatas[id]["doc_id"],
+                vector=embeddings[id],
+                text=texts[id],
+                metadata=metadata,
+            )
+            docs.append(doc)
+        self._db.collection(self._collection_name).upsert(docs, self._client_config.timeout)
+
+    def text_exists(self, id: str) -> bool:
+        docs = self._db.collection(self._collection_name).query(document_ids=[id])
+        if docs and len(docs) > 0:
+            return True
+        return False
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+        self._db.collection(self._collection_name).delete(document_ids=ids)
+
+    def delete_by_metadata_field(self, key: str, value: str) -> None:
+        self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value])))
+
+    def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+
+        res = self._db.collection(self._collection_name).search(vectors=[query_vector],
+                                                                params=document.HNSWSearchParams(
+                                                                    ef=kwargs.get("ef", 10)),
+                                                                retrieve_vector=False,
+                                                                limit=kwargs.get('top_k', 4),
+                                                                timeout=self._client_config.timeout,
+                                                                )
+        score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
+        return self._get_search_res(res, score_threshold)
+
+    def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+        return []
+
+    def _get_search_res(self, res, score_threshold):
+        docs = []
+        if res is None or len(res) == 0:
+            return docs
+
+        for result in res[0]:
+            meta = result.get(self.field_metadata)
+            if meta is not None:
+                meta = json.loads(meta)
+            score = 1 - result.get("score", 0.0)
+            if score > score_threshold:
+                meta["score"] = score
+                doc = Document(page_content=result.get(self.field_text), metadata=meta)
+                docs.append(doc)
+
+        return docs
+
+    def delete(self) -> None:
+        self._db.drop_collection(name=self._collection_name)
+
+
+
+
+class TencentVectorFactory(AbstractVectorFactory):
+    def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TencentVector:
+
+        if dataset.index_struct_dict:
+            class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
+            collection_name = class_prefix.lower()
+        else:
+            dataset_id = dataset.id
+            collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
+            dataset.index_struct = json.dumps(
+                self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
+
+        config = current_app.config
+        return TencentVector(
+            collection_name=collection_name,
+            config=TencentConfig(
+                url=config.get('TENCENT_VECTOR_DB_URL'),
+                api_key=config.get('TENCENT_VECTOR_DB_API_KEY'),
+                timeout=config.get('TENCENT_VECTOR_DB_TIMEOUT'),
+                username=config.get('TENCENT_VECTOR_DB_USERNAME'),
+                database=config.get('TENCENT_VECTOR_DB_DATABASE'),
+                shard=config.get('TENCENT_VECTOR_DB_SHARD'),
+                replicas=config.get('TENCENT_VECTOR_DB_REPLICAS'),
+            )
+        )

+ 3 - 1
api/core/rag/datasource/vdb/vector_factory.py

@@ -39,7 +39,6 @@ class Vector:
     def _init_vector(self) -> BaseVector:
         config = current_app.config
         vector_type = config.get('VECTOR_STORE')
-
         if self._dataset.index_struct_dict:
             vector_type = self._dataset.index_struct_dict['type']
 
@@ -76,6 +75,9 @@ class Vector:
             case VectorType.WEAVIATE:
                 from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
                 return WeaviateVectorFactory
+            case VectorType.TENCENT:
+                from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory
+                return TencentVectorFactory
             case _:
                 raise ValueError(f"Vector store {vector_type} is not supported.")
 

+ 1 - 0
api/core/rag/datasource/vdb/vector_type.py

@@ -10,3 +10,4 @@ class VectorType(str, Enum):
     RELYT = 'relyt'
     TIDB_VECTOR = 'tidb_vector'
     WEAVIATE = 'weaviate'
+    TENCENT = 'tencent'

+ 44 - 1
api/poetry.lock

@@ -1439,6 +1439,23 @@ mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.8.0)", "types-Pill
 test = ["Pillow", "contourpy[test-no-images]", "matplotlib"]
 test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"]
 
+[[package]]
+name = "cos-python-sdk-v5"
+version = "1.9.29"
+description = "cos-python-sdk-v5"
+optional = false
+python-versions = "*"
+files = [
+    {file = "cos-python-sdk-v5-1.9.29.tar.gz", hash = "sha256:1bb07022368d178e7a50a3cc42e0d6cbf4b0bef2af12a3bb8436904339cdec8e"},
+]
+
+[package.dependencies]
+crcmod = "*"
+pycryptodome = "*"
+requests = ">=2.8"
+six = "*"
+xmltodict = "*"
+
 [[package]]
 name = "coverage"
 version = "7.2.7"
@@ -7411,6 +7428,21 @@ files = [
 [package.extras]
 widechars = ["wcwidth"]
 
+[[package]]
+name = "tcvectordb"
+version = "1.3.2"
+description = "Tencent VectorDB Python SDK"
+optional = false
+python-versions = ">=3"
+files = [
+    {file = "tcvectordb-1.3.2-py3-none-any.whl", hash = "sha256:c4b6922d5df4cf14fcd3e61220d9374d1d53ec7270c254216ae35f8a752908f3"},
+    {file = "tcvectordb-1.3.2.tar.gz", hash = "sha256:2772f5871a69744ffc7c970b321312d626078533a721de3c744059a81aab419e"},
+]
+
+[package.dependencies]
+cos-python-sdk-v5 = ">=1.9.26"
+requests = "*"
+
 [[package]]
 name = "tenacity"
 version = "8.3.0"
@@ -8641,6 +8673,17 @@ files = [
     {file = "XlsxWriter-3.2.0.tar.gz", hash = "sha256:9977d0c661a72866a61f9f7a809e25ebbb0fb7036baa3b9fe74afcfca6b3cb8c"},
 ]
 
+[[package]]
+name = "xmltodict"
+version = "0.13.0"
+description = "Makes working with XML feel like you are working with JSON"
+optional = false
+python-versions = ">=3.4"
+files = [
+    {file = "xmltodict-0.13.0-py2.py3-none-any.whl", hash = "sha256:aa89e8fd76320154a40d19a0df04a4695fb9dc5ba977cbb68ab3e4eb225e7852"},
+    {file = "xmltodict-0.13.0.tar.gz", hash = "sha256:341595a488e3e01a85a9d8911d8912fd922ede5fecc4dce437eb4b6c8d037e56"},
+]
+
 [[package]]
 name = "yarl"
 version = "1.9.4"
@@ -8878,4 +8921,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"]
 [metadata]
 lock-version = "2.0"
 python-versions = "^3.10"
-content-hash = "32a9ac027beabdb863fb33886bbf6f0000cbddf4d6089cbdb5c5dbfba23b29b4"
+content-hash = "e967aa4b61dc7c40f2f50eb325038da1dc0ff633d8f778e7a7560bdabce744dc"

+ 1 - 0
api/pyproject.toml

@@ -179,6 +179,7 @@ google-cloud-aiplatform = "1.49.0"
 vanna = {version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"]}
 kaleido = "0.2.1"
 tencentcloud-sdk-python-hunyuan = "~3.0.1158"
+tcvectordb = "1.3.2"
 chromadb = "~0.5.0"
 
 [tool.poetry.group.dev]

+ 1 - 0
api/requirements.txt

@@ -78,6 +78,7 @@ lxml==5.1.0
 pydantic~=2.7.4
 pydantic_extra_types~=2.8.1
 pgvecto-rs==0.1.4
+tcvectordb==1.3.2
 firecrawl-py==0.0.5
 oss2==2.18.5
 pgvector==0.2.5

+ 0 - 0
api/tests/integration_tests/vdb/__mock/__init__.py


+ 132 - 0
api/tests/integration_tests/vdb/__mock/tcvectordb.py

@@ -0,0 +1,132 @@
+import os
+from typing import Optional
+
+import pytest
+from _pytest.monkeypatch import MonkeyPatch
+from requests.adapters import HTTPAdapter
+from tcvectordb import VectorDBClient
+from tcvectordb.model.database import Collection, Database
+from tcvectordb.model.document import Document, Filter
+from tcvectordb.model.enum import ReadConsistency
+from tcvectordb.model.index import Index
+from xinference_client.types import Embedding
+
+
+class MockTcvectordbClass:
+
+    def VectorDBClient(self, url=None, username='', key='',
+                       read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
+                       timeout=5,
+                       adapter: HTTPAdapter = None):
+        self._conn = None
+        self._read_consistency = read_consistency
+
+    def list_databases(self) -> list[Database]:
+        return [
+            Database(
+                conn=self._conn,
+                read_consistency=self._read_consistency,
+                name='dify',
+            )]
+
+    def list_collections(self, timeout: Optional[float] = None) -> list[Collection]:
+        return []
+
+    def drop_collection(self, name: str, timeout: Optional[float] = None):
+        return {
+            "code": 0,
+            "msg": "operation success"
+        }
+
+    def create_collection(
+            self,
+            name: str,
+            shard: int,
+            replicas: int,
+            description: str,
+            index: Index,
+            embedding: Embedding = None,
+            timeout: float = None,
+    ) -> Collection:
+        return Collection(self, name, shard, replicas, description, index, embedding=embedding,
+                          read_consistency=self._read_consistency, timeout=timeout)
+
+    def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection:
+        collection = Collection(
+            self,
+            name,
+            shard=1,
+            replicas=2,
+            description=name,
+            timeout=timeout
+        )
+        return collection
+
+    def collection_upsert(
+            self,
+            documents: list[Document],
+            timeout: Optional[float] = None,
+            build_index: bool = True,
+            **kwargs
+    ):
+        return {
+            "code": 0,
+            "msg": "operation success"
+        }
+
+    def collection_search(
+            self,
+            vectors: list[list[float]],
+            filter: Filter = None,
+            params=None,
+            retrieve_vector: bool = False,
+            limit: int = 10,
+            output_fields: Optional[list[str]] = None,
+            timeout: Optional[float] = None,
+    ) -> list[list[dict]]:
+        return [[{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}]]
+
+    def collection_query(
+            self,
+            document_ids: Optional[list] = None,
+            retrieve_vector: bool = False,
+            limit: Optional[int] = None,
+            offset: Optional[int] = None,
+            filter: Optional[Filter] = None,
+            output_fields: Optional[list[str]] = None,
+            timeout: Optional[float] = None,
+    ) -> list[dict]:
+        return [{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}]
+
+    def collection_delete(
+            self,
+            document_ids: list[str] = None,
+            filter: Filter = None,
+            timeout: float = None,
+    ):
+        return {
+            "code": 0,
+            "msg": "operation success"
+        }
+
+
+MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
+
+@pytest.fixture
+def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
+    if MOCK:
+        monkeypatch.setattr(VectorDBClient, '__init__', MockTcvectordbClass.VectorDBClient)
+        monkeypatch.setattr(VectorDBClient, 'list_databases', MockTcvectordbClass.list_databases)
+        monkeypatch.setattr(Database, 'collection', MockTcvectordbClass.describe_collection)
+        monkeypatch.setattr(Database, 'list_collections', MockTcvectordbClass.list_collections)
+        monkeypatch.setattr(Database, 'drop_collection', MockTcvectordbClass.drop_collection)
+        monkeypatch.setattr(Database, 'create_collection', MockTcvectordbClass.create_collection)
+        monkeypatch.setattr(Collection, 'upsert', MockTcvectordbClass.collection_upsert)
+        monkeypatch.setattr(Collection, 'search', MockTcvectordbClass.collection_search)
+        monkeypatch.setattr(Collection, 'query', MockTcvectordbClass.collection_query)
+        monkeypatch.setattr(Collection, 'delete', MockTcvectordbClass.collection_delete)
+
+    yield
+
+    if MOCK:
+        monkeypatch.undo()

+ 0 - 0
api/tests/integration_tests/vdb/tcvectordb/__init__.py


+ 35 - 0
api/tests/integration_tests/vdb/tcvectordb/test_tencent.py

@@ -0,0 +1,35 @@
+from unittest.mock import MagicMock
+
+from core.rag.datasource.vdb.tencent.tencent_vector import TencentConfig, TencentVector
+from tests.integration_tests.vdb.__mock.tcvectordb import setup_tcvectordb_mock
+from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
+
+mock_client = MagicMock()
+mock_client.list_databases.return_value = [{"name": "test"}]
+
+class TencentVectorTest(AbstractVectorTest):
+    def __init__(self):
+        super().__init__()
+        self.vector = TencentVector("dify", TencentConfig(
+            url="http://127.0.0.1",
+            api_key="dify",
+            timeout=30,
+            username="dify",
+            database="dify",
+            shard=1,
+            replicas=2,
+        ))
+
+    def search_by_vector(self):
+        hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
+        assert len(hits_by_vector) == 1
+
+    def search_by_full_text(self):
+        hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
+        assert len(hits_by_full_text) == 0
+
+def test_tencent_vector(setup_mock_redis,setup_tcvectordb_mock):
+    TencentVectorTest().run_all_tests()
+
+
+

+ 8 - 0
docker/docker-compose.yaml

@@ -298,6 +298,14 @@ services:
       RELYT_USER: postgres
       RELYT_PASSWORD: difyai123456
       RELYT_DATABASE: postgres
+      # tencent configurations
+      TENCENT_VECTOR_DB_URL: http://127.0.0.1
+      TENCENT_VECTOR_DB_API_KEY: dify
+      TENCENT_VECTOR_DB_TIMEOUT: 30
+      TENCENT_VECTOR_DB_USERNAME: dify
+      TENCENT_VECTOR_DB_DATABASE: dify
+      TENCENT_VECTOR_DB_SHARD: 1
+      TENCENT_VECTOR_DB_REPLICAS: 2
       # pgvector configurations
       PGVECTOR_HOST: pgvector
       PGVECTOR_PORT: 5432