Skip to content

Commit 027f7a1

Browse files
committed
Update product application logic to include start date
- Modified the `apply_product_for_team` function to accept a `start_date` parameter, allowing for accurate tracking of the last payment date. - Updated the `handle_stripe_event_background` function to pass the correct start date when applying products for new subscriptions and renewals. - Adjusted related tests to ensure compatibility with the new function signature and validate the correct application of products with the specified start date.
1 parent 7644c47 commit 027f7a1

File tree

5 files changed

+36
-19
lines changed

5 files changed

+36
-19
lines changed

app/core/resource_limits.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from sqlalchemy.orm import Session
2-
from sqlalchemy import func, and_, or_
2+
from sqlalchemy import func, or_
33
from app.db.models import DBTeam, DBUser, DBPrivateAIKey, DBTeamProduct, DBProduct
44
from fastapi import HTTPException, status
55
from typing import Optional

app/core/worker.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,16 @@
55
import logging
66
from collections import defaultdict
77
from app.core.resource_limits import get_token_restrictions
8-
from app.services.stripe import get_product_id_from_session, get_product_id_from_subscription, known_events, subscription_success_events, session_failure_events, subscription_failure_events, invoice_failure_events, invoice_success_events
8+
from app.services.stripe import (
9+
get_product_id_from_session,
10+
get_product_id_from_subscription,
11+
known_events,
12+
subscription_success_events,
13+
session_failure_events,
14+
subscription_failure_events,
15+
invoice_failure_events,
16+
invoice_success_events
17+
)
918

1019
logger = logging.getLogger(__name__)
1120

@@ -28,12 +37,14 @@ async def handle_stripe_event_background(event, db: Session):
2837
if event_type in subscription_success_events:
2938
# new subscription
3039
product_id = await get_product_id_from_subscription(event_object.id)
31-
await apply_product_for_team(db, customer_id, product_id)
40+
start_date = datetime.fromtimestamp(event_object.start_date, tz=UTC)
41+
await apply_product_for_team(db, customer_id, product_id, start_date)
3242
elif event_type in invoice_success_events:
3343
# subscription renewed
3444
subscription = event_object.parent.subscription_details.subscription
3545
product_id = await get_product_id_from_subscription(subscription)
36-
await apply_product_for_team(db, customer_id, product_id)
46+
start_date = datetime.fromtimestamp(event_object.period_start, tz=UTC)
47+
await apply_product_for_team(db, customer_id, product_id, start_date)
3748
# Failure Events
3849
elif event_type in session_failure_events:
3950
product_id = await get_product_id_from_session(event_object.id)
@@ -50,7 +61,7 @@ async def handle_stripe_event_background(event, db: Session):
5061
logger.error(f"Error in background event handler: {str(e)}")
5162

5263

53-
async def apply_product_for_team(db: Session, customer_id: str, product_id: str):
64+
async def apply_product_for_team(db: Session, customer_id: str, product_id: str, start_date: datetime):
5465
"""
5566
Apply a product to a team and update their last payment date.
5667
Also extends all team keys and sets their max budgets via LiteLLM service.
@@ -77,7 +88,7 @@ async def apply_product_for_team(db: Session, customer_id: str, product_id: str)
7788
return
7889

7990
# Update the last payment date
80-
team.last_payment = datetime.now(UTC)
91+
team.last_payment = start_date
8192

8293
# Check if the product is already active for the team
8394
existing_association = db.query(DBTeamProduct).filter(

app/services/stripe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
from typing import Optional, Union
21
import stripe
32
import os
43
import logging
54
from urllib.parse import urljoin
65
from fastapi import HTTPException, status
76
from sqlalchemy.orm import Session
8-
from stripe._list_object import ListObject
97
from app.db.models import DBTeam, DBSystemSecret
108

119
# Configure logger

tests/test_resource_limits.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,15 @@
22
from app.db.models import DBUser, DBProduct, DBTeamProduct, DBPrivateAIKey
33
from datetime import datetime, UTC, timedelta
44
from fastapi import HTTPException
5-
from app.core.resource_limits import check_key_limits, check_team_user_limit, check_vector_db_limits, get_token_restrictions, DEFAULT_KEY_DURATION, DEFAULT_MAX_SPEND, DEFAULT_RPM_PER_KEY
5+
from app.core.resource_limits import (
6+
check_key_limits,
7+
check_team_user_limit,
8+
check_vector_db_limits,
9+
get_token_restrictions,
10+
DEFAULT_KEY_DURATION,
11+
DEFAULT_MAX_SPEND,
12+
DEFAULT_RPM_PER_KEY
13+
)
614

715
def test_add_user_within_product_limit(db, test_team, test_product):
816
"""Test adding a user when within product user limit"""

tests/test_worker.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ async def test_apply_product_success(db, test_team, test_product):
180180
db.commit()
181181

182182
# Apply product to team
183-
await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id)
183+
await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id, datetime.now(UTC))
184184

185185
# Refresh team from database
186186
db.refresh(test_team)
@@ -200,7 +200,7 @@ async def test_apply_product_team_not_found(db, test_product):
200200
THEN: The operation completes without error
201201
"""
202202
# Try to apply product to non-existent team
203-
await apply_product_for_team(db, "cus_nonexistent", test_product.id)
203+
await apply_product_for_team(db, "cus_nonexistent", test_product.id, datetime.now(UTC))
204204
# No assertions needed as function should complete without error
205205

