Browse Source

improve: generalize vector factory classes and vector type (#5033)

Bowen Liang 10 months ago
parent
commit
bdad993901

+ 11 - 10
api/commands.py

@@ -9,6 +9,7 @@ from werkzeug.exceptions import NotFound
 
 from constants.languages import languages
 from core.rag.datasource.vdb.vector_factory import Vector
+from core.rag.datasource.vdb.vector_type import VectorType
 from core.rag.models.document import Document
 from extensions.ext_database import db
 from libs.helper import email as email_validate
@@ -266,15 +267,15 @@ def migrate_knowledge_vector_database():
                         skipped_count = skipped_count + 1
                         continue
                 collection_name = ''
-                if vector_type == "weaviate":
+                if vector_type == VectorType.WEAVIATE:
                     dataset_id = dataset.id
                     collection_name = Dataset.gen_collection_name_by_id(dataset_id)
                     index_struct_dict = {
-                        "type": 'weaviate',
+                        "type": VectorType.WEAVIATE,
                         "vector_store": {"class_prefix": collection_name}
                     }
                     dataset.index_struct = json.dumps(index_struct_dict)
-                elif vector_type == "qdrant":
+                elif vector_type == VectorType.QDRANT:
                     if dataset.collection_binding_id:
                         dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
                             filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
@@ -287,20 +288,20 @@ def migrate_knowledge_vector_database():
                         dataset_id = dataset.id
                         collection_name = Dataset.gen_collection_name_by_id(dataset_id)
                     index_struct_dict = {
-                        "type": 'qdrant',
+                        "type": VectorType.QDRANT,
                         "vector_store": {"class_prefix": collection_name}
                     }
                     dataset.index_struct = json.dumps(index_struct_dict)
 
-                elif vector_type == "milvus":
+                elif vector_type == VectorType.MILVUS:
                     dataset_id = dataset.id
                     collection_name = Dataset.gen_collection_name_by_id(dataset_id)
                     index_struct_dict = {
-                        "type": 'milvus',
+                        "type": VectorType.MILVUS,
                         "vector_store": {"class_prefix": collection_name}
                     }
                     dataset.index_struct = json.dumps(index_struct_dict)
-                elif vector_type == "relyt":
+                elif vector_type == VectorType.RELYT:
                     dataset_id = dataset.id
                     collection_name = Dataset.gen_collection_name_by_id(dataset_id)
                     index_struct_dict = {
@@ -308,16 +309,16 @@ def migrate_knowledge_vector_database():
                         "vector_store": {"class_prefix": collection_name}
                     }
                     dataset.index_struct = json.dumps(index_struct_dict)
-                elif vector_type == "pgvector":
+                elif vector_type == VectorType.PGVECTOR:
                     dataset_id = dataset.id
                     collection_name = Dataset.gen_collection_name_by_id(dataset_id)
                     index_struct_dict = {
-                        "type": 'pgvector',
+                        "type": VectorType.PGVECTOR,
                         "vector_store": {"class_prefix": collection_name}
                     }
                     dataset.index_struct = json.dumps(index_struct_dict)
                 else:
-                    raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
+                    raise ValueError(f"Vector store {vector_type} is not supported.")
 
                 vector = Vector(dataset)
                 click.echo(f"Start to migrate dataset {dataset.id}.")

+ 33 - 28
api/controllers/console/datasets/datasets.py

@@ -15,6 +15,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
 from core.indexing_runner import IndexingRunner
 from core.model_runtime.entities.model_entities import ModelType
 from core.provider_manager import ProviderManager
+from core.rag.datasource.vdb.vector_type import VectorType
 from core.rag.extractor.entity.extract_setting import ExtractSetting
 from extensions.ext_database import db
 from fields.app_fields import related_app_list
@@ -476,20 +477,22 @@ class DatasetRetrievalSettingApi(Resource):
     @account_initialization_required
     def get(self):
         vector_type = current_app.config['VECTOR_STORE']
-        if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs", 'tidb_vector'}:
-            return {
-                'retrieval_method': [
-                    'semantic_search'
-                ]
-            }
-        elif vector_type in {"qdrant", "weaviate"}:
-            return {
-                'retrieval_method': [
-                    'semantic_search', 'full_text_search', 'hybrid_search'
-                ]
-            }
-        else:
-            raise ValueError("Unsupported vector db type.")
+
+        match vector_type:
+            case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR:
+                return {
+                    'retrieval_method': [
+                        'semantic_search'
+                    ]
+                }
+            case VectorType.QDRANT | VectorType.WEAVIATE:
+                return {
+                    'retrieval_method': [
+                        'semantic_search', 'full_text_search', 'hybrid_search'
+                    ]
+                }
+            case _:
+                raise ValueError(f"Unsupported vector db type {vector_type}.")
 
 
 class DatasetRetrievalSettingMockApi(Resource):
@@ -497,20 +500,22 @@ class DatasetRetrievalSettingMockApi(Resource):
     @login_required
     @account_initialization_required
     def get(self, vector_type):
-        if vector_type in {'milvus', 'relyt', 'pgvector', 'tidb_vector'}:
-            return {
-                'retrieval_method': [
-                    'semantic_search'
-                ]
-            }
-        elif vector_type in {'qdrant', 'weaviate'}:
-            return {
-                'retrieval_method': [
-                    'semantic_search', 'full_text_search', 'hybrid_search'
-                ]
-            }
-        else:
-            raise ValueError("Unsupported vector db type.")
+        match vector_type:
+            case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR:
+                return {
+                    'retrieval_method': [
+                        'semantic_search'
+                    ]
+                }
+            case VectorType.QDRANT | VectorType.WEAVIATE:
+                return {
+                    'retrieval_method': [
+                        'semantic_search', 'full_text_search', 'hybrid_search'
+                    ]
+                }
+            case _:
+                raise ValueError(f"Unsupported vector db type {vector_type}.")
+
 
 class DatasetErrorDocs(Resource):
     @setup_required

+ 34 - 2
api/core/rag/datasource/vdb/milvus/milvus_vector.py

@@ -1,14 +1,20 @@
+import json
 import logging
 from typing import Any, Optional
 from uuid import uuid4
 
+from flask import current_app
 from pydantic import BaseModel, root_validator
 from pymilvus import MilvusClient, MilvusException, connections
 
+from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
+from core.rag.datasource.vdb.vector_type import VectorType
 from core.rag.models.document import Document
 from extensions.ext_redis import redis_client
+from models.dataset import Dataset
 
 logger = logging.getLogger(__name__)
 
@@ -55,7 +61,7 @@ class MilvusVector(BaseVector):
         self._fields = []
 
     def get_type(self) -> str:
-        return 'milvus'
+        return VectorType.MILVUS
 
     def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
         index_params = {
@@ -254,10 +260,36 @@ class MilvusVector(BaseVector):
                                                            schema=schema, index_param=index_params,
                                                            consistency_level=self._consistency_level)
             redis_client.set(collection_exist_cache_key, 1, ex=3600)
+
     def _init_client(self, config) -> MilvusClient:
         if config.secure:
             uri = "https://" + str(config.host) + ":" + str(config.port)
         else:
             uri = "http://" + str(config.host) + ":" + str(config.port)
-        client = MilvusClient(uri=uri, user=config.user, password=config.password,db_name=config.database)
+        client = MilvusClient(uri=uri, user=config.user, password=config.password, db_name=config.database)
         return client
+
+
+class MilvusVectorFactory(AbstractVectorFactory):
+    def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
+        if dataset.index_struct_dict:
+            class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
+            collection_name = class_prefix
+        else:
+            dataset_id = dataset.id
+            collection_name = Dataset.gen_collection_name_by_id(dataset_id)
+            dataset.index_struct = json.dumps(
+                self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
+
+        config = current_app.config
+        return MilvusVector(
+            collection_name=collection_name,
+            config=MilvusConfig(
+                host=config.get('MILVUS_HOST'),
+                port=config.get('MILVUS_PORT'),
+                user=config.get('MILVUS_USER'),
+                password=config.get('MILVUS_PASSWORD'),
+                secure=config.get('MILVUS_SECURE'),
+                database=config.get('MILVUS_DATABASE'),
+            )
+        )

+ 32 - 1
api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py

@@ -1,7 +1,9 @@
+import json
 import logging
 from typing import Any
 from uuid import UUID, uuid4
 
+from flask import current_app
 from numpy import ndarray
 from pgvecto_rs.sqlalchemy import Vector
 from pydantic import BaseModel, root_validator
@@ -10,10 +12,14 @@ from sqlalchemy import text as sql_text
 from sqlalchemy.dialects import postgresql
 from sqlalchemy.orm import Mapped, Session, mapped_column
 
+from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
 from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
+from core.rag.datasource.vdb.vector_type import VectorType
 from core.rag.models.document import Document
 from extensions.ext_redis import redis_client
+from models.dataset import Dataset
 
 logger = logging.getLogger(__name__)
 
@@ -67,7 +73,7 @@ class PGVectoRS(BaseVector):
         self._distance_op = "<=>"
 
     def get_type(self) -> str:
-        return 'pgvecto-rs'
+        return VectorType.PGVECTO_RS
 
     def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
         self.create_collection(len(embeddings[0]))
@@ -222,3 +228,28 @@ class PGVectoRS(BaseVector):
         #         docs.append(doc)
         #     return docs
         return []
+
+
+class PGVectoRSFactory(AbstractVectorFactory):
+    def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVectoRS:
+        if dataset.index_struct_dict:
+            class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
+            collection_name = class_prefix.lower()
+        else:
+            dataset_id = dataset.id
+            collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
+            dataset.index_struct = json.dumps(
+                self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
+        dim = len(embeddings.embed_query("pgvecto_rs"))
+        config = current_app.config
+        return PGVectoRS(
+            collection_name=collection_name,
+            config=PgvectoRSConfig(
+                host=config.get('PGVECTO_RS_HOST'),
+                port=config.get('PGVECTO_RS_PORT'),
+                user=config.get('PGVECTO_RS_USER'),
+                password=config.get('PGVECTO_RS_PASSWORD'),
+                database=config.get('PGVECTO_RS_DATABASE'),
+            ),
+            dim=dim
+        )

+ 30 - 1
api/core/rag/datasource/vdb/pgvector/pgvector.py

@@ -5,11 +5,16 @@ from typing import Any
 
 import psycopg2.extras
 import psycopg2.pool
+from flask import current_app
 from pydantic import BaseModel, root_validator
 
+from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
+from core.rag.datasource.vdb.vector_type import VectorType
 from core.rag.models.document import Document
 from extensions.ext_redis import redis_client
+from models.dataset import Dataset
 
 
 class PGVectorConfig(BaseModel):
@@ -51,7 +56,7 @@ class PGVector(BaseVector):
         self.table_name = f"embedding_{collection_name}"
 
     def get_type(self) -> str:
-        return "pgvector"
+        return VectorType.PGVECTOR
 
     def _create_connection_pool(self, config: PGVectorConfig):
         return psycopg2.pool.SimpleConnectionPool(
@@ -167,3 +172,27 @@ class PGVector(BaseVector):
                 cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
                 # TODO: create index https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
             redis_client.set(collection_exist_cache_key, 1, ex=3600)
+
+
+class PGVectorFactory(AbstractVectorFactory):
+    def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVector:
+        if dataset.index_struct_dict:
+            class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
+            collection_name = class_prefix
+        else:
+            dataset_id = dataset.id
+            collection_name = Dataset.gen_collection_name_by_id(dataset_id)
+            dataset.index_struct = json.dumps(
+                self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name))
+
+        config = current_app.config
+        return PGVector(
+            collection_name=collection_name,
+            config=PGVectorConfig(
+                host=config.get("PGVECTOR_HOST"),
+                port=config.get("PGVECTOR_PORT"),
+                user=config.get("PGVECTOR_USER"),
+                password=config.get("PGVECTOR_PASSWORD"),
+                database=config.get("PGVECTOR_DATABASE"),
+            ),
+        )

+ 45 - 1
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py

@@ -1,3 +1,4 @@
+import json
 import os
 import uuid
 from collections.abc import Generator, Iterable, Sequence
@@ -5,6 +6,7 @@ from itertools import islice
 from typing import TYPE_CHECKING, Any, Optional, Union, cast
 
 import qdrant_client
+from flask import current_app
 from pydantic import BaseModel
 from qdrant_client.http import models as rest
 from qdrant_client.http.models import (
@@ -17,10 +19,15 @@ from qdrant_client.http.models import (
 )
 from qdrant_client.local.qdrant_local import QdrantLocal
 
+from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
+from core.rag.datasource.vdb.vector_type import VectorType
 from core.rag.models.document import Document
+from extensions.ext_database import db
 from extensions.ext_redis import redis_client
+from models.dataset import Dataset, DatasetCollectionBinding
 
 if TYPE_CHECKING:
     from qdrant_client import grpc  # noqa
@@ -69,7 +76,7 @@ class QdrantVector(BaseVector):
         self._group_id = group_id
 
     def get_type(self) -> str:
-        return 'qdrant'
+        return VectorType.QDRANT
 
     def to_index_struct(self) -> dict:
         return {
@@ -408,3 +415,40 @@ class QdrantVector(BaseVector):
             page_content=scored_point.payload.get(content_payload_key),
             metadata=scored_point.payload.get(metadata_payload_key) or {},
         )
+
+
+class QdrantVectorFactory(AbstractVectorFactory):
+    def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector:
+        if dataset.collection_binding_id:
+            dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
+                filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
+                one_or_none()
+            if dataset_collection_binding:
+                collection_name = dataset_collection_binding.collection_name
+            else:
+                raise ValueError('Dataset Collection Bindings is not exist!')
+        else:
+            if dataset.index_struct_dict:
+                class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
+                collection_name = class_prefix
+            else:
+                dataset_id = dataset.id
+                collection_name = Dataset.gen_collection_name_by_id(dataset_id)
+
+        if not dataset.index_struct_dict:
+            dataset.index_struct = json.dumps(
+                self.gen_index_struct_dict(VectorType.QDRANT, collection_name))
+
+        config = current_app.config
+        return QdrantVector(
+            collection_name=collection_name,
+            group_id=dataset.id,
+            config=QdrantConfig(
+                endpoint=config.get('QDRANT_URL'),
+                api_key=config.get('QDRANT_API_KEY'),
+                root_path=config.root_path,
+                timeout=config.get('QDRANT_CLIENT_TIMEOUT'),
+                grpc_port=config.get('QDRANT_GRPC_PORT'),
+                prefer_grpc=config.get('QDRANT_GRPC_ENABLED')
+            )
+        )

+ 37 - 5
api/core/rag/datasource/vdb/relyt/relyt_vector.py

@@ -1,12 +1,19 @@
+import json
 import uuid
 from typing import Any, Optional
 
+from flask import current_app
 from pydantic import BaseModel, root_validator
 from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
 from sqlalchemy import text as sql_text
 from sqlalchemy.dialects.postgresql import JSON, TEXT
 from sqlalchemy.orm import Session
 
+from core.rag.datasource.entity.embedding import Embeddings
+from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
+from core.rag.datasource.vdb.vector_type import VectorType
+from models.dataset import Dataset
+
 try:
     from sqlalchemy.orm import declarative_base
 except ImportError:
@@ -53,7 +60,7 @@ class RelytVector(BaseVector):
         self._group_id = group_id
 
     def get_type(self) -> str:
-        return 'relyt'
+        return VectorType.RELYT
 
     def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
         index_params = {}
@@ -240,10 +247,10 @@ class RelytVector(BaseVector):
         return docs
 
     def similarity_search_with_score_by_vector(
-        self,
-        embedding: list[float],
-        k: int = 4,
-        filter: Optional[dict] = None,
+            self,
+            embedding: list[float],
+            k: int = 4,
+            filter: Optional[dict] = None,
     ) -> list[tuple[Document, float]]:
         # Add the filter if provided
         try:
@@ -298,3 +305,28 @@ class RelytVector(BaseVector):
     def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
         # milvus/zilliz/relyt doesn't support bm25 search
         return []
+
+
+class RelytVectorFactory(AbstractVectorFactory):
+    def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector:
+        if dataset.index_struct_dict:
+            class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
+            collection_name = class_prefix
+        else:
+            dataset_id = dataset.id
+            collection_name = Dataset.gen_collection_name_by_id(dataset_id)
+            dataset.index_struct = json.dumps(
+                self.gen_index_struct_dict(VectorType.RELYT, collection_name))
+
+        config = current_app.config
+        return RelytVector(
+            collection_name=collection_name,
+            config=RelytConfig(
+                host=config.get('RELYT_HOST'),
+                port=config.get('RELYT_PORT'),
+                user=config.get('RELYT_USER'),
+                password=config.get('RELYT_PASSWORD'),
+                database=config.get('RELYT_DATABASE'),
+            ),
+            group_id=dataset.id
+        )

+ 33 - 0
api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py

@@ -3,14 +3,19 @@ import logging
 from typing import Any
 
 import sqlalchemy
+from flask import current_app
 from pydantic import BaseModel, root_validator
 from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
 from sqlalchemy import text as sql_text
 from sqlalchemy.orm import Session, declarative_base
 
+from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
+from core.rag.datasource.vdb.vector_type import VectorType
 from core.rag.models.document import Document
 from extensions.ext_redis import redis_client
+from models.dataset import Dataset
 
 logger = logging.getLogger(__name__)
 
@@ -39,6 +44,9 @@ class TiDBVectorConfig(BaseModel):
 
 class TiDBVector(BaseVector):
 
+    def get_type(self) -> str:
+        return VectorType.TIDB_VECTOR
+
     def _table(self, dim: int) -> Table:
         from tidb_vector.sqlalchemy import VectorType
         return Table(
@@ -214,3 +222,28 @@ class TiDBVector(BaseVector):
         with Session(self._engine) as session:
             session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
             session.commit()
+
+
+class TiDBVectorFactory(AbstractVectorFactory):
+    def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector:
+
+        if dataset.index_struct_dict:
+            class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
+            collection_name = class_prefix.lower()
+        else:
+            dataset_id = dataset.id
+            collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
+            dataset.index_struct = json.dumps(
+                self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
+
+        config = current_app.config
+        return TiDBVector(
+            collection_name=collection_name,
+            config=TiDBVectorConfig(
+                host=config.get('TIDB_VECTOR_HOST'),
+                port=config.get('TIDB_VECTOR_PORT'),
+                user=config.get('TIDB_VECTOR_USER'),
+                password=config.get('TIDB_VECTOR_PASSWORD'),
+                database=config.get('TIDB_VECTOR_DATABASE'),
+            ),
+        )

+ 4 - 0
api/core/rag/datasource/vdb/vector_base.py

@@ -11,6 +11,10 @@ class BaseVector(ABC):
     def __init__(self, collection_name: str):
         self._collection_name = collection_name
 
+    @abstractmethod
+    def get_type(self) -> str:
+        raise NotImplementedError
+
     @abstractmethod
     def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
         raise NotImplementedError

+ 46 - 185
api/core/rag/datasource/vdb/vector_factory.py

@@ -1,4 +1,4 @@
-import json
+from abc import ABC, abstractmethod
 from typing import Any
 
 from flask import current_app
@@ -8,9 +8,23 @@ from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
 from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.datasource.vdb.vector_type import VectorType
 from core.rag.models.document import Document
-from extensions.ext_database import db
-from models.dataset import Dataset, DatasetCollectionBinding
+from models.dataset import Dataset
+
+
+class AbstractVectorFactory(ABC):
+    @abstractmethod
+    def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
+        raise NotImplementedError
+
+    @staticmethod
+    def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> dict:
+        index_struct_dict = {
+            "type": vector_type,
+            "vector_store": {"class_prefix": collection_name}
+        }
+        return index_struct_dict
 
 
 class Vector:
@@ -32,188 +46,35 @@ class Vector:
         if not vector_type:
             raise ValueError("Vector store must be specified.")
 
-        if vector_type == "weaviate":
-            from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
-            if self._dataset.index_struct_dict:
-                class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
-                collection_name = class_prefix
-            else:
-                dataset_id = self._dataset.id
-                collection_name = Dataset.gen_collection_name_by_id(dataset_id)
-                index_struct_dict = {
-                    "type": 'weaviate',
-                    "vector_store": {"class_prefix": collection_name}
-                }
-                self._dataset.index_struct = json.dumps(index_struct_dict)
-            return WeaviateVector(
-                collection_name=collection_name,
-                config=WeaviateConfig(
-                    endpoint=config.get('WEAVIATE_ENDPOINT'),
-                    api_key=config.get('WEAVIATE_API_KEY'),
-                    batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
-                ),
-                attributes=self._attributes
-            )
-        elif vector_type == "qdrant":
-            from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
-            if self._dataset.collection_binding_id:
-                dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
-                    filter(DatasetCollectionBinding.id == self._dataset.collection_binding_id). \
-                    one_or_none()
-                if dataset_collection_binding:
-                    collection_name = dataset_collection_binding.collection_name
-                else:
-                    raise ValueError('Dataset Collection Bindings is not exist!')
-            else:
-                if self._dataset.index_struct_dict:
-                    class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
-                    collection_name = class_prefix
-                else:
-                    dataset_id = self._dataset.id
-                    collection_name = Dataset.gen_collection_name_by_id(dataset_id)
-
-            if not self._dataset.index_struct_dict:
-                index_struct_dict = {
-                    "type": 'qdrant',
-                    "vector_store": {"class_prefix": collection_name}
-                }
-                self._dataset.index_struct = json.dumps(index_struct_dict)
-
-            return QdrantVector(
-                collection_name=collection_name,
-                group_id=self._dataset.id,
-                config=QdrantConfig(
-                    endpoint=config.get('QDRANT_URL'),
-                    api_key=config.get('QDRANT_API_KEY'),
-                    root_path=current_app.root_path,
-                    timeout=config.get('QDRANT_CLIENT_TIMEOUT'),
-                    grpc_port=config.get('QDRANT_GRPC_PORT'),
-                    prefer_grpc=config.get('QDRANT_GRPC_ENABLED')
-                )
-            )
-        elif vector_type == "milvus":
-            from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
-            if self._dataset.index_struct_dict:
-                class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
-                collection_name = class_prefix
-            else:
-                dataset_id = self._dataset.id
-                collection_name = Dataset.gen_collection_name_by_id(dataset_id)
-                index_struct_dict = {
-                    "type": 'milvus',
-                    "vector_store": {"class_prefix": collection_name}
-                }
-                self._dataset.index_struct = json.dumps(index_struct_dict)
-            return MilvusVector(
-                collection_name=collection_name,
-                config=MilvusConfig(
-                    host=config.get('MILVUS_HOST'),
-                    port=config.get('MILVUS_PORT'),
-                    user=config.get('MILVUS_USER'),
-                    password=config.get('MILVUS_PASSWORD'),
-                    secure=config.get('MILVUS_SECURE'),
-                    database=config.get('MILVUS_DATABASE'),
-                )
-            )
-        elif vector_type == "relyt":
-            from core.rag.datasource.vdb.relyt.relyt_vector import RelytConfig, RelytVector
-            if self._dataset.index_struct_dict:
-                class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
-                collection_name = class_prefix
-            else:
-                dataset_id = self._dataset.id
-                collection_name = Dataset.gen_collection_name_by_id(dataset_id)
-                index_struct_dict = {
-                    "type": 'relyt',
-                    "vector_store": {"class_prefix": collection_name}
-                }
-                self._dataset.index_struct = json.dumps(index_struct_dict)
-            return RelytVector(
-                collection_name=collection_name,
-                config=RelytConfig(
-                    host=config.get('RELYT_HOST'),
-                    port=config.get('RELYT_PORT'),
-                    user=config.get('RELYT_USER'),
-                    password=config.get('RELYT_PASSWORD'),
-                    database=config.get('RELYT_DATABASE'),
-                ),
-                group_id=self._dataset.id
-            )
-        elif vector_type == "pgvecto_rs":
-            from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig
-            if self._dataset.index_struct_dict:
-                class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
-                collection_name = class_prefix.lower()
-            else:
-                dataset_id = self._dataset.id
-                collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
-                index_struct_dict = {
-                    "type": 'pgvecto_rs',
-                    "vector_store": {"class_prefix": collection_name}
-                }
-                self._dataset.index_struct = json.dumps(index_struct_dict)
-            dim = len(self._embeddings.embed_query("pgvecto_rs"))
-            return PGVectoRS(
-                collection_name=collection_name,
-                config=PgvectoRSConfig(
-                    host=config.get('PGVECTO_RS_HOST'),
-                    port=config.get('PGVECTO_RS_PORT'),
-                    user=config.get('PGVECTO_RS_USER'),
-                    password=config.get('PGVECTO_RS_PASSWORD'),
-                    database=config.get('PGVECTO_RS_DATABASE'),
-                ),
-                dim=dim
-            )
-        elif vector_type == "pgvector":
-            from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig
-
-            if self._dataset.index_struct_dict:
-                class_prefix: str = self._dataset.index_struct_dict["vector_store"]["class_prefix"]
-                collection_name = class_prefix
-            else:
-                dataset_id = self._dataset.id
-                collection_name = Dataset.gen_collection_name_by_id(dataset_id)
-                index_struct_dict = {
-                    "type": "pgvector",
-                    "vector_store": {"class_prefix": collection_name}}
-                self._dataset.index_struct = json.dumps(index_struct_dict)
-            return PGVector(
-                collection_name=collection_name,
-                config=PGVectorConfig(
-                    host=config.get("PGVECTOR_HOST"),
-                    port=config.get("PGVECTOR_PORT"),
-                    user=config.get("PGVECTOR_USER"),
-                    password=config.get("PGVECTOR_PASSWORD"),
-                    database=config.get("PGVECTOR_DATABASE"),
-                ),
-            )
-        elif vector_type == "tidb_vector":
-            from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig
-
-            if self._dataset.index_struct_dict:
-                class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
-                collection_name = class_prefix.lower()
-            else:
-                dataset_id = self._dataset.id
-                collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
-                index_struct_dict = {
-                    "type": 'tidb_vector',
-                    "vector_store": {"class_prefix": collection_name}
-                }
-                self._dataset.index_struct = json.dumps(index_struct_dict)
-
-            return TiDBVector(
-                collection_name=collection_name,
-                config=TiDBVectorConfig(
-                    host=config.get('TIDB_VECTOR_HOST'),
-                    port=config.get('TIDB_VECTOR_PORT'),
-                    user=config.get('TIDB_VECTOR_USER'),
-                    password=config.get('TIDB_VECTOR_PASSWORD'),
-                    database=config.get('TIDB_VECTOR_DATABASE'),
-                ),
-            )
-        else:
-            raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
+        vector_factory_cls = self.get_vector_factory(vector_type)
+        return vector_factory_cls().init_vector(self._dataset, self._attributes, self._embeddings)
+
+    @staticmethod
+    def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
+        match vector_type:
+            case VectorType.MILVUS:
+                from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
+                return MilvusVectorFactory
+            case VectorType.PGVECTOR:
+                from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
+                return PGVectorFactory
+            case VectorType.PGVECTO_RS:
+                from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory
+                return PGVectoRSFactory
+            case VectorType.QDRANT:
+                from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory
+                return QdrantVectorFactory
+            case VectorType.RELYT:
+                from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
+                return RelytVectorFactory
+            case VectorType.TIDB_VECTOR:
+                from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
+                return TiDBVectorFactory
+            case VectorType.WEAVIATE:
+                from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
+                return WeaviateVectorFactory
+            case _:
+                raise ValueError(f"Vector store {vector_type} is not supported.")
 
     def create(self, texts: list = None, **kwargs):
         if texts:

+ 11 - 0
api/core/rag/datasource/vdb/vector_type.py

@@ -0,0 +1,11 @@
+from enum import Enum
+
+
+class VectorType(str, Enum):
+    MILVUS = 'milvus'
+    PGVECTOR = 'pgvector'
+    PGVECTO_RS = 'pgvecto-rs'
+    QDRANT = 'qdrant'
+    RELYT = 'relyt'
+    TIDB_VECTOR = 'tidb_vector'
+    WEAVIATE = 'weaviate'

+ 28 - 1
api/core/rag/datasource/vdb/weaviate/weaviate_vector.py

@@ -1,12 +1,17 @@
 import datetime
+import json
 from typing import Any, Optional
 
 import requests
 import weaviate
+from flask import current_app
 from pydantic import BaseModel, root_validator
 
+from core.rag.datasource.entity.embedding import Embeddings
 from core.rag.datasource.vdb.field import Field
 from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
+from core.rag.datasource.vdb.vector_type import VectorType
 from core.rag.models.document import Document
 from extensions.ext_redis import redis_client
 from models.dataset import Dataset
@@ -59,7 +64,7 @@ class WeaviateVector(BaseVector):
         return client
 
     def get_type(self) -> str:
-        return 'weaviate'
+        return VectorType.WEAVIATE
 
     def get_collection_name(self, dataset: Dataset) -> str:
         if dataset.index_struct_dict:
@@ -255,3 +260,25 @@ class WeaviateVector(BaseVector):
         if isinstance(value, datetime.datetime):
             return value.isoformat()
         return value
+
+
+class WeaviateVectorFactory(AbstractVectorFactory):
+    def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector:
+        if dataset.index_struct_dict:
+            class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
+            collection_name = class_prefix
+        else:
+            dataset_id = dataset.id
+            collection_name = Dataset.gen_collection_name_by_id(dataset_id)
+            dataset.index_struct = json.dumps(
+                self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
+
+        return WeaviateVector(
+            collection_name=collection_name,
+            config=WeaviateConfig(
+                endpoint=current_app.config.get('WEAVIATE_ENDPOINT'),
+                api_key=current_app.config.get('WEAVIATE_API_KEY'),
+                batch_size=int(current_app.config.get('WEAVIATE_BATCH_SIZE'))
+            ),
+            attributes=attributes
+        )