replicate_provider.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import json
  2. import logging
  3. from typing import Type
  4. import replicate
  5. from replicate.exceptions import ReplicateError
  6. from core.helper import encrypter
  7. from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType
  8. from core.model_providers.models.llm.replicate_model import ReplicateModel
  9. from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
  10. from core.model_providers.models.base import BaseProviderModel
  11. from core.model_providers.models.embedding.replicate_embedding import ReplicateEmbedding
  12. from models.provider import ProviderType
  13. class ReplicateProvider(BaseModelProvider):
  14. @property
  15. def provider_name(self):
  16. """
  17. Returns the name of a provider.
  18. """
  19. return 'replicate'
  20. def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
  21. return []
  22. def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
  23. """
  24. Returns the model class.
  25. :param model_type:
  26. :return:
  27. """
  28. if model_type == ModelType.TEXT_GENERATION:
  29. model_class = ReplicateModel
  30. elif model_type == ModelType.EMBEDDINGS:
  31. model_class = ReplicateEmbedding
  32. else:
  33. raise NotImplementedError
  34. return model_class
  35. def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
  36. """
  37. get model parameter rules.
  38. :param model_name:
  39. :param model_type:
  40. :return:
  41. """
  42. model_credentials = self.get_model_credentials(model_name, model_type)
  43. model = replicate.Client(api_token=model_credentials.get("replicate_api_token")).models.get(model_name)
  44. try:
  45. version = model.versions.get(model_credentials['model_version'])
  46. except ReplicateError as e:
  47. raise CredentialsValidateFailedError(f"Model {model_name}:{model_credentials['model_version']} not exists, "
  48. f"cause: {e.__class__.__name__}:{str(e)}")
  49. except Exception as e:
  50. logging.exception("Model validate failed.")
  51. raise e
  52. model_kwargs_rules = ModelKwargsRules()
  53. for key, value in version.openapi_schema['components']['schemas']['Input']['properties'].items():
  54. if key not in ['debug', 'prompt'] and value['type'] in ['number', 'integer']:
  55. if key == ['temperature', 'top_p']:
  56. kwarg_rule = KwargRule[float](
  57. type=KwargRuleType.FLOAT.value if value['type'] == 'number' else KwargRuleType.INTEGER.value,
  58. min=float(value.get('minimum')) if value.get('minimum') is not None else None,
  59. max=float(value.get('maximum')) if value.get('maximum') is not None else None,
  60. default=float(value.get('default')) if value.get('default') is not None else None,
  61. precision = 2
  62. )
  63. if key == 'temperature':
  64. model_kwargs_rules.temperature = kwarg_rule
  65. else:
  66. model_kwargs_rules.top_p = kwarg_rule
  67. elif key in ['max_length', 'max_new_tokens']:
  68. model_kwargs_rules.max_tokens = KwargRule[int](
  69. alias=key,
  70. type=KwargRuleType.INTEGER.value,
  71. min=int(value.get('minimum')) if value.get('minimum') is not None else 1,
  72. max=int(value.get('maximum')) if value.get('maximum') is not None else 8000,
  73. default=int(value.get('default')) if value.get('default') is not None else 500,
  74. precision = 0
  75. )
  76. return model_kwargs_rules
  77. @classmethod
  78. def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
  79. """
  80. check model credentials valid.
  81. :param model_name:
  82. :param model_type:
  83. :param credentials:
  84. """
  85. if 'replicate_api_token' not in credentials:
  86. raise CredentialsValidateFailedError('Replicate API Key must be provided.')
  87. if 'model_version' not in credentials:
  88. raise CredentialsValidateFailedError('Replicate Model Version must be provided.')
  89. if model_name.count("/") != 1:
  90. raise CredentialsValidateFailedError('Replicate Model Name must be provided, '
  91. 'format: {user_name}/{model_name}')
  92. version = credentials['model_version']
  93. try:
  94. model = replicate.Client(api_token=credentials.get("replicate_api_token")).models.get(model_name)
  95. rst = model.versions.get(version)
  96. if model_type == ModelType.EMBEDDINGS \
  97. and 'Embedding' not in rst.openapi_schema['components']['schemas']:
  98. raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Embedding model.")
  99. elif model_type == ModelType.TEXT_GENERATION \
  100. and ('items' not in rst.openapi_schema['components']['schemas']['Output']
  101. or 'type' not in rst.openapi_schema['components']['schemas']['Output']['items']
  102. or rst.openapi_schema['components']['schemas']['Output']['items']['type'] != 'string'):
  103. raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Text Generation model.")
  104. except ReplicateError as e:
  105. raise CredentialsValidateFailedError(
  106. f"Model {model_name}:{version} not exists, cause: {e.__class__.__name__}:{str(e)}")
  107. except Exception as e:
  108. logging.exception("Replicate config validation failed.")
  109. raise e
  110. @classmethod
  111. def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
  112. credentials: dict) -> dict:
  113. """
  114. encrypt model credentials for save.
  115. :param tenant_id:
  116. :param model_name:
  117. :param model_type:
  118. :param credentials:
  119. :return:
  120. """
  121. credentials['replicate_api_token'] = encrypter.encrypt_token(tenant_id, credentials['replicate_api_token'])
  122. return credentials
  123. def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
  124. """
  125. get credentials for llm use.
  126. :param model_name:
  127. :param model_type:
  128. :param obfuscated:
  129. :return:
  130. """
  131. if self.provider.provider_type != ProviderType.CUSTOM.value:
  132. raise NotImplementedError
  133. provider_model = self._get_provider_model(model_name, model_type)
  134. if not provider_model.encrypted_config:
  135. return {
  136. 'replicate_api_token': None,
  137. }
  138. credentials = json.loads(provider_model.encrypted_config)
  139. if credentials['replicate_api_token']:
  140. credentials['replicate_api_token'] = encrypter.decrypt_token(
  141. self.provider.tenant_id,
  142. credentials['replicate_api_token']
  143. )
  144. if obfuscated:
  145. credentials['replicate_api_token'] = encrypter.obfuscated_token(credentials['replicate_api_token'])
  146. return credentials
  147. @classmethod
  148. def is_provider_credentials_valid_or_raise(cls, credentials: dict):
  149. return
  150. @classmethod
  151. def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
  152. return {}
  153. def get_provider_credentials(self, obfuscated: bool = False) -> dict:
  154. return {}