Browse Source

Feat/add milvus vector db (#1302)

Co-authored-by: jyong <jyong@dify.ai>
Jyong 1 year ago
parent
commit
07aab5e868

+ 7 - 0
api/.env.example

@@ -63,6 +63,13 @@ WEAVIATE_BATCH_SIZE=100
 QDRANT_URL=http://localhost:6333
 QDRANT_API_KEY=difyai123456
 
+# Milvus configuration
+MILVUS_HOST=127.0.0.1
+MILVUS_PORT=19530
+MILVUS_USER=root
+MILVUS_PASSWORD=Milvus
+MILVUS_SECURE=false
+
 # Mail configuration, support: resend
 MAIL_TYPE=
 MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>

+ 8 - 0
api/config.py

@@ -135,6 +135,14 @@ class Config:
         self.QDRANT_URL = get_env('QDRANT_URL')
         self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
 
+        # milvus setting
+        self.MILVUS_HOST = get_env('MILVUS_HOST')
+        self.MILVUS_PORT = get_env('MILVUS_PORT')
+        self.MILVUS_USER = get_env('MILVUS_USER')
+        self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
+        self.MILVUS_SECURE = get_env('MILVUS_SECURE')
+
+
         # cors settings
         self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
             'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)

+ 860 - 0
api/core/index/vector_index/milvus.py

@@ -0,0 +1,860 @@
+"""Wrapper around the Milvus vector database."""
+from __future__ import annotations
+
+import logging
+from typing import Any, Iterable, List, Optional, Tuple, Union, Sequence
+from uuid import uuid4
+
+import numpy as np
+
+from langchain.docstore.document import Document
+from langchain.embeddings.base import Embeddings
+from langchain.vectorstores.base import VectorStore
+from langchain.vectorstores.utils import maximal_marginal_relevance
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_MILVUS_CONNECTION = {
+    "host": "localhost",
+    "port": "19530",
+    "user": "",
+    "password": "",
+    "secure": False,
+}
+
+
+class Milvus(VectorStore):
+    """Initialize wrapper around the milvus vector database.
+
+    In order to use this you need to have `pymilvus` installed and a
+    running Milvus
+
+    See the following documentation for how to run a Milvus instance:
+    https://milvus.io/docs/install_standalone-docker.md
+
+    If looking for a hosted Milvus, take a look at this documentation:
+    https://zilliz.com/cloud and make use of the Zilliz vectorstore found in
+    this project,
+
+    IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA.
+
+    Args:
+        embedding_function (Embeddings): Function used to embed the text.
+        collection_name (str): Which Milvus collection to use. Defaults to
+            "LangChainCollection".
+        connection_args (Optional[dict[str, any]]): The connection args used for
+            this class comes in the form of a dict.
+        consistency_level (str): The consistency level to use for a collection.
+            Defaults to "Session".
+        index_params (Optional[dict]): Which index params to use. Defaults to
+            HNSW/AUTOINDEX depending on service.
+        search_params (Optional[dict]): Which search params to use. Defaults to
+            default of index.
+        drop_old (Optional[bool]): Whether to drop the current collection. Defaults
+            to False.
+
+    The connection args used for this class comes in the form of a dict,
+    here are a few of the options:
+        address (str): The actual address of Milvus
+            instance. Example address: "localhost:19530"
+        uri (str): The uri of Milvus instance. Example uri:
+            "http://randomwebsite:19530",
+            "tcp:foobarsite:19530",
+            "https://ok.s3.south.com:19530".
+        host (str): The host of Milvus instance. Default at "localhost",
+            PyMilvus will fill in the default host if only port is provided.
+        port (str/int): The port of Milvus instance. Default at 19530, PyMilvus
+            will fill in the default port if only host is provided.
+        user (str): Use which user to connect to Milvus instance. If user and
+            password are provided, we will add related header in every RPC call.
+        password (str): Required when user is provided. The password
+            corresponding to the user.
+        secure (bool): Default is false. If set to true, tls will be enabled.
+        client_key_path (str): If use tls two-way authentication, need to
+            write the client.key path.
+        client_pem_path (str): If use tls two-way authentication, need to
+            write the client.pem path.
+        ca_pem_path (str): If use tls two-way authentication, need to write
+            the ca.pem path.
+        server_pem_path (str): If use tls one-way authentication, need to
+            write the server.pem path.
+        server_name (str): If use tls, need to write the common name.
+
+    Example:
+        .. code-block:: python
+
+        from langchain import Milvus
+        from langchain.embeddings import OpenAIEmbeddings
+
+        embedding = OpenAIEmbeddings()
+        # Connect to a milvus instance on localhost
+        milvus_store = Milvus(
+            embedding_function = Embeddings,
+            collection_name = "LangChainCollection",
+            drop_old = True,
+        )
+
+    Raises:
+        ValueError: If the pymilvus python package is not installed.
+    """
+
+    def __init__(
+        self,
+        embedding_function: Embeddings,
+        collection_name: str = "LangChainCollection",
+        connection_args: Optional[dict[str, Any]] = None,
+        consistency_level: str = "Session",
+        index_params: Optional[dict] = None,
+        search_params: Optional[dict] = None,
+        drop_old: Optional[bool] = False,
+    ):
+        """Initialize the Milvus vector store."""
+        try:
+            from pymilvus import Collection, utility
+        except ImportError:
+            raise ValueError(
+                "Could not import pymilvus python package. "
+                "Please install it with `pip install pymilvus`."
+            )
+
+        # Default search params when one is not provided.
+        self.default_search_params = {
+            "IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
+            "IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}},
+            "IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
+            "HNSW": {"metric_type": "L2", "params": {"ef": 10}},
+            "RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}},
+            "RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}},
+            "RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}},
+            "IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}},
+            "ANNOY": {"metric_type": "L2", "params": {"search_k": 10}},
+            "AUTOINDEX": {"metric_type": "L2", "params": {}},
+        }
+
+        self.embedding_func = embedding_function
+        self.collection_name = collection_name
+        self.index_params = index_params
+        self.search_params = search_params
+        self.consistency_level = consistency_level
+
+        # In order for a collection to be compatible, pk needs to be auto'id and int
+        self._primary_field = "id"
+        # In order for compatibility, the text field will need to be called "text"
+        self._text_field = "page_content"
+        # In order for compatibility, the vector field needs to be called "vector"
+        self._vector_field = "vectors"
+        # In order for compatibility, the metadata field will need to be called "metadata"
+        self._metadata_field = "metadata"
+        self.fields: list[str] = []
+        # Create the connection to the server
+        if connection_args is None:
+            connection_args = DEFAULT_MILVUS_CONNECTION
+        self.alias = self._create_connection_alias(connection_args)
+        self.col: Optional[Collection] = None
+
+        # Grab the existing collection if it exists
+        if utility.has_collection(self.collection_name, using=self.alias):
+            self.col = Collection(
+                self.collection_name,
+                using=self.alias,
+            )
+        # If need to drop old, drop it
+        if drop_old and isinstance(self.col, Collection):
+            self.col.drop()
+            self.col = None
+
+        # Initialize the vector store
+        self._init()
+
+    @property
+
+
+    def embeddings(self) -> Embeddings:
+        return self.embedding_func
+
+    def _create_connection_alias(self, connection_args: dict) -> str:
+        """Create the connection to the Milvus server."""
+        from pymilvus import MilvusException, connections
+
+        # Grab the connection arguments that are used for checking existing connection
+        host: str = connection_args.get("host", None)
+        port: Union[str, int] = connection_args.get("port", None)
+        address: str = connection_args.get("address", None)
+        uri: str = connection_args.get("uri", None)
+        user = connection_args.get("user", None)
+
+        # Order of use is host/port, uri, address
+        if host is not None and port is not None:
+            given_address = str(host) + ":" + str(port)
+        elif uri is not None:
+            given_address = uri.split("https://")[1]
+        elif address is not None:
+            given_address = address
+        else:
+            given_address = None
+            logger.debug("Missing standard address type for reuse atttempt")
+
+        # User defaults to empty string when getting connection info
+        if user is not None:
+            tmp_user = user
+        else:
+            tmp_user = ""
+
+        # If a valid address was given, then check if a connection exists
+        if given_address is not None:
+            for con in connections.list_connections():
+                addr = connections.get_connection_addr(con[0])
+                if (
+                    con[1]
+                    and ("address" in addr)
+                    and (addr["address"] == given_address)
+                    and ("user" in addr)
+                    and (addr["user"] == tmp_user)
+                ):
+                    logger.debug("Using previous connection: %s", con[0])
+                    return con[0]
+
+        # Generate a new connection if one doesn't exist
+        alias = uuid4().hex
+        try:
+            connections.connect(alias=alias, **connection_args)
+            logger.debug("Created new connection using: %s", alias)
+            return alias
+        except MilvusException as e:
+            logger.error("Failed to create new connection using: %s", alias)
+            raise e
+
+    def _init(
+        self, embeddings: Optional[list] = None, metadatas: Optional[list[dict]] = None
+    ) -> None:
+        if embeddings is not None:
+            self._create_collection(embeddings, metadatas)
+        self._extract_fields()
+        self._create_index()
+        self._create_search_params()
+        self._load()
+
+    def _create_collection(
+        self, embeddings: list, metadatas: Optional[list[dict]] = None
+    ) -> None:
+        from pymilvus import (
+            Collection,
+            CollectionSchema,
+            DataType,
+            FieldSchema,
+            MilvusException,
+        )
+        from pymilvus.orm.types import infer_dtype_bydata
+
+        # Determine embedding dim
+        dim = len(embeddings[0])
+        fields = []
+        # Determine metadata schema
+        # if metadatas:
+        #     # Create FieldSchema for each entry in metadata.
+        #     for key, value in metadatas[0].items():
+        #         # Infer the corresponding datatype of the metadata
+        #         dtype = infer_dtype_bydata(value)
+        #         # Datatype isn't compatible
+        #         if dtype == DataType.UNKNOWN or dtype == DataType.NONE:
+        #             logger.error(
+        #                 "Failure to create collection, unrecognized dtype for key: %s",
+        #                 key,
+        #             )
+        #             raise ValueError(f"Unrecognized datatype for {key}.")
+        #         # Dataype is a string/varchar equivalent
+        #         elif dtype == DataType.VARCHAR:
+        #             fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535))
+        #         else:
+        #             fields.append(FieldSchema(key, dtype))
+        if metadatas:
+            fields.append(FieldSchema(self._metadata_field, DataType.JSON, max_length=65_535))
+
+        # Create the text field
+        fields.append(
+            FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535)
+        )
+        # Create the primary key field
+        fields.append(
+            FieldSchema(
+                self._primary_field, DataType.INT64, is_primary=True, auto_id=True
+            )
+        )
+        # Create the vector field, supports binary or float vectors
+        fields.append(
+            FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim)
+        )
+
+        # Create the schema for the collection
+        schema = CollectionSchema(fields)
+
+        # Create the collection
+        try:
+            self.col = Collection(
+                name=self.collection_name,
+                schema=schema,
+                consistency_level=self.consistency_level,
+                using=self.alias,
+            )
+        except MilvusException as e:
+            logger.error(
+                "Failed to create collection: %s error: %s", self.collection_name, e
+            )
+            raise e
+
+    def _extract_fields(self) -> None:
+        """Grab the existing fields from the Collection"""
+        from pymilvus import Collection
+
+        if isinstance(self.col, Collection):
+            schema = self.col.schema
+            for x in schema.fields:
+                self.fields.append(x.name)
+            # Since primary field is auto-id, no need to track it
+            self.fields.remove(self._primary_field)
+
+    def _get_index(self) -> Optional[dict[str, Any]]:
+        """Return the vector index information if it exists"""
+        from pymilvus import Collection
+
+        if isinstance(self.col, Collection):
+            for x in self.col.indexes:
+                if x.field_name == self._vector_field:
+                    return x.to_dict()
+        return None
+
+    def _create_index(self) -> None:
+        """Create a index on the collection"""
+        from pymilvus import Collection, MilvusException
+
+        if isinstance(self.col, Collection) and self._get_index() is None:
+            try:
+                # If no index params, use a default HNSW based one
+                if self.index_params is None:
+                    self.index_params = {
+                        "metric_type": "IP",
+                        "index_type": "HNSW",
+                        "params": {"M": 8, "efConstruction": 64},
+                    }
+
+                try:
+                    self.col.create_index(
+                        self._vector_field,
+                        index_params=self.index_params,
+                        using=self.alias,
+                    )
+
+                # If default did not work, most likely on Zilliz Cloud
+                except MilvusException:
+                    # Use AUTOINDEX based index
+                    self.index_params = {
+                        "metric_type": "L2",
+                        "index_type": "AUTOINDEX",
+                        "params": {},
+                    }
+                    self.col.create_index(
+                        self._vector_field,
+                        index_params=self.index_params,
+                        using=self.alias,
+                    )
+                logger.debug(
+                    "Successfully created an index on collection: %s",
+                    self.collection_name,
+                )
+
+            except MilvusException as e:
+                logger.error(
+                    "Failed to create an index on collection: %s", self.collection_name
+                )
+                raise e
+
+    def _create_search_params(self) -> None:
+        """Generate search params based on the current index type"""
+        from pymilvus import Collection
+
+        if isinstance(self.col, Collection) and self.search_params is None:
+            index = self._get_index()
+            if index is not None:
+                index_type: str = index["index_param"]["index_type"]
+                metric_type: str = index["index_param"]["metric_type"]
+                self.search_params = self.default_search_params[index_type]
+                self.search_params["metric_type"] = metric_type
+
+    def _load(self) -> None:
+        """Load the collection if available."""
+        from pymilvus import Collection
+
+        if isinstance(self.col, Collection) and self._get_index() is not None:
+            self.col.load()
+
+    def add_texts(
+        self,
+        texts: Iterable[str],
+        metadatas: Optional[List[dict]] = None,
+        timeout: Optional[int] = None,
+        batch_size: int = 1000,
+        **kwargs: Any,
+    ) -> List[str]:
+        """Insert text data into Milvus.
+
+        Inserting data when the collection has not be made yet will result
+        in creating a new Collection. The data of the first entity decides
+        the schema of the new collection, the dim is extracted from the first
+        embedding and the columns are decided by the first metadata dict.
+        Metada keys will need to be present for all inserted values. At
+        the moment there is no None equivalent in Milvus.
+
+        Args:
+            texts (Iterable[str]): The texts to embed, it is assumed
+                that they all fit in memory.
+            metadatas (Optional[List[dict]]): Metadata dicts attached to each of
+                the texts. Defaults to None.
+            timeout (Optional[int]): Timeout for each batch insert. Defaults
+                to None.
+            batch_size (int, optional): Batch size to use for insertion.
+                Defaults to 1000.
+
+        Raises:
+            MilvusException: Failure to add texts
+
+        Returns:
+            List[str]: The resulting keys for each inserted element.
+        """
+        from pymilvus import Collection, MilvusException
+
+        texts = list(texts)
+
+        try:
+            embeddings = self.embedding_func.embed_documents(texts)
+        except NotImplementedError:
+            embeddings = [self.embedding_func.embed_query(x) for x in texts]
+
+        if len(embeddings) == 0:
+            logger.debug("Nothing to insert, skipping.")
+            return []
+
+        # If the collection hasn't been initialized yet, perform all steps to do so
+        if not isinstance(self.col, Collection):
+            self._init(embeddings, metadatas)
+
+        # Dict to hold all insert columns
+        insert_dict: dict[str, list] = {
+            self._text_field: texts,
+            self._vector_field: embeddings,
+        }
+
+        # Collect the metadata into the insert dict.
+        # if metadatas is not None:
+        #     for d in metadatas:
+        #         for key, value in d.items():
+        #             if key in self.fields:
+        #                 insert_dict.setdefault(key, []).append(value)
+        if metadatas is not None:
+            for d in metadatas:
+                insert_dict.setdefault(self._metadata_field, []).append(d)
+
+        # Total insert count
+        vectors: list = insert_dict[self._vector_field]
+        total_count = len(vectors)
+
+        pks: list[str] = []
+
+        assert isinstance(self.col, Collection)
+        for i in range(0, total_count, batch_size):
+            # Grab end index
+            end = min(i + batch_size, total_count)
+            # Convert dict to list of lists batch for insertion
+            insert_list = [insert_dict[x][i:end] for x in self.fields]
+            # Insert into the collection.
+            try:
+                res: Collection
+                res = self.col.insert(insert_list, timeout=timeout, **kwargs)
+                pks.extend(res.primary_keys)
+            except MilvusException as e:
+                logger.error(
+                    "Failed to insert batch starting at entity: %s/%s", i, total_count
+                )
+                raise e
+        return pks
+
+    def similarity_search(
+        self,
+        query: str,
+        k: int = 4,
+        param: Optional[dict] = None,
+        expr: Optional[str] = None,
+        timeout: Optional[int] = None,
+        **kwargs: Any,
+    ) -> List[Document]:
+        """Perform a similarity search against the query string.
+
+        Args:
+            query (str): The text to search.
+            k (int, optional): How many results to return. Defaults to 4.
+            param (dict, optional): The search params for the index type.
+                Defaults to None.
+            expr (str, optional): Filtering expression. Defaults to None.
+            timeout (int, optional): How long to wait before timeout error.
+                Defaults to None.
+            kwargs: Collection.search() keyword arguments.
+
+        Returns:
+            List[Document]: Document results for search.
+        """
+        if self.col is None:
+            logger.debug("No existing collection to search.")
+            return []
+        res = self.similarity_search_with_score(
+            query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
+        )
+        return [doc for doc, _ in res]
+
+    def similarity_search_by_vector(
+        self,
+        embedding: List[float],
+        k: int = 4,
+        param: Optional[dict] = None,
+        expr: Optional[str] = None,
+        timeout: Optional[int] = None,
+        **kwargs: Any,
+    ) -> List[Document]:
+        """Perform a similarity search against the query string.
+
+        Args:
+            embedding (List[float]): The embedding vector to search.
+            k (int, optional): How many results to return. Defaults to 4.
+            param (dict, optional): The search params for the index type.
+                Defaults to None.
+            expr (str, optional): Filtering expression. Defaults to None.
+            timeout (int, optional): How long to wait before timeout error.
+                Defaults to None.
+            kwargs: Collection.search() keyword arguments.
+
+        Returns:
+            List[Document]: Document results for search.
+        """
+        if self.col is None:
+            logger.debug("No existing collection to search.")
+            return []
+        res = self.similarity_search_with_score_by_vector(
+            embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
+        )
+        return [doc for doc, _ in res]
+
+    def similarity_search_with_score(
+        self,
+        query: str,
+        k: int = 4,
+        param: Optional[dict] = None,
+        expr: Optional[str] = None,
+        timeout: Optional[int] = None,
+        **kwargs: Any,
+    ) -> List[Tuple[Document, float]]:
+        """Perform a search on a query string and return results with score.
+
+        For more information about the search parameters, take a look at the pymilvus
+        documentation found here:
+        https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
+
+        Args:
+            query (str): The text being searched.
+            k (int, optional): The amount of results to return. Defaults to 4.
+            param (dict): The search params for the specified index.
+                Defaults to None.
+            expr (str, optional): Filtering expression. Defaults to None.
+            timeout (int, optional): How long to wait before timeout error.
+                Defaults to None.
+            kwargs: Collection.search() keyword arguments.
+
+        Returns:
+            List[float], List[Tuple[Document, any, any]]:
+        """
+        if self.col is None:
+            logger.debug("No existing collection to search.")
+            return []
+
+        # Embed the query text.
+        embedding = self.embedding_func.embed_query(query)
+
+        res = self.similarity_search_with_score_by_vector(
+            embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
+        )
+        return res
+
+    def _similarity_search_with_relevance_scores(
+        self,
+        query: str,
+        k: int = 4,
+        **kwargs: Any,
+    ) -> List[Tuple[Document, float]]:
+        """Return docs and relevance scores in the range [0, 1].
+
+        0 is dissimilar, 1 is most similar.
+
+        Args:
+            query: input text
+            k: Number of Documents to return. Defaults to 4.
+            **kwargs: kwargs to be passed to similarity search. Should include:
+                score_threshold: Optional, a floating point value between 0 to 1 to
+                    filter the resulting set of retrieved docs
+
+        Returns:
+            List of Tuples of (doc, similarity_score)
+        """
+        return self.similarity_search_with_score(query, k, **kwargs)
+
+    def similarity_search_with_score_by_vector(
+        self,
+        embedding: List[float],
+        k: int = 4,
+        param: Optional[dict] = None,
+        expr: Optional[str] = None,
+        timeout: Optional[int] = None,
+        **kwargs: Any,
+    ) -> List[Tuple[Document, float]]:
+        """Perform a search on a query string and return results with score.
+
+        For more information about the search parameters, take a look at the pymilvus
+        documentation found here:
+        https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
+
+        Args:
+            embedding (List[float]): The embedding vector being searched.
+            k (int, optional): The amount of results to return. Defaults to 4.
+            param (dict): The search params for the specified index.
+                Defaults to None.
+            expr (str, optional): Filtering expression. Defaults to None.
+            timeout (int, optional): How long to wait before timeout error.
+                Defaults to None.
+            kwargs: Collection.search() keyword arguments.
+
+        Returns:
+            List[Tuple[Document, float]]: Result doc and score.
+        """
+        if self.col is None:
+            logger.debug("No existing collection to search.")
+            return []
+
+        if param is None:
+            param = self.search_params
+
+        # Determine result metadata fields.
+        output_fields = self.fields[:]
+        output_fields.remove(self._vector_field)
+
+        # Perform the search.
+        res = self.col.search(
+            data=[embedding],
+            anns_field=self._vector_field,
+            param=param,
+            limit=k,
+            expr=expr,
+            output_fields=output_fields,
+            timeout=timeout,
+            **kwargs,
+        )
+        # Organize results.
+        ret = []
+        for result in res[0]:
+            meta = {x: result.entity.get(x) for x in output_fields}
+            doc = Document(page_content=meta.pop(self._text_field), metadata=meta.get('metadata'))
+            pair = (doc, result.score)
+            ret.append(pair)
+
+        return ret
+
+    def max_marginal_relevance_search(
+        self,
+        query: str,
+        k: int = 4,
+        fetch_k: int = 20,
+        lambda_mult: float = 0.5,
+        param: Optional[dict] = None,
+        expr: Optional[str] = None,
+        timeout: Optional[int] = None,
+        **kwargs: Any,
+    ) -> List[Document]:
+        """Perform a search and return results that are reordered by MMR.
+
+        Args:
+            query (str): The text being searched.
+            k (int, optional): How many results to give. Defaults to 4.
+            fetch_k (int, optional): Total results to select k from.
+                Defaults to 20.
+            lambda_mult: Number between 0 and 1 that determines the degree
+                        of diversity among the results with 0 corresponding
+                        to maximum diversity and 1 to minimum diversity.
+                        Defaults to 0.5
+            param (dict, optional): The search params for the specified index.
+                Defaults to None.
+            expr (str, optional): Filtering expression. Defaults to None.
+            timeout (int, optional): How long to wait before timeout error.
+                Defaults to None.
+            kwargs: Collection.search() keyword arguments.
+
+
+        Returns:
+            List[Document]: Document results for search.
+        """
+        if self.col is None:
+            logger.debug("No existing collection to search.")
+            return []
+
+        embedding = self.embedding_func.embed_query(query)
+
+        return self.max_marginal_relevance_search_by_vector(
+            embedding=embedding,
+            k=k,
+            fetch_k=fetch_k,
+            lambda_mult=lambda_mult,
+            param=param,
+            expr=expr,
+            timeout=timeout,
+            **kwargs,
+        )
+
+    def max_marginal_relevance_search_by_vector(
+        self,
+        embedding: list[float],
+        k: int = 4,
+        fetch_k: int = 20,
+        lambda_mult: float = 0.5,
+        param: Optional[dict] = None,
+        expr: Optional[str] = None,
+        timeout: Optional[int] = None,
+        **kwargs: Any,
+    ) -> List[Document]:
+        """Perform a search and return results that are reordered by MMR.
+
+        Args:
+            embedding (str): The embedding vector being searched.
+            k (int, optional): How many results to give. Defaults to 4.
+            fetch_k (int, optional): Total results to select k from.
+                Defaults to 20.
+            lambda_mult: Number between 0 and 1 that determines the degree
+                        of diversity among the results with 0 corresponding
+                        to maximum diversity and 1 to minimum diversity.
+                        Defaults to 0.5
+            param (dict, optional): The search params for the specified index.
+                Defaults to None.
+            expr (str, optional): Filtering expression. Defaults to None.
+            timeout (int, optional): How long to wait before timeout error.
+                Defaults to None.
+            kwargs: Collection.search() keyword arguments.
+
+        Returns:
+            List[Document]: Document results for search.
+        """
+        if self.col is None:
+            logger.debug("No existing collection to search.")
+            return []
+
+        if param is None:
+            param = self.search_params
+
+        # Determine result metadata fields.
+        output_fields = self.fields[:]
+        output_fields.remove(self._vector_field)
+
+        # Perform the search.
+        res = self.col.search(
+            data=[embedding],
+            anns_field=self._vector_field,
+            param=param,
+            limit=fetch_k,
+            expr=expr,
+            output_fields=output_fields,
+            timeout=timeout,
+            **kwargs,
+        )
+        # Organize results.
+        ids = []
+        documents = []
+        scores = []
+        for result in res[0]:
+            meta = {x: result.entity.get(x) for x in output_fields}
+            doc = Document(page_content=meta.pop(self._text_field), metadata=meta)
+            documents.append(doc)
+            scores.append(result.score)
+            ids.append(result.id)
+
+        vectors = self.col.query(
+            expr=f"{self._primary_field} in {ids}",
+            output_fields=[self._primary_field, self._vector_field],
+            timeout=timeout,
+        )
+        # Reorganize the results from query to match search order.
+        vectors = {x[self._primary_field]: x[self._vector_field] for x in vectors}
+
+        ordered_result_embeddings = [vectors[x] for x in ids]
+
+        # Get the new order of results.
+        new_ordering = maximal_marginal_relevance(
+            np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult
+        )
+
+        # Reorder the values and return.
+        ret = []
+        for x in new_ordering:
+            # Function can return -1 index
+            if x == -1:
+                break
+            else:
+                ret.append(documents[x])
+        return ret
+
+    @classmethod
+    def from_texts(
+        cls,
+        texts: List[str],
+        embedding: Embeddings,
+        metadatas: Optional[List[dict]] = None,
+        collection_name: str = "LangChainCollection",
+        connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION,
+        consistency_level: str = "Session",
+        index_params: Optional[dict] = None,
+        search_params: Optional[dict] = None,
+        drop_old: bool = False,
+        batch_size: int = 100,
+        ids: Optional[Sequence[str]] = None,
+        **kwargs: Any,
+    ) -> Milvus:
+        """Create a Milvus collection, indexes it with HNSW, and insert data.
+
+        Args:
+            texts (List[str]): Text data.
+            embedding (Embeddings): Embedding function.
+            metadatas (Optional[List[dict]]): Metadata for each text if it exists.
+                Defaults to None.
+            collection_name (str, optional): Collection name to use. Defaults to
+                "LangChainCollection".
+            connection_args (dict[str, Any], optional): Connection args to use. Defaults
+                to DEFAULT_MILVUS_CONNECTION.
+            consistency_level (str, optional): Which consistency level to use. Defaults
+                to "Session".
+            index_params (Optional[dict], optional): Which index_params to use. Defaults
+                to None.
+            search_params (Optional[dict], optional): Which search params to use.
+                Defaults to None.
+            drop_old (Optional[bool], optional): Whether to drop the collection with
+                that name if it exists. Defaults to False.
+            batch_size:
+                How many vectors upload per-request.
+                Default: 100
+            ids: Optional[Sequence[str]] = None,
+
+        Returns:
+            Milvus: Milvus Vector Store
+        """
+        vector_db = cls(
+            embedding_function=embedding,
+            collection_name=collection_name,
+            connection_args=connection_args,
+            consistency_level=consistency_level,
+            index_params=index_params,
+            search_params=search_params,
+            drop_old=drop_old,
+            **kwargs,
+        )
+        vector_db.add_texts(texts=texts, metadatas=metadatas, batch_size=batch_size)
+        return vector_db

