Browse Source

feat: support relyt vector database (#3367)

Co-authored-by: jingsi <jingsi@leadincloud.com>
Jingpan Xiong 1 year ago
parent
commit
33397836a5

+ 8 - 1
api/.env.example

@@ -57,7 +57,7 @@ AZURE_BLOB_ACCOUNT_URL=https://<your_account_name>.blob.core.windows.net
 WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
 CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
 
-# Vector database configuration, support: weaviate, qdrant, milvus
+# Vector database configuration, support: weaviate, qdrant, milvus, relyt
 VECTOR_STORE=weaviate
 
 # Weaviate configuration
@@ -78,6 +78,13 @@ MILVUS_USER=root
 MILVUS_PASSWORD=Milvus
 MILVUS_SECURE=false
 
+# Relyt configuration
+RELYT_HOST=127.0.0.1
+RELYT_PORT=5432
+RELYT_USER=postgres
+RELYT_PASSWORD=postgres
+RELYT_DATABASE=postgres
+
 # Upload configuration
 UPLOAD_FILE_SIZE_LIMIT=15
 UPLOAD_FILE_BATCH_LIMIT=5

+ 8 - 0
api/commands.py

@@ -297,6 +297,14 @@ def migrate_knowledge_vector_database():
                         "vector_store": {"class_prefix": collection_name}
                     }
                     dataset.index_struct = json.dumps(index_struct_dict)
+                elif vector_type == "relyt":
+                    dataset_id = dataset.id
+                    collection_name = Dataset.gen_collection_name_by_id(dataset_id)
+                    index_struct_dict = {
+                        "type": 'relyt',
+                        "vector_store": {"class_prefix": collection_name}
+                    }
+                    dataset.index_struct = json.dumps(index_struct_dict)
                 else:
                     raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
 

+ 8 - 1
api/config.py

@@ -198,7 +198,7 @@ class Config:
 
         # ------------------------
         # Vector Store Configurations.
-        # Currently, only support: qdrant, milvus, zilliz, weaviate
+        # Currently, only support: qdrant, milvus, zilliz, weaviate, relyt
         # ------------------------
         self.VECTOR_STORE = get_env('VECTOR_STORE')
         self.KEYWORD_STORE = get_env('KEYWORD_STORE')
@@ -221,6 +221,13 @@ class Config:
         self.WEAVIATE_GRPC_ENABLED = get_bool_env('WEAVIATE_GRPC_ENABLED')
         self.WEAVIATE_BATCH_SIZE = int(get_env('WEAVIATE_BATCH_SIZE'))
 
+        # relyt settings
+        self.RELYT_HOST = get_env('RELYT_HOST')
+        self.RELYT_PORT = get_env('RELYT_PORT')
+        self.RELYT_USER = get_env('RELYT_USER')
+        self.RELYT_PASSWORD = get_env('RELYT_PASSWORD')
+        self.RELYT_DATABASE = get_env('RELYT_DATABASE')
+
         # ------------------------
         # Mail Configurations.
         # ------------------------

+ 0 - 0
api/core/rag/datasource/vdb/relyt/__init__.py


+ 169 - 0
api/core/rag/datasource/vdb/relyt/relyt_vector.py

