provider.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from enum import Enum
  2. from sqlalchemy.dialects.postgresql import UUID
  3. from extensions.ext_database import db
  4. class ProviderType(Enum):
  5. CUSTOM = 'custom'
  6. SYSTEM = 'system'
  7. class ProviderName(Enum):
  8. OPENAI = 'openai'
  9. AZURE_OPENAI = 'azure_openai'
  10. ANTHROPIC = 'anthropic'
  11. COHERE = 'cohere'
  12. HUGGINGFACEHUB = 'huggingfacehub'
  13. @staticmethod
  14. def value_of(value):
  15. for member in ProviderName:
  16. if member.value == value:
  17. return member
  18. raise ValueError(f"No matching enum found for value '{value}'")
  19. class ProviderQuotaType(Enum):
  20. MONTHLY = 'monthly'
  21. TRIAL = 'trial'
  22. class Provider(db.Model):
  23. """
  24. Provider model representing the API providers and their configurations.
  25. """
  26. __tablename__ = 'providers'
  27. __table_args__ = (
  28. db.PrimaryKeyConstraint('id', name='provider_pkey'),
  29. db.Index('provider_tenant_id_provider_idx', 'tenant_id', 'provider_name'),
  30. db.UniqueConstraint('tenant_id', 'provider_name', 'provider_type', 'quota_type', name='unique_provider_name_type_quota')
  31. )
  32. id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
  33. tenant_id = db.Column(UUID, nullable=False)
  34. provider_name = db.Column(db.String(40), nullable=False)
  35. provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying"))
  36. encrypted_config = db.Column(db.Text, nullable=True)
  37. is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
  38. last_used = db.Column(db.DateTime, nullable=True)
  39. quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying"))
  40. quota_limit = db.Column(db.Integer, nullable=True)
  41. quota_used = db.Column(db.Integer, default=0)
  42. created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
  43. updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
  44. def __repr__(self):
  45. return f"<Provider(id={self.id}, tenant_id={self.tenant_id}, provider_name='{self.provider_name}', provider_type='{self.provider_type}')>"
  46. @property
  47. def token_is_set(self):
  48. """
  49. Returns True if the encrypted_config is not None, indicating that the token is set.
  50. """
  51. return self.encrypted_config is not None
  52. @property
  53. def is_enabled(self):
  54. """
  55. Returns True if the provider is enabled.
  56. """
  57. if self.provider_type == ProviderType.SYSTEM.value:
  58. return self.is_valid
  59. else:
  60. return self.is_valid and self.token_is_set