provider_service.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from typing import Union
  2. from flask import current_app
  3. from core.llm.provider.llm_provider_service import LLMProviderService
  4. from models.account import Tenant
  5. from models.provider import *
  6. class ProviderService:
  7. @staticmethod
  8. def init_supported_provider(tenant, edition):
  9. """Initialize the model provider, check whether the supported provider has a record"""
  10. providers = Provider.query.filter_by(tenant_id=tenant.id).all()
  11. openai_provider_exists = False
  12. azure_openai_provider_exists = False
  13. # TODO: The cloud version needs to construct the data of the SYSTEM type
  14. for provider in providers:
  15. if provider.provider_name == ProviderName.OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value:
  16. openai_provider_exists = True
  17. if provider.provider_name == ProviderName.AZURE_OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value:
  18. azure_openai_provider_exists = True
  19. # Initialize the model provider, check whether the supported provider has a record
  20. # Create default providers if they don't exist
  21. if not openai_provider_exists:
  22. openai_provider = Provider(
  23. tenant_id=tenant.id,
  24. provider_name=ProviderName.OPENAI.value,
  25. provider_type=ProviderType.CUSTOM.value,
  26. is_valid=False
  27. )
  28. db.session.add(openai_provider)
  29. if not azure_openai_provider_exists:
  30. azure_openai_provider = Provider(
  31. tenant_id=tenant.id,
  32. provider_name=ProviderName.AZURE_OPENAI.value,
  33. provider_type=ProviderType.CUSTOM.value,
  34. is_valid=False
  35. )
  36. db.session.add(azure_openai_provider)
  37. if not openai_provider_exists or not azure_openai_provider_exists:
  38. db.session.commit()
  39. @staticmethod
  40. def get_obfuscated_api_key(tenant, provider_name: ProviderName):
  41. llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
  42. return llm_provider_service.get_provider_configs(obfuscated=True)
  43. @staticmethod
  44. def get_token_type(tenant, provider_name: ProviderName):
  45. llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
  46. return llm_provider_service.get_token_type()
  47. @staticmethod
  48. def validate_provider_configs(tenant, provider_name: ProviderName, configs: Union[dict | str]):
  49. llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
  50. return llm_provider_service.config_validate(configs)
  51. @staticmethod
  52. def get_encrypted_token(tenant, provider_name: ProviderName, configs: Union[dict | str]):
  53. llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
  54. return llm_provider_service.get_encrypted_token(configs)
  55. @staticmethod
  56. def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value,
  57. is_valid: bool = True):
  58. if current_app.config['EDITION'] != 'CLOUD':
  59. return
  60. provider = db.session.query(Provider).filter(
  61. Provider.tenant_id == tenant.id,
  62. Provider.provider_name == provider_name,
  63. Provider.provider_type == ProviderType.SYSTEM.value
  64. ).one_or_none()
  65. if not provider:
  66. provider = Provider(
  67. tenant_id=tenant.id,
  68. provider_name=provider_name,
  69. provider_type=ProviderType.SYSTEM.value,
  70. quota_type=ProviderQuotaType.TRIAL.value,
  71. quota_limit=200,
  72. encrypted_config='',
  73. is_valid=is_valid,
  74. )
  75. db.session.add(provider)
  76. db.session.commit()