| 
					
				 | 
			
			
				@@ -1,3 +1,5 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import base64 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import json 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import logging 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from typing import List, Optional 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -5,6 +7,8 @@ import numpy as np 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.model_manager import ModelInstance 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from extensions.ext_database import db 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from langchain.embeddings.base import Embeddings 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from extensions.ext_redis import redis_client 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from libs import helper 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from models.dataset import Embedding 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from sqlalchemy.exc import IntegrityError 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -24,9 +28,12 @@ class CacheEmbedding(Embeddings): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         embedding_queue_indices = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         for i, text in enumerate(texts): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             hash = helper.generate_text_hash(text) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            embedding = redis_client.get(embedding_cache_key) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if embedding: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                text_embeddings[i] = embedding.get_embedding() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                redis_client.expire(embedding_cache_key, 3600) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                text_embeddings[i] = list(np.frombuffer(base64.b64decode(embedding), dtype="float")) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 embedding_queue_indices.append(i) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -46,18 +53,24 @@ class CacheEmbedding(Embeddings): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 hash = helper.generate_text_hash(texts[indice]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 try: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    embedding = Embedding(model_name=self._model_instance.model, hash=hash) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     vector = embedding_results[i] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     normalized_embedding = (vector / np.linalg.norm(vector)).tolist() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     text_embeddings[indice] = normalized_embedding 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    embedding.set_embedding(normalized_embedding) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    db.session.add(embedding) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    db.session.commit() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    # encode embedding to base64 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    embedding_vector = np.array(normalized_embedding) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    vector_bytes = embedding_vector.tobytes() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    # Transform to Base64 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    encoded_vector = base64.b64encode(vector_bytes) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    # Transform to string 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    encoded_str = encoded_vector.decode("utf-8") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    redis_client.setex(embedding_cache_key, 3600, encoded_str) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 except IntegrityError: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     db.session.rollback() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 except: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    logging.exception('Failed to add embedding to db') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    logging.exception('Failed to add embedding to redis') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return text_embeddings 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -66,9 +79,12 @@ class CacheEmbedding(Embeddings): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """Embed query text.""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # use doc embedding cache or store if not exists 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         hash = helper.generate_text_hash(text) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        embedding = redis_client.get(embedding_cache_key) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if embedding: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            return embedding.get_embedding() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            redis_client.expire(embedding_cache_key, 3600) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return list(np.frombuffer(base64.b64decode(embedding), dtype="float")) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         try: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             embedding_result = self._model_instance.invoke_text_embedding( 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -82,13 +98,18 @@ class CacheEmbedding(Embeddings): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             raise ex 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         try: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            embedding = Embedding(model_name=self._model_instance.model, hash=hash) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            embedding.set_embedding(embedding_results) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            db.session.add(embedding) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            db.session.commit() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # encode embedding to base64 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            embedding_vector = np.array(embedding_results) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            vector_bytes = embedding_vector.tobytes() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # Transform to Base64 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            encoded_vector = base64.b64encode(vector_bytes) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # Transform to string 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            encoded_str = encoded_vector.decode("utf-8") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            redis_client.setex(embedding_cache_key, 3600, encoded_str) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         except IntegrityError: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             db.session.rollback() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         except: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            logging.exception('Failed to add embedding to db') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            logging.exception('Failed to add embedding to redis') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return embedding_results 
			 |