Browse Source

feat: support Chroma vector store (#5015)

Bowen Liang 10 months ago
parent
commit
cdc08a434f

+ 6 - 2
.github/workflows/api-tests.yml

@@ -58,7 +58,7 @@ jobs:
       - name: Run Workflow
         run: dev/pytest/pytest_workflow.sh
 
-      - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS)
+      - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma)
         uses: hoverkraft-tech/compose-action@v2.0.0
         with:
           compose-file: |
@@ -67,6 +67,7 @@ jobs:
             docker/docker-compose.milvus.yaml
             docker/docker-compose.pgvecto-rs.yaml
             docker/docker-compose.pgvector.yaml
+            docker/docker-compose.chroma.yaml
           services: |
             weaviate
             qdrant
@@ -75,6 +76,7 @@ jobs:
             milvus-standalone
             pgvecto-rs
             pgvector
+            chroma
 
       - name: Test Vector Stores
         run: dev/pytest/pytest_vdb.sh
@@ -131,7 +133,7 @@ jobs:
       - name: Run Workflow
         run: poetry run -C api bash dev/pytest/pytest_workflow.sh
 
-      - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS)
+      - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma)
         uses: hoverkraft-tech/compose-action@v2.0.0
         with:
           compose-file: |
@@ -140,6 +142,7 @@ jobs:
             docker/docker-compose.milvus.yaml
             docker/docker-compose.pgvecto-rs.yaml
             docker/docker-compose.pgvector.yaml
+            docker/docker-compose.chroma.yaml
           services: |
             weaviate
             qdrant
@@ -148,6 +151,7 @@ jobs:
             milvus-standalone
             pgvecto-rs
             pgvector
+            chroma
 
       - name: Test Vector Stores
         run: poetry run -C api bash dev/pytest/pytest_vdb.sh

+ 1 - 0
.gitignore

