浏览代码

feat: claude paid optimize (#890)

takatost 1 年之前
父节点
当前提交
9adbeadeec

+ 4 - 2
api/.env.example

@@ -117,10 +117,12 @@ HOSTED_AZURE_OPENAI_QUOTA_LIMIT=200
 HOSTED_ANTHROPIC_ENABLED=false
 HOSTED_ANTHROPIC_API_BASE=
 HOSTED_ANTHROPIC_API_KEY=
-HOSTED_ANTHROPIC_QUOTA_LIMIT=1000000
+HOSTED_ANTHROPIC_QUOTA_LIMIT=600000
 HOSTED_ANTHROPIC_PAID_ENABLED=false
 HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID=
-HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1
+HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000
+HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20
+HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100
 
 STRIPE_API_KEY=
 STRIPE_WEBHOOK_SECRET=

+ 5 - 2
api/commands.py

@@ -258,6 +258,8 @@ def sync_anthropic_hosted_providers():
     click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
     count = 0
 
+    new_quota_limit = hosted_model_providers.anthropic.quota_limit
+
     page = 1
     while True:
         try:
@@ -265,6 +267,7 @@ def sync_anthropic_hosted_providers():
                 Provider.provider_name == 'anthropic',
                 Provider.provider_type == ProviderType.SYSTEM.value,
                 Provider.quota_type == ProviderQuotaType.TRIAL.value,
+                Provider.quota_limit != new_quota_limit
             ).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100)
         except NotFound:
             break
@@ -272,9 +275,9 @@ def sync_anthropic_hosted_providers():
         page += 1
         for provider in providers:
             try:
-                click.echo('Syncing tenant anthropic hosted provider: {}'.format(provider.tenant_id))
+                click.echo('Syncing tenant anthropic hosted provider: {}, origin: limit {}, used {}'
+                           .format(provider.tenant_id, provider.quota_limit, provider.quota_used))
                 original_quota_limit = provider.quota_limit
-                new_quota_limit = hosted_model_providers.anthropic.quota_limit
                 division = math.ceil(new_quota_limit / 1000)
 
                 provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \

+ 10 - 6
api/config.py

@@ -57,10 +57,12 @@ DEFAULTS = {
     'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
     'HOSTED_AZURE_OPENAI_ENABLED': 'False',
     'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
-    'HOSTED_ANTHROPIC_QUOTA_LIMIT': 1000000,
+    'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
     'HOSTED_ANTHROPIC_ENABLED': 'False',
     'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
-    'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1,
+    'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
+    'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
+    'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
     'TENANT_DOCUMENT_COUNT': 100,
     'CLEAN_DAY_SETTING': 30,
     'UPLOAD_FILE_SIZE_LIMIT': 15,
@@ -211,7 +213,7 @@ class Config:
         self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
         self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
         self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
-        self.HOSTED_OPENAI_QUOTA_LIMIT = get_env('HOSTED_OPENAI_QUOTA_LIMIT')
+        self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
         self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
         self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
         self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
@@ -219,15 +221,17 @@ class Config:
         self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
         self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
         self.HOSTED_AZURE_OPENAI_API_BASE = get_env('HOSTED_AZURE_OPENAI_API_BASE')
-        self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT')
+        self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT'))
 
         self.HOSTED_ANTHROPIC_ENABLED = get_bool_env('HOSTED_ANTHROPIC_ENABLED')
         self.HOSTED_ANTHROPIC_API_BASE = get_env('HOSTED_ANTHROPIC_API_BASE')
         self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
-        self.HOSTED_ANTHROPIC_QUOTA_LIMIT = get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT')
+        self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT'))
         self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
         self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
-        self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA')
+        self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA'))
+        self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
+        self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
 
         self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
         self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')

+ 9 - 1
api/controllers/console/webhook/stripe.py

@@ -38,12 +38,20 @@ class StripeWebhookApi(Resource):
             logging.debug(event['data']['object']['payment_status'])
             logging.debug(event['data']['object']['metadata'])
 
+            session = stripe.checkout.Session.retrieve(
+                event['data']['object']['id'],
+                expand=['line_items'],
+            )
+
+            logging.debug(session.line_items['data'][0]['quantity'])
+
             # Fulfill the purchase...
             provider_checkout_service = ProviderCheckoutService()
 
             try:
-                provider_checkout_service.fulfill_provider_order(event)
+                provider_checkout_service.fulfill_provider_order(event, session.line_items)
             except Exception as e:
+
                 logging.debug(str(e))
                 return 'success', 200
 

+ 2 - 0
api/core/model_providers/models/llm/base.py

@@ -125,6 +125,8 @@ class BaseLLM(BaseProviderModel):
             completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
             total_tokens = prompt_tokens + completion_tokens
 
+        self.model_provider.update_last_used()
+
         if self.deduct_quota:
             self.model_provider.deduct_quota(total_tokens)
 

+ 2 - 0
api/core/model_providers/providers/anthropic_provider.py

@@ -183,6 +183,8 @@ class AnthropicProvider(BaseModelProvider):
             return {
                 'product_id': hosted_model_providers.anthropic.paid_stripe_price_id,
                 'increase_quota': hosted_model_providers.anthropic.paid_increase_quota,
+                'min_quantity': hosted_model_providers.anthropic.paid_min_quantity,
+                'max_quantity': hosted_model_providers.anthropic.paid_max_quantity,
             }
 
         return None

+ 5 - 1
api/core/model_providers/providers/hosted.py

@@ -31,7 +31,9 @@ class HostedAnthropic(BaseModel):
     """Quota limit for the anthropic hosted model. 0 means unlimited."""
     paid_enabled: bool = False
     paid_stripe_price_id: str = None
-    paid_increase_quota: int = 1
+    paid_increase_quota: int = 1000000
+    paid_min_quantity: int = 20
+    paid_max_quantity: int = 100
 
 
 class HostedModelProviders(BaseModel):
@@ -73,4 +75,6 @@ def init_app(app: Flask):
             paid_enabled=app.config.get("HOSTED_ANTHROPIC_PAID_ENABLED"),
             paid_stripe_price_id=app.config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"),
             paid_increase_quota=app.config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA"),
+            paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"),
+            paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"),
         )

