replicate_provider.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import json
  2. import logging
  3. from typing import Type
  4. import replicate
  5. from replicate.exceptions import ReplicateError
  6. from core.helper import encrypter
  7. from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType, \
  8. ModelMode
  9. from core.model_providers.models.llm.replicate_model import ReplicateModel
  10. from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
  11. from core.model_providers.models.base import BaseProviderModel
  12. from core.model_providers.models.embedding.replicate_embedding import ReplicateEmbedding
  13. from models.provider import ProviderType
  14. class ReplicateProvider(BaseModelProvider):
  15. @property
  16. def provider_name(self):
  17. """
  18. Returns the name of a provider.
  19. """
  20. return 'replicate'
  21. def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
  22. return []
  23. def _get_text_generation_model_mode(self, model_name) -> str:
  24. return ModelMode.CHAT.value if model_name.endswith('-chat') else ModelMode.COMPLETION.value
  25. def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
  26. """
  27. Returns the model class.
  28. :param model_type:
  29. :return:
  30. """
  31. if model_type == ModelType.TEXT_GENERATION:
  32. model_class = ReplicateModel
  33. elif model_type == ModelType.EMBEDDINGS:
  34. model_class = ReplicateEmbedding
  35. else:
  36. raise NotImplementedError
  37. return model_class
  38. def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
  39. """
  40. get model parameter rules.
  41. :param model_name:
  42. :param model_type:
  43. :return:
  44. """
  45. model_credentials = self.get_model_credentials(model_name, model_type)
  46. model = replicate.Client(api_token=model_credentials.get("replicate_api_token")).models.get(model_name)
  47. try:
  48. version = model.versions.get(model_credentials['model_version'])
  49. except ReplicateError as e:
  50. raise CredentialsValidateFailedError(f"Model {model_name}:{model_credentials['model_version']} not exists, "
  51. f"cause: {e.__class__.__name__}:{str(e)}")
  52. except Exception as e:
  53. logging.exception("Model validate failed.")
  54. raise e
  55. model_kwargs_rules = ModelKwargsRules()
  56. for key, value in version.openapi_schema['components']['schemas']['Input']['properties'].items():
  57. if key not in ['debug', 'prompt'] and value['type'] in ['number', 'integer']:
  58. if key == ['temperature', 'top_p']:
  59. kwarg_rule = KwargRule[float](
  60. type=KwargRuleType.FLOAT.value if value['type'] == 'number' else KwargRuleType.INTEGER.value,
  61. min=float(value.get('minimum')) if value.get('minimum') is not None else None,
  62. max=float(value.get('maximum')) if value.get('maximum') is not None else None,
  63. default=float(value.get('default')) if value.get('default') is not None else None,
  64. precision = 2
  65. )
  66. if key == 'temperature':
  67. model_kwargs_rules.temperature = kwarg_rule
  68. else:
  69. model_kwargs_rules.top_p = kwarg_rule
  70. elif key in ['max_length', 'max_new_tokens']:
  71. model_kwargs_rules.max_tokens = KwargRule[int](
  72. alias=key,
  73. type=KwargRuleType.INTEGER.value,
  74. min=int(value.get('minimum')) if value.get('minimum') is not None else 1,
  75. max=int(value.get('maximum')) if value.get('maximum') is not None else 8000,
  76. default=int(value.get('default')) if value.get('default') is not None else 500,
  77. precision = 0
  78. )
  79. return model_kwargs_rules
  80. @classmethod
  81. def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
  82. """
  83. check model credentials valid.
  84. :param model_name:
  85. :param model_type:
  86. :param credentials:
  87. """
  88. if 'replicate_api_token' not in credentials:
  89. raise CredentialsValidateFailedError('Replicate API Key must be provided.')
  90. if 'model_version' not in credentials:
  91. raise CredentialsValidateFailedError('Replicate Model Version must be provided.')
  92. if model_name.count("/") != 1:
  93. raise CredentialsValidateFailedError('Replicate Model Name must be provided, '
  94. 'format: {user_name}/{model_name}')
  95. version = credentials['model_version']
  96. try:
  97. model = replicate.Client(api_token=credentials.get("replicate_api_token")).models.get(model_name)
  98. rst = model.versions.get(version)
  99. if model_type == ModelType.EMBEDDINGS \
  100. and 'Embedding' not in rst.openapi_schema['components']['schemas']:
  101. raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Embedding model.")
  102. elif model_type == ModelType.TEXT_GENERATION \
  103. and ('items' not in rst.openapi_schema['components']['schemas']['Output']
  104. or 'type' not in rst.openapi_schema['components']['schemas']['Output']['items']
  105. or rst.openapi_schema['components']['schemas']['Output']['items']['type'] != 'string'):
  106. raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Text Generation model.")
  107. except ReplicateError as e:
  108. raise CredentialsValidateFailedError(
  109. f"Model {model_name}:{version} not exists, cause: {e.__class__.__name__}:{str(e)}")
  110. except Exception as e:
  111. logging.exception("Replicate config validation failed.")
  112. raise e
  113. @classmethod
  114. def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
  115. credentials: dict) -> dict:
  116. """
  117. encrypt model credentials for save.
  118. :param tenant_id:
  119. :param model_name:
  120. :param model_type:
  121. :param credentials:
  122. :return:
  123. """
  124. credentials['replicate_api_token'] = encrypter.encrypt_token(tenant_id, credentials['replicate_api_token'])
  125. return credentials
  126. def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
  127. """
  128. get credentials for llm use.
  129. :param model_name:
  130. :param model_type:
  131. :param obfuscated:
  132. :return:
  133. """
  134. if self.provider.provider_type != ProviderType.CUSTOM.value:
  135. raise NotImplementedError
  136. provider_model = self._get_provider_model(model_name, model_type)
  137. if not provider_model.encrypted_config:
  138. return {
  139. 'replicate_api_token': None,
  140. }
  141. credentials = json.loads(provider_model.encrypted_config)
  142. if credentials['replicate_api_token']:
  143. credentials['replicate_api_token'] = encrypter.decrypt_token(
  144. self.provider.tenant_id,
  145. credentials['replicate_api_token']
  146. )
  147. if obfuscated:
  148. credentials['replicate_api_token'] = encrypter.obfuscated_token(credentials['replicate_api_token'])
  149. return credentials
  150. @classmethod
  151. def is_provider_credentials_valid_or_raise(cls, credentials: dict):
  152. return
  153. @classmethod
  154. def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
  155. return {}
  156. def get_provider_credentials(self, obfuscated: bool = False) -> dict:
  157. return {}