@@ -149,6 +149,7 @@ docker/volumes/qdrant/*
 docker/volumes/etcd/*
 docker/volumes/minio/*
 docker/volumes/milvus/*
+docker/volumes/chroma/*
 
 sdks/python-client/build
 sdks/python-client/dist

+ 8 - 0
api/.env.example

@@ -119,6 +119,14 @@ TIDB_VECTOR_USER=xxx.root
 TIDB_VECTOR_PASSWORD=xxxxxx
 TIDB_VECTOR_DATABASE=dify
 
+# Chroma configuration
+CHROMA_HOST=127.0.0.1
+CHROMA_PORT=8000
+CHROMA_TENANT=default_tenant
+CHROMA_DATABASE=default_database
+CHROMA_AUTH_PROVIDER=chromadb.auth.token_authn.TokenAuthenticationServerProvider
+CHROMA_AUTH_CREDENTIALS=difyai123456
+
 # Upload configuration
 UPLOAD_FILE_SIZE_LIMIT=15
 UPLOAD_FILE_BATCH_LIMIT=5

+ 8 - 0
api/config.py

@@ -306,6 +306,14 @@ class Config:
         self.TIDB_VECTOR_PASSWORD = get_env('TIDB_VECTOR_PASSWORD')
         self.TIDB_VECTOR_DATABASE = get_env('TIDB_VECTOR_DATABASE')
 
+        # chroma settings
+        self.CHROMA_HOST = get_env('CHROMA_HOST')
+        self.CHROMA_PORT = get_env('CHROMA_PORT')
+        self.CHROMA_TENANT = get_env('CHROMA_TENANT')
+        self.CHROMA_DATABASE = get_env('CHROMA_DATABASE')
+        self.CHROMA_AUTH_PROVIDER = get_env('CHROMA_AUTH_PROVIDER')
+        self.CHROMA_AUTH_CREDENTIALS = get_env('CHROMA_AUTH_CREDENTIALS')
+
         # ------------------------
         # Mail Configurations.
         # ------------------------

+ 2 - 2
api/controllers/console/datasets/datasets.py

@@ -479,7 +479,7 @@ class DatasetRetrievalSettingApi(Resource):
         vector_type = current_app.config['VECTOR_STORE']
 
         match vector_type:
-            case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR:
+            case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA:
                 return {
                     'retrieval_method': [
                         'semantic_search'
@@ -501,7 +501,7 @@ class DatasetRetrievalSettingMockApi(Resource):
     @account_initialization_required
     def get(self, vector_type):
         match vector_type:
-            case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR:
+            case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA:
                 return {
                     'retrieval_method': [
                         'semantic_search'

+ 0 - 0
api/core/rag/datasource/vdb/chroma/__init__.py


+ 147 - 0
api/core/rag/datasource/vdb/chroma/chroma_vector.py

@@ -0,0 +1,147 @@
+import json
+from typing import Any, Optional
+
+import chromadb
+from chromadb import QueryResult, Settings
+from flask import current_app
+from pydantic import BaseModel
+
+from core.rag.datasource.entity.embedding import Embeddings
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
+from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.models.document import Document
+from extensions.ext_redis import redis_client
+from models.dataset import Dataset
+
+
+class ChromaConfig(BaseModel):
+    host: str
+    port: int
+    tenant: str
+    database: str
+    auth_provider: Optional[str] = None
+    auth_credentials: Optional[str] = None
+
+    def to_chroma_params(self):
+        settings = Settings(
+            # auth
+            chroma_client_auth_provider=self.auth_provider,
+            chroma_client_auth_credentials=self.auth_credentials
+        )
+
+        return {
+            'host': self.host,
+            'port': self.port,
+            'ssl': False,
+            'tenant': self.tenant,
+            'database': self.database,
+            'settings': settings,
+        }
+
+
+class ChromaVector(BaseVector):
+
+    def __init__(self, collection_name: str, config: ChromaConfig):
+        super().__init__(collection_name)
+        self._client_config = config
+        self._client = chromadb.HttpClient(**self._client_config.to_chroma_params())
+
+    def get_type(self) -> str:
+        return VectorType.CHROMA
+
+    def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+        if texts:
+            # create collection
+            self.create_collection(self._collection_name)
+
+            self.add_texts(texts, embeddings, **kwargs)
+
+    def create_collection(self, collection_name: str):
+        lock_name = 'vector_indexing_lock_{}'.format(collection_name)
+        with redis_client.lock(lock_name, timeout=20):
+            collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
+            if redis_client.get(collection_exist_cache_key):
+                return
+            self._client.get_or_create_collection(collection_name)
+            redis_client.set(collection_exist_cache_key, 1, ex=3600)
+
+    def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+        uuids = self._get_uuids(documents)
+        texts = [d.page_content for d in documents]
+        metadatas = [d.metadata for d in documents]
+
+        collection = self._client.get_or_create_collection(self._collection_name)
+        collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas)
+
+    def delete_by_metadata_field(self, key: str, value: str):
+        collection = self._client.get_or_create_collection(self._collection_name)
+        collection.delete(where={key: {'$eq': value}})
+
+    def delete(self):
+        self._client.delete_collection(self._collection_name)
+
+    def delete_by_ids(self, ids: list[str]) -> None:
+        collection = self._client.get_or_create_collection(self._collection_name)
+        collection.delete(ids=ids)
+
+    def text_exists(self, id: str) -> bool:
+        collection = self._client.get_or_create_collection(self._collection_name)
+        response = collection.get(ids=[id])
+        return len(response) > 0
+
+    def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+        collection = self._client.get_or_create_collection(self._collection_name)
+        results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
+        score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
+
+        ids: list[str] = results['ids'][0]
+        documents: list[str] = results['documents'][0]
+        metadatas: dict[str, Any] = results['metadatas'][0]
+        distances: list[float] = results['distances'][0]
+
+        docs = []
+        for index in range(len(ids)):
+            distance = distances[index]
+            metadata = metadatas[index]
+            if distance >= score_threshold:
+                metadata['score'] = distance
+                doc = Document(
+                    page_content=documents[index],
+                    metadata=metadata,
+                )
+                docs.append(doc)
+
+        return docs
+
+    def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+        # chroma does not support BM25 full text searching
+        return []
+
+
+class ChromaVectorFactory(AbstractVectorFactory):
+    def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
+        if dataset.index_struct_dict:
+            class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
+            collection_name = class_prefix.lower()
+        else:
+            dataset_id = dataset.id
+            collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
+            index_struct_dict = {
+                "type": VectorType.CHROMA,
+                "vector_store": {"class_prefix": collection_name}
+            }
+            dataset.index_struct = json.dumps(index_struct_dict)
+
+        config = current_app.config
+        return ChromaVector(
+            collection_name=collection_name,
+            config=ChromaConfig(
+                host=config.get('CHROMA_HOST'),
+                port=int(config.get('CHROMA_PORT')),
+                tenant=config.get('CHROMA_TENANT', chromadb.DEFAULT_TENANT),
+                database=config.get('CHROMA_DATABASE', chromadb.DEFAULT_DATABASE),
+                auth_provider=config.get('CHROMA_AUTH_PROVIDER'),
+                auth_credentials=config.get('CHROMA_AUTH_CREDENTIALS'),
+            ),
+        )

+ 3 - 0
api/core/rag/datasource/vdb/vector_factory.py

@@ -52,6 +52,9 @@ class Vector:
     @staticmethod
     def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
         match vector_type:
+            case VectorType.CHROMA:
+                from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory
+                return ChromaVectorFactory
             case VectorType.MILVUS:
                 from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
                 return MilvusVectorFactory

+ 1 - 0
api/core/rag/datasource/vdb/vector_type.py

@@ -2,6 +2,7 @@ from enum import Enum
 
 
 class VectorType(str, Enum):
+    CHROMA = 'chroma'
     MILVUS = 'milvus'
     PGVECTOR = 'pgvector'
     PGVECTO_RS = 'pgvecto-rs'

File diff suppressed because it is too large
+ 841 - 0
api/poetry.lock


+ 1 - 1
api/pyproject.toml

@@ -107,7 +107,6 @@ pycryptodome = "3.19.1"
 python-dotenv = "1.0.0"
 authlib = "1.2.0"
 boto3 = "1.28.17"
-tenacity = "8.2.2"
 cachetools = "~5.3.0"
 weaviate-client = "~3.21.0"
 mailchimp-transactional = "~1.0.50"
@@ -179,6 +178,7 @@ google-cloud-aiplatform = "1.49.0"
 vanna = {version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"]}
 kaleido = "0.2.1"
 tencentcloud-sdk-python-hunyuan = "~3.0.1158"
+chromadb = "~0.5.0"
 
 [tool.poetry.group.dev]
 optional = true

+ 2 - 2
api/requirements.txt

@@ -16,7 +16,6 @@ pycryptodome==3.19.1
 python-dotenv==1.0.0
 Authlib==1.2.0
 boto3==1.34.123
-tenacity==8.2.2
 cachetools~=5.3.0
 weaviate-client~=3.21.0
 mailchimp-transactional~=1.0.50
@@ -85,4 +84,5 @@ pymysql==1.1.1
 tidb-vector==0.0.9
 google-cloud-aiplatform==1.49.0
 vanna[postgres,mysql,clickhouse,duckdb]==0.5.5
-tencentcloud-sdk-python-hunyuan~=3.0.1158
+tencentcloud-sdk-python-hunyuan~=3.0.1158
+chromadb~=0.5.0

+ 0 - 0
api/tests/integration_tests/vdb/chroma/__init__.py


+ 33 - 0
api/tests/integration_tests/vdb/chroma/test_chroma.py

@@ -0,0 +1,33 @@
+import chromadb
+
+from core.rag.datasource.vdb.chroma.chroma_vector import ChromaConfig, ChromaVector
+from tests.integration_tests.vdb.test_vector_store import (
+    AbstractVectorTest,
+    get_example_text,
+    setup_mock_redis,
+)
+
+
+class ChromaVectorTest(AbstractVectorTest):
+    def __init__(self):
+        super().__init__()
+        self.vector = ChromaVector(
+            collection_name=self.collection_name,
+            config=ChromaConfig(
+                host='localhost',
+                port=8000,
+                tenant=chromadb.DEFAULT_TENANT,
+                database=chromadb.DEFAULT_DATABASE,
+                auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider",
+                auth_credentials="difyai123456",
+            )
+        )
+
+    def search_by_full_text(self):
+        # chroma dos not support full text searching
+        hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
+        assert len(hits_by_full_text) == 0
+
+
+def test_chroma_vector(setup_mock_redis):
+    ChromaVectorTest().run_all_tests()

+ 14 - 0
docker/docker-compose.chroma.yaml

@@ -0,0 +1,14 @@
+version: '3'
+services:
+  # Chroma vector store.
+  chroma:
+    image: ghcr.io/chroma-core/chroma:0.5.0
+    restart: always
+    volumes:
+      - ./volumes/chroma:/chroma/chroma
+    environment:
+      CHROMA_SERVER_AUTHN_CREDENTIALS: difyai123456
+      CHROMA_SERVER_AUTHN_PROVIDER: chromadb.auth.token_authn.TokenAuthenticationServerProvider
+      IS_PERSISTENT: TRUE
+    ports:
+      - "8000:8000"

+ 14 - 0
docker/docker-compose.yaml

@@ -140,6 +140,13 @@ services:
       TIDB_VECTOR_USER: xxx.root
       TIDB_VECTOR_PASSWORD: xxxxxx
       TIDB_VECTOR_DATABASE: dify
+      # Chroma configuration
+      CHROMA_HOST: 127.0.0.1
+      CHROMA_PORT: 8000
+      CHROMA_TENANT: default_tenant
+      CHROMA_DATABASE: default_database
+      CHROMA_AUTH_PROVIDER: chromadb.auth.token_authn.TokenAuthClientProvider
+      CHROMA_AUTH_CREDENTIALS: xxxxxx
       # Mail configuration, support: resend, smtp
       MAIL_TYPE: ''
       # default send from email address, if not specified
@@ -301,6 +308,13 @@ services:
       TIDB_VECTOR_USER: xxx.root
       TIDB_VECTOR_PASSWORD: xxxxxx
       TIDB_VECTOR_DATABASE: dify
+      # Chroma configuration
+      CHROMA_HOST: 127.0.0.1
+      CHROMA_PORT: 8000
+      CHROMA_TENANT: default_tenant
+      CHROMA_DATABASE: default_database
+      CHROMA_AUTH_PROVIDER: chromadb.auth.token_authn.TokenAuthClientProvider
+      CHROMA_AUTH_CREDENTIALS: xxxxxx
       # Notion import configuration, support public and internal
       NOTION_INTEGRATION_TYPE: public
       NOTION_CLIENT_SECRET: you-client-secret

Some files were not shown because too many files changed in this diff