test_vector_store.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import random
  2. import uuid
  3. from unittest.mock import MagicMock
  4. import pytest
  5. from core.rag.models.document import Document
  6. from extensions import ext_redis
  7. from models.dataset import Dataset
  8. def get_example_text() -> str:
  9. return "test_text"
  10. def get_example_document(doc_id: str) -> Document:
  11. doc = Document(
  12. page_content=get_example_text(),
  13. metadata={
  14. "doc_id": doc_id,
  15. "doc_hash": doc_id,
  16. "document_id": doc_id,
  17. "dataset_id": doc_id,
  18. },
  19. )
  20. return doc
  21. @pytest.fixture
  22. def setup_mock_redis() -> None:
  23. # get
  24. ext_redis.redis_client.get = MagicMock(return_value=None)
  25. # set
  26. ext_redis.redis_client.set = MagicMock(return_value=None)
  27. # lock
  28. mock_redis_lock = MagicMock()
  29. mock_redis_lock.__enter__ = MagicMock()
  30. mock_redis_lock.__exit__ = MagicMock()
  31. ext_redis.redis_client.lock = mock_redis_lock
  32. class AbstractVectorTest:
  33. def __init__(self):
  34. self.vector = None
  35. self.dataset_id = str(uuid.uuid4())
  36. self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test"
  37. self.example_doc_id = str(uuid.uuid4())
  38. self.example_embedding = [1.001 * i for i in range(128)]
  39. def create_vector(self) -> None:
  40. self.vector.create(
  41. texts=[get_example_document(doc_id=self.example_doc_id)],
  42. embeddings=[self.example_embedding],
  43. )
  44. def search_by_vector(self):
  45. hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding)
  46. assert len(hits_by_vector) == 1
  47. assert hits_by_vector[0].metadata["doc_id"] == self.example_doc_id
  48. def search_by_full_text(self):
  49. hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
  50. assert len(hits_by_full_text) == 1
  51. assert hits_by_full_text[0].metadata["doc_id"] == self.example_doc_id
  52. def delete_vector(self):
  53. self.vector.delete()
  54. def delete_by_ids(self, ids: list[str]):
  55. self.vector.delete_by_ids(ids=ids)
  56. def add_texts(self) -> list[str]:
  57. batch_size = 100
  58. documents = [get_example_document(doc_id=str(uuid.uuid4())) for _ in range(batch_size)]
  59. embeddings = [self.example_embedding] * batch_size
  60. self.vector.add_texts(documents=documents, embeddings=embeddings)
  61. return [doc.metadata["doc_id"] for doc in documents]
  62. def text_exists(self):
  63. assert self.vector.text_exists(self.example_doc_id)
  64. def get_ids_by_metadata_field(self):
  65. with pytest.raises(NotImplementedError):
  66. self.vector.get_ids_by_metadata_field(key="key", value="value")
  67. def run_all_tests(self):
  68. self.create_vector()
  69. self.search_by_vector()
  70. self.search_by_full_text()
  71. self.text_exists()
  72. self.get_ids_by_metadata_field()
  73. added_doc_ids = self.add_texts()
  74. self.delete_by_ids(added_doc_ids)
  75. self.delete_vector()