replicate_provider.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import json
  2. import logging
  3. from typing import Type
  4. import replicate
  5. from replicate.exceptions import ReplicateError
  6. from core.helper import encrypter
  7. from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType
  8. from core.model_providers.models.llm.replicate_model import ReplicateModel
  9. from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
  10. from core.model_providers.models.base import BaseProviderModel
  11. from core.model_providers.models.embedding.replicate_embedding import ReplicateEmbedding
  12. from models.provider import ProviderType
  13. class ReplicateProvider(BaseModelProvider):
  14. @property
  15. def provider_name(self):
  16. """
  17. Returns the name of a provider.
  18. """
  19. return 'replicate'
  20. def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
  21. return []
  22. def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
  23. """
  24. Returns the model class.
  25. :param model_type:
  26. :return:
  27. """
  28. if model_type == ModelType.TEXT_GENERATION:
  29. model_class = ReplicateModel
  30. elif model_type == ModelType.EMBEDDINGS:
  31. model_class = ReplicateEmbedding
  32. else:
  33. raise NotImplementedError
  34. return model_class
  35. def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
  36. """
  37. get model parameter rules.
  38. :param model_name:
  39. :param model_type:
  40. :return:
  41. """
  42. model_credentials = self.get_model_credentials(model_name, model_type)
  43. model = replicate.Client(api_token=model_credentials.get("replicate_api_token")).models.get(model_name)
  44. try:
  45. version = model.versions.get(model_credentials['model_version'])
  46. except ReplicateError as e:
  47. raise CredentialsValidateFailedError(f"Model {model_name}:{model_credentials['model_version']} not exists, "
  48. f"cause: {e.__class__.__name__}:{str(e)}")
  49. except Exception as e:
  50. logging.exception("Model validate failed.")
  51. raise e
  52. model_kwargs_rules = ModelKwargsRules()
  53. for key, value in version.openapi_schema['components']['schemas']['Input']['properties'].items():
  54. if key not in ['debug', 'prompt'] and value['type'] in ['number', 'integer']:
  55. if key == ['temperature', 'top_p']:
  56. kwarg_rule = KwargRule[float](
  57. type=KwargRuleType.FLOAT.value if value['type'] == 'number' else KwargRuleType.INTEGER.value,
  58. min=float(value.get('minimum')) if value.get('minimum') is not None else None,
  59. max=float(value.get('maximum')) if value.get('maximum') is not None else None,
  60. default=float(value.get('default')) if value.get('default') is not None else None,
  61. )
  62. if key == 'temperature':
  63. model_kwargs_rules.temperature = kwarg_rule
  64. else:
  65. model_kwargs_rules.top_p = kwarg_rule
  66. elif key in ['max_length', 'max_new_tokens']:
  67. model_kwargs_rules.max_tokens = KwargRule[int](
  68. alias=key,
  69. type=KwargRuleType.INTEGER.value,
  70. min=int(value.get('minimum')) if value.get('minimum') is not None else 1,
  71. max=int(value.get('maximum')) if value.get('maximum') is not None else 8000,
  72. default=int(value.get('default')) if value.get('default') is not None else 500,
  73. )
  74. return model_kwargs_rules
  75. @classmethod
  76. def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
  77. """
  78. check model credentials valid.
  79. :param model_name:
  80. :param model_type:
  81. :param credentials:
  82. """
  83. if 'replicate_api_token' not in credentials:
  84. raise CredentialsValidateFailedError('Replicate API Key must be provided.')
  85. if 'model_version' not in credentials:
  86. raise CredentialsValidateFailedError('Replicate Model Version must be provided.')
  87. if model_name.count("/") != 1:
  88. raise CredentialsValidateFailedError('Replicate Model Name must be provided, '
  89. 'format: {user_name}/{model_name}')
  90. version = credentials['model_version']
  91. try:
  92. model = replicate.Client(api_token=credentials.get("replicate_api_token")).models.get(model_name)
  93. rst = model.versions.get(version)
  94. if model_type == ModelType.EMBEDDINGS \
  95. and 'Embedding' not in rst.openapi_schema['components']['schemas']:
  96. raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Embedding model.")
  97. elif model_type == ModelType.TEXT_GENERATION \
  98. and ('items' not in rst.openapi_schema['components']['schemas']['Output']
  99. or 'type' not in rst.openapi_schema['components']['schemas']['Output']['items']
  100. or rst.openapi_schema['components']['schemas']['Output']['items']['type'] != 'string'):
  101. raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Text Generation model.")
  102. except ReplicateError as e:
  103. raise CredentialsValidateFailedError(
  104. f"Model {model_name}:{version} not exists, cause: {e.__class__.__name__}:{str(e)}")
  105. except Exception as e:
  106. logging.exception("Replicate config validation failed.")
  107. raise e
  108. @classmethod
  109. def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
  110. credentials: dict) -> dict:
  111. """
  112. encrypt model credentials for save.
  113. :param tenant_id:
  114. :param model_name:
  115. :param model_type:
  116. :param credentials:
  117. :return:
  118. """
  119. credentials['replicate_api_token'] = encrypter.encrypt_token(tenant_id, credentials['replicate_api_token'])
  120. return credentials
  121. def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
  122. """
  123. get credentials for llm use.
  124. :param model_name:
  125. :param model_type:
  126. :param obfuscated:
  127. :return:
  128. """
  129. if self.provider.provider_type != ProviderType.CUSTOM.value:
  130. raise NotImplementedError
  131. provider_model = self._get_provider_model(model_name, model_type)
  132. if not provider_model.encrypted_config:
  133. return {
  134. 'replicate_api_token': None,
  135. }
  136. credentials = json.loads(provider_model.encrypted_config)
  137. if credentials['replicate_api_token']:
  138. credentials['replicate_api_token'] = encrypter.decrypt_token(
  139. self.provider.tenant_id,
  140. credentials['replicate_api_token']
  141. )
  142. if obfuscated:
  143. credentials['replicate_api_token'] = encrypter.obfuscated_token(credentials['replicate_api_token'])
  144. return credentials
  145. @classmethod
  146. def is_provider_credentials_valid_or_raise(cls, credentials: dict):
  147. return
  148. @classmethod
  149. def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
  150. return {}
  151. def get_provider_credentials(self, obfuscated: bool = False) -> dict:
  152. return {}