Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
d875a8e
Add endpoint to extend private AI key duration
PhiRho May 14, 2025
a4906f5
Update dependencies in requirements files
PhiRho May 16, 2025
674d332
Add endpoint to retrieve private AI key details
PhiRho May 16, 2025
659abff
Add billing functionality with Stripe integration
PhiRho May 16, 2025
bc3fbd5
Refactor API routers to standardize tags
PhiRho May 16, 2025
66146fb
Add product management functionality
PhiRho May 16, 2025
d491226
Enhance Stripe billing integration and configuration
PhiRho May 19, 2025
6385f16
Implement product management and user limits functionality
PhiRho May 20, 2025
d48e46e
Add key limit checks and enhance user limit validation
PhiRho May 20, 2025
383084c
Add vector DB limit checks and enhance resource limits validation
PhiRho May 20, 2025
7b0a776
Refactor billing API to enhance Stripe customer portal functionality
PhiRho May 20, 2025
0129d58
Enhance Stripe event handling and product management in billing API
PhiRho May 21, 2025
64544e7
Add pricing table session endpoint and improve error handling in bill…
PhiRho May 21, 2025
8421411
Implement resource limits checks for team and user management
PhiRho May 21, 2025
1666f06
Enhance product management and Stripe integration
PhiRho May 22, 2025
3467782
Refactor API key management and enhance resource limits validation
PhiRho May 22, 2025
66b91ee
Enhance token restrictions management and resource limits validation
PhiRho May 23, 2025
91218f8
Add frontend for purchase from pricing table
PhiRho May 30, 2025
0500eae
Add team specific product list
PhiRho Jun 2, 2025
3ba1d8e
Refactor billing API and enhance Stripe event handling
PhiRho Jun 3, 2025
7644c47
Enhance resource limits validation and optimize queries
PhiRho Jun 3, 2025
027f7a1
Update product application logic to include start date
PhiRho Jun 3, 2025
ef933ef
Refactor Stripe event handling constants and update model attribute
PhiRho Jun 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ make backend-test-cov # Run backend tests with coverage report
make backend-test-regex # Waits for a string which pytest will parse to only collect a subset of tests
```

### 💳 Testing Stripe
See [[tests/stripe_test_trigger.md]] for detailed instructions on testing integration with Stripe for billing purposes.

### Frontend Tests
```bash
make frontend-test # Run frontend tests
Expand Down
2 changes: 1 addition & 1 deletion app/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
auth_logger = logging.getLogger(__name__)

router = APIRouter(
tags=["Authentication"]
tags=["auth"]
)

def get_cookie_domain():
Expand Down
177 changes: 177 additions & 0 deletions app/api/billing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
from fastapi import APIRouter, Depends, HTTPException, status, Request, Response, BackgroundTasks
from sqlalchemy.orm import Session
import logging
import os
from app.db.database import get_db
from app.core.security import check_specific_team_admin
from app.db.models import DBTeam, DBSystemSecret
from app.schemas.models import PricingTableSession
from app.services.stripe import (
decode_stripe_event,
create_portal_session,
create_stripe_customer,
get_pricing_table_secret,
)
from app.core.worker import handle_stripe_event_background

# Configure logger
logger = logging.getLogger(__name__)
BILLING_WEBHOOK_KEY = "stripe_webhook_secret"
BILLING_WEBHOOK_ROUTE = "/billing/events"

router = APIRouter(
tags=["billing"]
)

# TODO: Verify where we want this to be
def get_return_url(team_id: int) -> str:
"""
Get the return URL for the team dashboard.

Args:
team_id: The ID of the team to get the return URL for

Returns:
The return URL for the team dashboard
"""
# Get the frontend URL from environment
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000")
return f"{frontend_url}/teams/{team_id}/dashboard"


@router.post("/events")
async def handle_events(
request: Request,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db)
):
"""
Handle Stripe webhook events.

This endpoint processes various Stripe events like subscription updates,
payment successes, and failures. Events are processed asynchronously in the background.
"""
try:
# Get the webhook secret from database or environment variable
if os.getenv("WEBHOOK_SIG"):
webhook_secret = os.getenv("WEBHOOK_SIG")
else:
webhook_secret = db.query(DBSystemSecret).filter(
DBSystemSecret.key == BILLING_WEBHOOK_KEY
).first().value

if not webhook_secret:
logger.error("Stripe webhook secret not configured")
# 404 for security reasons - if we're not accepting traffic here, then it doesn't exist
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Not found"
)

# Get the raw request body
payload = await request.body()
signature = request.headers.get("stripe-signature")

event = decode_stripe_event(payload, signature, webhook_secret)

