anthropic_provider.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. import json
  2. import logging
  3. from json import JSONDecodeError
  4. from typing import Type, Optional
  5. import anthropic
  6. from flask import current_app
  7. from langchain.schema import HumanMessage
  8. from core.helper import encrypter
  9. from core.model_providers.models.base import BaseProviderModel
  10. from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelMode
  11. from core.model_providers.models.entity.provider import ModelFeature
  12. from core.model_providers.models.llm.anthropic_model import AnthropicModel
  13. from core.model_providers.models.llm.base import ModelType
  14. from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
  15. from core.model_providers.providers.hosted import hosted_model_providers
  16. from core.third_party.langchain.llms.anthropic_llm import AnthropicLLM
  17. from models.provider import ProviderType
  18. class AnthropicProvider(BaseModelProvider):
  19. @property
  20. def provider_name(self):
  21. """
  22. Returns the name of a provider.
  23. """
  24. return 'anthropic'
  25. def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
  26. if model_type == ModelType.TEXT_GENERATION:
  27. return [
  28. {
  29. 'id': 'claude-instant-1',
  30. 'name': 'claude-instant-1',
  31. 'mode': ModelMode.CHAT.value,
  32. },
  33. {
  34. 'id': 'claude-2',
  35. 'name': 'claude-2',
  36. 'mode': ModelMode.CHAT.value,
  37. 'features': [
  38. ModelFeature.AGENT_THOUGHT.value
  39. ]
  40. },
  41. ]
  42. else:
  43. return []
  44. def _get_text_generation_model_mode(self, model_name) -> str:
  45. return ModelMode.CHAT.value
  46. def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
  47. """
  48. Returns the model class.
  49. :param model_type:
  50. :return:
  51. """
  52. if model_type == ModelType.TEXT_GENERATION:
  53. model_class = AnthropicModel
  54. else:
  55. raise NotImplementedError
  56. return model_class
  57. def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
  58. """
  59. get model parameter rules.
  60. :param model_name:
  61. :param model_type:
  62. :return:
  63. """
  64. return ModelKwargsRules(
  65. temperature=KwargRule[float](min=0, max=1, default=1, precision=2),
  66. top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
  67. presence_penalty=KwargRule[float](enabled=False),
  68. frequency_penalty=KwargRule[float](enabled=False),
  69. max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256, precision=0),
  70. )
  71. @classmethod
  72. def is_provider_credentials_valid_or_raise(cls, credentials: dict):
  73. """
  74. Validates the given credentials.
  75. """
  76. if 'anthropic_api_key' not in credentials:
  77. raise CredentialsValidateFailedError('Anthropic API Key must be provided.')
  78. try:
  79. credential_kwargs = {
  80. 'anthropic_api_key': credentials['anthropic_api_key']
  81. }
  82. if 'anthropic_api_url' in credentials:
  83. credential_kwargs['anthropic_api_url'] = credentials['anthropic_api_url']
  84. chat_llm = AnthropicLLM(
  85. model='claude-instant-1',
  86. max_tokens_to_sample=10,
  87. temperature=0,
  88. default_request_timeout=60,
  89. **credential_kwargs
  90. )
  91. messages = [
  92. HumanMessage(
  93. content="ping"
  94. )
  95. ]
  96. chat_llm(messages)
  97. except anthropic.APIConnectionError as ex:
  98. raise CredentialsValidateFailedError(str(ex))
  99. except (anthropic.APIStatusError, anthropic.RateLimitError) as ex:
  100. raise CredentialsValidateFailedError(str(ex))
  101. except Exception as ex:
  102. logging.exception('Anthropic config validation failed')
  103. raise ex
  104. @classmethod
  105. def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
  106. credentials['anthropic_api_key'] = encrypter.encrypt_token(tenant_id, credentials['anthropic_api_key'])
  107. return credentials
  108. def get_provider_credentials(self, obfuscated: bool = False) -> dict:
  109. if self.provider.provider_type == ProviderType.CUSTOM.value:
  110. try:
  111. credentials = json.loads(self.provider.encrypted_config)
  112. except JSONDecodeError:
  113. credentials = {
  114. 'anthropic_api_url': None,
  115. 'anthropic_api_key': None
  116. }
  117. if credentials['anthropic_api_key']:
  118. credentials['anthropic_api_key'] = encrypter.decrypt_token(
  119. self.provider.tenant_id,
  120. credentials['anthropic_api_key']
  121. )
  122. if obfuscated:
  123. credentials['anthropic_api_key'] = encrypter.obfuscated_token(credentials['anthropic_api_key'])
  124. if 'anthropic_api_url' not in credentials:
  125. credentials['anthropic_api_url'] = None
  126. return credentials
  127. else:
  128. if hosted_model_providers.anthropic:
  129. return {
  130. 'anthropic_api_url': hosted_model_providers.anthropic.api_base,
  131. 'anthropic_api_key': hosted_model_providers.anthropic.api_key,
  132. }
  133. else:
  134. return {
  135. 'anthropic_api_url': None,
  136. 'anthropic_api_key': None
  137. }
  138. @classmethod
  139. def is_provider_type_system_supported(cls) -> bool:
  140. if current_app.config['EDITION'] != 'CLOUD':
  141. return False
  142. if hosted_model_providers.anthropic:
  143. return True
  144. return False
  145. def should_deduct_quota(self):
  146. if hosted_model_providers.anthropic and \
  147. hosted_model_providers.anthropic.quota_limit and hosted_model_providers.anthropic.quota_limit > -1:
  148. return True
  149. return False
  150. def get_payment_info(self) -> Optional[dict]:
  151. """
  152. get product info if it payable.
  153. :return:
  154. """
  155. if hosted_model_providers.anthropic \
  156. and hosted_model_providers.anthropic.paid_enabled:
  157. return {
  158. 'product_id': hosted_model_providers.anthropic.paid_stripe_price_id,
  159. 'increase_quota': hosted_model_providers.anthropic.paid_increase_quota,
  160. 'min_quantity': hosted_model_providers.anthropic.paid_min_quantity,
  161. 'max_quantity': hosted_model_providers.anthropic.paid_max_quantity,
  162. }
  163. return None
  164. @classmethod
  165. def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
  166. """
  167. check model credentials valid.
  168. :param model_name:
  169. :param model_type:
  170. :param credentials:
  171. """
  172. return
  173. @classmethod
  174. def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
  175. credentials: dict) -> dict:
  176. """
  177. encrypt model credentials for save.
  178. :param tenant_id:
  179. :param model_name:
  180. :param model_type:
  181. :param credentials:
  182. :return:
  183. """
  184. return {}
  185. def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
  186. """
  187. get credentials for llm use.
  188. :param model_name:
  189. :param model_type:
  190. :param obfuscated:
  191. :return:
  192. """
  193. return self.get_provider_credentials(obfuscated)