| 
					
				 | 
			
			
				@@ -0,0 +1,204 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import time
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from typing import Optional
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.model_runtime.entities.common_entities import I18nObject
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.model_runtime.errors.invoke import (
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    InvokeAuthorizationError,
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    InvokeBadRequestError,
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    InvokeConnectionError,
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    InvokeError,
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    InvokeRateLimitError,
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    InvokeServerUnavailableError,
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+)
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.model_runtime.errors.validate import CredentialsValidateFailedError
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiHelper
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    """
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    Model class for Text Embedding Inference text embedding model.
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    """
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _invoke(
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) -> TextEmbeddingResult:
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Invoke text embedding model
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        credentials should be like:
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        {
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'server_url': 'server url',
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'model_uid': 'model uid',
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        }
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param model: model name
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param credentials: model credentials
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param texts: texts to embed
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param user: unique user id
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :return: embeddings result
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        server_url = credentials['server_url']
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if server_url.endswith('/'):
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            server_url = server_url[:-1]
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # get model properties
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        context_size = self._get_context_size(model, credentials)
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        max_chunks = self._get_max_chunks(model, credentials)
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        inputs = []
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        indices = []
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        used_tokens = 0
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # get tokenized results from TEI
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts)
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)):
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # Check if the number of tokens is larger than the context size
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            num_tokens = len(tokenize_result)
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if num_tokens >= context_size:
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                # Find the best cutoff point
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                pre_special_token_count = 0
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                for token in tokenize_result:
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    if token['special']:
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        pre_special_token_count += 1
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    else:
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        break
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                rest_special_token_count = len([token for token in tokenize_result if token['special']]) - pre_special_token_count
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                # Calculate the cutoff point, leave 20 extra space to avoid exceeding the limit
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                token_cutoff = context_size - rest_special_token_count - 20
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                # Find the cutoff index
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                cutpoint_token = tokenize_result[token_cutoff]
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                cutoff = cutpoint_token['start']
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                inputs.append(text[0: cutoff])
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            else:
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                inputs.append(text)
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            indices += [i]
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        batched_embeddings = []
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        _iter = range(0, len(inputs), max_chunks)
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        try:
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            used_tokens = 0
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for i in _iter:
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                iter_texts = inputs[i : i + max_chunks]
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                results = TeiHelper.invoke_embeddings(server_url, iter_texts)
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                embeddings = results['data']
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                embeddings = [embedding['embedding'] for embedding in embeddings]
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                batched_embeddings.extend(embeddings)
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                usage = results['usage']
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                used_tokens += usage['total_tokens']
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        except RuntimeError as e:
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            raise InvokeServerUnavailableError(str(e))
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        result = TextEmbeddingResult(model=model, embeddings=batched_embeddings, usage=usage)
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return result
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Get number of tokens for given prompt messages
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param model: model name
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param credentials: model credentials
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param texts: texts to embed
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :return:
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        num_tokens = 0
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        server_url = credentials['server_url']
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if server_url.endswith('/'):
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            server_url = server_url[:-1]
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        batch_tokens = TeiHelper.invoke_tokenize(server_url, texts)
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        num_tokens = sum(len(tokens) for tokens in batch_tokens)
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return num_tokens
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def validate_credentials(self, model: str, credentials: dict) -> None:
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Validate model credentials
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param model: model name
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param credentials: model credentials
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :return:
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        try:
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            server_url = credentials['server_url']
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            print(extra_args)
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if extra_args.model_type != 'embedding':
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                raise CredentialsValidateFailedError('Current model is not a embedding model')
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            credentials['context_size'] = extra_args.max_input_length
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            credentials['max_chunks'] = extra_args.max_client_batch_size
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self._invoke(model=model, credentials=credentials, texts=['ping'])
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        except Exception as ex:
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            raise CredentialsValidateFailedError(str(ex))
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    @property
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return {
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            InvokeConnectionError: [InvokeConnectionError],
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            InvokeServerUnavailableError: [InvokeServerUnavailableError],
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            InvokeRateLimitError: [InvokeRateLimitError],
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            InvokeAuthorizationError: [InvokeAuthorizationError],
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            InvokeBadRequestError: [KeyError],
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        }
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Calculate response usage
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param model: model name
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param credentials: model credentials
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param tokens: input tokens
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :return: usage
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # get input price info
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        input_price_info = self.get_price(
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        )
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # transform usage
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        usage = EmbeddingUsage(
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            tokens=tokens,
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            total_tokens=tokens,
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            unit_price=input_price_info.unit_price,
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            price_unit=input_price_info.unit,
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            total_price=input_price_info.total_amount,
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            currency=input_price_info.currency,
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            latency=time.perf_counter() - self.started_at,
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        )
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return usage
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        used to define customizable model schema
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        entity = AIModelEntity(
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            model=model,
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            label=I18nObject(en_US=model),
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            model_type=ModelType.TEXT_EMBEDDING,
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            model_properties={
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ModelPropertyKey.MAX_CHUNKS: int(credentials.get('max_chunks', 1)),
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)),
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            },
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            parameter_rules=[],
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        )
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+
 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return entity
 
			 |