llm_provider_service.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. from typing import Optional, Union
  2. from core.llm.provider.anthropic_provider import AnthropicProvider
  3. from core.llm.provider.azure_provider import AzureProvider
  4. from core.llm.provider.base import BaseProvider
  5. from core.llm.provider.huggingface_provider import HuggingfaceProvider
  6. from core.llm.provider.openai_provider import OpenAIProvider
  7. from models.provider import Provider
  8. class LLMProviderService:
  9. def __init__(self, tenant_id: str, provider_name: str):
  10. self.provider = self.init_provider(tenant_id, provider_name)
  11. def init_provider(self, tenant_id: str, provider_name: str) -> BaseProvider:
  12. if provider_name == 'openai':
  13. return OpenAIProvider(tenant_id)
  14. elif provider_name == 'azure_openai':
  15. return AzureProvider(tenant_id)
  16. elif provider_name == 'anthropic':
  17. return AnthropicProvider(tenant_id)
  18. elif provider_name == 'huggingface':
  19. return HuggingfaceProvider(tenant_id)
  20. else:
  21. raise Exception('provider {} not found'.format(provider_name))
  22. def get_models(self, model_id: Optional[str] = None) -> list[dict]:
  23. return self.provider.get_models(model_id)
  24. def get_credentials(self, model_id: Optional[str] = None) -> dict:
  25. return self.provider.get_credentials(model_id)
  26. def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
  27. return self.provider.get_provider_configs(obfuscated)
  28. def get_provider_db_record(self, prefer_custom: bool = False) -> Optional[Provider]:
  29. return self.provider.get_provider(prefer_custom)
  30. def config_validate(self, config: Union[dict | str]):
  31. """
  32. Validates the given config.
  33. :param config:
  34. :raises: ValidateFailedError
  35. """
  36. return self.provider.config_validate(config)
  37. def get_token_type(self):
  38. return self.provider.get_token_type()
  39. def get_encrypted_token(self, config: Union[dict | str]):
  40. return self.provider.get_encrypted_token(config)