@@ -0,0 +1,169 @@
+import logging
+from typing import Any
+
+from pgvecto_rs.sdk import PGVectoRs, Record
+from pydantic import BaseModel, root_validator
+from sqlalchemy import text as sql_text
+from sqlalchemy.orm import Session
+
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.models.document import Document
+from extensions.ext_redis import redis_client
+
+logger = logging.getLogger(__name__)
+
+class RelytConfig(BaseModel):
+    host: str
+    port: int
+    user: str
+    password: str
+    database: str
+
+    @root_validator()
+    def validate_config(cls, values: dict) -> dict:
+        if not values['host']:
+            raise ValueError("config RELYT_HOST is required")
+        if not values['port']:
+            raise ValueError("config RELYT_PORT is required")
+        if not values['user']:
+            raise ValueError("config RELYT_USER is required")
+        if not values['password']:
+            raise ValueError("config RELYT_PASSWORD is required")
+        if not values['database']:
+            raise ValueError("config RELYT_DATABASE is required")
+        return values
+
+
+class RelytVector(BaseVector):
+
+    def __init__(self, collection_name: str, config: RelytConfig, dim: int):
+        super().__init__(collection_name)
+        self._client_config = config
+        self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
+        self._client = PGVectoRs(
+            db_url=self._url,
+            collection_name=self._collection_name,
+            dimension=dim
+        )
+        self._fields = []
+
+    def get_type(self) -> str:
+        return 'relyt'
+
+    def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+        index_params = {}
+        metadatas = [d.metadata for d in texts]
+        self.create_collection(len(embeddings[0]))
+        self.add_texts(texts, embeddings)
+
+    def create_collection(self, dimension: int):
+        lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
+        with redis_client.lock(lock_name, timeout=20):
+            collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
+            if redis_client.get(collection_exist_cache_key):
+                return
+            index_name = f"{self._collection_name}_embedding_index"
+            with Session(self._client._engine) as session:
+                drop_statement = sql_text(f"DROP TABLE IF EXISTS collection_{self._collection_name}")
+                session.execute(drop_statement)
+                create_statement = sql_text(f"""
+                    CREATE TABLE IF NOT EXISTS collection_{self._collection_name} (
+                        id UUID PRIMARY KEY,
+                        text TEXT NOT NULL,
+                        meta JSONB NOT NULL,
+                        embedding vector({dimension}) NOT NULL
+                    ) using heap; 
+                """)
+                session.execute(create_statement)
+                index_statement = sql_text(f"""
+                        CREATE INDEX {index_name}
+                        ON collection_{self._collection_name} USING vectors(embedding vector_l2_ops)
+                        WITH (options = $$
+                                optimizing.optimizing_threads = 30
+                                segment.max_growing_segment_size = 2000
+                                segment.max_sealed_segment_size = 30000000
+                                [indexing.hnsw]
+                                m=30
+                                ef_construction=500
+                                $$);
+                    """)
+                session.execute(index_statement)
+                session.commit()
+            redis_client.set(collection_exist_cache_key, 1, ex=3600)
+
+    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+        records = [Record.from_text(d.page_content, e, d.metadata) for d, e in zip(documents, embeddings)]
+        pks = [str(r.id) for r in records]
+        self._client.insert(records)
+        return pks
+
+    def delete_by_document_id(self, document_id: str):
+        ids = self.get_ids_by_metadata_field('document_id', document_id)
+        if ids:
+            self._client.delete_by_ids(ids)
+
+    def get_ids_by_metadata_field(self, key: str, value: str):
+        result = None
+        with Session(self._client._engine) as session:
+            select_statement = sql_text(
+                f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'{key}' = '{value}'; "
+            )
+            result = session.execute(select_statement).fetchall()
+        if result:
+            return [item[0] for item in result]
+        else:
+            return None
+
+    def delete_by_metadata_field(self, key: str, value: str):
+
+        ids = self.get_ids_by_metadata_field(key, value)
+        if ids:
+            self._client.delete_by_ids(ids)
+
+    def delete_by_ids(self, doc_ids: list[str]) -> None:
+        with Session(self._client._engine) as session:
+            select_statement = sql_text(
+                f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'doc_id' in ('{doc_ids}'); "
+            )
+            result = session.execute(select_statement).fetchall()
+        if result:
+            ids = [item[0] for item in result]
+            self._client.delete_by_ids(ids)
+
+    def delete(self) -> None:
+        with Session(self._client._engine) as session:
+            session.execute(sql_text(f"DROP TABLE IF EXISTS collection_{self._collection_name}"))
+            session.commit()
+
+    def text_exists(self, id: str) -> bool:
+        with Session(self._client._engine) as session:
+            select_statement = sql_text(
+                f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'doc_id' = '{id}' limit 1; "
+            )
+            result = session.execute(select_statement).fetchall()
+        return len(result) > 0
+
+    def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+        from pgvecto_rs.sdk import filters
+        filter_condition = filters.meta_contains(kwargs.get('filter'))
+        results = self._client.search(
+            top_k=int(kwargs.get('top_k')),
+            embedding=query_vector,
+            filter=filter_condition
+        )
+
+        # Organize results.
+        docs = []
+        for record, dis in results:
+            metadata = record.meta
+            metadata['score'] = dis
+            score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
+            if dis > score_threshold:
+                doc = Document(page_content=record.text,
+                               metadata=metadata)
+                docs.append(doc)
+        return docs
+
+    def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+        # milvus/zilliz/relyt doesn't support bm25 search
+        return []

+ 25 - 0
api/core/rag/datasource/vdb/vector_factory.py

@@ -113,6 +113,31 @@ class Vector:
                     database=config.get('MILVUS_DATABASE'),
                 )
             )
+        elif vector_type == "relyt":
+            from core.rag.datasource.vdb.relyt.relyt_vector import RelytConfig, RelytVector
+            if self._dataset.index_struct_dict:
+                class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
+                collection_name = class_prefix
+            else:
+                dataset_id = self._dataset.id
+                collection_name = Dataset.gen_collection_name_by_id(dataset_id)
+                index_struct_dict = {
+                    "type": 'relyt',
+                    "vector_store": {"class_prefix": collection_name}
+                }
+                self._dataset.index_struct = json.dumps(index_struct_dict)
+            dim = len(self._embeddings.embed_query("hello relyt"))
+            return RelytVector(
+                collection_name=collection_name,
+                config=RelytConfig(
+                    host=config.get('RELYT_HOST'),
+                    port=config.get('RELYT_PORT'),
+                    user=config.get('RELYT_USER'),
+                    password=config.get('RELYT_PASSWORD'),
+                    database=config.get('RELYT_DATABASE'),
+                ),
+                dim=dim
+            )
         else:
             raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
 

+ 1 - 0
api/requirements.txt

@@ -79,3 +79,4 @@ azure-storage-blob==12.9.0
 azure-identity==1.15.0
 lxml==5.1.0
 xlrd~=2.0.1
+pgvecto-rs==0.1.4

+ 6 - 0
docker/docker-compose.yaml

@@ -223,6 +223,12 @@ services:
       # the api-key for resend (https://resend.com)
       RESEND_API_KEY: ''
       RESEND_API_URL: https://api.resend.com
+      # relyt configurations
+      RELYT_HOST: db
+      RELYT_PORT: 5432
+      RELYT_USER: postgres
+      RELYT_PASSWORD: difyai123456
+      RELYT_DATABASE: postgres
     depends_on:
       - db
       - redis