206206
@pytest.mark.asyncio
@@ -217,7 +217,7 @@ async def test_apply_product_product_not_found(db, test_team):
217217
db.commit()
218218

219219
# Try to apply non-existent product
220-
await apply_product_for_team(db, test_team.stripe_customer_id, "prod_nonexistent")
220+
await apply_product_for_team(db, test_team.stripe_customer_id, "prod_nonexistent", datetime.now(UTC))
221221
# No assertions needed as function should complete without error
222222

223223
@pytest.mark.asyncio
@@ -257,7 +257,7 @@ async def test_apply_product_multiple_products(db, test_team, test_product):
257257

258258
# Apply each product to the team
259259
for product in products:
260-
await apply_product_for_team(db, test_team.stripe_customer_id, product.id)
260+
await apply_product_for_team(db, test_team.stripe_customer_id, product.id, datetime.now(UTC))
261261

262262
# Refresh team from database
263263
db.refresh(test_team)
@@ -283,14 +283,14 @@ async def test_apply_product_already_active(db, test_team, test_product):
283283
db.refresh(test_team) # Refresh to ensure we have the latest data
284284

285285
# First apply the product
286-
await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id)
286+
await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id, datetime.now(UTC))
287287

288288
# Get the initial last payment date
289289
db.refresh(test_team)
290290
initial_last_payment = test_team.last_payment
291291

292292
# Apply the same product again
293-
await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id)
293+
await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id, datetime.now(UTC))
294294

295295
# Refresh team from database
296296
db.refresh(test_team)
@@ -357,7 +357,7 @@ async def test_apply_product_extends_keys_and_sets_budget(mock_litellm, db, test
357357
mock_instance.set_key_restrictions = AsyncMock()
358358

359359
# Apply product to team
360-
await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id)
360+
await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id, datetime.now(UTC))
361361

362362
# Verify LiteLLM service was initialized with correct region settings
363363
mock_litellm.assert_called_once_with(
@@ -400,7 +400,7 @@ async def test_remove_product_success(db, test_team, test_product):
400400
db.commit()
401401

402402
# First apply the product to ensure it exists
403-
await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id)
403+
await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id, datetime.now(UTC))
404404

405405
# Remove the product
406406
await remove_product_from_team(db, test_team.stripe_customer_id, test_product.id)
@@ -491,8 +491,8 @@ async def test_remove_product_multiple_products(db, test_team, test_product):
491491
db.commit()
492492

493493
# Apply both products to the team
494-
await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id)
495-
await apply_product_for_team(db, test_team.stripe_customer_id, second_product.id)
494+
await apply_product_for_team(db, test_team.stripe_customer_id, test_product.id, datetime.now(UTC))
495+
await apply_product_for_team(db, test_team.stripe_customer_id, second_product.id, datetime.now(UTC))
496496

497497
# Remove only the first product
498498
await remove_product_from_team(db, test_team.stripe_customer_id, test_product.id)

0 commit comments

Comments
 (0)