# Add the event handling to background tasks
background_tasks.add_task(handle_stripe_event_background, event, db)

return Response(
status_code=status.HTTP_200_OK,
content="Webhook received and processing started"
)

except Exception as e:
logger.error(f"Error handling Stripe event: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error processing webhook"
)

@router.post("/teams/{team_id}/portal", dependencies=[Depends(check_specific_team_admin)])
async def get_portal(
team_id: int,
db: Session = Depends(get_db)
):
"""
Create a Stripe Customer Portal session for team subscription management and redirect to it.
If the team doesn't have a Stripe customer ID, one will be created first.

Args:
team_id: The ID of the team to create the portal session for

Returns:
Redirects to the Stripe Customer Portal URL
"""
# Get the team
team = db.query(DBTeam).filter(DBTeam.id == team_id).first()
if not team:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Team not found"
)
if not team.stripe_customer_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Team has not been registered with Stripe"
)

try:
return_url = get_return_url(team_id)
# Create portal session using the service
portal_url = await create_portal_session(team.stripe_customer_id, return_url)

return Response(
status_code=status.HTTP_303_SEE_OTHER,
headers={"Location": portal_url}
)
except Exception as e:
logger.error(f"Error creating portal session: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error creating portal session"
)

@router.get("/teams/{team_id}/pricing-table-session", dependencies=[Depends(check_specific_team_admin)], response_model=PricingTableSession)
async def get_pricing_table_session(
team_id: int,
db: Session = Depends(get_db)
):
"""
Create a Stripe Customer Session client secret for team subscription management.
If the team doesn't have a Stripe customer ID, one will be created first.

Args:
team_id: The ID of the team to create the customer session for

Returns:
JSON response containing the client secret
"""
# Get the team
team = db.query(DBTeam).filter(DBTeam.id == team_id).first()
if not team:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Team not found"
)

try:
# Create Stripe customer if one doesn't exist
if not team.stripe_customer_id:
logger.info(f"Creating Stripe customer for team {team.id}")
team.stripe_customer_id = await create_stripe_customer(team)
db.add(team)
db.commit()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another nit: sems like these same few lines appear in several places. could extract to _create_customer_if_necessary() or similar.


logger.info(f"Stripe ID is {team.stripe_customer_id}")
# Create customer session using the service
client_secret = await get_pricing_table_secret(team.stripe_customer_id)

return PricingTableSession(client_secret=client_secret)
except Exception as e:
logger.error(f"Error creating customer session: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error creating customer session"
)
150 changes: 143 additions & 7 deletions app/api/private_ai_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,19 @@
from app.db.database import get_db
from app.schemas.models import (
PrivateAIKey, PrivateAIKeyCreate, PrivateAIKeySpend,
BudgetPeriodUpdate, LiteLLMToken, VectorDBCreate, VectorDB
BudgetPeriodUpdate, LiteLLMToken, VectorDBCreate, VectorDB,
TokenDurationUpdate, PrivateAIKeyDetail
)
from app.db.postgres import PostgresManager
from app.db.models import DBPrivateAIKey, DBRegion, DBUser, DBTeam
from app.services.litellm import LiteLLMService
from app.core.security import get_current_user_from_auth, get_role_min_key_creator, get_role_min_team_admin, UserRole
from app.core.security import get_current_user_from_auth, get_role_min_key_creator, get_role_min_team_admin, UserRole, check_system_admin
from app.core.config import settings
from app.core.resource_limits import check_key_limits, check_vector_db_limits, get_token_restrictions, DEFAULT_KEY_DURATION, DEFAULT_MAX_SPEND, DEFAULT_RPM_PER_KEY

router = APIRouter(
tags=["private-ai-keys"]
)

