Browse Source

fix: using api can not execute relyt vector database (#3766)

Co-authored-by: jingsi <jingsi@leadincloud.com>
Jingpan Xiong 11 months ago
parent
commit
1be222af2e

+ 2 - 2
api/controllers/console/datasets/datasets.py

@@ -476,7 +476,7 @@ class DatasetRetrievalSettingApi(Resource):
     @account_initialization_required
     def get(self):
         vector_type = current_app.config['VECTOR_STORE']
-        if vector_type == 'milvus':
+        if vector_type == 'milvus' or vector_type == 'relyt':
             return {
                 'retrieval_method': [
                     'semantic_search'
@@ -498,7 +498,7 @@ class DatasetRetrievalSettingMockApi(Resource):
     @account_initialization_required
     def get(self, vector_type):
 
-        if vector_type == 'milvus':
+        if vector_type == 'milvus' or vector_type == 'relyt':
             return {
                 'retrieval_method': [
                     'semantic_search'

+ 175 - 44
api/core/rag/datasource/vdb/relyt/relyt_vector.py

@@ -1,16 +1,23 @@
-import logging
-from typing import Any
+import uuid
+from typing import Any, Optional
 
-from pgvecto_rs.sdk import PGVectoRs, Record
 from pydantic import BaseModel, root_validator
+from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
 from sqlalchemy import text as sql_text
+from sqlalchemy.dialects.postgresql import JSON, TEXT
 from sqlalchemy.orm import Session
 
+try:
+    from sqlalchemy.orm import declarative_base
+except ImportError:
+    from sqlalchemy.ext.declarative import declarative_base
+
 from core.rag.datasource.vdb.vector_base import BaseVector
 from core.rag.models.document import Document
 from extensions.ext_redis import redis_client
 
-logger = logging.getLogger(__name__)
+Base = declarative_base()  # type: Any
+
 
 class RelytConfig(BaseModel):
     host: str
@@ -36,16 +43,14 @@ class RelytConfig(BaseModel):
 
 class RelytVector(BaseVector):
 
-    def __init__(self, collection_name: str, config: RelytConfig, dim: int):
+    def __init__(self, collection_name: str, config: RelytConfig, group_id: str):
         super().__init__(collection_name)
+        self.embedding_dimension = 1536
         self._client_config = config
         self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
-        self._client = PGVectoRs(
-            db_url=self._url,
-            collection_name=self._collection_name,
-            dimension=dim
-        )
+        self.client = create_engine(self._url)
         self._fields = []
+        self._group_id = group_id
 
     def get_type(self) -> str:
         return 'relyt'
@@ -54,6 +59,7 @@ class RelytVector(BaseVector):
         index_params = {}
         metadatas = [d.metadata for d in texts]
         self.create_collection(len(embeddings[0]))
+        self.embedding_dimension = len(embeddings[0])
         self.add_texts(texts, embeddings)
 
     def create_collection(self, dimension: int):
@@ -63,21 +69,21 @@ class RelytVector(BaseVector):
             if redis_client.get(collection_exist_cache_key):
                 return
             index_name = f"{self._collection_name}_embedding_index"
-            with Session(self._client._engine) as session:
-                drop_statement = sql_text(f"DROP TABLE IF EXISTS collection_{self._collection_name}")
+            with Session(self.client) as session:
+                drop_statement = sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}"; """)
                 session.execute(drop_statement)
                 create_statement = sql_text(f"""
-                    CREATE TABLE IF NOT EXISTS collection_{self._collection_name} (
-                        id UUID PRIMARY KEY,
-                        text TEXT NOT NULL,
-                        meta JSONB NOT NULL,
+                    CREATE TABLE IF NOT EXISTS "{self._collection_name}" (
+                        id TEXT PRIMARY KEY,
+                        document TEXT NOT NULL,
+                        metadata JSON NOT NULL,
                         embedding vector({dimension}) NOT NULL
                     ) using heap; 
                 """)
                 session.execute(create_statement)
                 index_statement = sql_text(f"""
                         CREATE INDEX {index_name}
-                        ON collection_{self._collection_name} USING vectors(embedding vector_l2_ops)
+                        ON "{self._collection_name}" USING vectors(embedding vector_l2_ops)
                         WITH (options = $$
                                 optimizing.optimizing_threads = 30
                                 segment.max_growing_segment_size = 2000
@@ -92,21 +98,62 @@ class RelytVector(BaseVector):
             redis_client.set(collection_exist_cache_key, 1, ex=3600)
 
     def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
-        records = [Record.from_text(d.page_content, e, d.metadata) for d, e in zip(documents, embeddings)]
-        pks = [str(r.id) for r in records]
-        self._client.insert(records)
-        return pks
+        from pgvecto_rs.sqlalchemy import Vector
+
+        ids = [str(uuid.uuid1()) for _ in documents]
+        metadatas = [d.metadata for d in documents]
+        for metadata in metadatas:
+            metadata['group_id'] = self._group_id
+        texts = [d.page_content for d in documents]
+
+        # Define the table schema
+        chunks_table = Table(
+            self._collection_name,
+            Base.metadata,
+            Column("id", TEXT, primary_key=True),
+            Column("embedding", Vector(len(embeddings[0]))),
+            Column("document", String, nullable=True),
+            Column("metadata", JSON, nullable=True),
+            extend_existing=True,
+        )
+
+        chunks_table_data = []
+        with self.client.connect() as conn:
+            with conn.begin():
+                for document, metadata, chunk_id, embedding in zip(
+                        texts, metadatas, ids, embeddings
+                ):
+                    chunks_table_data.append(
+                        {
+                            "id": chunk_id,
+                            "embedding": embedding,
+                            "document": document,
+                            "metadata": metadata,
+                        }
+                    )
+
+                    # Execute the batch insert when the batch size is reached
+                    if len(chunks_table_data) == 500:
+                        conn.execute(insert(chunks_table).values(chunks_table_data))
+                        # Clear the chunks_table_data list for the next batch
+                        chunks_table_data.clear()
+
+                # Insert any remaining records that didn't make up a full batch
+                if chunks_table_data:
+                    conn.execute(insert(chunks_table).values(chunks_table_data))
+
+        return ids
 
     def delete_by_document_id(self, document_id: str):
         ids = self.get_ids_by_metadata_field('document_id', document_id)
         if ids:
-            self._client.delete_by_ids(ids)
+            self.delete_by_uuids(ids)
 
     def get_ids_by_metadata_field(self, key: str, value: str):
         result = None
-        with Session(self._client._engine) as session:
+        with Session(self.client) as session:
             select_statement = sql_text(
-                f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'{key}' = '{value}'; "
+                f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'{key}' = '{value}'; """
             )
             result = session.execute(select_statement).fetchall()
         if result:
@@ -114,56 +161,140 @@ class RelytVector(BaseVector):
         else:
             return None
 
+    def delete_by_uuids(self, ids: list[str] = None):
+        """Delete by vector IDs.
+
+        Args:
+            ids: List of ids to delete.
+        """
+        from pgvecto_rs.sqlalchemy import Vector
+
+        if ids is None:
+            raise ValueError("No ids provided to delete.")
+
+        # Define the table schema
+        chunks_table = Table(
+            self._collection_name,
+            Base.metadata,
+            Column("id", TEXT, primary_key=True),
+            Column("embedding", Vector(self.embedding_dimension)),
+            Column("document", String, nullable=True),
+            Column("metadata", JSON, nullable=True),
+            extend_existing=True,
+        )
+
+        try:
+            with self.client.connect() as conn:
+                with conn.begin():
+                    delete_condition = chunks_table.c.id.in_(ids)
+                    conn.execute(chunks_table.delete().where(delete_condition))
+                    return True
+        except Exception as e:
+            print("Delete operation failed:", str(e))  # noqa: T201
+            return False
+
     def delete_by_metadata_field(self, key: str, value: str):
 
         ids = self.get_ids_by_metadata_field(key, value)
         if ids:
-            self._client.delete_by_ids(ids)
+            self.delete_by_uuids(ids)
 
     def delete_by_ids(self, doc_ids: list[str]) -> None:
-        with Session(self._client._engine) as session:
+
+        with Session(self.client) as session:
+            ids_str = ','.join(f"'{doc_id}'" for doc_id in doc_ids)
             select_statement = sql_text(
-                f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'doc_id' in ('{doc_ids}'); "
+                f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """
             )
             result = session.execute(select_statement).fetchall()
         if result:
             ids = [item[0] for item in result]
-            self._client.delete_by_ids(ids)
+            self.delete_by_uuids(ids)
 
     def delete(self) -> None:
-        with Session(self._client._engine) as session:
-            session.execute(sql_text(f"DROP TABLE IF EXISTS collection_{self._collection_name}"))
+        with Session(self.client) as session:
+            session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";"""))
             session.commit()
 
     def text_exists(self, id: str) -> bool:
-        with Session(self._client._engine) as session:
+        with Session(self.client) as session:
             select_statement = sql_text(
-                f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'doc_id' = '{id}' limit 1; "
+                f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' = '{id}' limit 1; """
             )
             result = session.execute(select_statement).fetchall()
         return len(result) > 0
 
     def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
-        from pgvecto_rs.sdk import filters
-        filter_condition = filters.meta_contains(kwargs.get('filter'))
-        results = self._client.search(
-            top_k=int(kwargs.get('top_k')),
+        results = self.similarity_search_with_score_by_vector(
+            k=int(kwargs.get('top_k')),
             embedding=query_vector,
-            filter=filter_condition
+            filter=kwargs.get('filter')
         )
 
         # Organize results.
         docs = []
-        for record, dis in results:
-            metadata = record.meta
-            metadata['score'] = dis
+        for document, score in results:
             score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
-            if dis > score_threshold:
-                doc = Document(page_content=record.text,
-                               metadata=metadata)
-                docs.append(doc)
+            if score > score_threshold:
+                docs.append(document)
         return docs
 
+    def similarity_search_with_score_by_vector(
+        self,
+        embedding: list[float],
+        k: int = 4,
+        filter: Optional[dict] = None,
+    ) -> list[tuple[Document, float]]:
+        # Add the filter if provided
+        try:
+            from sqlalchemy.engine import Row
+        except ImportError:
+            raise ImportError(
+                "Could not import Row from sqlalchemy.engine. "
+                "Please 'pip install sqlalchemy>=1.4'."
+            )
+
+        filter_condition = ""
+        if filter is not None:
+            conditions = [
+                f"metadata->>{key!r} in ({', '.join(map(repr, value))})" if len(value) > 1
+                else f"metadata->>{key!r} = {value[0]!r}"
+                for key, value in filter.items()
+            ]
+            filter_condition = f"WHERE {' AND '.join(conditions)}"
+
+        # Define the base query
+        sql_query = f"""
+            set vectors.enable_search_growing = on;
+            set vectors.enable_search_write = on;
+            SELECT document, metadata, embedding <-> :embedding as distance
+            FROM "{self._collection_name}"
+            {filter_condition}
+            ORDER BY embedding <-> :embedding
+            LIMIT :k
+        """
+
+        # Set up the query parameters
+        embedding_str = ", ".join(format(x) for x in embedding)
+        embedding_str = "[" + embedding_str + "]"
+        params = {"embedding": embedding_str, "k": k}
+
+        # Execute the query and fetch the results
+        with self.client.connect() as conn:
+            results: Sequence[Row] = conn.execute(sql_text(sql_query), params).fetchall()
+
+        documents_with_scores = [
+            (
+                Document(
+                    page_content=result.document,
+                    metadata=result.metadata,
+                ),
+                result.distance,
+            )
+            for result in results
+        ]
+        return documents_with_scores
+
     def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
         # milvus/zilliz/relyt doesn't support bm25 search
         return []

+ 1 - 2
api/core/rag/datasource/vdb/vector_factory.py

@@ -126,7 +126,6 @@ class Vector:
                     "vector_store": {"class_prefix": collection_name}
                 }
                 self._dataset.index_struct = json.dumps(index_struct_dict)
-            dim = len(self._embeddings.embed_query("hello relyt"))
             return RelytVector(
                 collection_name=collection_name,
                 config=RelytConfig(
@@ -136,7 +135,7 @@ class Vector:
                     password=config.get('RELYT_PASSWORD'),
                     database=config.get('RELYT_DATABASE'),
                 ),
-                dim=dim
+                group_id=self._dataset.id
             )
         else:
             raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")

+ 8 - 2
docker/docker-compose.yaml

@@ -86,7 +86,7 @@ services:
       AZURE_BLOB_ACCOUNT_KEY: 'difyai'
       AZURE_BLOB_CONTAINER_NAME: 'difyai-container'
       AZURE_BLOB_ACCOUNT_URL: 'https://<your_account_name>.blob.core.windows.net'
-      # The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`.
+      # The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`.
       VECTOR_STORE: weaviate
       # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
       WEAVIATE_ENDPOINT: http://weaviate:8080
@@ -109,6 +109,12 @@ services:
       MILVUS_PASSWORD: Milvus
       # The milvus tls switch.
       MILVUS_SECURE: 'false'
+      # relyt configurations
+      RELYT_HOST: db
+      RELYT_PORT: 5432
+      RELYT_USER: postgres
+      RELYT_PASSWORD: difyai123456
+      RELYT_DATABASE: postgres
       # Mail configuration, support: resend, smtp
       MAIL_TYPE: ''
       # default send from email address, if not specified
@@ -193,7 +199,7 @@ services:
       AZURE_BLOB_ACCOUNT_KEY: 'difyai'
       AZURE_BLOB_CONTAINER_NAME: 'difyai-container'
       AZURE_BLOB_ACCOUNT_URL: 'https://<your_account_name>.blob.core.windows.net'
-      # The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`.
+      # The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`.
       VECTOR_STORE: weaviate
       # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
       WEAVIATE_ENDPOINT: http://weaviate:8080