|
@@ -1,4 +1,4 @@
|
|
|
-import json
|
|
|
+from abc import ABC, abstractmethod
|
|
|
from typing import Any
|
|
|
|
|
|
from flask import current_app
|
|
@@ -8,9 +8,23 @@ from core.model_manager import ModelManager
|
|
|
from core.model_runtime.entities.model_entities import ModelType
|
|
|
from core.rag.datasource.entity.embedding import Embeddings
|
|
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
|
|
+from core.rag.datasource.vdb.vector_type import VectorType
|
|
|
from core.rag.models.document import Document
|
|
|
-from extensions.ext_database import db
|
|
|
-from models.dataset import Dataset, DatasetCollectionBinding
|
|
|
+from models.dataset import Dataset
|
|
|
+
|
|
|
+
|
|
|
+class AbstractVectorFactory(ABC):
|
|
|
+ @abstractmethod
|
|
|
+ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
|
|
|
+ raise NotImplementedError
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> dict:
|
|
|
+ index_struct_dict = {
|
|
|
+ "type": vector_type,
|
|
|
+ "vector_store": {"class_prefix": collection_name}
|
|
|
+ }
|
|
|
+ return index_struct_dict
|
|
|
|
|
|
|
|
|
class Vector:
|
|
@@ -32,188 +46,35 @@ class Vector:
|
|
|
if not vector_type:
|
|
|
raise ValueError("Vector store must be specified.")
|
|
|
|
|
|
- if vector_type == "weaviate":
|
|
|
- from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
|
|
|
- if self._dataset.index_struct_dict:
|
|
|
- class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
|
|
- collection_name = class_prefix
|
|
|
- else:
|
|
|
- dataset_id = self._dataset.id
|
|
|
- collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
|
- index_struct_dict = {
|
|
|
- "type": 'weaviate',
|
|
|
- "vector_store": {"class_prefix": collection_name}
|
|
|
- }
|
|
|
- self._dataset.index_struct = json.dumps(index_struct_dict)
|
|
|
- return WeaviateVector(
|
|
|
- collection_name=collection_name,
|
|
|
- config=WeaviateConfig(
|
|
|
- endpoint=config.get('WEAVIATE_ENDPOINT'),
|
|
|
- api_key=config.get('WEAVIATE_API_KEY'),
|
|
|
- batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
|
|
|
- ),
|
|
|
- attributes=self._attributes
|
|
|
- )
|
|
|
- elif vector_type == "qdrant":
|
|
|
- from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
|
|
|
- if self._dataset.collection_binding_id:
|
|
|
- dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
|
|
- filter(DatasetCollectionBinding.id == self._dataset.collection_binding_id). \
|
|
|
- one_or_none()
|
|
|
- if dataset_collection_binding:
|
|
|
- collection_name = dataset_collection_binding.collection_name
|
|
|
- else:
|
|
|
- raise ValueError('Dataset Collection Bindings is not exist!')
|
|
|
- else:
|
|
|
- if self._dataset.index_struct_dict:
|
|
|
- class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
|
|
- collection_name = class_prefix
|
|
|
- else:
|
|
|
- dataset_id = self._dataset.id
|
|
|
- collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
|
-
|
|
|
- if not self._dataset.index_struct_dict:
|
|
|
- index_struct_dict = {
|
|
|
- "type": 'qdrant',
|
|
|
- "vector_store": {"class_prefix": collection_name}
|
|
|
- }
|
|
|
- self._dataset.index_struct = json.dumps(index_struct_dict)
|
|
|
-
|
|
|
- return QdrantVector(
|
|
|
- collection_name=collection_name,
|
|
|
- group_id=self._dataset.id,
|
|
|
- config=QdrantConfig(
|
|
|
- endpoint=config.get('QDRANT_URL'),
|
|
|
- api_key=config.get('QDRANT_API_KEY'),
|
|
|
- root_path=current_app.root_path,
|
|
|
- timeout=config.get('QDRANT_CLIENT_TIMEOUT'),
|
|
|
- grpc_port=config.get('QDRANT_GRPC_PORT'),
|
|
|
- prefer_grpc=config.get('QDRANT_GRPC_ENABLED')
|
|
|
- )
|
|
|
- )
|
|
|
- elif vector_type == "milvus":
|
|
|
- from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
|
|
|
- if self._dataset.index_struct_dict:
|
|
|
- class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
|
|
- collection_name = class_prefix
|
|
|
- else:
|
|
|
- dataset_id = self._dataset.id
|
|
|
- collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
|
- index_struct_dict = {
|
|
|
- "type": 'milvus',
|
|
|
- "vector_store": {"class_prefix": collection_name}
|
|
|
- }
|
|
|
- self._dataset.index_struct = json.dumps(index_struct_dict)
|
|
|
- return MilvusVector(
|
|
|
- collection_name=collection_name,
|
|
|
- config=MilvusConfig(
|
|
|
- host=config.get('MILVUS_HOST'),
|
|
|
- port=config.get('MILVUS_PORT'),
|
|
|
- user=config.get('MILVUS_USER'),
|
|
|
- password=config.get('MILVUS_PASSWORD'),
|
|
|
- secure=config.get('MILVUS_SECURE'),
|
|
|
- database=config.get('MILVUS_DATABASE'),
|
|
|
- )
|
|
|
- )
|
|
|
- elif vector_type == "relyt":
|
|
|
- from core.rag.datasource.vdb.relyt.relyt_vector import RelytConfig, RelytVector
|
|
|
- if self._dataset.index_struct_dict:
|
|
|
- class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
|
|
- collection_name = class_prefix
|
|
|
- else:
|
|
|
- dataset_id = self._dataset.id
|
|
|
- collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
|
- index_struct_dict = {
|
|
|
- "type": 'relyt',
|
|
|
- "vector_store": {"class_prefix": collection_name}
|
|
|
- }
|
|
|
- self._dataset.index_struct = json.dumps(index_struct_dict)
|
|
|
- return RelytVector(
|
|
|
- collection_name=collection_name,
|
|
|
- config=RelytConfig(
|
|
|
- host=config.get('RELYT_HOST'),
|
|
|
- port=config.get('RELYT_PORT'),
|
|
|
- user=config.get('RELYT_USER'),
|
|
|
- password=config.get('RELYT_PASSWORD'),
|
|
|
- database=config.get('RELYT_DATABASE'),
|
|
|
- ),
|
|
|
- group_id=self._dataset.id
|
|
|
- )
|
|
|
- elif vector_type == "pgvecto_rs":
|
|
|
- from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig
|
|
|
- if self._dataset.index_struct_dict:
|
|
|
- class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
|
|
- collection_name = class_prefix.lower()
|
|
|
- else:
|
|
|
- dataset_id = self._dataset.id
|
|
|
- collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
|
|
- index_struct_dict = {
|
|
|
- "type": 'pgvecto_rs',
|
|
|
- "vector_store": {"class_prefix": collection_name}
|
|
|
- }
|
|
|
- self._dataset.index_struct = json.dumps(index_struct_dict)
|
|
|
- dim = len(self._embeddings.embed_query("pgvecto_rs"))
|
|
|
- return PGVectoRS(
|
|
|
- collection_name=collection_name,
|
|
|
- config=PgvectoRSConfig(
|
|
|
- host=config.get('PGVECTO_RS_HOST'),
|
|
|
- port=config.get('PGVECTO_RS_PORT'),
|
|
|
- user=config.get('PGVECTO_RS_USER'),
|
|
|
- password=config.get('PGVECTO_RS_PASSWORD'),
|
|
|
- database=config.get('PGVECTO_RS_DATABASE'),
|
|
|
- ),
|
|
|
- dim=dim
|
|
|
- )
|
|
|
- elif vector_type == "pgvector":
|
|
|
- from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig
|
|
|
-
|
|
|
- if self._dataset.index_struct_dict:
|
|
|
- class_prefix: str = self._dataset.index_struct_dict["vector_store"]["class_prefix"]
|
|
|
- collection_name = class_prefix
|
|
|
- else:
|
|
|
- dataset_id = self._dataset.id
|
|
|
- collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
|
- index_struct_dict = {
|
|
|
- "type": "pgvector",
|
|
|
- "vector_store": {"class_prefix": collection_name}}
|
|
|
- self._dataset.index_struct = json.dumps(index_struct_dict)
|
|
|
- return PGVector(
|
|
|
- collection_name=collection_name,
|
|
|
- config=PGVectorConfig(
|
|
|
- host=config.get("PGVECTOR_HOST"),
|
|
|
- port=config.get("PGVECTOR_PORT"),
|
|
|
- user=config.get("PGVECTOR_USER"),
|
|
|
- password=config.get("PGVECTOR_PASSWORD"),
|
|
|
- database=config.get("PGVECTOR_DATABASE"),
|
|
|
- ),
|
|
|
- )
|
|
|
- elif vector_type == "tidb_vector":
|
|
|
- from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig
|
|
|
-
|
|
|
- if self._dataset.index_struct_dict:
|
|
|
- class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
|
|
- collection_name = class_prefix.lower()
|
|
|
- else:
|
|
|
- dataset_id = self._dataset.id
|
|
|
- collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
|
|
- index_struct_dict = {
|
|
|
- "type": 'tidb_vector',
|
|
|
- "vector_store": {"class_prefix": collection_name}
|
|
|
- }
|
|
|
- self._dataset.index_struct = json.dumps(index_struct_dict)
|
|
|
-
|
|
|
- return TiDBVector(
|
|
|
- collection_name=collection_name,
|
|
|
- config=TiDBVectorConfig(
|
|
|
- host=config.get('TIDB_VECTOR_HOST'),
|
|
|
- port=config.get('TIDB_VECTOR_PORT'),
|
|
|
- user=config.get('TIDB_VECTOR_USER'),
|
|
|
- password=config.get('TIDB_VECTOR_PASSWORD'),
|
|
|
- database=config.get('TIDB_VECTOR_DATABASE'),
|
|
|
- ),
|
|
|
- )
|
|
|
- else:
|
|
|
- raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
|
|
+ vector_factory_cls = self.get_vector_factory(vector_type)
|
|
|
+ return vector_factory_cls().init_vector(self._dataset, self._attributes, self._embeddings)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
|
|
|
+ match vector_type:
|
|
|
+ case VectorType.MILVUS:
|
|
|
+ from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
|
|
|
+ return MilvusVectorFactory
|
|
|
+ case VectorType.PGVECTOR:
|
|
|
+ from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
|
|
|
+ return PGVectorFactory
|
|
|
+ case VectorType.PGVECTO_RS:
|
|
|
+ from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory
|
|
|
+ return PGVectoRSFactory
|
|
|
+ case VectorType.QDRANT:
|
|
|
+ from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory
|
|
|
+ return QdrantVectorFactory
|
|
|
+ case VectorType.RELYT:
|
|
|
+ from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
|
|
|
+ return RelytVectorFactory
|
|
|
+ case VectorType.TIDB_VECTOR:
|
|
|
+ from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
|
|
|
+ return TiDBVectorFactory
|
|
|
+ case VectorType.WEAVIATE:
|
|
|
+ from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
|
|
|
+ return WeaviateVectorFactory
|
|
|
+ case _:
|
|
|
+ raise ValueError(f"Vector store {vector_type} is not supported.")
|
|
|
|
|
|
def create(self, texts: list = None, **kwargs):
|
|
|
if texts:
|