test_vector_store.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import uuid
  2. from unittest.mock import MagicMock
  3. import pytest
  4. from core.rag.models.document import Document
  5. from extensions import ext_redis
  6. from models.dataset import Dataset
  7. def get_example_text() -> str:
  8. return "test_text"
  9. def get_example_document(doc_id: str) -> Document:
  10. doc = Document(
  11. page_content=get_example_text(),
  12. metadata={
  13. "doc_id": doc_id,
  14. "doc_hash": doc_id,
  15. "document_id": doc_id,
  16. "dataset_id": doc_id,
  17. },
  18. )
  19. return doc
  20. @pytest.fixture
  21. def setup_mock_redis() -> None:
  22. # get
  23. ext_redis.redis_client.get = MagicMock(return_value=None)
  24. # set
  25. ext_redis.redis_client.set = MagicMock(return_value=None)
  26. # lock
  27. mock_redis_lock = MagicMock()
  28. mock_redis_lock.__enter__ = MagicMock()
  29. mock_redis_lock.__exit__ = MagicMock()
  30. ext_redis.redis_client.lock = mock_redis_lock
  31. class AbstractVectorTest:
  32. def __init__(self):
  33. self.vector = None
  34. self.dataset_id = str(uuid.uuid4())
  35. self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test"
  36. self.example_doc_id = str(uuid.uuid4())
  37. self.example_embedding = [1.001 * i for i in range(128)]
  38. def create_vector(self) -> None:
  39. self.vector.create(
  40. texts=[get_example_document(doc_id=self.example_doc_id)],
  41. embeddings=[self.example_embedding],
  42. )
  43. def search_by_vector(self):
  44. hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding)
  45. assert len(hits_by_vector) == 1
  46. assert hits_by_vector[0].metadata["doc_id"] == self.example_doc_id
  47. def search_by_full_text(self):
  48. hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
  49. assert len(hits_by_full_text) == 1
  50. assert hits_by_full_text[0].metadata["doc_id"] == self.example_doc_id
  51. def delete_vector(self):
  52. self.vector.delete()
  53. def delete_by_ids(self, ids: list[str]):
  54. self.vector.delete_by_ids(ids=ids)
  55. def add_texts(self) -> list[str]:
  56. batch_size = 100
  57. documents = [get_example_document(doc_id=str(uuid.uuid4())) for _ in range(batch_size)]
  58. embeddings = [self.example_embedding] * batch_size
  59. self.vector.add_texts(documents=documents, embeddings=embeddings)
  60. return [doc.metadata["doc_id"] for doc in documents]
  61. def text_exists(self):
  62. assert self.vector.text_exists(self.example_doc_id)
  63. def get_ids_by_metadata_field(self):
  64. with pytest.raises(NotImplementedError):
  65. self.vector.get_ids_by_metadata_field(key="key", value="value")
  66. def run_all_tests(self):
  67. self.create_vector()
  68. self.search_by_vector()
  69. self.search_by_full_text()
  70. self.text_exists()
  71. self.get_ids_by_metadata_field()
  72. added_doc_ids = self.add_texts()
  73. self.delete_by_ids(added_doc_ids)
  74. self.delete_vector()