provider_service.py 3.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. from typing import Union
  2. from flask import current_app
  3. from core.llm.provider.llm_provider_service import LLMProviderService
  4. from models.account import Tenant
  5. from models.provider import *
  6. class ProviderService:
  7. @staticmethod
  8. def init_supported_provider(tenant, edition):
  9. """Initialize the model provider, check whether the supported provider has a record"""
  10. providers = Provider.query.filter_by(tenant_id=tenant.id).all()
  11. openai_provider_exists = False
  12. azure_openai_provider_exists = False
  13. # TODO: The cloud version needs to construct the data of the SYSTEM type
  14. for provider in providers:
  15. if provider.provider_name == ProviderName.OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value:
  16. openai_provider_exists = True
  17. if provider.provider_name == ProviderName.AZURE_OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value:
  18. azure_openai_provider_exists = True
  19. # Initialize the model provider, check whether the supported provider has a record
  20. # Create default providers if they don't exist
  21. if not openai_provider_exists:
  22. openai_provider = Provider(
  23. tenant_id=tenant.id,
  24. provider_name=ProviderName.OPENAI.value,
  25. provider_type=ProviderType.CUSTOM.value,
  26. is_valid=False
  27. )
  28. db.session.add(openai_provider)
  29. if not azure_openai_provider_exists:
  30. azure_openai_provider = Provider(
  31. tenant_id=tenant.id,
  32. provider_name=ProviderName.AZURE_OPENAI.value,
  33. provider_type=ProviderType.CUSTOM.value,
  34. is_valid=False
  35. )
  36. db.session.add(azure_openai_provider)
  37. if not openai_provider_exists or not azure_openai_provider_exists:
  38. db.session.commit()
  39. @staticmethod
  40. def get_obfuscated_api_key(tenant, provider_name: ProviderName):
  41. llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
  42. return llm_provider_service.get_provider_configs(obfuscated=True)
  43. @staticmethod
  44. def get_token_type(tenant, provider_name: ProviderName):
  45. llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
  46. return llm_provider_service.get_token_type()
  47. @staticmethod
  48. def validate_provider_configs(tenant, provider_name: ProviderName, configs: Union[dict | str]):
  49. if current_app.config['DISABLE_PROVIDER_CONFIG_VALIDATION']:
  50. return
  51. llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
  52. return llm_provider_service.config_validate(configs)
  53. @staticmethod
  54. def get_encrypted_token(tenant, provider_name: ProviderName, configs: Union[dict | str]):
  55. llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
  56. return llm_provider_service.get_encrypted_token(configs)
  57. @staticmethod
  58. def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value,
  59. is_valid: bool = True):
  60. if current_app.config['EDITION'] != 'CLOUD':
  61. return
  62. provider = db.session.query(Provider).filter(
  63. Provider.tenant_id == tenant.id,
  64. Provider.provider_name == provider_name,
  65. Provider.provider_type == ProviderType.SYSTEM.value
  66. ).one_or_none()
  67. if not provider:
  68. provider = Provider(
  69. tenant_id=tenant.id,
  70. provider_name=provider_name,
  71. provider_type=ProviderType.SYSTEM.value,
  72. quota_type=ProviderQuotaType.TRIAL.value,
  73. quota_limit=200,
  74. encrypted_config='',
  75. is_valid=is_valid,
  76. )
  77. db.session.add(provider)
  78. db.session.commit()