-
Notifications
You must be signed in to change notification settings - Fork 1
Begin the Stripe Integration #56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
d875a8e
Add endpoint to extend private AI key duration
PhiRho a4906f5
Update dependencies in requirements files
PhiRho 674d332
Add endpoint to retrieve private AI key details
PhiRho 659abff
Add billing functionality with Stripe integration
PhiRho bc3fbd5
Refactor API routers to standardize tags
PhiRho 66146fb
Add product management functionality
PhiRho d491226
Enhance Stripe billing integration and configuration
PhiRho 6385f16
Implement product management and user limits functionality
PhiRho d48e46e
Add key limit checks and enhance user limit validation
PhiRho 383084c
Add vector DB limit checks and enhance resource limits validation
PhiRho 7b0a776
Refactor billing API to enhance Stripe customer portal functionality
PhiRho 0129d58
Enhance Stripe event handling and product management in billing API
PhiRho 64544e7
Add pricing table session endpoint and improve error handling in bill…
PhiRho 8421411
Implement resource limits checks for team and user management
PhiRho 1666f06
Enhance product management and Stripe integration
PhiRho 3467782
Refactor API key management and enhance resource limits validation
PhiRho 66b91ee
Enhance token restrictions management and resource limits validation
PhiRho 91218f8
Add frontend for purchase from pricing table
PhiRho 0500eae
Add team specific product list
PhiRho 3ba1d8e
Refactor billing API and enhance Stripe event handling
PhiRho 7644c47
Enhance resource limits validation and optimize queries
PhiRho 027f7a1
Update product application logic to include start date
PhiRho ef933ef
Refactor Stripe event handling constants and update model attribute
PhiRho File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
from fastapi import APIRouter, Depends, HTTPException, status, Request, Response, BackgroundTasks | ||
from sqlalchemy.orm import Session | ||
import logging | ||
import os | ||
from app.db.database import get_db | ||
from app.core.security import check_specific_team_admin | ||
from app.db.models import DBTeam, DBSystemSecret | ||
from app.schemas.models import PricingTableSession | ||
from app.services.stripe import ( | ||
decode_stripe_event, | ||
create_portal_session, | ||
create_stripe_customer, | ||
get_pricing_table_secret, | ||
) | ||
from app.core.worker import handle_stripe_event_background | ||
|
||
# Configure logger | ||
logger = logging.getLogger(__name__) | ||
BILLING_WEBHOOK_KEY = "stripe_webhook_secret" | ||
BILLING_WEBHOOK_ROUTE = "/billing/events" | ||
|
||
router = APIRouter( | ||
tags=["billing"] | ||
) | ||
|
||
# TODO: Verify where we want this to be | ||
def get_return_url(team_id: int) -> str: | ||
""" | ||
Get the return URL for the team dashboard. | ||
|
||
Args: | ||
team_id: The ID of the team to get the return URL for | ||
|
||
Returns: | ||
The return URL for the team dashboard | ||
""" | ||
# Get the frontend URL from environment | ||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") | ||
return f"{frontend_url}/teams/{team_id}/dashboard" | ||
|
||
|
||
@router.post("/events") | ||
async def handle_events( | ||
request: Request, | ||
background_tasks: BackgroundTasks, | ||
db: Session = Depends(get_db) | ||
): | ||
""" | ||
Handle Stripe webhook events. | ||
|
||
This endpoint processes various Stripe events like subscription updates, | ||
payment successes, and failures. Events are processed asynchronously in the background. | ||
""" | ||
try: | ||
# Get the webhook secret from database or environment variable | ||
if os.getenv("WEBHOOK_SIG"): | ||
webhook_secret = os.getenv("WEBHOOK_SIG") | ||
else: | ||
webhook_secret = db.query(DBSystemSecret).filter( | ||
DBSystemSecret.key == BILLING_WEBHOOK_KEY | ||
).first().value | ||
|
||
if not webhook_secret: | ||
logger.error("Stripe webhook secret not configured") | ||
# 404 for security reasons - if we're not accepting traffic here, then it doesn't exist | ||
raise HTTPException( | ||
status_code=status.HTTP_404_NOT_FOUND, | ||
detail="Not found" | ||
) | ||
|
||
# Get the raw request body | ||
payload = await request.body() | ||
signature = request.headers.get("stripe-signature") | ||
|
||
event = decode_stripe_event(payload, signature, webhook_secret) | ||
|
||
# Add the event handling to background tasks | ||
background_tasks.add_task(handle_stripe_event_background, event, db) | ||
|
||
return Response( | ||
status_code=status.HTTP_200_OK, | ||
content="Webhook received and processing started" | ||
) | ||
|
||
except Exception as e: | ||
logger.error(f"Error handling Stripe event: {str(e)}") | ||
raise HTTPException( | ||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
detail="Error processing webhook" | ||
) | ||
|
||
@router.post("/teams/{team_id}/portal", dependencies=[Depends(check_specific_team_admin)]) | ||
async def get_portal( | ||
team_id: int, | ||
db: Session = Depends(get_db) | ||
): | ||
""" | ||
Create a Stripe Customer Portal session for team subscription management and redirect to it. | ||
If the team doesn't have a Stripe customer ID, one will be created first. | ||
|
||
Args: | ||
team_id: The ID of the team to create the portal session for | ||
|
||
Returns: | ||
Redirects to the Stripe Customer Portal URL | ||
""" | ||
# Get the team | ||
team = db.query(DBTeam).filter(DBTeam.id == team_id).first() | ||
if not team: | ||
raise HTTPException( | ||
status_code=status.HTTP_404_NOT_FOUND, | ||
detail="Team not found" | ||
) | ||
if not team.stripe_customer_id: | ||
raise HTTPException( | ||
status_code=status.HTTP_400_BAD_REQUEST, | ||
detail="Team has not been registered with Stripe" | ||
) | ||
|
||
try: | ||
return_url = get_return_url(team_id) | ||
# Create portal session using the service | ||
portal_url = await create_portal_session(team.stripe_customer_id, return_url) | ||
|
||
return Response( | ||
status_code=status.HTTP_303_SEE_OTHER, | ||
headers={"Location": portal_url} | ||
) | ||
except Exception as e: | ||
logger.error(f"Error creating portal session: {str(e)}") | ||
raise HTTPException( | ||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
detail="Error creating portal session" | ||
) | ||
|
||
@router.get("/teams/{team_id}/pricing-table-session", dependencies=[Depends(check_specific_team_admin)], response_model=PricingTableSession) | ||
async def get_pricing_table_session( | ||
team_id: int, | ||
db: Session = Depends(get_db) | ||
): | ||
""" | ||
Create a Stripe Customer Session client secret for team subscription management. | ||
If the team doesn't have a Stripe customer ID, one will be created first. | ||
|
||
Args: | ||
team_id: The ID of the team to create the customer session for | ||
|
||
Returns: | ||
JSON response containing the client secret | ||
""" | ||
# Get the team | ||
team = db.query(DBTeam).filter(DBTeam.id == team_id).first() | ||
if not team: | ||
raise HTTPException( | ||
status_code=status.HTTP_404_NOT_FOUND, | ||
detail="Team not found" | ||
) | ||
|
||
try: | ||
# Create Stripe customer if one doesn't exist | ||
if not team.stripe_customer_id: | ||
logger.info(f"Creating Stripe customer for team {team.id}") | ||
team.stripe_customer_id = await create_stripe_customer(team) | ||
db.add(team) | ||
db.commit() | ||
|
||
logger.info(f"Stripe ID is {team.stripe_customer_id}") | ||
# Create customer session using the service | ||
client_secret = await get_pricing_table_secret(team.stripe_customer_id) | ||
|
||
return PricingTableSession(client_secret=client_secret) | ||
except Exception as e: | ||
logger.error(f"Error creating customer session: {str(e)}") | ||
raise HTTPException( | ||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
detail="Error creating customer session" | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,12 +7,19 @@ | |
from app.db.database import get_db | ||
from app.schemas.models import ( | ||
PrivateAIKey, PrivateAIKeyCreate, PrivateAIKeySpend, | ||
BudgetPeriodUpdate, LiteLLMToken, VectorDBCreate, VectorDB | ||
BudgetPeriodUpdate, LiteLLMToken, VectorDBCreate, VectorDB, | ||
TokenDurationUpdate, PrivateAIKeyDetail | ||
) | ||
from app.db.postgres import PostgresManager | ||
from app.db.models import DBPrivateAIKey, DBRegion, DBUser, DBTeam | ||
from app.services.litellm import LiteLLMService | ||
from app.core.security import get_current_user_from_auth, get_role_min_key_creator, get_role_min_team_admin, UserRole | ||
from app.core.security import get_current_user_from_auth, get_role_min_key_creator, get_role_min_team_admin, UserRole, check_system_admin | ||
from app.core.config import settings | ||
from app.core.resource_limits import check_key_limits, check_vector_db_limits, get_token_restrictions, DEFAULT_KEY_DURATION, DEFAULT_MAX_SPEND, DEFAULT_RPM_PER_KEY | ||
|
||
router = APIRouter( | ||
tags=["private-ai-keys"] | ||
) | ||
|
||
# Set up logging | ||
logger = logging.getLogger(__name__) | ||
|
@@ -57,10 +64,6 @@ def _validate_permissions_and_get_ownership_info( | |
|
||
return owner_id, team_id | ||
|
||
router = APIRouter( | ||
tags=["Private AI Keys"] | ||
) | ||
|
||
@router.post("/vector-db", response_model=VectorDB) | ||
async def create_vector_db( | ||
vector_db: VectorDBCreate, | ||
|
@@ -112,6 +115,12 @@ async def create_vector_db( | |
detail="Region not found or inactive" | ||
) | ||
|
||
if settings.ENABLE_LIMITS: | ||
if not team_id: # if the team_id is not set we have already validated the owner_id | ||
user = db.query(DBUser).filter(DBUser.id == owner_id).first() | ||
team_id = user.team_id or FAKE_ID | ||
check_vector_db_limits(db, team_id) | ||
|
||
try: | ||
# Create new postgres database | ||
postgres_manager = PostgresManager(region=region) | ||
|
@@ -285,6 +294,16 @@ async def create_llm_token( | |
detail="Team not found" | ||
) | ||
|
||
if owner.team_id or team_id: | ||
if settings.ENABLE_LIMITS: | ||
check_key_limits(db, owner.team_id or team_id, owner_id) | ||
# Limits are conditionally applied in LiteLLM service | ||
days_left_in_period, max_max_spend, max_rpm_limit = get_token_restrictions(db, owner.team_id or team_id) | ||
else: # Super system users... | ||
days_left_in_period = DEFAULT_KEY_DURATION | ||
max_max_spend = DEFAULT_MAX_SPEND | ||
max_rpm_limit = DEFAULT_RPM_PER_KEY | ||
|
||
if team is not None: | ||
owner_email = team.admin_email | ||
litellm_team = team.id | ||
|
@@ -302,7 +321,10 @@ async def create_llm_token( | |
email=owner_email, | ||
name=private_ai_key.name, | ||
user_id=owner_id, | ||
team_id=f"{region.name.replace(' ', '_')}_{litellm_team}" | ||
team_id=f"{region.name.replace(' ', '_')}_{litellm_team}", | ||
duration=f"{days_left_in_period}d", | ||
max_budget=max_max_spend, | ||
rpm_limit=max_rpm_limit | ||
) | ||
|
||
# Create response object | ||
|
@@ -390,6 +412,67 @@ async def list_private_ai_keys( | |
private_ai_keys = query.all() | ||
return [key.to_dict() for key in private_ai_keys] | ||
|
||
@router.get("/{key_id}", response_model=PrivateAIKeyDetail, dependencies=[Depends(check_system_admin)]) | ||
async def get_private_ai_key( | ||
key_id: int, | ||
current_user = Depends(get_current_user_from_auth), | ||
db: Session = Depends(get_db) | ||
): | ||
""" | ||
Get details of a specific private AI key. | ||
|
||
This endpoint will: | ||
1. Verify the user has access to the key | ||
2. Return the full details of the key including LiteLLM-specific data | ||
|
||
Required parameters: | ||
- **key_id**: The ID of the private AI key to retrieve | ||
|
||
The response will include: | ||
- Database connection details (host, database name, username, password) | ||
- LiteLLM API token for authentication | ||
- LiteLLM API URL for making requests | ||
- Owner and team information | ||
- Region information | ||
- LiteLLM-specific data (spend, duration, budget, etc.) | ||
|
||
Note: You must be authenticated to use this endpoint. | ||
Only system administrators can access this endpoint. | ||
""" | ||
private_ai_key = _get_key_if_allowed(key_id, current_user, "system_admin", db) | ||
|
||
# Get the region | ||
region = db.query(DBRegion).filter(DBRegion.id == private_ai_key.region_id).first() | ||
if not region: | ||
raise HTTPException( | ||
status_code=status.HTTP_404_NOT_FOUND, | ||
detail="Region not found" | ||
) | ||
|
||
# Create LiteLLM service instance | ||
litellm_service = LiteLLMService( | ||
api_url=region.litellm_api_url, | ||
api_key=region.litellm_api_key | ||
) | ||
|
||
try: | ||
# Get LiteLLM key info | ||
litellm_data = await litellm_service.get_key_info(private_ai_key.litellm_token) | ||
info = litellm_data.get("info", {}) | ||
logger.info(f"LiteLLM key info: {info}") | ||
|
||
# Combine database key info with LiteLLM info | ||
key_data = private_ai_key.to_dict() | ||
key_data.update(info) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was looking at this commit and noticed that it previously changed "expires" to "expires_at". just flagging in case that's a bug/problem. |
||
|
||
return PrivateAIKeyDetail.model_validate(key_data) | ||
except Exception as e: | ||
logger.error(f"Failed to get Private AI Key details: {str(e)}", exc_info=True) | ||
raise HTTPException( | ||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
detail=f"Failed to get Private AI Key details: {str(e)}" | ||
) | ||
|
||
def _get_key_if_allowed(key_id: int, current_user: DBUser, user_role: UserRole, db: Session) -> DBPrivateAIKey: | ||
# First try to find the key | ||
private_ai_key = db.query(DBPrivateAIKey).filter( | ||
|
@@ -571,3 +654,56 @@ async def update_budget_period( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
detail=f"Failed to update budget period: {str(e)}" | ||
) | ||
|
||
@router.put("/{key_id}/extend-token-life") | ||
async def extend_token_life( | ||
key_id: int, | ||
duration_update: TokenDurationUpdate, | ||
current_user = Depends(get_current_user_from_auth), | ||
user_role: UserRole = Depends(get_role_min_team_admin), | ||
db: Session = Depends(get_db) | ||
): | ||
""" | ||
Extend the life of a private AI key. | ||
|
||
This endpoint will: | ||
1. Verify the user has access to the key | ||
2. Update the key's duration in LiteLLM | ||
3. Return the updated key information | ||
|
||
Required parameters: | ||
- **duration**: The amount of time to add to the key's life (e.g. "30d" for 30 days, "1y" for 1 year) | ||
|
||
Note: You must be authenticated to use this endpoint. | ||
Only the owner of the key or an admin can update it. | ||
""" | ||
private_ai_key = _get_key_if_allowed(key_id, current_user, user_role, db) | ||
|
||
# Get the region | ||
region = db.query(DBRegion).filter(DBRegion.id == private_ai_key.region_id).first() | ||
if not region: | ||
raise HTTPException( | ||
status_code=status.HTTP_404_NOT_FOUND, | ||
detail="Region not found" | ||
) | ||
|
||
litellm_service = LiteLLMService( | ||
api_url=region.litellm_api_url, | ||
api_key=region.litellm_api_key | ||
) | ||
|
||
try: | ||
# Update key duration in LiteLLM | ||
await litellm_service.update_key_duration( | ||
litellm_token=private_ai_key.litellm_token, | ||
duration=duration_update.duration | ||
) | ||
|
||
# Get updated key information | ||
key_data = await litellm_service.get_key_info(private_ai_key.litellm_token) | ||
return key_data | ||
except Exception as e: | ||
raise HTTPException( | ||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
detail=f"Failed to extend token life: {str(e)}" | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another nit: sems like these same few lines appear in several places. could extract to
_create_customer_if_necessary()
or similar.