+ 69 - 40
api/core/index/vector_index/milvus_vector_index.py

@@ -9,30 +9,46 @@ from core.index.base import BaseIndex
 from core.index.vector_index.base import BaseVectorIndex
 from core.vector_store.milvus_vector_store import MilvusVectorStore
 from core.vector_store.weaviate_vector_store import WeaviateVectorStore
-from models.dataset import Dataset
+from extensions.ext_database import db
+from models.dataset import Dataset, DatasetCollectionBinding
 
 
 class MilvusConfig(BaseModel):
-    endpoint: str
+    host: str
+    port: int
     user: str
     password: str
+    secure: bool
     batch_size: int = 100
 
     @root_validator()
     def validate_config(cls, values: dict) -> dict:
-        if not values['endpoint']:
-            raise ValueError("config MILVUS_ENDPOINT is required")
+        if not values['host']:
+            raise ValueError("config MILVUS_HOST is required")
+        if not values['port']:
+            raise ValueError("config MILVUS_PORT is required")
+        if not values['secure']:
+            raise ValueError("config MILVUS_SECURE is required")
         if not values['user']:
             raise ValueError("config MILVUS_USER is required")
         if not values['password']:
             raise ValueError("config MILVUS_PASSWORD is required")
         return values
 
+    def to_milvus_params(self):
+        return {
+            'host': self.host,
+            'port': self.port,
+            'user': self.user,
+            'password': self.password,
+            'secure': self.secure
+        }
+
 
 class MilvusVectorIndex(BaseVectorIndex):
     def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
         super().__init__(dataset, embeddings)
