test_vector_store.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import uuid
  2. from unittest.mock import MagicMock
  3. import pytest
  4. from core.rag.models.document import Document
  5. from extensions import ext_redis
  6. from models.dataset import Dataset
  7. def get_sample_text() -> str:
  8. return 'test_text'
  9. def get_sample_embedding() -> list[float]:
  10. return [1.1, 2.2, 3.3]
  11. def get_sample_query_vector() -> list[float]:
  12. return get_sample_embedding()
  13. def get_sample_document(sample_dataset_id: str) -> Document:
  14. doc = Document(
  15. page_content=get_sample_text(),
  16. metadata={
  17. "doc_id": sample_dataset_id,
  18. "doc_hash": sample_dataset_id,
  19. "document_id": sample_dataset_id,
  20. "dataset_id": sample_dataset_id,
  21. }
  22. )
  23. return doc
  24. @pytest.fixture
  25. def setup_mock_redis() -> None:
  26. # get
  27. ext_redis.redis_client.get = MagicMock(return_value=None)
  28. # set
  29. ext_redis.redis_client.set = MagicMock(return_value=None)
  30. # lock
  31. mock_redis_lock = MagicMock()
  32. mock_redis_lock.__enter__ = MagicMock()
  33. mock_redis_lock.__exit__ = MagicMock()
  34. ext_redis.redis_client.lock = mock_redis_lock
  35. class AbstractTestVector:
  36. def __init__(self):
  37. self.vector = None
  38. self.dataset_id = str(uuid.uuid4())
  39. self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id)
  40. def create_vector(self) -> None:
  41. self.vector.create(
  42. texts=[get_sample_document(self.dataset_id)],
  43. embeddings=[get_sample_embedding()],
  44. )
  45. def search_by_vector(self):
  46. hits_by_vector = self.vector.search_by_vector(query_vector=get_sample_query_vector())
  47. assert len(hits_by_vector) >= 1
  48. def search_by_full_text(self):
  49. hits_by_full_text = self.vector.search_by_full_text(query=get_sample_text())
  50. assert len(hits_by_full_text) >= 1
  51. def delete_vector(self):
  52. self.vector.delete()
  53. def run_all_test(self):
  54. self.create_vector()
  55. self.search_by_vector()
  56. self.search_by_full_text()
  57. self.delete_vector()