|  | @@ -1,10 +1,9 @@
 | 
	
		
			
				|  |  |  import json
 | 
	
		
			
				|  |  |  import logging
 | 
	
		
			
				|  |  |  from typing import Any, Optional
 | 
	
		
			
				|  |  | -from uuid import uuid4
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from pydantic import BaseModel, model_validator
 | 
	
		
			
				|  |  | -from pymilvus import MilvusClient, MilvusException, connections
 | 
	
		
			
				|  |  | +from pymilvus import MilvusClient, MilvusException
 | 
	
		
			
				|  |  |  from pymilvus.milvus_client import IndexParams
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from configs import dify_config
 | 
	
	
		
			
				|  | @@ -21,20 +20,17 @@ logger = logging.getLogger(__name__)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class MilvusConfig(BaseModel):
 | 
	
		
			
				|  |  | -    host: str
 | 
	
		
			
				|  |  | -    port: int
 | 
	
		
			
				|  |  | +    uri: str
 | 
	
		
			
				|  |  | +    token: Optional[str] = None
 | 
	
		
			
				|  |  |      user: str
 | 
	
		
			
				|  |  |      password: str
 | 
	
		
			
				|  |  | -    secure: bool = False
 | 
	
		
			
				|  |  |      batch_size: int = 100
 | 
	
		
			
				|  |  |      database: str = "default"
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @model_validator(mode='before')
 | 
	
		
			
				|  |  |      def validate_config(cls, values: dict) -> dict:
 | 
	
		
			
				|  |  | -        if not values.get('host'):
 | 
	
		
			
				|  |  | -            raise ValueError("config MILVUS_HOST is required")
 | 
	
		
			
				|  |  | -        if not values.get('port'):
 | 
	
		
			
				|  |  | -            raise ValueError("config MILVUS_PORT is required")
 | 
	
		
			
				|  |  | +        if not values.get('uri'):
 | 
	
		
			
				|  |  | +            raise ValueError("config MILVUS_URI is required")
 | 
	
		
			
				|  |  |          if not values.get('user'):
 | 
	
		
			
				|  |  |              raise ValueError("config MILVUS_USER is required")
 | 
	
		
			
				|  |  |          if not values.get('password'):
 | 
	
	
		
			
				|  | @@ -43,11 +39,10 @@ class MilvusConfig(BaseModel):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def to_milvus_params(self):
 | 
	
		
			
				|  |  |          return {
 | 
	
		
			
				|  |  | -            'host': self.host,
 | 
	
		
			
				|  |  | -            'port': self.port,
 | 
	
		
			
				|  |  | +            'uri': self.uri,
 | 
	
		
			
				|  |  | +            'token': self.token,
 | 
	
		
			
				|  |  |              'user': self.user,
 | 
	
		
			
				|  |  |              'password': self.password,
 | 
	
		
			
				|  |  | -            'secure': self.secure,
 | 
	
		
			
				|  |  |              'db_name': self.database,
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -111,32 +106,14 @@ class MilvusVector(BaseVector):
 | 
	
		
			
				|  |  |              return None
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def delete_by_metadata_field(self, key: str, value: str):
 | 
	
		
			
				|  |  | -        alias = uuid4().hex
 | 
	
		
			
				|  |  | -        if self._client_config.secure:
 | 
	
		
			
				|  |  | -            uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
 | 
	
		
			
				|  |  | -        else:
 | 
	
		
			
				|  |  | -            uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
 | 
	
		
			
				|  |  | -        connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password,
 | 
	
		
			
				|  |  | -                            db_name=self._client_config.database)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        from pymilvus import utility
 | 
	
		
			
				|  |  | -        if utility.has_collection(self._collection_name, using=alias):
 | 
	
		
			
				|  |  | +        if self._client.has_collection(self._collection_name):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |              ids = self.get_ids_by_metadata_field(key, value)
 | 
	
		
			
				|  |  |              if ids:
 | 
	
		
			
				|  |  |                  self._client.delete(collection_name=self._collection_name, pks=ids)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def delete_by_ids(self, ids: list[str]) -> None:
 | 
	
		
			
				|  |  | -        alias = uuid4().hex
 | 
	
		
			
				|  |  | -        if self._client_config.secure:
 | 
	
		
			
				|  |  | -            uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
 | 
	
		
			
				|  |  | -        else:
 | 
	
		
			
				|  |  | -            uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
 | 
	
		
			
				|  |  | -        connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password,
 | 
	
		
			
				|  |  | -                            db_name=self._client_config.database)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        from pymilvus import utility
 | 
	
		
			
				|  |  | -        if utility.has_collection(self._collection_name, using=alias):
 | 
	
		
			
				|  |  | +        if self._client.has_collection(self._collection_name):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |              result = self._client.query(collection_name=self._collection_name,
 | 
	
		
			
				|  |  |                                          filter=f'metadata["doc_id"] in {ids}',
 | 
	
	
		
			
				|  | @@ -146,29 +123,11 @@ class MilvusVector(BaseVector):
 | 
	
		
			
				|  |  |                  self._client.delete(collection_name=self._collection_name, pks=ids)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def delete(self) -> None:
 | 
	
		
			
				|  |  | -        alias = uuid4().hex
 | 
	
		
			
				|  |  | -        if self._client_config.secure:
 | 
	
		
			
				|  |  | -            uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
 | 
	
		
			
				|  |  | -        else:
 | 
	
		
			
				|  |  | -            uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
 | 
	
		
			
				|  |  | -        connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password,
 | 
	
		
			
				|  |  | -                            db_name=self._client_config.database)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        from pymilvus import utility
 | 
	
		
			
				|  |  | -        if utility.has_collection(self._collection_name, using=alias):
 | 
	
		
			
				|  |  | -            utility.drop_collection(self._collection_name, None, using=alias)
 | 
	
		
			
				|  |  | +        if self._client.has_collection(self._collection_name):
 | 
	
		
			
				|  |  | +            self._client.drop_collection(self._collection_name, None)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def text_exists(self, id: str) -> bool:
 | 
	
		
			
				|  |  | -        alias = uuid4().hex
 | 
	
		
			
				|  |  | -        if self._client_config.secure:
 | 
	
		
			
				|  |  | -            uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
 | 
	
		
			
				|  |  | -        else:
 | 
	
		
			
				|  |  | -            uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
 | 
	
		
			
				|  |  | -        connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password,
 | 
	
		
			
				|  |  | -                            db_name=self._client_config.database)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        from pymilvus import utility
 | 
	
		
			
				|  |  | -        if not utility.has_collection(self._collection_name, using=alias):
 | 
	
		
			
				|  |  | +        if not self._client.has_collection(self._collection_name):
 | 
	
		
			
				|  |  |              return False
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          result = self._client.query(collection_name=self._collection_name,
 | 
	
	
		
			
				|  | @@ -210,15 +169,7 @@ class MilvusVector(BaseVector):
 | 
	
		
			
				|  |  |              if redis_client.get(collection_exist_cache_key):
 | 
	
		
			
				|  |  |                  return
 | 
	
		
			
				|  |  |              # Grab the existing collection if it exists
 | 
	
		
			
				|  |  | -            from pymilvus import utility
 | 
	
		
			
				|  |  | -            alias = uuid4().hex
 | 
	
		
			
				|  |  | -            if self._client_config.secure:
 | 
	
		
			
				|  |  | -                uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
 | 
	
		
			
				|  |  | -            else:
 | 
	
		
			
				|  |  | -                uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
 | 
	
		
			
				|  |  | -            connections.connect(alias=alias, uri=uri, user=self._client_config.user,
 | 
	
		
			
				|  |  | -                                password=self._client_config.password, db_name=self._client_config.database)
 | 
	
		
			
				|  |  | -            if not utility.has_collection(self._collection_name, using=alias):
 | 
	
		
			
				|  |  | +            if not self._client.has_collection(self._collection_name):
 | 
	
		
			
				|  |  |                  from pymilvus import CollectionSchema, DataType, FieldSchema
 | 
	
		
			
				|  |  |                  from pymilvus.orm.types import infer_dtype_bydata
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -263,11 +214,7 @@ class MilvusVector(BaseVector):
 | 
	
		
			
				|  |  |              redis_client.set(collection_exist_cache_key, 1, ex=3600)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def _init_client(self, config) -> MilvusClient:
 | 
	
		
			
				|  |  | -        if config.secure:
 | 
	
		
			
				|  |  | -            uri = "https://" + str(config.host) + ":" + str(config.port)
 | 
	
		
			
				|  |  | -        else:
 | 
	
		
			
				|  |  | -            uri = "http://" + str(config.host) + ":" + str(config.port)
 | 
	
		
			
				|  |  | -        client = MilvusClient(uri=uri, user=config.user, password=config.password, db_name=config.database)
 | 
	
		
			
				|  |  | +        client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
 | 
	
		
			
				|  |  |          return client
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -285,11 +232,10 @@ class MilvusVectorFactory(AbstractVectorFactory):
 | 
	
		
			
				|  |  |          return MilvusVector(
 | 
	
		
			
				|  |  |              collection_name=collection_name,
 | 
	
		
			
				|  |  |              config=MilvusConfig(
 | 
	
		
			
				|  |  | -                host=dify_config.MILVUS_HOST,
 | 
	
		
			
				|  |  | -                port=dify_config.MILVUS_PORT,
 | 
	
		
			
				|  |  | +                uri=dify_config.MILVUS_URI,
 | 
	
		
			
				|  |  | +                token=dify_config.MILVUS_TOKEN,
 | 
	
		
			
				|  |  |                  user=dify_config.MILVUS_USER,
 | 
	
		
			
				|  |  |                  password=dify_config.MILVUS_PASSWORD,
 | 
	
		
			
				|  |  | -                secure=dify_config.MILVUS_SECURE,
 | 
	
		
			
				|  |  |                  database=dify_config.MILVUS_DATABASE,
 | 
	
		
			
				|  |  |              )
 | 
	
		
			
				|  |  |          )
 |