diff --git a/api/.env.example b/api/.env.example index e8e81d993a..2af1fddad5 100644 --- a/api/.env.example +++ b/api/.env.example @@ -117,10 +117,12 @@ HOSTED_AZURE_OPENAI_QUOTA_LIMIT=200 HOSTED_ANTHROPIC_ENABLED=false HOSTED_ANTHROPIC_API_BASE= HOSTED_ANTHROPIC_API_KEY= -HOSTED_ANTHROPIC_QUOTA_LIMIT=1000000 +HOSTED_ANTHROPIC_QUOTA_LIMIT=600000 HOSTED_ANTHROPIC_PAID_ENABLED=false HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID= -HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1 +HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000 +HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20 +HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100 STRIPE_API_KEY= STRIPE_WEBHOOK_SECRET= \ No newline at end of file diff --git a/api/commands.py b/api/commands.py index caa5e1ee20..a86e61e059 100644 --- a/api/commands.py +++ b/api/commands.py @@ -258,6 +258,8 @@ def sync_anthropic_hosted_providers(): click.echo(click.style('Start sync anthropic hosted providers.', fg='green')) count = 0 + new_quota_limit = hosted_model_providers.anthropic.quota_limit + page = 1 while True: try: @@ -265,6 +267,7 @@ def sync_anthropic_hosted_providers(): Provider.provider_name == 'anthropic', Provider.provider_type == ProviderType.SYSTEM.value, Provider.quota_type == ProviderQuotaType.TRIAL.value, + Provider.quota_limit != new_quota_limit ).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100) except NotFound: break @@ -272,9 +275,9 @@ def sync_anthropic_hosted_providers(): page += 1 for provider in providers: try: - click.echo('Syncing tenant anthropic hosted provider: {}'.format(provider.tenant_id)) + click.echo('Syncing tenant anthropic hosted provider: {}, origin: limit {}, used {}' + .format(provider.tenant_id, provider.quota_limit, provider.quota_used)) original_quota_limit = provider.quota_limit - new_quota_limit = hosted_model_providers.anthropic.quota_limit division = math.ceil(new_quota_limit / 1000) provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \ diff --git a/api/config.py b/api/config.py index aaa45dfe26..d0b5f39cac 100644 --- a/api/config.py +++ b/api/config.py @@ -57,10 +57,12 @@ DEFAULTS = { 'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1, 'HOSTED_AZURE_OPENAI_ENABLED': 'False', 'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200, - 'HOSTED_ANTHROPIC_QUOTA_LIMIT': 1000000, + 'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000, 'HOSTED_ANTHROPIC_ENABLED': 'False', 'HOSTED_ANTHROPIC_PAID_ENABLED': 'False', - 'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1, + 'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000, + 'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20, + 'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100, 'TENANT_DOCUMENT_COUNT': 100, 'CLEAN_DAY_SETTING': 30, 'UPLOAD_FILE_SIZE_LIMIT': 15, @@ -211,7 +213,7 @@ class Config: self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY') self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE') self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION') - self.HOSTED_OPENAI_QUOTA_LIMIT = get_env('HOSTED_OPENAI_QUOTA_LIMIT') + self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT')) self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED') self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID') self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA')) @@ -219,15 +221,17 @@ class Config: self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED') self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY') self.HOSTED_AZURE_OPENAI_API_BASE = get_env('HOSTED_AZURE_OPENAI_API_BASE') - self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT') + self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT')) self.HOSTED_ANTHROPIC_ENABLED = get_bool_env('HOSTED_ANTHROPIC_ENABLED') self.HOSTED_ANTHROPIC_API_BASE = get_env('HOSTED_ANTHROPIC_API_BASE') self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY') - self.HOSTED_ANTHROPIC_QUOTA_LIMIT = get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT') + self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT')) self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED') self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID') - self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA') + self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA')) + self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY')) + self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY')) self.STRIPE_API_KEY = get_env('STRIPE_API_KEY') self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET') diff --git a/api/controllers/console/webhook/stripe.py b/api/controllers/console/webhook/stripe.py index da906b0dc8..15cce34723 100644 --- a/api/controllers/console/webhook/stripe.py +++ b/api/controllers/console/webhook/stripe.py @@ -38,12 +38,20 @@ class StripeWebhookApi(Resource): logging.debug(event['data']['object']['payment_status']) logging.debug(event['data']['object']['metadata']) + session = stripe.checkout.Session.retrieve( + event['data']['object']['id'], + expand=['line_items'], + ) + + logging.debug(session.line_items['data'][0]['quantity']) + # Fulfill the purchase... provider_checkout_service = ProviderCheckoutService() try: - provider_checkout_service.fulfill_provider_order(event) + provider_checkout_service.fulfill_provider_order(event, session.line_items) except Exception as e: + logging.debug(str(e)) return 'success', 200 diff --git a/api/core/model_providers/models/llm/base.py b/api/core/model_providers/models/llm/base.py index d9216f6b26..a00ea87504 100644 --- a/api/core/model_providers/models/llm/base.py +++ b/api/core/model_providers/models/llm/base.py @@ -125,6 +125,8 @@ class BaseLLM(BaseProviderModel): completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)]) total_tokens = prompt_tokens + completion_tokens + self.model_provider.update_last_used() + if self.deduct_quota: self.model_provider.deduct_quota(total_tokens) diff --git a/api/core/model_providers/providers/anthropic_provider.py b/api/core/model_providers/providers/anthropic_provider.py index 8daeff44e7..8bab7bb251 100644 --- a/api/core/model_providers/providers/anthropic_provider.py +++ b/api/core/model_providers/providers/anthropic_provider.py @@ -183,6 +183,8 @@ class AnthropicProvider(BaseModelProvider): return { 'product_id': hosted_model_providers.anthropic.paid_stripe_price_id, 'increase_quota': hosted_model_providers.anthropic.paid_increase_quota, + 'min_quantity': hosted_model_providers.anthropic.paid_min_quantity, + 'max_quantity': hosted_model_providers.anthropic.paid_max_quantity, } return None diff --git a/api/core/model_providers/providers/hosted.py b/api/core/model_providers/providers/hosted.py index b34153d0ab..a5f1ce83b6 100644 --- a/api/core/model_providers/providers/hosted.py +++ b/api/core/model_providers/providers/hosted.py @@ -31,7 +31,9 @@ class HostedAnthropic(BaseModel): """Quota limit for the anthropic hosted model. 0 means unlimited.""" paid_enabled: bool = False paid_stripe_price_id: str = None - paid_increase_quota: int = 1 + paid_increase_quota: int = 1000000 + paid_min_quantity: int = 20 + paid_max_quantity: int = 100 class HostedModelProviders(BaseModel): @@ -73,4 +75,6 @@ def init_app(app: Flask): paid_enabled=app.config.get("HOSTED_ANTHROPIC_PAID_ENABLED"), paid_stripe_price_id=app.config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"), paid_increase_quota=app.config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA"), + paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"), + paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"), ) diff --git a/api/core/model_providers/rules/anthropic.json b/api/core/model_providers/rules/anthropic.json index 3b0ef372ee..8e0bee4425 100644 --- a/api/core/model_providers/rules/anthropic.json +++ b/api/core/model_providers/rules/anthropic.json @@ -5,10 +5,11 @@ ], "system_config": { "supported_quota_types": [ + "paid", "trial" ], - "quota_unit": "times", - "quota_limit": 1000 + "quota_unit": "tokens", + "quota_limit": 600000 }, "model_flexibility": "fixed" } \ No newline at end of file diff --git a/api/services/provider_checkout_service.py b/api/services/provider_checkout_service.py index 80391dfac1..4268acf657 100644 --- a/api/services/provider_checkout_service.py +++ b/api/services/provider_checkout_service.py @@ -39,6 +39,8 @@ class ProviderCheckoutService: raise ValueError(f'provider name {provider_name} not support payment') payment_product_id = payment_info['product_id'] + payment_min_quantity = payment_info['min_quantity'] + payment_max_quantity = payment_info['max_quantity'] # create provider order provider_order = ProviderOrder( @@ -53,18 +55,29 @@ class ProviderCheckoutService: db.session.add(provider_order) db.session.flush() + line_item = { + 'price': f'{payment_product_id}', + 'quantity': payment_min_quantity + } + + if payment_min_quantity > 1 and payment_max_quantity != payment_min_quantity: + line_item['adjustable_quantity'] = { + 'enabled': True, + 'minimum': payment_min_quantity, + 'maximum': payment_max_quantity + } + try: # create stripe checkout session checkout_session = stripe.checkout.Session.create( line_items=[ - { - 'price': f'{payment_product_id}', - 'quantity': 1, - }, + line_item ], mode='payment', - success_url=current_app.config.get("CONSOLE_WEB_URL") + '?provider_payment=succeeded', - cancel_url=current_app.config.get("CONSOLE_WEB_URL") + '?provider_payment=cancelled', + success_url=current_app.config.get("CONSOLE_WEB_URL") + + f'?provider_name={provider_name}&payment_result=succeeded', + cancel_url=current_app.config.get("CONSOLE_WEB_URL") + + f'?provider_name={provider_name}&payment_result=cancelled', automatic_tax={'enabled': True}, ) except Exception as e: @@ -76,7 +89,7 @@ class ProviderCheckoutService: return ProviderCheckout(checkout_session) - def fulfill_provider_order(self, event): + def fulfill_provider_order(self, event, line_items): provider_order = db.session.query(ProviderOrder) \ .filter(ProviderOrder.payment_id == event['data']['object']['id']) \ .first() @@ -85,7 +98,8 @@ class ProviderCheckoutService: raise ValueError(f'provider order not found, payment id: {event["data"]["object"]["id"]}') if provider_order.payment_status != ProviderOrderPaymentStatus.WAIT_PAY.value: - raise ValueError(f'provider order payment status is not wait pay, payment id: {event["data"]["object"]["id"]}') + raise ValueError( + f'provider order payment status is not wait pay, payment id: {event["data"]["object"]["id"]}') provider_order.transaction_id = event['data']['object']['payment_intent'] provider_order.currency = event['data']['object']['currency'] @@ -110,10 +124,12 @@ class ProviderCheckoutService: model_provider = model_provider_class(provider=provider) payment_info = model_provider.get_payment_info() + quantity = line_items['data'][0]['quantity'] + if not payment_info: increase_quota = 0 else: - increase_quota = int(payment_info['increase_quota']) + increase_quota = int(payment_info['increase_quota']) * quantity if increase_quota > 0: provider.quota_limit += increase_quota diff --git a/api/services/provider_service.py b/api/services/provider_service.py index 9a4ae5ee73..703ec51b6b 100644 --- a/api/services/provider_service.py +++ b/api/services/provider_service.py @@ -133,12 +133,14 @@ class ProviderService: provider_parameter_dict[key]['is_valid'] = provider.is_valid provider_parameter_dict[key]['quota_used'] = provider.quota_used provider_parameter_dict[key]['quota_limit'] = provider.quota_limit - provider_parameter_dict[key]['last_used'] = provider.last_used + provider_parameter_dict[key]['last_used'] = int(provider.last_used.timestamp()) \ + if provider.last_used else None elif provider.provider_type == ProviderType.CUSTOM.value \ and ProviderType.CUSTOM.value in provider_parameter_dict: # if custom key = ProviderType.CUSTOM.value - provider_parameter_dict[key]['last_used'] = provider.last_used + provider_parameter_dict[key]['last_used'] = int(provider.last_used.timestamp()) \ + if provider.last_used else None provider_parameter_dict[key]['is_valid'] = provider.is_valid if model_provider_rule['model_flexibility'] == 'fixed':