cached_embedding.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import base64
  2. import logging
  3. from typing import Optional, cast
  4. import numpy as np
  5. from sqlalchemy.exc import IntegrityError
  6. from core.model_manager import ModelInstance
  7. from core.model_runtime.entities.model_entities import ModelPropertyKey
  8. from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
  9. from core.rag.datasource.entity.embedding import Embeddings
  10. from extensions.ext_database import db
  11. from extensions.ext_redis import redis_client
  12. from libs import helper
  13. from models.dataset import Embedding
  14. logger = logging.getLogger(__name__)
  15. class CacheEmbedding(Embeddings):
  16. def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> None:
  17. self._model_instance = model_instance
  18. self._user = user
  19. def embed_documents(self, texts: list[str]) -> list[list[float]]:
  20. """Embed search docs in batches of 10."""
  21. # use doc embedding cache or store if not exists
  22. text_embeddings = [None for _ in range(len(texts))]
  23. embedding_queue_indices = []
  24. for i, text in enumerate(texts):
  25. hash = helper.generate_text_hash(text)
  26. embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model,
  27. hash=hash,
  28. provider_name=self._model_instance.provider).first()
  29. if embedding:
  30. text_embeddings[i] = embedding.get_embedding()
  31. else:
  32. embedding_queue_indices.append(i)
  33. if embedding_queue_indices:
  34. embedding_queue_texts = [texts[i] for i in embedding_queue_indices]
  35. embedding_queue_embeddings = []
  36. try:
  37. model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
  38. model_schema = model_type_instance.get_model_schema(self._model_instance.model,
  39. self._model_instance.credentials)
  40. max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \
  41. if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1
  42. for i in range(0, len(embedding_queue_texts), max_chunks):
  43. batch_texts = embedding_queue_texts[i:i + max_chunks]
  44. embedding_result = self._model_instance.invoke_text_embedding(
  45. texts=batch_texts,
  46. user=self._user
  47. )
  48. for vector in embedding_result.embeddings:
  49. try:
  50. normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
  51. embedding_queue_embeddings.append(normalized_embedding)
  52. except IntegrityError:
  53. db.session.rollback()
  54. except Exception as e:
  55. logging.exception('Failed transform embedding: ', e)
  56. cache_embeddings = []
  57. try:
  58. for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
  59. text_embeddings[i] = embedding
  60. hash = helper.generate_text_hash(texts[i])
  61. if hash not in cache_embeddings:
  62. embedding_cache = Embedding(model_name=self._model_instance.model,
  63. hash=hash,
  64. provider_name=self._model_instance.provider)
  65. embedding_cache.set_embedding(embedding)
  66. db.session.add(embedding_cache)
  67. cache_embeddings.append(hash)
  68. db.session.commit()
  69. except IntegrityError:
  70. db.session.rollback()
  71. except Exception as ex:
  72. db.session.rollback()
  73. logger.error('Failed to embed documents: ', ex)
  74. raise ex
  75. return text_embeddings
  76. def embed_query(self, text: str) -> list[float]:
  77. """Embed query text."""
  78. # use doc embedding cache or store if not exists
  79. hash = helper.generate_text_hash(text)
  80. embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
  81. embedding = redis_client.get(embedding_cache_key)
  82. if embedding:
  83. redis_client.expire(embedding_cache_key, 600)
  84. return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
  85. try:
  86. embedding_result = self._model_instance.invoke_text_embedding(
  87. texts=[text],
  88. user=self._user
  89. )
  90. embedding_results = embedding_result.embeddings[0]
  91. embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
  92. except Exception as ex:
  93. raise ex
  94. try:
  95. # encode embedding to base64
  96. embedding_vector = np.array(embedding_results)
  97. vector_bytes = embedding_vector.tobytes()
  98. # Transform to Base64
  99. encoded_vector = base64.b64encode(vector_bytes)
  100. # Transform to string
  101. encoded_str = encoded_vector.decode("utf-8")
  102. redis_client.setex(embedding_cache_key, 600, encoded_str)
  103. except IntegrityError:
  104. db.session.rollback()
  105. except:
  106. logging.exception('Failed to add embedding to redis')
  107. return embedding_results