Skip to content

Commit bfb21c1

Browse files
authored
Merge pull request #58 from amazeeio/dev
Add Stripe integration
2 parents 0863a5b + c198e71 commit bfb21c1

40 files changed

+5268
-76
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ make backend-test-cov # Run backend tests with coverage report
6767
make backend-test-regex # Waits for a string which pytest will parse to only collect a subset of tests
6868
```
6969

70+
### 💳 Testing Stripe
71+
See [[tests/stripe_test_trigger.md]] for detailed instructions on testing integration with Stripe for billing purposes.
72+
7073
### Frontend Tests
7174
```bash
7275
make frontend-test # Run frontend tests

app/api/auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
auth_logger = logging.getLogger(__name__)
6868

6969
router = APIRouter(
70-
tags=["Authentication"]
70+
tags=["auth"]
7171
)
7272

7373
def get_cookie_domain():

app/api/billing.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
from fastapi import APIRouter, Depends, HTTPException, status, Request, Response, BackgroundTasks
2+
from sqlalchemy.orm import Session
3+
import logging
4+
import os
5+
from app.db.database import get_db
6+
from app.core.security import check_specific_team_admin
7+
from app.db.models import DBTeam, DBSystemSecret
8+
from app.schemas.models import PricingTableSession
9+
from app.services.stripe import (
10+
decode_stripe_event,
11+
create_portal_session,
12+
create_stripe_customer,
13+
get_pricing_table_secret,
14+
)
15+
from app.core.worker import handle_stripe_event_background
16+
17+
# Configure logger
18+
logger = logging.getLogger(__name__)
19+
BILLING_WEBHOOK_KEY = "stripe_webhook_secret"
20+
BILLING_WEBHOOK_ROUTE = "/billing/events"
21+
22+
router = APIRouter(
23+
tags=["billing"]
24+
)
25+
26+
# TODO: Verify where we want this to be
27+
def get_return_url(team_id: int) -> str:
28+
"""
29+
Get the return URL for the team dashboard.
30+
31+
Args:
32+
team_id: The ID of the team to get the return URL for
33+
34+
Returns:
35+
The return URL for the team dashboard
36+
"""
37+
# Get the frontend URL from environment
38+
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
39+
return f"{frontend_url}/teams/{team_id}/dashboard"
40+
41+
42+
@router.post("/events")
43+
async def handle_events(
44+
request: Request,
45+
background_tasks: BackgroundTasks,
46+
db: Session = Depends(get_db)
47+
):
48+
"""
49+
Handle Stripe webhook events.
50+
51+
This endpoint processes various Stripe events like subscription updates,
52+
payment successes, and failures. Events are processed asynchronously in the background.
53+
"""
54+
try:
55+
# Get the webhook secret from database or environment variable
56+
if os.getenv("WEBHOOK_SIG"):
57+
webhook_secret = os.getenv("WEBHOOK_SIG")
58+
else:
59+
webhook_secret = db.query(DBSystemSecret).filter(
60+
DBSystemSecret.key == BILLING_WEBHOOK_KEY
61+
).first().value
62+
63+
if not webhook_secret:
64+
logger.error("Stripe webhook secret not configured")
65+
# 404 for security reasons - if we're not accepting traffic here, then it doesn't exist
66+
raise HTTPException(
67+
status_code=status.HTTP_404_NOT_FOUND,
68+
detail="Not found"
69+
)
70+
71+
# Get the raw request body
72+
payload = await request.body()
73+
signature = request.headers.get("stripe-signature")
74+
75+
event = decode_stripe_event(payload, signature, webhook_secret)
76+
77+
# Add the event handling to background tasks
78+
background_tasks.add_task(handle_stripe_event_background, event, db)
79+
80+
return Response(
81+
status_code=status.HTTP_200_OK,
82+
content="Webhook received and processing started"
83+
)
84+
85+
except Exception as e:
86+
logger.error(f"Error handling Stripe event: {str(e)}")
87+
raise HTTPException(
88+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
89+
detail="Error processing webhook"
90+
)
91+
92+
@router.post("/teams/{team_id}/portal", dependencies=[Depends(check_specific_team_admin)])
93+
async def get_portal(
94+
team_id: int,
95+
db: Session = Depends(get_db)
96+
):
97+
"""
98+
Create a Stripe Customer Portal session for team subscription management and redirect to it.
99+
If the team doesn't have a Stripe customer ID, one will be created first.
100+
101+
Args:
102+
team_id: The ID of the team to create the portal session for
103+
104+
Returns:
105+
Redirects to the Stripe Customer Portal URL
106+
"""
107+
# Get the team
108+
team = db.query(DBTeam).filter(DBTeam.id == team_id).first()
109+
if not team:
110+
raise HTTPException(
111+
status_code=status.HTTP_404_NOT_FOUND,
112+
detail="Team not found"
113+
)
114+
if not team.stripe_customer_id:
115+
raise HTTPException(
116+
status_code=status.HTTP_400_BAD_REQUEST,
117+
detail="Team has not been registered with Stripe"
118+
)
119+
120+
try:
121+
return_url = get_return_url(team_id)
122+
# Create portal session using the service
123+
portal_url = await create_portal_session(team.stripe_customer_id, return_url)
124+
125+
return Response(
126+
status_code=status.HTTP_303_SEE_OTHER,
127+
headers={"Location": portal_url}
128+
)
129+
except Exception as e:
130+
logger.error(f"Error creating portal session: {str(e)}")
131+
raise HTTPException(
132+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
133+
detail="Error creating portal session"
134+
)
135+
136+
@router.get("/teams/{team_id}/pricing-table-session", dependencies=[Depends(check_specific_team_admin)], response_model=PricingTableSession)
137+
async def get_pricing_table_session(
138+
team_id: int,
139+
db: Session = Depends(get_db)
140+
):
141+
"""
142+
Create a Stripe Customer Session client secret for team subscription management.
143+
If the team doesn't have a Stripe customer ID, one will be created first.
144+
145+
Args:
146+
team_id: The ID of the team to create the customer session for
147+
148+
Returns:
149+
JSON response containing the client secret
150+
"""
151+
# Get the team
152+
team = db.query(DBTeam).filter(DBTeam.id == team_id).first()
153+
if not team:
154+
raise HTTPException(
155+
status_code=status.HTTP_404_NOT_FOUND,
156+
detail="Team not found"
157+
)
158+
159+
try:
160+
# Create Stripe customer if one doesn't exist
161+
if not team.stripe_customer_id:
162+
logger.info(f"Creating Stripe customer for team {team.id}")
163+
team.stripe_customer_id = await create_stripe_customer(team)
164+
db.add(team)
165+
db.commit()
166+
167+
logger.info(f"Stripe ID is {team.stripe_customer_id}")
168+
# Create customer session using the service
169+
client_secret = await get_pricing_table_secret(team.stripe_customer_id)
170+
171+
return PricingTableSession(client_secret=client_secret)
172+
except Exception as e:
173+
logger.error(f"Error creating customer session: {str(e)}")
174+
raise HTTPException(
175+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
176+
detail="Error creating customer session"
177+
)

