| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215 | import osfrom typing import Unionfrom unittest.mock import MagicMockimport pytestfrom _pytest.monkeypatch import MonkeyPatchfrom volcengine.viking_db import (    Collection,    Data,    DistanceType,    Field,    FieldType,    Index,    IndexType,    QuantType,    VectorIndexParams,    VikingDBService,)from core.rag.datasource.vdb.field import Field as vdb_Fieldclass MockVikingDBClass:    def __init__(        self,        host="api-vikingdb.volces.com",        region="cn-north-1",        ak="",        sk="",        scheme="http",        connection_timeout=30,        socket_timeout=30,        proxy=None,    ):        self._viking_db_service = MagicMock()        self._viking_db_service.get_exception = MagicMock(return_value='{"data": {"primary_key": "test_id"}}')    def get_collection(self, collection_name) -> Collection:        return Collection(            collection_name=collection_name,            description="Collection For Dify",            viking_db_service=self._viking_db_service,            primary_key=vdb_Field.PRIMARY_KEY.value,            fields=[                Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True),                Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String),                Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String),                Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text),                Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=768),            ],            indexes=[                Index(                    collection_name=collection_name,                    index_name=f"{collection_name}_idx",                    vector_index=VectorIndexParams(                        distance=DistanceType.L2,                        index_type=IndexType.HNSW,                        quant=QuantType.Float,                    ),                    scalar_index=None,                    stat=None,                    viking_db_service=self._viking_db_service,                )            ],        )    def drop_collection(self, collection_name):        assert collection_name != ""    def create_collection(self, collection_name, fields, description="") -> Collection:        return Collection(            collection_name=collection_name,            description=description,            primary_key=vdb_Field.PRIMARY_KEY.value,            viking_db_service=self._viking_db_service,            fields=fields,        )    def get_index(self, collection_name, index_name) -> Index:        return Index(            collection_name=collection_name,            index_name=index_name,            viking_db_service=self._viking_db_service,            stat=None,            scalar_index=None,            vector_index=VectorIndexParams(                distance=DistanceType.L2,                index_type=IndexType.HNSW,                quant=QuantType.Float,            ),        )    def create_index(        self,        collection_name,        index_name,        vector_index=None,        cpu_quota=2,        description="",        partition_by="",        scalar_index=None,        shard_count=None,        shard_policy=None,    ):        return Index(            collection_name=collection_name,            index_name=index_name,            vector_index=vector_index,            cpu_quota=cpu_quota,            description=description,            partition_by=partition_by,            scalar_index=scalar_index,            shard_count=shard_count,            shard_policy=shard_policy,            viking_db_service=self._viking_db_service,            stat=None,        )    def drop_index(self, collection_name, index_name):        assert collection_name != ""        assert index_name != ""    def upsert_data(self, data: Union[Data, list[Data]]):        assert data is not None    def fetch_data(self, id: Union[str, list[str], int, list[int]]):        return Data(            fields={                vdb_Field.GROUP_KEY.value: "test_group",                vdb_Field.METADATA_KEY.value: "{}",                vdb_Field.CONTENT_KEY.value: "content",                vdb_Field.PRIMARY_KEY.value: id,                vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],            },            id=id,        )    def delete_data(self, id: Union[str, list[str], int, list[int]]):        assert id is not None    def search_by_vector(        self,        vector,        sparse_vectors=None,        filter=None,        limit=10,        output_fields=None,        partition="default",        dense_weight=None,    ) -> list[Data]:        return [            Data(                fields={                    vdb_Field.GROUP_KEY.value: "test_group",                    vdb_Field.METADATA_KEY.value: '\                    {"source": "/var/folders/ml/xxx/xxx.txt", \                    "document_id": "test_document_id", \                    "dataset_id": "test_dataset_id", \                    "doc_id": "test_id", \                    "doc_hash": "test_hash"}',                    vdb_Field.CONTENT_KEY.value: "content",                    vdb_Field.PRIMARY_KEY.value: "test_id",                    vdb_Field.VECTOR.value: vector,                },                id="test_id",                score=0.10,            )        ]    def search(        self, order=None, filter=None, limit=10, output_fields=None, partition="default", dense_weight=None    ) -> list[Data]:        return [            Data(                fields={                    vdb_Field.GROUP_KEY.value: "test_group",                    vdb_Field.METADATA_KEY.value: '\                    {"source": "/var/folders/ml/xxx/xxx.txt", \                    "document_id": "test_document_id", \                    "dataset_id": "test_dataset_id", \                    "doc_id": "test_id", \                    "doc_hash": "test_hash"}',                    vdb_Field.CONTENT_KEY.value: "content",                    vdb_Field.PRIMARY_KEY.value: "test_id",                    vdb_Field.VECTOR.value: [-0.00762577411336441, -0.01949881482151406, 0.008832383941428398],                },                id="test_id",                score=0.10,            )        ]MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"@pytest.fixturedef setup_vikingdb_mock(monkeypatch: MonkeyPatch):    if MOCK:        monkeypatch.setattr(VikingDBService, "__init__", MockVikingDBClass.__init__)        monkeypatch.setattr(VikingDBService, "get_collection", MockVikingDBClass.get_collection)        monkeypatch.setattr(VikingDBService, "create_collection", MockVikingDBClass.create_collection)        monkeypatch.setattr(VikingDBService, "drop_collection", MockVikingDBClass.drop_collection)        monkeypatch.setattr(VikingDBService, "get_index", MockVikingDBClass.get_index)        monkeypatch.setattr(VikingDBService, "create_index", MockVikingDBClass.create_index)        monkeypatch.setattr(VikingDBService, "drop_index", MockVikingDBClass.drop_index)        monkeypatch.setattr(Collection, "upsert_data", MockVikingDBClass.upsert_data)        monkeypatch.setattr(Collection, "fetch_data", MockVikingDBClass.fetch_data)        monkeypatch.setattr(Collection, "delete_data", MockVikingDBClass.delete_data)        monkeypatch.setattr(Index, "search_by_vector", MockVikingDBClass.search_by_vector)        monkeypatch.setattr(Index, "search", MockVikingDBClass.search)    yield    if MOCK:        monkeypatch.undo()
 |