|  | @@ -1,10 +1,12 @@
 | 
	
		
			
				|  |  |  import base64
 | 
	
		
			
				|  |  |  import json
 | 
	
		
			
				|  |  |  import logging
 | 
	
		
			
				|  |  | -from typing import List, Optional
 | 
	
		
			
				|  |  | +from typing import List, Optional, cast
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import numpy as np
 | 
	
		
			
				|  |  |  from core.model_manager import ModelInstance
 | 
	
		
			
				|  |  | +from core.model_runtime.entities.model_entities import ModelPropertyKey
 | 
	
		
			
				|  |  | +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
 | 
	
		
			
				|  |  |  from extensions.ext_database import db
 | 
	
		
			
				|  |  |  from langchain.embeddings.base import Embeddings
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -22,56 +24,33 @@ class CacheEmbedding(Embeddings):
 | 
	
		
			
				|  |  |          self._user = user
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def embed_documents(self, texts: List[str]) -> List[List[float]]:
 | 
	
		
			
				|  |  | -        """Embed search docs."""
 | 
	
		
			
				|  |  | -        # use doc embedding cache or store if not exists
 | 
	
		
			
				|  |  | -        text_embeddings = [None for _ in range(len(texts))]
 | 
	
		
			
				|  |  | -        embedding_queue_indices = []
 | 
	
		
			
				|  |  | -        for i, text in enumerate(texts):
 | 
	
		
			
				|  |  | -            hash = helper.generate_text_hash(text)
 | 
	
		
			
				|  |  | -            embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
 | 
	
		
			
				|  |  | -            embedding = redis_client.get(embedding_cache_key)
 | 
	
		
			
				|  |  | -            if embedding:
 | 
	
		
			
				|  |  | -                redis_client.expire(embedding_cache_key, 3600)
 | 
	
		
			
				|  |  | -                text_embeddings[i] = list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -            else:
 | 
	
		
			
				|  |  | -                embedding_queue_indices.append(i)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        if embedding_queue_indices:
 | 
	
		
			
				|  |  | -            try:
 | 
	
		
			
				|  |  | +        """Embed search docs in batches of 10."""
 | 
	
		
			
				|  |  | +        text_embeddings = []
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
 | 
	
		
			
				|  |  | +            model_schema = model_type_instance.get_model_schema(self._model_instance.model, self._model_instance.credentials)
 | 
	
		
			
				|  |  | +            max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \
 | 
	
		
			
				|  |  | +                if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1
 | 
	
		
			
				|  |  | +            for i in range(0, len(texts), max_chunks):
 | 
	
		
			
				|  |  | +                batch_texts = texts[i:i + max_chunks]
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |                  embedding_result = self._model_instance.invoke_text_embedding(
 | 
	
		
			
				|  |  | -                    texts=[texts[i] for i in embedding_queue_indices],
 | 
	
		
			
				|  |  | +                    texts=batch_texts,
 | 
	
		
			
				|  |  |                      user=self._user
 | 
	
		
			
				|  |  |                  )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -                embedding_results = embedding_result.embeddings
 | 
	
		
			
				|  |  | -            except Exception as ex:
 | 
	
		
			
				|  |  | -                logger.error('Failed to embed documents: ', ex)
 | 
	
		
			
				|  |  | -                raise ex
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -            for i, indice in enumerate(embedding_queue_indices):
 | 
	
		
			
				|  |  | -                hash = helper.generate_text_hash(texts[indice])
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -                try:
 | 
	
		
			
				|  |  | -                    embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
 | 
	
		
			
				|  |  | -                    vector = embedding_results[i]
 | 
	
		
			
				|  |  | -                    normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
 | 
	
		
			
				|  |  | -                    text_embeddings[indice] = normalized_embedding
 | 
	
		
			
				|  |  | -                    # encode embedding to base64
 | 
	
		
			
				|  |  | -                    embedding_vector = np.array(normalized_embedding)
 | 
	
		
			
				|  |  | -                    vector_bytes = embedding_vector.tobytes()
 | 
	
		
			
				|  |  | -                    # Transform to Base64
 | 
	
		
			
				|  |  | -                    encoded_vector = base64.b64encode(vector_bytes)
 | 
	
		
			
				|  |  | -                    # Transform to string
 | 
	
		
			
				|  |  | -                    encoded_str = encoded_vector.decode("utf-8")
 | 
	
		
			
				|  |  | -                    redis_client.setex(embedding_cache_key, 3600, encoded_str)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -                except IntegrityError:
 | 
	
		
			
				|  |  | -                    db.session.rollback()
 | 
	
		
			
				|  |  | -                    continue
 | 
	
		
			
				|  |  | -                except:
 | 
	
		
			
				|  |  | -                    logging.exception('Failed to add embedding to redis')
 | 
	
		
			
				|  |  | -                    continue
 | 
	
		
			
				|  |  | +                for vector in embedding_result.embeddings:
 | 
	
		
			
				|  |  | +                    try:
 | 
	
		
			
				|  |  | +                        normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
 | 
	
		
			
				|  |  | +                        text_embeddings.append(normalized_embedding)
 | 
	
		
			
				|  |  | +                    except IntegrityError:
 | 
	
		
			
				|  |  | +                        db.session.rollback()
 | 
	
		
			
				|  |  | +                    except Exception as e:
 | 
	
		
			
				|  |  | +                        logging.exception('Failed to add embedding to redis')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        except Exception as ex:
 | 
	
		
			
				|  |  | +            logger.error('Failed to embed documents: ', ex)
 | 
	
		
			
				|  |  | +            raise ex
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          return text_embeddings
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -82,7 +61,7 @@ class CacheEmbedding(Embeddings):
 | 
	
		
			
				|  |  |          embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
 | 
	
		
			
				|  |  |          embedding = redis_client.get(embedding_cache_key)
 | 
	
		
			
				|  |  |          if embedding:
 | 
	
		
			
				|  |  | -            redis_client.expire(embedding_cache_key, 3600)
 | 
	
		
			
				|  |  | +            redis_client.expire(embedding_cache_key, 600)
 | 
	
		
			
				|  |  |              return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -105,7 +84,7 @@ class CacheEmbedding(Embeddings):
 | 
	
		
			
				|  |  |              encoded_vector = base64.b64encode(vector_bytes)
 | 
	
		
			
				|  |  |              # Transform to string
 | 
	
		
			
				|  |  |              encoded_str = encoded_vector.decode("utf-8")
 | 
	
		
			
				|  |  | -            redis_client.setex(embedding_cache_key, 3600, encoded_str)
 | 
	
		
			
				|  |  | +            redis_client.setex(embedding_cache_key, 600, encoded_str)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          except IntegrityError:
 | 
	
		
			
				|  |  |              db.session.rollback()
 |