openai_provider.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import logging
  2. from typing import Optional, Union
  3. import openai
  4. from openai.error import AuthenticationError, OpenAIError
  5. from core.llm.moderation import Moderation
  6. from core.llm.provider.base import BaseProvider
  7. from core.llm.provider.errors import ValidateFailedError
  8. from models.provider import ProviderName
  9. class OpenAIProvider(BaseProvider):
  10. def get_models(self, model_id: Optional[str] = None) -> list[dict]:
  11. credentials = self.get_credentials(model_id)
  12. response = openai.Model.list(**credentials)
  13. return [{
  14. 'id': model['id'],
  15. 'name': model['id'],
  16. } for model in response['data']]
  17. def get_credentials(self, model_id: Optional[str] = None) -> dict:
  18. """
  19. Returns the credentials for the given tenant_id and provider_name.
  20. """
  21. return {
  22. 'openai_api_key': self.get_provider_api_key(model_id=model_id)
  23. }
  24. def get_provider_name(self):
  25. return ProviderName.OPENAI
  26. def config_validate(self, config: Union[dict | str]):
  27. """
  28. Validates the given config.
  29. """
  30. try:
  31. Moderation(self.get_provider_name().value, config).moderate('test')
  32. except (AuthenticationError, OpenAIError) as ex:
  33. raise ValidateFailedError(str(ex))
  34. except Exception as ex:
  35. logging.exception('OpenAI config validation failed')
  36. raise ex