provider_service.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. from typing import Union
  2. from flask import current_app
  3. from core.llm.provider.llm_provider_service import LLMProviderService
  4. from models.account import Tenant
  5. from models.provider import *
  6. class ProviderService:
  7. @staticmethod
  8. def init_supported_provider(tenant):
  9. """Initialize the model provider, check whether the supported provider has a record"""
  10. need_init_provider_names = [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value, ProviderName.ANTHROPIC.value]
  11. providers = db.session.query(Provider).filter(
  12. Provider.tenant_id == tenant.id,
  13. Provider.provider_type == ProviderType.CUSTOM.value,
  14. Provider.provider_name.in_(need_init_provider_names)
  15. ).all()
  16. exists_provider_names = []
  17. for provider in providers:
  18. exists_provider_names.append(provider.provider_name)
  19. not_exists_provider_names = list(set(need_init_provider_names) - set(exists_provider_names))
  20. if not_exists_provider_names:
  21. # Initialize the model provider, check whether the supported provider has a record
  22. for provider_name in not_exists_provider_names:
  23. provider = Provider(
  24. tenant_id=tenant.id,
  25. provider_name=provider_name,
  26. provider_type=ProviderType.CUSTOM.value,
  27. is_valid=False
  28. )
  29. db.session.add(provider)
  30. db.session.commit()
  31. @staticmethod
  32. def get_obfuscated_api_key(tenant, provider_name: ProviderName, only_custom: bool = False):
  33. llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
  34. return llm_provider_service.get_provider_configs(obfuscated=True, only_custom=only_custom)
  35. @staticmethod
  36. def get_token_type(tenant, provider_name: ProviderName):
  37. llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
  38. return llm_provider_service.get_token_type()
  39. @staticmethod
  40. def validate_provider_configs(tenant, provider_name: ProviderName, configs: Union[dict | str]):
  41. if current_app.config['DISABLE_PROVIDER_CONFIG_VALIDATION']:
  42. return
  43. llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
  44. return llm_provider_service.config_validate(configs)
  45. @staticmethod
  46. def get_encrypted_token(tenant, provider_name: ProviderName, configs: Union[dict | str]):
  47. llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
  48. return llm_provider_service.get_encrypted_token(configs)
  49. @staticmethod
  50. def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value, quota_limit: int = 200,
  51. is_valid: bool = True):
  52. if current_app.config['EDITION'] != 'CLOUD':
  53. return
  54. provider = db.session.query(Provider).filter(
  55. Provider.tenant_id == tenant.id,
  56. Provider.provider_name == provider_name,
  57. Provider.provider_type == ProviderType.SYSTEM.value
  58. ).one_or_none()
  59. if not provider:
  60. provider = Provider(
  61. tenant_id=tenant.id,
  62. provider_name=provider_name,
  63. provider_type=ProviderType.SYSTEM.value,
  64. quota_type=ProviderQuotaType.TRIAL.value,
  65. quota_limit=quota_limit,
  66. encrypted_config='',
  67. is_valid=is_valid,
  68. )
  69. db.session.add(provider)
  70. db.session.commit()