cached_embedding.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import logging
  2. from typing import List
  3. from langchain.embeddings.base import Embeddings
  4. from sqlalchemy.exc import IntegrityError
  5. from core.model_providers.models.embedding.base import BaseEmbedding
  6. from extensions.ext_database import db
  7. from libs import helper
  8. from models.dataset import Embedding
  9. class CacheEmbedding(Embeddings):
  10. def __init__(self, embeddings: BaseEmbedding):
  11. self._embeddings = embeddings
  12. def embed_documents(self, texts: List[str]) -> List[List[float]]:
  13. """Embed search docs."""
  14. # use doc embedding cache or store if not exists
  15. text_embeddings = []
  16. embedding_queue_texts = []
  17. for text in texts:
  18. hash = helper.generate_text_hash(text)
  19. embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
  20. if embedding:
  21. text_embeddings.append(embedding.get_embedding())
  22. else:
  23. embedding_queue_texts.append(text)
  24. if embedding_queue_texts:
  25. try:
  26. embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts)
  27. except Exception as ex:
  28. raise self._embeddings.handle_exceptions(ex)
  29. i = 0
  30. for text in embedding_queue_texts:
  31. hash = helper.generate_text_hash(text)
  32. try:
  33. embedding = Embedding(model_name=self._embeddings.name, hash=hash)
  34. embedding.set_embedding(embedding_results[i])
  35. db.session.add(embedding)
  36. db.session.commit()
  37. except IntegrityError:
  38. db.session.rollback()
  39. continue
  40. except:
  41. logging.exception('Failed to add embedding to db')
  42. continue
  43. finally:
  44. i += 1
  45. text_embeddings.extend(embedding_results)
  46. return text_embeddings
  47. def embed_query(self, text: str) -> List[float]:
  48. """Embed query text."""
  49. # use doc embedding cache or store if not exists
  50. hash = helper.generate_text_hash(text)
  51. embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first()
  52. if embedding:
  53. return embedding.get_embedding()
  54. try:
  55. embedding_results = self._embeddings.client.embed_query(text)
  56. except Exception as ex:
  57. raise self._embeddings.handle_exceptions(ex)
  58. try:
  59. embedding = Embedding(model_name=self._embeddings.name, hash=hash)
  60. embedding.set_embedding(embedding_results)
  61. db.session.add(embedding)
  62. db.session.commit()
  63. except IntegrityError:
  64. db.session.rollback()
  65. except:
  66. logging.exception('Failed to add embedding to db')
  67. return embedding_results