-        self._client = self._init_client(config)
+        self._client_config = config
 
     def get_type(self) -> str:
         return 'milvus'
@@ -49,7 +65,6 @@ class MilvusVectorIndex(BaseVectorIndex):
         dataset_id = dataset.id
         return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
 
-
     def to_index_struct(self) -> dict:
         return {
             "type": self.get_type(),
@@ -58,26 +73,29 @@ class MilvusVectorIndex(BaseVectorIndex):
 
     def create(self, texts: list[Document], **kwargs) -> BaseIndex:
         uuids = self._get_uuids(texts)
-        self._vector_store = WeaviateVectorStore.from_documents(
+        index_params = {
+            'metric_type': 'IP',
+            'index_type': "HNSW",
+            'params':  {"M": 8, "efConstruction": 64}
+        }
+        self._vector_store = MilvusVectorStore.from_documents(
             texts,
             self._embeddings,
-            client=self._client,
-            index_name=self.get_index_name(self.dataset),
-            uuids=uuids,
-            by_text=False
+            collection_name=self.get_index_name(self.dataset),
+            connection_args=self._client_config.to_milvus_params(),
+            index_params=index_params
         )
 
         return self
 
     def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
         uuids = self._get_uuids(texts)
-        self._vector_store = WeaviateVectorStore.from_documents(
+        self._vector_store = MilvusVectorStore.from_documents(
             texts,
             self._embeddings,
-            client=self._client,
-            index_name=collection_name,
-            uuids=uuids,
-            by_text=False
+            collection_name=collection_name,
+            ids=uuids,
+            content_payload_key='page_content'
         )
 
         return self
@@ -86,42 +104,53 @@ class MilvusVectorIndex(BaseVectorIndex):
         """Only for created index."""
         if self._vector_store:
             return self._vector_store
-
         attributes = ['doc_id', 'dataset_id', 'document_id']
-        if self._is_origin():
-            attributes = ['doc_id']
-
-        return WeaviateVectorStore(
-            client=self._client,
-            index_name=self.get_index_name(self.dataset),
-            text_key='text',
-            embedding=self._embeddings,
-            attributes=attributes,
-            by_text=False
+
+        return MilvusVectorStore(
+            collection_name=self.get_index_name(self.dataset),
+            embedding_function=self._embeddings,
+            connection_args=self._client_config.to_milvus_params()
         )
 
     def _get_vector_store_class(self) -> type:
         return MilvusVectorStore
 
     def delete_by_document_id(self, document_id: str):
-        if self._is_origin():
-            self.recreate_dataset(self.dataset)
-            return
 
         vector_store = self._get_vector_store()
         vector_store = cast(self._get_vector_store_class(), vector_store)
+        ids = vector_store.get_ids_by_document_id(document_id)
+        if ids:
+            vector_store.del_texts({
+                'filter': f'id in {ids}'
+            })
+
+    def delete_by_ids(self, doc_ids: list[str]) -> None:
 
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+        ids = vector_store.get_ids_by_doc_ids(doc_ids)
         vector_store.del_texts({
-            "operator": "Equal",
-            "path": ["document_id"],
-            "valueText": document_id
+            'filter': f' id in {ids}'
         })
 
-    def _is_origin(self):
-        if self.dataset.index_struct_dict:
-            class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
-            if not class_prefix.endswith('_Node'):
-                # original class_prefix
-                return True
+    def delete_by_group_id(self, group_id: str) -> None:
+
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        vector_store.delete()
+
+    def delete(self) -> None:
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
 
-        return False
+        from qdrant_client.http import models
+        vector_store.del_texts(models.Filter(
+            must=[
+                models.FieldCondition(
+                    key="group_id",
+                    match=models.MatchValue(value=self.dataset.id),
+                ),
+            ],
+        ))

+ 14 - 0
api/core/index/vector_index/vector_index.py

@@ -47,6 +47,20 @@ class VectorIndex:
                 ),
                 embeddings=embeddings
             )
+        elif vector_type == "milvus":
+            from core.index.vector_index.milvus_vector_index import MilvusVectorIndex, MilvusConfig
+
+            return MilvusVectorIndex(
+                dataset=dataset,
+                config=MilvusConfig(
+                    host=config.get('MILVUS_HOST'),
+                    port=config.get('MILVUS_PORT'),
+                    user=config.get('MILVUS_USER'),
+                    password=config.get('MILVUS_PASSWORD'),
+                    secure=config.get('MILVUS_SECURE'),
+                ),
+                embeddings=embeddings
+            )
         else:
             raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
 

+ 31 - 23
api/core/vector_store/milvus_vector_store.py

@@ -1,4 +1,4 @@
-from langchain.vectorstores import  Milvus
+from core.index.vector_index.milvus import Milvus
 
 
 class MilvusVectorStore(Milvus):
@@ -6,33 +6,41 @@ class MilvusVectorStore(Milvus):
         if not where_filter:
             raise ValueError('where_filter must not be empty')
 
-        self._client.batch.delete_objects(
-            class_name=self._index_name,
-            where=where_filter,
-            output='minimal'
-        )
+        self.col.delete(where_filter.get('filter'))
 
     def del_text(self, uuid: str) -> None:
-        self._client.data_object.delete(
-            uuid,
-            class_name=self._index_name
-        )
+        expr = f"id == {uuid}"
+        self.col.delete(expr)
 
     def text_exists(self, uuid: str) -> bool:
-        result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({
-            "path": ["doc_id"],
-            "operator": "Equal",
-            "valueText": uuid,
-        }).with_limit(1).do()
-
-        if "errors" in result:
-            raise ValueError(f"Error during query: {result['errors']}")
+        result = self.col.query(
+            expr=f'metadata["doc_id"] == "{uuid}"',
+            output_fields=["id"]
+        )
 
-        entries = result["data"]["Get"][self._index_name]
-        if len(entries) == 0:
-            return False
+        return len(result) > 0
 
-        return True
+    def get_ids_by_document_id(self, document_id: str):
+        result = self.col.query(
+            expr=f'metadata["document_id"] == "{document_id}"',
+            output_fields=["id"]
+        )
+        if result:
+            return [item["id"] for item in result]
+        else:
+            return None
+
+    def get_ids_by_doc_ids(self, doc_ids: list):
+        result = self.col.query(
+            expr=f'metadata["doc_id"] in {doc_ids}',
+            output_fields=["id"]
+        )
+        if result:
+            return [item["id"] for item in result]
+        else:
+            return None
 
     def delete(self):
-        self._client.schema.delete_class(self._index_name)
+        from pymilvus import utility
+        utility.drop_collection(self.collection_name, None, self.alias)
+

+ 2 - 1
api/requirements.txt

@@ -52,4 +52,5 @@ pandas==1.5.3
 xinference==0.5.2
 safetensors==0.3.2
 zhipuai==1.0.7
-werkzeug==2.3.7
+werkzeug==2.3.7
+pymilvus==2.3.0

+ 64 - 0
docker/milvus-standalone-docker-compose.yml

@@ -0,0 +1,64 @@
+version: '3.5'
+
+services:
+  etcd:
+    container_name: milvus-etcd
+    image: quay.io/coreos/etcd:v3.5.5
+    environment:
+      - ETCD_AUTO_COMPACTION_MODE=revision
+      - ETCD_AUTO_COMPACTION_RETENTION=1000
+      - ETCD_QUOTA_BACKEND_BYTES=4294967296
+      - ETCD_SNAPSHOT_COUNT=50000
+    volumes:
+      - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd
+    command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
+    healthcheck:
+      test: ["CMD", "etcdctl", "endpoint", "health"]
+      interval: 30s
+      timeout: 20s
+      retries: 3
+
+  minio:
+    container_name: milvus-minio
+    image: minio/minio:RELEASE.2023-03-20T20-16-18Z
+    environment:
+      MINIO_ACCESS_KEY: minioadmin
+      MINIO_SECRET_KEY: minioadmin
+    ports:
+      - "9001:9001"
+      - "9000:9000"
+    volumes:
+      - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data
+    command: minio server /minio_data --console-address ":9001"
+    healthcheck:
+      test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
+      interval: 30s
+      timeout: 20s
+      retries: 3
+
+  standalone:
+    container_name: milvus-standalone
+    image: milvusdb/milvus:v2.3.1
+    command: ["milvus", "run", "standalone"]
+    environment:
+      ETCD_ENDPOINTS: etcd:2379
+      MINIO_ADDRESS: minio:9000
+      common.security.authorizationEnabled: true
+    volumes:
+      - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/milvus:/var/lib/milvus
+    healthcheck:
+      test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"]
+      interval: 30s
+      start_period: 90s
+      timeout: 20s
+      retries: 3
+    ports:
+      - "19530:19530"
+      - "9091:9091"
+    depends_on:
+      - "etcd"
+      - "minio"
+
+networks:
+  default:
+    name: milvus