|  | @@ -0,0 +1,163 @@
 | 
											
												
													
														|  | 
 |  | +import time
 | 
											
												
													
														|  | 
 |  | +from json import JSONDecodeError, dumps
 | 
											
												
													
														|  | 
 |  | +from typing import Optional
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +import requests
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +from core.model_runtime.entities.common_entities import I18nObject
 | 
											
												
													
														|  | 
 |  | +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
 | 
											
												
													
														|  | 
 |  | +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 | 
											
												
													
														|  | 
 |  | +from core.model_runtime.errors.invoke import (
 | 
											
												
													
														|  | 
 |  | +    InvokeAuthorizationError,
 | 
											
												
													
														|  | 
 |  | +    InvokeBadRequestError,
 | 
											
												
													
														|  | 
 |  | +    InvokeConnectionError,
 | 
											
												
													
														|  | 
 |  | +    InvokeError,
 | 
											
												
													
														|  | 
 |  | +    InvokeRateLimitError,
 | 
											
												
													
														|  | 
 |  | +    InvokeServerUnavailableError,
 | 
											
												
													
														|  | 
 |  | +)
 | 
											
												
													
														|  | 
 |  | +from core.model_runtime.errors.validate import CredentialsValidateFailedError
 | 
											
												
													
														|  | 
 |  | +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class MixedBreadTextEmbeddingModel(TextEmbeddingModel):
 | 
											
												
													
														|  | 
 |  | +    """
 | 
											
												
													
														|  | 
 |  | +    Model class for MixedBread text embedding model.
 | 
											
												
													
														|  | 
 |  | +    """
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    api_base: str = "https://api.mixedbread.ai/v1"
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def _invoke(
 | 
											
												
													
														|  | 
 |  | +        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
 | 
											
												
													
														|  | 
 |  | +    ) -> TextEmbeddingResult:
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        Invoke text embedding model
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        :param model: model name
 | 
											
												
													
														|  | 
 |  | +        :param credentials: model credentials
 | 
											
												
													
														|  | 
 |  | +        :param texts: texts to embed
 | 
											
												
													
														|  | 
 |  | +        :param user: unique user id
 | 
											
												
													
														|  | 
 |  | +        :return: embeddings result
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        api_key = credentials["api_key"]
 | 
											
												
													
														|  | 
 |  | +        if not api_key:
 | 
											
												
													
														|  | 
 |  | +            raise CredentialsValidateFailedError("api_key is required")
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        base_url = credentials.get("base_url", self.api_base)
 | 
											
												
													
														|  | 
 |  | +        base_url = base_url.removesuffix("/")
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        url = base_url + "/embeddings"
 | 
											
												
													
														|  | 
 |  | +        headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        data = {"model": model, "input": texts}
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        try:
 | 
											
												
													
														|  | 
 |  | +            response = requests.post(url, headers=headers, data=dumps(data))
 | 
											
												
													
														|  | 
 |  | +        except Exception as e:
 | 
											
												
													
														|  | 
 |  | +            raise InvokeConnectionError(str(e))
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        if response.status_code != 200:
 | 
											
												
													
														|  | 
 |  | +            try:
 | 
											
												
													
														|  | 
 |  | +                resp = response.json()
 | 
											
												
													
														|  | 
 |  | +                msg = resp["detail"]
 | 
											
												
													
														|  | 
 |  | +                if response.status_code == 401:
 | 
											
												
													
														|  | 
 |  | +                    raise InvokeAuthorizationError(msg)
 | 
											
												
													
														|  | 
 |  | +                elif response.status_code == 429:
 | 
											
												
													
														|  | 
 |  | +                    raise InvokeRateLimitError(msg)
 | 
											
												
													
														|  | 
 |  | +                elif response.status_code == 500:
 | 
											
												
													
														|  | 
 |  | +                    raise InvokeServerUnavailableError(msg)
 | 
											
												
													
														|  | 
 |  | +                else:
 | 
											
												
													
														|  | 
 |  | +                    raise InvokeBadRequestError(msg)
 | 
											
												
													
														|  | 
 |  | +            except JSONDecodeError as e:
 | 
											
												
													
														|  | 
 |  | +                raise InvokeServerUnavailableError(
 | 
											
												
													
														|  | 
 |  | +                    f"Failed to convert response to json: {e} with text: {response.text}"
 | 
											
												
													
														|  | 
 |  | +                )
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        try:
 | 
											
												
													
														|  | 
 |  | +            resp = response.json()
 | 
											
												
													
														|  | 
 |  | +            embeddings = resp["data"]
 | 
											
												
													
														|  | 
 |  | +            usage = resp["usage"]
 | 
											
												
													
														|  | 
 |  | +        except Exception as e:
 | 
											
												
													
														|  | 
 |  | +            raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"])
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        result = TextEmbeddingResult(
 | 
											
												
													
														|  | 
 |  | +            model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage
 | 
											
												
													
														|  | 
 |  | +        )
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        return result
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        Get number of tokens for given prompt messages
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        :param model: model name
 | 
											
												
													
														|  | 
 |  | +        :param credentials: model credentials
 | 
											
												
													
														|  | 
 |  | +        :param texts: texts to embed
 | 
											
												
													
														|  | 
 |  | +        :return:
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def validate_credentials(self, model: str, credentials: dict) -> None:
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        Validate model credentials
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        :param model: model name
 | 
											
												
													
														|  | 
 |  | +        :param credentials: model credentials
 | 
											
												
													
														|  | 
 |  | +        :return:
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        try:
 | 
											
												
													
														|  | 
 |  | +            self._invoke(model=model, credentials=credentials, texts=["ping"])
 | 
											
												
													
														|  | 
 |  | +        except Exception as e:
 | 
											
												
													
														|  | 
 |  | +            raise CredentialsValidateFailedError(f"Credentials validation failed: {e}")
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    @property
 | 
											
												
													
														|  | 
 |  | +    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
 | 
											
												
													
														|  | 
 |  | +        return {
 | 
											
												
													
														|  | 
 |  | +            InvokeConnectionError: [InvokeConnectionError],
 | 
											
												
													
														|  | 
 |  | +            InvokeServerUnavailableError: [InvokeServerUnavailableError],
 | 
											
												
													
														|  | 
 |  | +            InvokeRateLimitError: [InvokeRateLimitError],
 | 
											
												
													
														|  | 
 |  | +            InvokeAuthorizationError: [InvokeAuthorizationError],
 | 
											
												
													
														|  | 
 |  | +            InvokeBadRequestError: [KeyError, InvokeBadRequestError],
 | 
											
												
													
														|  | 
 |  | +        }
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        Calculate response usage
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        :param model: model name
 | 
											
												
													
														|  | 
 |  | +        :param credentials: model credentials
 | 
											
												
													
														|  | 
 |  | +        :param tokens: input tokens
 | 
											
												
													
														|  | 
 |  | +        :return: usage
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        # get input price info
 | 
											
												
													
														|  | 
 |  | +        input_price_info = self.get_price(
 | 
											
												
													
														|  | 
 |  | +            model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
 | 
											
												
													
														|  | 
 |  | +        )
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # transform usage
 | 
											
												
													
														|  | 
 |  | +        usage = EmbeddingUsage(
 | 
											
												
													
														|  | 
 |  | +            tokens=tokens,
 | 
											
												
													
														|  | 
 |  | +            total_tokens=tokens,
 | 
											
												
													
														|  | 
 |  | +            unit_price=input_price_info.unit_price,
 | 
											
												
													
														|  | 
 |  | +            price_unit=input_price_info.unit,
 | 
											
												
													
														|  | 
 |  | +            total_price=input_price_info.total_amount,
 | 
											
												
													
														|  | 
 |  | +            currency=input_price_info.currency,
 | 
											
												
													
														|  | 
 |  | +            latency=time.perf_counter() - self.started_at,
 | 
											
												
													
														|  | 
 |  | +        )
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        return usage
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        generate custom model entities from credentials
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        entity = AIModelEntity(
 | 
											
												
													
														|  | 
 |  | +            model=model,
 | 
											
												
													
														|  | 
 |  | +            label=I18nObject(en_US=model),
 | 
											
												
													
														|  | 
 |  | +            model_type=ModelType.TEXT_EMBEDDING,
 | 
											
												
													
														|  | 
 |  | +            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
 | 
											
												
													
														|  | 
 |  | +            model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "512"))},
 | 
											
												
													
														|  | 
 |  | +        )
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        return entity
 |