| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191 | import jsonimport loggingfrom typing import Typeimport replicatefrom replicate.exceptions import ReplicateErrorfrom core.helper import encrypterfrom core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType, \    ModelModefrom core.model_providers.models.llm.replicate_model import ReplicateModelfrom core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedErrorfrom core.model_providers.models.base import BaseProviderModelfrom core.model_providers.models.embedding.replicate_embedding import ReplicateEmbeddingfrom models.provider import ProviderTypeclass ReplicateProvider(BaseModelProvider):    @property    def provider_name(self):        """        Returns the name of a provider.        """        return 'replicate'    def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:        return []    def _get_text_generation_model_mode(self, model_name) -> str:        return ModelMode.CHAT.value if model_name.endswith('-chat') else ModelMode.COMPLETION.value    def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:        """        Returns the model class.        :param model_type:        :return:        """        if model_type == ModelType.TEXT_GENERATION:            model_class = ReplicateModel        elif model_type == ModelType.EMBEDDINGS:            model_class = ReplicateEmbedding        else:            raise NotImplementedError        return model_class    def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:        """        get model parameter rules.        :param model_name:        :param model_type:        :return:        """        model_credentials = self.get_model_credentials(model_name, model_type)        model = replicate.Client(api_token=model_credentials.get("replicate_api_token")).models.get(model_name)        try:            version = model.versions.get(model_credentials['model_version'])        except ReplicateError as e:            raise CredentialsValidateFailedError(f"Model {model_name}:{model_credentials['model_version']} not exists, "                                                 f"cause: {e.__class__.__name__}:{str(e)}")        except Exception as e:            logging.exception("Model validate failed.")            raise e        model_kwargs_rules = ModelKwargsRules()        for key, value in version.openapi_schema['components']['schemas']['Input']['properties'].items():            if key not in ['debug', 'prompt'] and value['type'] in ['number', 'integer']:                if key == ['temperature', 'top_p']:                    kwarg_rule = KwargRule[float](                        type=KwargRuleType.FLOAT.value if value['type'] == 'number' else KwargRuleType.INTEGER.value,                        min=float(value.get('minimum')) if value.get('minimum') is not None else None,                        max=float(value.get('maximum')) if value.get('maximum') is not None else None,                        default=float(value.get('default')) if value.get('default') is not None else None,                        precision = 2                    )                    if key == 'temperature':                        model_kwargs_rules.temperature = kwarg_rule                    else:                        model_kwargs_rules.top_p = kwarg_rule                elif key in ['max_length', 'max_new_tokens']:                    model_kwargs_rules.max_tokens = KwargRule[int](                        alias=key,                        type=KwargRuleType.INTEGER.value,                        min=int(value.get('minimum')) if value.get('minimum') is not None else 1,                        max=int(value.get('maximum')) if value.get('maximum') is not None else 8000,                        default=int(value.get('default')) if value.get('default') is not None else 500,                        precision = 0                    )        return model_kwargs_rules    @classmethod    def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):        """        check model credentials valid.        :param model_name:        :param model_type:        :param credentials:        """        if 'replicate_api_token' not in credentials:            raise CredentialsValidateFailedError('Replicate API Key must be provided.')        if 'model_version' not in credentials:            raise CredentialsValidateFailedError('Replicate Model Version must be provided.')        if model_name.count("/") != 1:            raise CredentialsValidateFailedError('Replicate Model Name must be provided, '                                                 'format: {user_name}/{model_name}')        version = credentials['model_version']        try:            model = replicate.Client(api_token=credentials.get("replicate_api_token")).models.get(model_name)            rst = model.versions.get(version)            if model_type == ModelType.EMBEDDINGS \                    and 'Embedding' not in rst.openapi_schema['components']['schemas']:                raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Embedding model.")            elif model_type == ModelType.TEXT_GENERATION \                    and ('items' not in rst.openapi_schema['components']['schemas']['Output']                         or 'type' not in rst.openapi_schema['components']['schemas']['Output']['items']                         or rst.openapi_schema['components']['schemas']['Output']['items']['type'] != 'string'):                raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Text Generation model.")        except ReplicateError as e:            raise CredentialsValidateFailedError(                f"Model {model_name}:{version} not exists, cause: {e.__class__.__name__}:{str(e)}")        except Exception as e:            logging.exception("Replicate config validation failed.")            raise e    @classmethod    def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,                                  credentials: dict) -> dict:        """        encrypt model credentials for save.        :param tenant_id:        :param model_name:        :param model_type:        :param credentials:        :return:        """        credentials['replicate_api_token'] = encrypter.encrypt_token(tenant_id, credentials['replicate_api_token'])        return credentials    def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:        """        get credentials for llm use.        :param model_name:        :param model_type:        :param obfuscated:        :return:        """        if self.provider.provider_type != ProviderType.CUSTOM.value:            raise NotImplementedError        provider_model = self._get_provider_model(model_name, model_type)        if not provider_model.encrypted_config:            return {                'replicate_api_token': None,            }        credentials = json.loads(provider_model.encrypted_config)        if credentials['replicate_api_token']:            credentials['replicate_api_token'] = encrypter.decrypt_token(                self.provider.tenant_id,                credentials['replicate_api_token']            )            if obfuscated:                credentials['replicate_api_token'] = encrypter.obfuscated_token(credentials['replicate_api_token'])        return credentials    @classmethod    def is_provider_credentials_valid_or_raise(cls, credentials: dict):        return    @classmethod    def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:        return {}    def get_provider_credentials(self, obfuscated: bool = False) -> dict:        return {}
 |