|  | @@ -0,0 +1,214 @@
 | 
	
		
			
				|  |  | +import json
 | 
	
		
			
				|  |  | +import logging
 | 
	
		
			
				|  |  | +from typing import Any
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +import sqlalchemy
 | 
	
		
			
				|  |  | +from pydantic import BaseModel, root_validator
 | 
	
		
			
				|  |  | +from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
 | 
	
		
			
				|  |  | +from sqlalchemy import text as sql_text
 | 
	
		
			
				|  |  | +from sqlalchemy.orm import Session, declarative_base
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +from core.rag.datasource.vdb.vector_base import BaseVector
 | 
	
		
			
				|  |  | +from core.rag.models.document import Document
 | 
	
		
			
				|  |  | +from extensions.ext_redis import redis_client
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +logger = logging.getLogger(__name__)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class TiDBVectorConfig(BaseModel):
 | 
	
		
			
				|  |  | +    host: str
 | 
	
		
			
				|  |  | +    port: int
 | 
	
		
			
				|  |  | +    user: str
 | 
	
		
			
				|  |  | +    password: str
 | 
	
		
			
				|  |  | +    database: str
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @root_validator()
 | 
	
		
			
				|  |  | +    def validate_config(cls, values: dict) -> dict:
 | 
	
		
			
				|  |  | +        if not values['host']:
 | 
	
		
			
				|  |  | +            raise ValueError("config TIDB_VECTOR_HOST is required")
 | 
	
		
			
				|  |  | +        if not values['port']:
 | 
	
		
			
				|  |  | +            raise ValueError("config TIDB_VECTOR_PORT is required")
 | 
	
		
			
				|  |  | +        if not values['user']:
 | 
	
		
			
				|  |  | +            raise ValueError("config TIDB_VECTOR_USER is required")
 | 
	
		
			
				|  |  | +        if not values['password']:
 | 
	
		
			
				|  |  | +            raise ValueError("config TIDB_VECTOR_PASSWORD is required")
 | 
	
		
			
				|  |  | +        if not values['database']:
 | 
	
		
			
				|  |  | +            raise ValueError("config TIDB_VECTOR_DATABASE is required")
 | 
	
		
			
				|  |  | +        return values
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class TiDBVector(BaseVector):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _table(self, dim: int) -> Table:
 | 
	
		
			
				|  |  | +        from tidb_vector.sqlalchemy import VectorType
 | 
	
		
			
				|  |  | +        return Table(
 | 
	
		
			
				|  |  | +            self._collection_name,
 | 
	
		
			
				|  |  | +            self._orm_base.metadata,
 | 
	
		
			
				|  |  | +            Column('id', String(36), primary_key=True, nullable=False),
 | 
	
		
			
				|  |  | +            Column("vector", VectorType(dim), nullable=False, comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})"),
 | 
	
		
			
				|  |  | +            Column("text", TEXT, nullable=False),
 | 
	
		
			
				|  |  | +            Column("meta", JSON, nullable=False),
 | 
	
		
			
				|  |  | +            Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
 | 
	
		
			
				|  |  | +            Column("update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")),
 | 
	
		
			
				|  |  | +            extend_existing=True
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = 'cosine'):
 | 
	
		
			
				|  |  | +        super().__init__(collection_name)
 | 
	
		
			
				|  |  | +        self._client_config = config
 | 
	
		
			
				|  |  | +        self._url = (f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?"
 | 
	
		
			
				|  |  | +                     f"ssl_verify_cert=true&ssl_verify_identity=true")
 | 
	
		
			
				|  |  | +        self._distance_func = distance_func.lower()
 | 
	
		
			
				|  |  | +        self._engine = create_engine(self._url)
 | 
	
		
			
				|  |  | +        self._orm_base = declarative_base()
 | 
	
		
			
				|  |  | +        self._dimension = 1536
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
 | 
	
		
			
				|  |  | +        logger.info("create collection and add texts, collection_name: " + self._collection_name)
 | 
	
		
			
				|  |  | +        self._create_collection(len(embeddings[0]))
 | 
	
		
			
				|  |  | +        self.add_texts(texts, embeddings)
 | 
	
		
			
				|  |  | +        self._dimension = len(embeddings[0])
 | 
	
		
			
				|  |  | +        pass
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _create_collection(self, dimension: int):
 | 
	
		
			
				|  |  | +        logger.info("_create_collection, collection_name " + self._collection_name)
 | 
	
		
			
				|  |  | +        lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
 | 
	
		
			
				|  |  | +        with redis_client.lock(lock_name, timeout=20):
 | 
	
		
			
				|  |  | +            collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
 | 
	
		
			
				|  |  | +            if redis_client.get(collection_exist_cache_key):
 | 
	
		
			
				|  |  | +                return
 | 
	
		
			
				|  |  | +            with Session(self._engine) as session:
 | 
	
		
			
				|  |  | +                session.begin()
 | 
	
		
			
				|  |  | +                drop_statement = sql_text(f"""DROP TABLE IF EXISTS {self._collection_name}; """)
 | 
	
		
			
				|  |  | +                session.execute(drop_statement)
 | 
	
		
			
				|  |  | +                create_statement = sql_text(f"""
 | 
	
		
			
				|  |  | +                    CREATE TABLE IF NOT EXISTS {self._collection_name} (
 | 
	
		
			
				|  |  | +                        id CHAR(36) PRIMARY KEY,
 | 
	
		
			
				|  |  | +                        text TEXT NOT NULL,
 | 
	
		
			
				|  |  | +                        meta JSON NOT NULL,
 | 
	
		
			
				|  |  | +                        vector VECTOR<FLOAT>({dimension}) NOT NULL COMMENT "hnsw(distance={self._distance_func})",
 | 
	
		
			
				|  |  | +                        create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
 | 
	
		
			
				|  |  | +                        update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
 | 
	
		
			
				|  |  | +                    );
 | 
	
		
			
				|  |  | +                """)
 | 
	
		
			
				|  |  | +                session.execute(create_statement)
 | 
	
		
			
				|  |  | +                # tidb vector not support 'CREATE/ADD INDEX' now
 | 
	
		
			
				|  |  | +                session.commit()
 | 
	
		
			
				|  |  | +            redis_client.set(collection_exist_cache_key, 1, ex=3600)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
 | 
	
		
			
				|  |  | +        table = self._table(len(embeddings[0]))
 | 
	
		
			
				|  |  | +        ids = self._get_uuids(documents)
 | 
	
		
			
				|  |  | +        metas = [d.metadata for d in documents]
 | 
	
		
			
				|  |  | +        texts = [d.page_content for d in documents]
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        chunks_table_data = []
 | 
	
		
			
				|  |  | +        with self._engine.connect() as conn:
 | 
	
		
			
				|  |  | +            with conn.begin():
 | 
	
		
			
				|  |  | +                for id, text, meta, embedding in zip(
 | 
	
		
			
				|  |  | +                        ids, texts, metas, embeddings
 | 
	
		
			
				|  |  | +                ):
 | 
	
		
			
				|  |  | +                    chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                    # Execute the batch insert when the batch size is reached
 | 
	
		
			
				|  |  | +                    if len(chunks_table_data) == 500:
 | 
	
		
			
				|  |  | +                        conn.execute(insert(table).values(chunks_table_data))
 | 
	
		
			
				|  |  | +                        # Clear the chunks_table_data list for the next batch
 | 
	
		
			
				|  |  | +                        chunks_table_data.clear()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                # Insert any remaining records that didn't make up a full batch
 | 
	
		
			
				|  |  | +                if chunks_table_data:
 | 
	
		
			
				|  |  | +                    conn.execute(insert(table).values(chunks_table_data))
 | 
	
		
			
				|  |  | +        return ids
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def text_exists(self, id: str) -> bool:
 | 
	
		
			
				|  |  | +        result = self.get_ids_by_metadata_field('doc_id', id)
 | 
	
		
			
				|  |  | +        return len(result) > 0
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def delete_by_ids(self, ids: list[str]) -> None:
 | 
	
		
			
				|  |  | +        with Session(self._engine) as session:
 | 
	
		
			
				|  |  | +            ids_str = ','.join(f"'{doc_id}'" for doc_id in ids)
 | 
	
		
			
				|  |  | +            select_statement = sql_text(
 | 
	
		
			
				|  |  | +                f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.doc_id' in ({ids_str}); """
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +            result = session.execute(select_statement).fetchall()
 | 
	
		
			
				|  |  | +        if result:
 | 
	
		
			
				|  |  | +            ids = [item[0] for item in result]
 | 
	
		
			
				|  |  | +            self._delete_by_ids(ids)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _delete_by_ids(self, ids: list[str]) -> bool:
 | 
	
		
			
				|  |  | +        if ids is None:
 | 
	
		
			
				|  |  | +            raise ValueError("No ids provided to delete.")
 | 
	
		
			
				|  |  | +        table = self._table(self._dimension)
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            with self._engine.connect() as conn:
 | 
	
		
			
				|  |  | +                with conn.begin():
 | 
	
		
			
				|  |  | +                    delete_condition = table.c.id.in_(ids)
 | 
	
		
			
				|  |  | +                    conn.execute(table.delete().where(delete_condition))
 | 
	
		
			
				|  |  | +                    return True
 | 
	
		
			
				|  |  | +        except Exception as e:
 | 
	
		
			
				|  |  | +            print("Delete operation failed:", str(e))
 | 
	
		
			
				|  |  | +            return False
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def delete_by_document_id(self, document_id: str):
 | 
	
		
			
				|  |  | +        ids = self.get_ids_by_metadata_field('document_id', document_id)
 | 
	
		
			
				|  |  | +        if ids:
 | 
	
		
			
				|  |  | +            self._delete_by_ids(ids)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def get_ids_by_metadata_field(self, key: str, value: str):
 | 
	
		
			
				|  |  | +        with Session(self._engine) as session:
 | 
	
		
			
				|  |  | +            select_statement = sql_text(
 | 
	
		
			
				|  |  | +                f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.{key}' = '{value}'; """
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +            result = session.execute(select_statement).fetchall()
 | 
	
		
			
				|  |  | +        if result:
 | 
	
		
			
				|  |  | +            return [item[0] for item in result]
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            return None
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def delete_by_metadata_field(self, key: str, value: str) -> None:
 | 
	
		
			
				|  |  | +        ids = self.get_ids_by_metadata_field(key, value)
 | 
	
		
			
				|  |  | +        if ids:
 | 
	
		
			
				|  |  | +            self._delete_by_ids(ids)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
 | 
	
		
			
				|  |  | +        top_k = kwargs.get("top_k", 5)
 | 
	
		
			
				|  |  | +        score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
 | 
	
		
			
				|  |  | +        filter = kwargs.get('filter')
 | 
	
		
			
				|  |  | +        distance = 1 - score_threshold
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        query_vector_str = ", ".join(format(x) for x in query_vector)
 | 
	
		
			
				|  |  | +        query_vector_str = "[" + query_vector_str + "]"
 | 
	
		
			
				|  |  | +        logger.debug(f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        docs = []
 | 
	
		
			
				|  |  | +        if self._distance_func == 'l2':
 | 
	
		
			
				|  |  | +            tidb_func = 'Vec_l2_distance'
 | 
	
		
			
				|  |  | +        elif self._distance_func == 'l2':
 | 
	
		
			
				|  |  | +            tidb_func = 'Vec_Cosine_distance'
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            tidb_func = 'Vec_Cosine_distance'
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        with Session(self._engine) as session:
 | 
	
		
			
				|  |  | +            select_statement = sql_text(
 | 
	
		
			
				|  |  | +                f"""SELECT meta, text FROM (
 | 
	
		
			
				|  |  | +                        SELECT meta, text, {tidb_func}(vector, "{query_vector_str}")  as distance 
 | 
	
		
			
				|  |  | +                        FROM {self._collection_name} 
 | 
	
		
			
				|  |  | +                        ORDER BY distance
 | 
	
		
			
				|  |  | +                        LIMIT {top_k}
 | 
	
		
			
				|  |  | +                    ) t WHERE distance < {distance};"""
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +            res = session.execute(select_statement)
 | 
	
		
			
				|  |  | +            results = [(row[0], row[1]) for row in res]
 | 
	
		
			
				|  |  | +            for meta, text in results:
 | 
	
		
			
				|  |  | +                docs.append(Document(page_content=text, metadata=json.loads(meta)))
 | 
	
		
			
				|  |  | +        return docs
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
 | 
	
		
			
				|  |  | +        # tidb doesn't support bm25 search
 | 
	
		
			
				|  |  | +        return []
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def delete(self) -> None:
 | 
	
		
			
				|  |  | +        with Session(self._engine) as session:
 | 
	
		
			
				|  |  | +            session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
 | 
	
		
			
				|  |  | +            session.commit()
 |