hosting_configuration.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import os
  2. from typing import Optional
  3. from flask import Flask
  4. from pydantic import BaseModel
  5. from core.entities.provider_entities import QuotaUnit
  6. from models.provider import ProviderQuotaType
  7. class HostingQuota(BaseModel):
  8. quota_type: ProviderQuotaType
  9. restrict_llms: list[str] = []
  10. class TrialHostingQuota(HostingQuota):
  11. quota_type: ProviderQuotaType = ProviderQuotaType.TRIAL
  12. quota_limit: int = 0
  13. """Quota limit for the hosting provider models. -1 means unlimited."""
  14. class PaidHostingQuota(HostingQuota):
  15. quota_type: ProviderQuotaType = ProviderQuotaType.PAID
  16. stripe_price_id: str = None
  17. increase_quota: int = 1
  18. min_quantity: int = 20
  19. max_quantity: int = 100
  20. class FreeHostingQuota(HostingQuota):
  21. quota_type: ProviderQuotaType = ProviderQuotaType.FREE
  22. class HostingProvider(BaseModel):
  23. enabled: bool = False
  24. credentials: Optional[dict] = None
  25. quota_unit: Optional[QuotaUnit] = None
  26. quotas: list[HostingQuota] = []
  27. class HostedModerationConfig(BaseModel):
  28. enabled: bool = False
  29. providers: list[str] = []
  30. class HostingConfiguration:
  31. provider_map: dict[str, HostingProvider] = {}
  32. moderation_config: HostedModerationConfig = None
  33. def init_app(self, app: Flask):
  34. if app.config.get('EDITION') != 'CLOUD':
  35. return
  36. self.provider_map["openai"] = self.init_openai()
  37. self.provider_map["anthropic"] = self.init_anthropic()
  38. self.provider_map["minimax"] = self.init_minimax()
  39. self.provider_map["spark"] = self.init_spark()
  40. self.provider_map["zhipuai"] = self.init_zhipuai()
  41. self.moderation_config = self.init_moderation_config()
  42. def init_openai(self) -> HostingProvider:
  43. quota_unit = QuotaUnit.TIMES
  44. if os.environ.get("HOSTED_OPENAI_ENABLED") and os.environ.get("HOSTED_OPENAI_ENABLED").lower() == 'true':
  45. credentials = {
  46. "openai_api_key": os.environ.get("HOSTED_OPENAI_API_KEY"),
  47. }
  48. if os.environ.get("HOSTED_OPENAI_API_BASE"):
  49. credentials["openai_api_base"] = os.environ.get("HOSTED_OPENAI_API_BASE")
  50. if os.environ.get("HOSTED_OPENAI_API_ORGANIZATION"):
  51. credentials["openai_organization"] = os.environ.get("HOSTED_OPENAI_API_ORGANIZATION")
  52. quotas = []
  53. hosted_quota_limit = int(os.environ.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
  54. if hosted_quota_limit != -1 or hosted_quota_limit > 0:
  55. trial_quota = TrialHostingQuota(
  56. quota_limit=hosted_quota_limit,
  57. restrict_llms=[
  58. "gpt-3.5-turbo",
  59. "gpt-3.5-turbo-1106",
  60. "gpt-3.5-turbo-instruct",
  61. "gpt-3.5-turbo-16k",
  62. "text-davinci-003"
  63. ]
  64. )
  65. quotas.append(trial_quota)
  66. if os.environ.get("HOSTED_OPENAI_PAID_ENABLED") and os.environ.get(
  67. "HOSTED_OPENAI_PAID_ENABLED").lower() == 'true':
  68. paid_quota = PaidHostingQuota(
  69. stripe_price_id=os.environ.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"),
  70. increase_quota=int(os.environ.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA", "1")),
  71. min_quantity=int(os.environ.get("HOSTED_OPENAI_PAID_MIN_QUANTITY", "1")),
  72. max_quantity=int(os.environ.get("HOSTED_OPENAI_PAID_MAX_QUANTITY", "1"))
  73. )
  74. quotas.append(paid_quota)
  75. return HostingProvider(
  76. enabled=True,
  77. credentials=credentials,
  78. quota_unit=quota_unit,
  79. quotas=quotas
  80. )
  81. return HostingProvider(
  82. enabled=False,
  83. quota_unit=quota_unit,
  84. )
  85. def init_anthropic(self) -> HostingProvider:
  86. quota_unit = QuotaUnit.TOKENS
  87. if os.environ.get("HOSTED_ANTHROPIC_ENABLED") and os.environ.get("HOSTED_ANTHROPIC_ENABLED").lower() == 'true':
  88. credentials = {
  89. "anthropic_api_key": os.environ.get("HOSTED_ANTHROPIC_API_KEY"),
  90. }
  91. if os.environ.get("HOSTED_ANTHROPIC_API_BASE"):
  92. credentials["anthropic_api_url"] = os.environ.get("HOSTED_ANTHROPIC_API_BASE")
  93. quotas = []
  94. hosted_quota_limit = int(os.environ.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0"))
  95. if hosted_quota_limit != -1 or hosted_quota_limit > 0:
  96. trial_quota = TrialHostingQuota(
  97. quota_limit=hosted_quota_limit
  98. )
  99. quotas.append(trial_quota)
  100. if os.environ.get("HOSTED_ANTHROPIC_PAID_ENABLED") and os.environ.get(
  101. "HOSTED_ANTHROPIC_PAID_ENABLED").lower() == 'true':
  102. paid_quota = PaidHostingQuota(
  103. stripe_price_id=os.environ.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
  104. increase_quota=int(os.environ.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA", "1000000")),
  105. min_quantity=int(os.environ.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY", "20")),
  106. max_quantity=int(os.environ.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY", "100"))
  107. )
  108. quotas.append(paid_quota)
  109. return HostingProvider(
  110. enabled=True,
  111. credentials=credentials,
  112. quota_unit=quota_unit,
  113. quotas=quotas
  114. )
  115. return HostingProvider(
  116. enabled=False,
  117. quota_unit=quota_unit,
  118. )
  119. def init_minimax(self) -> HostingProvider:
  120. quota_unit = QuotaUnit.TOKENS
  121. if os.environ.get("HOSTED_MINIMAX_ENABLED") and os.environ.get("HOSTED_MINIMAX_ENABLED").lower() == 'true':
  122. quotas = [FreeHostingQuota()]
  123. return HostingProvider(
  124. enabled=True,
  125. credentials=None, # use credentials from the provider
  126. quota_unit=quota_unit,
  127. quotas=quotas
  128. )
  129. return HostingProvider(
  130. enabled=False,
  131. quota_unit=quota_unit,
  132. )
  133. def init_spark(self) -> HostingProvider:
  134. quota_unit = QuotaUnit.TOKENS
  135. if os.environ.get("HOSTED_SPARK_ENABLED") and os.environ.get("HOSTED_SPARK_ENABLED").lower() == 'true':
  136. quotas = [FreeHostingQuota()]
  137. return HostingProvider(
  138. enabled=True,
  139. credentials=None, # use credentials from the provider
  140. quota_unit=quota_unit,
  141. quotas=quotas
  142. )
  143. return HostingProvider(
  144. enabled=False,
  145. quota_unit=quota_unit,
  146. )
  147. def init_zhipuai(self) -> HostingProvider:
  148. quota_unit = QuotaUnit.TOKENS
  149. if os.environ.get("HOSTED_ZHIPUAI_ENABLED") and os.environ.get("HOSTED_ZHIPUAI_ENABLED").lower() == 'true':
  150. quotas = [FreeHostingQuota()]
  151. return HostingProvider(
  152. enabled=True,
  153. credentials=None, # use credentials from the provider
  154. quota_unit=quota_unit,
  155. quotas=quotas
  156. )
  157. return HostingProvider(
  158. enabled=False,
  159. quota_unit=quota_unit,
  160. )
  161. def init_moderation_config(self) -> HostedModerationConfig:
  162. if os.environ.get("HOSTED_MODERATION_ENABLED") and os.environ.get("HOSTED_MODERATION_ENABLED").lower() == 'true' \
  163. and os.environ.get("HOSTED_MODERATION_PROVIDERS"):
  164. return HostedModerationConfig(
  165. enabled=True,
  166. providers=os.environ.get("HOSTED_MODERATION_PROVIDERS").split(',')
  167. )
  168. return HostedModerationConfig(
  169. enabled=False
  170. )