fake_model_provider.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. from typing import Type
  2. from core.model_providers.models.base import BaseProviderModel
  3. from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, ModelMode
  4. from core.model_providers.models.llm.openai_model import OpenAIModel
  5. from core.model_providers.providers.base import BaseModelProvider
  6. class FakeModelProvider(BaseModelProvider):
  7. @property
  8. def provider_name(self):
  9. return 'fake'
  10. def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
  11. return [{'id': 'test_model', 'name': 'Test Model', 'mode': 'completion'}]
  12. def _get_text_generation_model_mode(self, model_name) -> str:
  13. return ModelMode.COMPLETION.value
  14. def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
  15. return OpenAIModel
  16. @classmethod
  17. def is_provider_credentials_valid_or_raise(cls, credentials: dict):
  18. pass
  19. @classmethod
  20. def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
  21. return credentials
  22. def get_provider_credentials(self, obfuscated: bool = False) -> dict:
  23. return {}
  24. @classmethod
  25. def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
  26. pass
  27. @classmethod
  28. def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
  29. credentials: dict) -> dict:
  30. return credentials
  31. def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
  32. return ModelKwargsRules()
  33. def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
  34. return {}