|  | @@ -0,0 +1,209 @@
 | 
											
												
													
														|  | 
 |  | +import json
 | 
											
												
													
														|  | 
 |  | +import time
 | 
											
												
													
														|  | 
 |  | +from typing import Optional
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +import boto3
 | 
											
												
													
														|  | 
 |  | +from botocore.config import Config
 | 
											
												
													
														|  | 
 |  | +from botocore.exceptions import (
 | 
											
												
													
														|  | 
 |  | +    ClientError,
 | 
											
												
													
														|  | 
 |  | +    EndpointConnectionError,
 | 
											
												
													
														|  | 
 |  | +    NoRegionError,
 | 
											
												
													
														|  | 
 |  | +    ServiceNotInRegionError,
 | 
											
												
													
														|  | 
 |  | +    UnknownServiceError,
 | 
											
												
													
														|  | 
 |  | +)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +from core.model_runtime.entities.model_entities import PriceType
 | 
											
												
													
														|  | 
 |  | +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 | 
											
												
													
														|  | 
 |  | +from core.model_runtime.errors.invoke import (
 | 
											
												
													
														|  | 
 |  | +    InvokeAuthorizationError,
 | 
											
												
													
														|  | 
 |  | +    InvokeBadRequestError,
 | 
											
												
													
														|  | 
 |  | +    InvokeConnectionError,
 | 
											
												
													
														|  | 
 |  | +    InvokeError,
 | 
											
												
													
														|  | 
 |  | +    InvokeRateLimitError,
 | 
											
												
													
														|  | 
 |  | +    InvokeServerUnavailableError,
 | 
											
												
													
														|  | 
 |  | +)
 | 
											
												
													
														|  | 
 |  | +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class BedrockTextEmbeddingModel(TextEmbeddingModel):
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def _invoke(self, model: str, credentials: dict,
 | 
											
												
													
														|  | 
 |  | +                texts: list[str], user: Optional[str] = None) \
 | 
											
												
													
														|  | 
 |  | +            -> TextEmbeddingResult:
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        Invoke text embedding model
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        :param model: model name
 | 
											
												
													
														|  | 
 |  | +        :param credentials: model credentials
 | 
											
												
													
														|  | 
 |  | +        :param texts: texts to embed
 | 
											
												
													
														|  | 
 |  | +        :param user: unique user id
 | 
											
												
													
														|  | 
 |  | +        :return: embeddings result
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        client_config = Config(
 | 
											
												
													
														|  | 
 |  | +            region_name=credentials["aws_region"]
 | 
											
												
													
														|  | 
 |  | +        )
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        bedrock_runtime = boto3.client(
 | 
											
												
													
														|  | 
 |  | +            service_name='bedrock-runtime',
 | 
											
												
													
														|  | 
 |  | +            config=client_config,
 | 
											
												
													
														|  | 
 |  | +            aws_access_key_id=credentials["aws_access_key_id"],
 | 
											
												
													
														|  | 
 |  | +            aws_secret_access_key=credentials["aws_secret_access_key"]
 | 
											
												
													
														|  | 
 |  | +        )
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        embeddings = []
 | 
											
												
													
														|  | 
 |  | +        token_usage = 0
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        model_prefix = model.split('.')[0]
 | 
											
												
													
														|  | 
 |  | +        if model_prefix == "amazon":
 | 
											
												
													
														|  | 
 |  | +            for text in texts:
 | 
											
												
													
														|  | 
 |  | +                body = {
 | 
											
												
													
														|  | 
 |  | +                    "inputText": text,
 | 
											
												
													
														|  | 
 |  | +                }
 | 
											
												
													
														|  | 
 |  | +                response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
 | 
											
												
													
														|  | 
 |  | +                embeddings.extend([response_body.get('embedding')])
 | 
											
												
													
														|  | 
 |  | +                token_usage += response_body.get('inputTextTokenCount')
 | 
											
												
													
														|  | 
 |  | +            result = TextEmbeddingResult(
 | 
											
												
													
														|  | 
 |  | +                model=model,
 | 
											
												
													
														|  | 
 |  | +                embeddings=embeddings,
 | 
											
												
													
														|  | 
 |  | +                usage=self._calc_response_usage(
 | 
											
												
													
														|  | 
 |  | +                    model=model,
 | 
											
												
													
														|  | 
 |  | +                    credentials=credentials,
 | 
											
												
													
														|  | 
 |  | +                    tokens=token_usage
 | 
											
												
													
														|  | 
 |  | +                )
 | 
											
												
													
														|  | 
 |  | +            )
 | 
											
												
													
														|  | 
 |  | +        else:
 | 
											
												
													
														|  | 
 |  | +            raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        return result
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        Get number of tokens for given prompt messages
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        :param model: model name
 | 
											
												
													
														|  | 
 |  | +        :param credentials: model credentials
 | 
											
												
													
														|  | 
 |  | +        :param texts: texts to embed
 | 
											
												
													
														|  | 
 |  | +        :return:
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        num_tokens = 0
 | 
											
												
													
														|  | 
 |  | +        for text in texts:
 | 
											
												
													
														|  | 
 |  | +            num_tokens += self._get_num_tokens_by_gpt2(text)
 | 
											
												
													
														|  | 
 |  | +        return num_tokens
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def validate_credentials(self, model: str, credentials: dict) -> None:
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        Validate model credentials
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        :param model: model name
 | 
											
												
													
														|  | 
 |  | +        :param credentials: model credentials
 | 
											
												
													
														|  | 
 |  | +        :return:
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +    
 | 
											
												
													
														|  | 
 |  | +    @property
 | 
											
												
													
														|  | 
 |  | +    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        Map model invoke error to unified error
 | 
											
												
													
														|  | 
 |  | +        The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller
 | 
											
												
													
														|  | 
 |  | +        The value is the md = genai.GenerativeModel(model)error type thrown by the model,
 | 
											
												
													
														|  | 
 |  | +        which needs to be converted into a unified error type for the caller.
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        :return: Invoke emd = genai.GenerativeModel(model)rror mapping
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        return {
 | 
											
												
													
														|  | 
 |  | +            InvokeConnectionError: [],
 | 
											
												
													
														|  | 
 |  | +            InvokeServerUnavailableError: [],
 | 
											
												
													
														|  | 
 |  | +            InvokeRateLimitError: [],
 | 
											
												
													
														|  | 
 |  | +            InvokeAuthorizationError: [],
 | 
											
												
													
														|  | 
 |  | +            InvokeBadRequestError: []
 | 
											
												
													
														|  | 
 |  | +        }
 | 
											
												
													
														|  | 
 |  | +    
 | 
											
												
													
														|  | 
 |  | +    def _create_payload(self, model_prefix: str, texts: list[str], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True):
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        Create payload for bedrock api call depending on model provider
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        payload = dict()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        if model_prefix == "amazon":
 | 
											
												
													
														|  | 
 |  | +            payload['inputText'] = texts
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    
 | 
											
												
													
														|  | 
 |  | +    def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        Calculate response usage
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        :param model: model name
 | 
											
												
													
														|  | 
 |  | +        :param credentials: model credentials
 | 
											
												
													
														|  | 
 |  | +        :param tokens: input tokens
 | 
											
												
													
														|  | 
 |  | +        :return: usage
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        # get input price info
 | 
											
												
													
														|  | 
 |  | +        input_price_info = self.get_price(
 | 
											
												
													
														|  | 
 |  | +            model=model,
 | 
											
												
													
														|  | 
 |  | +            credentials=credentials,
 | 
											
												
													
														|  | 
 |  | +            price_type=PriceType.INPUT,
 | 
											
												
													
														|  | 
 |  | +            tokens=tokens
 | 
											
												
													
														|  | 
 |  | +        )
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # transform usage
 | 
											
												
													
														|  | 
 |  | +        usage = EmbeddingUsage(
 | 
											
												
													
														|  | 
 |  | +            tokens=tokens,
 | 
											
												
													
														|  | 
 |  | +            total_tokens=tokens,
 | 
											
												
													
														|  | 
 |  | +            unit_price=input_price_info.unit_price,
 | 
											
												
													
														|  | 
 |  | +            price_unit=input_price_info.unit,
 | 
											
												
													
														|  | 
 |  | +            total_price=input_price_info.total_amount,
 | 
											
												
													
														|  | 
 |  | +            currency=input_price_info.currency,
 | 
											
												
													
														|  | 
 |  | +            latency=time.perf_counter() - self.started_at
 | 
											
												
													
														|  | 
 |  | +        )
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        return usage
 | 
											
												
													
														|  | 
 |  | +    
 | 
											
												
													
														|  | 
 |  | +    def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]:
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        Map client error to invoke error
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        :param error_code: error code
 | 
											
												
													
														|  | 
 |  | +        :param error_msg: error message
 | 
											
												
													
														|  | 
 |  | +        :return: invoke error
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        if error_code == "AccessDeniedException":
 | 
											
												
													
														|  | 
 |  | +            return InvokeAuthorizationError(error_msg)
 | 
											
												
													
														|  | 
 |  | +        elif error_code in ["ResourceNotFoundException", "ValidationException"]:
 | 
											
												
													
														|  | 
 |  | +            return InvokeBadRequestError(error_msg)
 | 
											
												
													
														|  | 
 |  | +        elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
 | 
											
												
													
														|  | 
 |  | +            return InvokeRateLimitError(error_msg)
 | 
											
												
													
														|  | 
 |  | +        elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]:
 | 
											
												
													
														|  | 
 |  | +            return InvokeServerUnavailableError(error_msg)
 | 
											
												
													
														|  | 
 |  | +        elif error_code == "ModelStreamErrorException":
 | 
											
												
													
														|  | 
 |  | +            return InvokeConnectionError(error_msg)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        return InvokeError(error_msg)
 | 
											
												
													
														|  | 
 |  | +    
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def _invoke_bedrock_embedding(self, model: str, bedrock_runtime, body: dict, ):
 | 
											
												
													
														|  | 
 |  | +        accept = 'application/json' 
 | 
											
												
													
														|  | 
 |  | +        content_type = 'application/json'
 | 
											
												
													
														|  | 
 |  | +        try:
 | 
											
												
													
														|  | 
 |  | +            response = bedrock_runtime.invoke_model(
 | 
											
												
													
														|  | 
 |  | +                body=json.dumps(body), 
 | 
											
												
													
														|  | 
 |  | +                modelId=model, 
 | 
											
												
													
														|  | 
 |  | +                accept=accept, 
 | 
											
												
													
														|  | 
 |  | +                contentType=content_type
 | 
											
												
													
														|  | 
 |  | +            )
 | 
											
												
													
														|  | 
 |  | +            response_body = json.loads(response.get('body').read().decode('utf-8'))
 | 
											
												
													
														|  | 
 |  | +            return response_body
 | 
											
												
													
														|  | 
 |  | +        except ClientError as ex:
 | 
											
												
													
														|  | 
 |  | +            error_code = ex.response['Error']['Code']
 | 
											
												
													
														|  | 
 |  | +            full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
 | 
											
												
													
														|  | 
 |  | +            raise self._map_client_to_invoke_error(error_code, full_error_msg)
 | 
											
												
													
														|  | 
 |  | +        
 | 
											
												
													
														|  | 
 |  | +        except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
 | 
											
												
													
														|  | 
 |  | +            raise InvokeConnectionError(str(ex))
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        except UnknownServiceError as ex:
 | 
											
												
													
														|  | 
 |  | +            raise InvokeServerUnavailableError(str(ex))
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        except Exception as ex:
 | 
											
												
													
														|  | 
 |  | +            raise InvokeError(str(ex))
 |