+ 3 - 2
api/core/model_providers/rules/anthropic.json

@@ -5,10 +5,11 @@
     ],
     "system_config": {
         "supported_quota_types": [
+            "paid",
             "trial"
         ],
-        "quota_unit": "times",
-        "quota_limit": 1000
+        "quota_unit": "tokens",
+        "quota_limit": 600000
     },
     "model_flexibility": "fixed"
 }

+ 25 - 9
api/services/provider_checkout_service.py

@@ -39,6 +39,8 @@ class ProviderCheckoutService:
             raise ValueError(f'provider name {provider_name} not support payment')
 
         payment_product_id = payment_info['product_id']
+        payment_min_quantity = payment_info['min_quantity']
+        payment_max_quantity = payment_info['max_quantity']
 
         # create provider order
         provider_order = ProviderOrder(
@@ -53,18 +55,29 @@ class ProviderCheckoutService:
         db.session.add(provider_order)
         db.session.flush()
 
+        line_item = {
+            'price': f'{payment_product_id}',
+            'quantity': payment_min_quantity
+        }
+
+        if payment_min_quantity > 1 and payment_max_quantity != payment_min_quantity:
+            line_item['adjustable_quantity'] = {
+                'enabled': True,
+                'minimum': payment_min_quantity,
+                'maximum': payment_max_quantity
+            }
+
         try:
             # create stripe checkout session
             checkout_session = stripe.checkout.Session.create(
                 line_items=[
-                    {
-                        'price': f'{payment_product_id}',
-                        'quantity': 1,
-                    },
+                    line_item
                 ],
                 mode='payment',
-                success_url=current_app.config.get("CONSOLE_WEB_URL") + '?provider_payment=succeeded',
-                cancel_url=current_app.config.get("CONSOLE_WEB_URL") + '?provider_payment=cancelled',
+                success_url=current_app.config.get("CONSOLE_WEB_URL")
+                            + f'?provider_name={provider_name}&payment_result=succeeded',
+                cancel_url=current_app.config.get("CONSOLE_WEB_URL")
+                           + f'?provider_name={provider_name}&payment_result=cancelled',
                 automatic_tax={'enabled': True},
             )
         except Exception as e:
@@ -76,7 +89,7 @@ class ProviderCheckoutService:
 
         return ProviderCheckout(checkout_session)
 
-    def fulfill_provider_order(self, event):
+    def fulfill_provider_order(self, event, line_items):
         provider_order = db.session.query(ProviderOrder) \
             .filter(ProviderOrder.payment_id == event['data']['object']['id']) \
             .first()
@@ -85,7 +98,8 @@ class ProviderCheckoutService:
             raise ValueError(f'provider order not found, payment id: {event["data"]["object"]["id"]}')
 
         if provider_order.payment_status != ProviderOrderPaymentStatus.WAIT_PAY.value:
-            raise ValueError(f'provider order payment status is not wait pay, payment id: {event["data"]["object"]["id"]}')
+            raise ValueError(
+                f'provider order payment status is not wait pay, payment id: {event["data"]["object"]["id"]}')
 
         provider_order.transaction_id = event['data']['object']['payment_intent']
         provider_order.currency = event['data']['object']['currency']
@@ -110,10 +124,12 @@ class ProviderCheckoutService:
         model_provider = model_provider_class(provider=provider)
         payment_info = model_provider.get_payment_info()
 
+        quantity = line_items['data'][0]['quantity']
+
         if not payment_info:
             increase_quota = 0
         else:
-            increase_quota = int(payment_info['increase_quota'])
+            increase_quota = int(payment_info['increase_quota']) * quantity
 
         if increase_quota > 0:
             provider.quota_limit += increase_quota

+ 4 - 2
api/services/provider_service.py

@@ -133,12 +133,14 @@ class ProviderService:
                         provider_parameter_dict[key]['is_valid'] = provider.is_valid
                         provider_parameter_dict[key]['quota_used'] = provider.quota_used
                         provider_parameter_dict[key]['quota_limit'] = provider.quota_limit
-                        provider_parameter_dict[key]['last_used'] = provider.last_used
+                        provider_parameter_dict[key]['last_used'] = int(provider.last_used.timestamp()) \
+                            if provider.last_used else None
                 elif provider.provider_type == ProviderType.CUSTOM.value \
                         and ProviderType.CUSTOM.value in provider_parameter_dict:
                     # if custom
                     key = ProviderType.CUSTOM.value
-                    provider_parameter_dict[key]['last_used'] = provider.last_used
+                    provider_parameter_dict[key]['last_used'] = int(provider.last_used.timestamp()) \
+                            if provider.last_used else None
                     provider_parameter_dict[key]['is_valid'] = provider.is_valid
 
                     if model_provider_rule['model_flexibility'] == 'fixed':