# Set up logging
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -57,10 +64,6 @@ def _validate_permissions_and_get_ownership_info(

return owner_id, team_id

router = APIRouter(
tags=["Private AI Keys"]
)

@router.post("/vector-db", response_model=VectorDB)
async def create_vector_db(
vector_db: VectorDBCreate,
Expand Down Expand Up @@ -112,6 +115,12 @@ async def create_vector_db(
detail="Region not found or inactive"
)

if settings.ENABLE_LIMITS:
if not team_id: # if the team_id is not set we have already validated the owner_id
user = db.query(DBUser).filter(DBUser.id == owner_id).first()
team_id = user.team_id or FAKE_ID
check_vector_db_limits(db, team_id)

try:
# Create new postgres database
postgres_manager = PostgresManager(region=region)
Expand Down Expand Up @@ -285,6 +294,16 @@ async def create_llm_token(
detail="Team not found"
)

if owner.team_id or team_id:
if settings.ENABLE_LIMITS:
check_key_limits(db, owner.team_id or team_id, owner_id)
# Limits are conditionally applied in LiteLLM service
days_left_in_period, max_max_spend, max_rpm_limit = get_token_restrictions(db, owner.team_id or team_id)
else: # Super system users...
days_left_in_period = DEFAULT_KEY_DURATION
max_max_spend = DEFAULT_MAX_SPEND
max_rpm_limit = DEFAULT_RPM_PER_KEY

if team is not None:
owner_email = team.admin_email
litellm_team = team.id
Expand All @@ -302,7 +321,10 @@ async def create_llm_token(
email=owner_email,
name=private_ai_key.name,
user_id=owner_id,
team_id=f"{region.name.replace(' ', '_')}_{litellm_team}"
team_id=f"{region.name.replace(' ', '_')}_{litellm_team}",
duration=f"{days_left_in_period}d",
max_budget=max_max_spend,
rpm_limit=max_rpm_limit
)

# Create response object
Expand Down Expand Up @@ -390,6 +412,67 @@ async def list_private_ai_keys(
private_ai_keys = query.all()
return [key.to_dict() for key in private_ai_keys]

@router.get("/{key_id}", response_model=PrivateAIKeyDetail, dependencies=[Depends(check_system_admin)])
async def get_private_ai_key(
key_id: int,
current_user = Depends(get_current_user_from_auth),
db: Session = Depends(get_db)
):
"""
Get details of a specific private AI key.

This endpoint will:
1. Verify the user has access to the key
2. Return the full details of the key including LiteLLM-specific data

Required parameters:
- **key_id**: The ID of the private AI key to retrieve

The response will include:
- Database connection details (host, database name, username, password)
- LiteLLM API token for authentication
- LiteLLM API URL for making requests
- Owner and team information
- Region information
- LiteLLM-specific data (spend, duration, budget, etc.)

Note: You must be authenticated to use this endpoint.
Only system administrators can access this endpoint.
"""
private_ai_key = _get_key_if_allowed(key_id, current_user, "system_admin", db)

# Get the region
region = db.query(DBRegion).filter(DBRegion.id == private_ai_key.region_id).first()
if not region:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Region not found"
)

# Create LiteLLM service instance
litellm_service = LiteLLMService(
api_url=region.litellm_api_url,
api_key=region.litellm_api_key
)

try:
# Get LiteLLM key info
litellm_data = await litellm_service.get_key_info(private_ai_key.litellm_token)
info = litellm_data.get("info", {})
logger.info(f"LiteLLM key info: {info}")

# Combine database key info with LiteLLM info
key_data = private_ai_key.to_dict()
key_data.update(info)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was looking at this commit and noticed that it previously changed "expires" to "expires_at". just flagging in case that's a bug/problem.


return PrivateAIKeyDetail.model_validate(key_data)
except Exception as e:
logger.error(f"Failed to get Private AI Key details: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get Private AI Key details: {str(e)}"
)

def _get_key_if_allowed(key_id: int, current_user: DBUser, user_role: UserRole, db: Session) -> DBPrivateAIKey:
# First try to find the key
private_ai_key = db.query(DBPrivateAIKey).filter(
Expand Down Expand Up @@ -571,3 +654,56 @@ async def update_budget_period(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update budget period: {str(e)}"
)

@router.put("/{key_id}/extend-token-life")
async def extend_token_life(
key_id: int,
duration_update: TokenDurationUpdate,
current_user = Depends(get_current_user_from_auth),
user_role: UserRole = Depends(get_role_min_team_admin),
db: Session = Depends(get_db)
):
"""
Extend the life of a private AI key.

This endpoint will:
1. Verify the user has access to the key
2. Update the key's duration in LiteLLM
3. Return the updated key information

Required parameters:
- **duration**: The amount of time to add to the key's life (e.g. "30d" for 30 days, "1y" for 1 year)

Note: You must be authenticated to use this endpoint.
Only the owner of the key or an admin can update it.
"""
private_ai_key = _get_key_if_allowed(key_id, current_user, user_role, db)

# Get the region
region = db.query(DBRegion).filter(DBRegion.id == private_ai_key.region_id).first()
if not region:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Region not found"
)

litellm_service = LiteLLMService(
api_url=region.litellm_api_url,
api_key=region.litellm_api_key
)

try:
# Update key duration in LiteLLM
await litellm_service.update_key_duration(
litellm_token=private_ai_key.litellm_token,
duration=duration_update.duration
)

# Get updated key information
key_data = await litellm_service.get_key_info(private_ai_key.litellm_token)
return key_data
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to extend token life: {str(e)}"
)
Loading