|  | @@ -1,3 +1,5 @@
 | 
	
		
			
				|  |  | +import base64
 | 
	
		
			
				|  |  | +import json
 | 
	
		
			
				|  |  |  import logging
 | 
	
		
			
				|  |  |  from typing import List, Optional
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -5,6 +7,8 @@ import numpy as np
 | 
	
		
			
				|  |  |  from core.model_manager import ModelInstance
 | 
	
		
			
				|  |  |  from extensions.ext_database import db
 | 
	
		
			
				|  |  |  from langchain.embeddings.base import Embeddings
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +from extensions.ext_redis import redis_client
 | 
	
		
			
				|  |  |  from libs import helper
 | 
	
		
			
				|  |  |  from models.dataset import Embedding
 | 
	
		
			
				|  |  |  from sqlalchemy.exc import IntegrityError
 | 
	
	
		
			
				|  | @@ -24,9 +28,12 @@ class CacheEmbedding(Embeddings):
 | 
	
		
			
				|  |  |          embedding_queue_indices = []
 | 
	
		
			
				|  |  |          for i, text in enumerate(texts):
 | 
	
		
			
				|  |  |              hash = helper.generate_text_hash(text)
 | 
	
		
			
				|  |  | -            embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
 | 
	
		
			
				|  |  | +            embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
 | 
	
		
			
				|  |  | +            embedding = redis_client.get(embedding_cache_key)
 | 
	
		
			
				|  |  |              if embedding:
 | 
	
		
			
				|  |  | -                text_embeddings[i] = embedding.get_embedding()
 | 
	
		
			
				|  |  | +                redis_client.expire(embedding_cache_key, 3600)
 | 
	
		
			
				|  |  | +                text_embeddings[i] = list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |              else:
 | 
	
		
			
				|  |  |                  embedding_queue_indices.append(i)
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -46,18 +53,24 @@ class CacheEmbedding(Embeddings):
 | 
	
		
			
				|  |  |                  hash = helper.generate_text_hash(texts[indice])
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |                  try:
 | 
	
		
			
				|  |  | -                    embedding = Embedding(model_name=self._model_instance.model, hash=hash)
 | 
	
		
			
				|  |  | +                    embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
 | 
	
		
			
				|  |  |                      vector = embedding_results[i]
 | 
	
		
			
				|  |  |                      normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
 | 
	
		
			
				|  |  |                      text_embeddings[indice] = normalized_embedding
 | 
	
		
			
				|  |  | -                    embedding.set_embedding(normalized_embedding)
 | 
	
		
			
				|  |  | -                    db.session.add(embedding)
 | 
	
		
			
				|  |  | -                    db.session.commit()
 | 
	
		
			
				|  |  | +                    # encode embedding to base64
 | 
	
		
			
				|  |  | +                    embedding_vector = np.array(normalized_embedding)
 | 
	
		
			
				|  |  | +                    vector_bytes = embedding_vector.tobytes()
 | 
	
		
			
				|  |  | +                    # Transform to Base64
 | 
	
		
			
				|  |  | +                    encoded_vector = base64.b64encode(vector_bytes)
 | 
	
		
			
				|  |  | +                    # Transform to string
 | 
	
		
			
				|  |  | +                    encoded_str = encoded_vector.decode("utf-8")
 | 
	
		
			
				|  |  | +                    redis_client.setex(embedding_cache_key, 3600, encoded_str)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |                  except IntegrityError:
 | 
	
		
			
				|  |  |                      db.session.rollback()
 | 
	
		
			
				|  |  |                      continue
 | 
	
		
			
				|  |  |                  except:
 | 
	
		
			
				|  |  | -                    logging.exception('Failed to add embedding to db')
 | 
	
		
			
				|  |  | +                    logging.exception('Failed to add embedding to redis')
 | 
	
		
			
				|  |  |                      continue
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          return text_embeddings
 | 
	
	
		
			
				|  | @@ -66,9 +79,12 @@ class CacheEmbedding(Embeddings):
 | 
	
		
			
				|  |  |          """Embed query text."""
 | 
	
		
			
				|  |  |          # use doc embedding cache or store if not exists
 | 
	
		
			
				|  |  |          hash = helper.generate_text_hash(text)
 | 
	
		
			
				|  |  | -        embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first()
 | 
	
		
			
				|  |  | +        embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
 | 
	
		
			
				|  |  | +        embedding = redis_client.get(embedding_cache_key)
 | 
	
		
			
				|  |  |          if embedding:
 | 
	
		
			
				|  |  | -            return embedding.get_embedding()
 | 
	
		
			
				|  |  | +            redis_client.expire(embedding_cache_key, 3600)
 | 
	
		
			
				|  |  | +            return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          try:
 | 
	
		
			
				|  |  |              embedding_result = self._model_instance.invoke_text_embedding(
 | 
	
	
		
			
				|  | @@ -82,13 +98,18 @@ class CacheEmbedding(Embeddings):
 | 
	
		
			
				|  |  |              raise ex
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          try:
 | 
	
		
			
				|  |  | -            embedding = Embedding(model_name=self._model_instance.model, hash=hash)
 | 
	
		
			
				|  |  | -            embedding.set_embedding(embedding_results)
 | 
	
		
			
				|  |  | -            db.session.add(embedding)
 | 
	
		
			
				|  |  | -            db.session.commit()
 | 
	
		
			
				|  |  | +            # encode embedding to base64
 | 
	
		
			
				|  |  | +            embedding_vector = np.array(embedding_results)
 | 
	
		
			
				|  |  | +            vector_bytes = embedding_vector.tobytes()
 | 
	
		
			
				|  |  | +            # Transform to Base64
 | 
	
		
			
				|  |  | +            encoded_vector = base64.b64encode(vector_bytes)
 | 
	
		
			
				|  |  | +            # Transform to string
 | 
	
		
			
				|  |  | +            encoded_str = encoded_vector.decode("utf-8")
 | 
	
		
			
				|  |  | +            redis_client.setex(embedding_cache_key, 3600, encoded_str)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          except IntegrityError:
 | 
	
		
			
				|  |  |              db.session.rollback()
 | 
	
		
			
				|  |  |          except:
 | 
	
		
			
				|  |  | -            logging.exception('Failed to add embedding to db')
 | 
	
		
			
				|  |  | +            logging.exception('Failed to add embedding to redis')
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          return embedding_results
 |