app/api/pricing_tables.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from fastapi import APIRouter, Depends, HTTPException, status
2+
from sqlalchemy.orm import Session
3+
from datetime import datetime, UTC
4+
5+
from app.db.database import get_db
6+
from app.db.models import DBSystemSecret
7+
from app.core.security import check_system_admin, get_role_min_team_admin
8+
from app.schemas.models import PricingTableCreate, PricingTableResponse
9+
10+
router = APIRouter(
11+
tags=["pricing-tables"]
12+
)
13+
14+
@router.post("", response_model=PricingTableResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(check_system_admin)])
15+
@router.post("/", response_model=PricingTableResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(check_system_admin)])
16+
async def create_pricing_table(
17+
pricing_table: PricingTableCreate,
18+
db: Session = Depends(get_db)
19+
):
20+
"""
21+
Create or update the current pricing table. Only accessible by system admin users.
22+
There can only be one active pricing table at a time.
23+
"""
24+
# Check if a pricing table already exists
25+
existing_table = db.query(DBSystemSecret).filter(DBSystemSecret.key == "CurrentPricingTable").first()
26+
27+
if existing_table:
28+
# Update existing table
29+
existing_table.value = pricing_table.pricing_table_id
30+
existing_table.updated_at = datetime.now(UTC)
31+
db.commit()
32+
db.refresh(existing_table)
33+
return PricingTableResponse(
34+
pricing_table_id=existing_table.value,
35+
updated_at=existing_table.updated_at
36+
)
37+
else:
38+
# Create new table
39+
db_table = DBSystemSecret(
40+
key="CurrentPricingTable",
41+
value=pricing_table.pricing_table_id,
42+
description="Current Stripe pricing table ID",
43+
created_at=datetime.now(UTC)
44+
)
45+
db.add(db_table)
46+
db.commit()
47+
db.refresh(db_table)
48+
return PricingTableResponse(
49+
pricing_table_id=db_table.value,
50+
updated_at=db_table.created_at
51+
)
52+
53+
@router.get("", response_model=PricingTableResponse, dependencies=[Depends(get_role_min_team_admin)])
54+
@router.get("/", response_model=PricingTableResponse, dependencies=[Depends(get_role_min_team_admin)])
55+
async def get_pricing_table(
56+
db: Session = Depends(get_db)
57+
):
58+
"""
59+
Get the current pricing table ID. Only accessible by team admin users or higher privileges.
60+
"""
61+
pricing_table = db.query(DBSystemSecret).filter(DBSystemSecret.key == "CurrentPricingTable").first()
62+
if not pricing_table:
63+
raise HTTPException(
64+
status_code=status.HTTP_404_NOT_FOUND,
65+
detail="No pricing table found"
66+
)
67+
return PricingTableResponse(
68+
pricing_table_id=pricing_table.value,
69+
updated_at=pricing_table.updated_at or pricing_table.created_at
70+
)
71+
72+
@router.delete("", dependencies=[Depends(check_system_admin)])
73+
@router.delete("/", dependencies=[Depends(check_system_admin)])
74+
async def delete_pricing_table(
75+
db: Session = Depends(get_db)
76+
):
77+
"""
78+
Delete the current pricing table. Only accessible by system admin users.
79+
"""
80+
pricing_table = db.query(DBSystemSecret).filter(DBSystemSecret.key == "CurrentPricingTable").first()
81+
if not pricing_table:
82+
raise HTTPException(
83+
status_code=status.HTTP_404_NOT_FOUND,
84+
detail="No pricing table found"
85+
)
86+
87+
db.delete(pricing_table)
88+
db.commit()
89+
90+
return {"message": "Pricing table deleted successfully"}

0 commit comments

Comments
 (0)