diff --git a/README.md b/README.md index a202446..4834231 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,9 @@ make backend-test-cov # Run backend tests with coverage report make backend-test-regex # Waits for a string which pytest will parse to only collect a subset of tests ``` +### 💳 Testing Stripe +See [[tests/stripe_test_trigger.md]] for detailed instructions on testing integration with Stripe for billing purposes. + ### Frontend Tests ```bash make frontend-test # Run frontend tests diff --git a/app/api/auth.py b/app/api/auth.py index b11625a..b3df166 100644 --- a/app/api/auth.py +++ b/app/api/auth.py @@ -67,7 +67,7 @@ auth_logger = logging.getLogger(__name__) router = APIRouter( - tags=["Authentication"] + tags=["auth"] ) def get_cookie_domain(): diff --git a/app/api/billing.py b/app/api/billing.py new file mode 100644 index 0000000..99998f4 --- /dev/null +++ b/app/api/billing.py @@ -0,0 +1,177 @@ +from fastapi import APIRouter, Depends, HTTPException, status, Request, Response, BackgroundTasks +from sqlalchemy.orm import Session +import logging +import os +from app.db.database import get_db +from app.core.security import check_specific_team_admin +from app.db.models import DBTeam, DBSystemSecret +from app.schemas.models import PricingTableSession +from app.services.stripe import ( + decode_stripe_event, + create_portal_session, + create_stripe_customer, + get_pricing_table_secret, +) +from app.core.worker import handle_stripe_event_background + +# Configure logger +logger = logging.getLogger(__name__) +BILLING_WEBHOOK_KEY = "stripe_webhook_secret" +BILLING_WEBHOOK_ROUTE = "/billing/events" + +router = APIRouter( + tags=["billing"] +) + +# TODO: Verify where we want this to be +def get_return_url(team_id: int) -> str: + """ + Get the return URL for the team dashboard. + + Args: + team_id: The ID of the team to get the return URL for + + Returns: + The return URL for the team dashboard + """ + # Get the frontend URL from environment + frontend_url = os.getenv("FRONTEND_URL", "http://localhost:3000") + return f"{frontend_url}/teams/{team_id}/dashboard" + + +@router.post("/events") +async def handle_events( + request: Request, + background_tasks: BackgroundTasks, + db: Session = Depends(get_db) +): + """ + Handle Stripe webhook events. + + This endpoint processes various Stripe events like subscription updates, + payment successes, and failures. Events are processed asynchronously in the background. + """ + try: + # Get the webhook secret from database or environment variable + if os.getenv("WEBHOOK_SIG"): + webhook_secret = os.getenv("WEBHOOK_SIG") + else: + webhook_secret = db.query(DBSystemSecret).filter( + DBSystemSecret.key == BILLING_WEBHOOK_KEY + ).first().value + + if not webhook_secret: + logger.error("Stripe webhook secret not configured") + # 404 for security reasons - if we're not accepting traffic here, then it doesn't exist + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Not found" + ) + + # Get the raw request body + payload = await request.body() + signature = request.headers.get("stripe-signature") + + event = decode_stripe_event(payload, signature, webhook_secret) + + # Add the event handling to background tasks + background_tasks.add_task(handle_stripe_event_background, event, db) + + return Response( + status_code=status.HTTP_200_OK, + content="Webhook received and processing started" + ) + + except Exception as e: + logger.error(f"Error handling Stripe event: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error processing webhook" + ) + +@router.post("/teams/{team_id}/portal", dependencies=[Depends(check_specific_team_admin)]) +async def get_portal( + team_id: int, + db: Session = Depends(get_db) +): + """ + Create a Stripe Customer Portal session for team subscription management and redirect to it. + If the team doesn't have a Stripe customer ID, one will be created first. + + Args: + team_id: The ID of the team to create the portal session for + + Returns: + Redirects to the Stripe Customer Portal URL + """ + # Get the team + team = db.query(DBTeam).filter(DBTeam.id == team_id).first() + if not team: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Team not found" + ) + if not team.stripe_customer_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Team has not been registered with Stripe" + ) + + try: + return_url = get_return_url(team_id) + # Create portal session using the service + portal_url = await create_portal_session(team.stripe_customer_id, return_url) + + return Response( + status_code=status.HTTP_303_SEE_OTHER, + headers={"Location": portal_url} + ) + except Exception as e: + logger.error(f"Error creating portal session: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error creating portal session" + ) + +@router.get("/teams/{team_id}/pricing-table-session", dependencies=[Depends(check_specific_team_admin)], response_model=PricingTableSession) +async def get_pricing_table_session( + team_id: int, + db: Session = Depends(get_db) +): + """ + Create a Stripe Customer Session client secret for team subscription management. + If the team doesn't have a Stripe customer ID, one will be created first. + + Args: + team_id: The ID of the team to create the customer session for + + Returns: + JSON response containing the client secret + """ + # Get the team + team = db.query(DBTeam).filter(DBTeam.id == team_id).first() + if not team: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Team not found" + ) + + try: + # Create Stripe customer if one doesn't exist + if not team.stripe_customer_id: + logger.info(f"Creating Stripe customer for team {team.id}") + team.stripe_customer_id = await create_stripe_customer(team) + db.add(team) + db.commit() + + logger.info(f"Stripe ID is {team.stripe_customer_id}") + # Create customer session using the service + client_secret = await get_pricing_table_secret(team.stripe_customer_id) + + return PricingTableSession(client_secret=client_secret) + except Exception as e: + logger.error(f"Error creating customer session: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error creating customer session" + ) diff --git a/app/api/private_ai_keys.py b/app/api/private_ai_keys.py index 9264da0..dadab8c 100644 --- a/app/api/private_ai_keys.py +++ b/app/api/private_ai_keys.py @@ -7,12 +7,19 @@ from app.db.database import get_db from app.schemas.models import ( PrivateAIKey, PrivateAIKeyCreate, PrivateAIKeySpend, - BudgetPeriodUpdate, LiteLLMToken, VectorDBCreate, VectorDB + BudgetPeriodUpdate, LiteLLMToken, VectorDBCreate, VectorDB, + TokenDurationUpdate, PrivateAIKeyDetail ) from app.db.postgres import PostgresManager from app.db.models import DBPrivateAIKey, DBRegion, DBUser, DBTeam from app.services.litellm import LiteLLMService -from app.core.security import get_current_user_from_auth, get_role_min_key_creator, get_role_min_team_admin, UserRole +from app.core.security import get_current_user_from_auth, get_role_min_key_creator, get_role_min_team_admin, UserRole, check_system_admin +from app.core.config import settings +from app.core.resource_limits import check_key_limits, check_vector_db_limits, get_token_restrictions, DEFAULT_KEY_DURATION, DEFAULT_MAX_SPEND, DEFAULT_RPM_PER_KEY + +router = APIRouter( + tags=["private-ai-keys"] +) # Set up logging logger = logging.getLogger(__name__) @@ -57,10 +64,6 @@ def _validate_permissions_and_get_ownership_info( return owner_id, team_id -router = APIRouter( - tags=["Private AI Keys"] -) - @router.post("/vector-db", response_model=VectorDB) async def create_vector_db( vector_db: VectorDBCreate, @@ -112,6 +115,12 @@ async def create_vector_db( detail="Region not found or inactive" ) + if settings.ENABLE_LIMITS: + if not team_id: # if the team_id is not set we have already validated the owner_id + user = db.query(DBUser).filter(DBUser.id == owner_id).first() + team_id = user.team_id or FAKE_ID + check_vector_db_limits(db, team_id) + try: # Create new postgres database postgres_manager = PostgresManager(region=region) @@ -285,6 +294,16 @@ async def create_llm_token( detail="Team not found" ) + if owner.team_id or team_id: + if settings.ENABLE_LIMITS: + check_key_limits(db, owner.team_id or team_id, owner_id) + # Limits are conditionally applied in LiteLLM service + days_left_in_period, max_max_spend, max_rpm_limit = get_token_restrictions(db, owner.team_id or team_id) + else: # Super system users... + days_left_in_period = DEFAULT_KEY_DURATION + max_max_spend = DEFAULT_MAX_SPEND + max_rpm_limit = DEFAULT_RPM_PER_KEY + if team is not None: owner_email = team.admin_email litellm_team = team.id @@ -302,7 +321,10 @@ async def create_llm_token( email=owner_email, name=private_ai_key.name, user_id=owner_id, - team_id=f"{region.name.replace(' ', '_')}_{litellm_team}" + team_id=f"{region.name.replace(' ', '_')}_{litellm_team}", + duration=f"{days_left_in_period}d", + max_budget=max_max_spend, + rpm_limit=max_rpm_limit ) # Create response object @@ -390,6 +412,67 @@ async def list_private_ai_keys( private_ai_keys = query.all() return [key.to_dict() for key in private_ai_keys] +@router.get("/{key_id}", response_model=PrivateAIKeyDetail, dependencies=[Depends(check_system_admin)]) +async def get_private_ai_key( + key_id: int, + current_user = Depends(get_current_user_from_auth), + db: Session = Depends(get_db) +): + """ + Get details of a specific private AI key. + + This endpoint will: + 1. Verify the user has access to the key + 2. Return the full details of the key including LiteLLM-specific data + + Required parameters: + - **key_id**: The ID of the private AI key to retrieve + + The response will include: + - Database connection details (host, database name, username, password) + - LiteLLM API token for authentication + - LiteLLM API URL for making requests + - Owner and team information + - Region information + - LiteLLM-specific data (spend, duration, budget, etc.) + + Note: You must be authenticated to use this endpoint. + Only system administrators can access this endpoint. + """ + private_ai_key = _get_key_if_allowed(key_id, current_user, "system_admin", db) + + # Get the region + region = db.query(DBRegion).filter(DBRegion.id == private_ai_key.region_id).first() + if not region: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Region not found" + ) + + # Create LiteLLM service instance + litellm_service = LiteLLMService( + api_url=region.litellm_api_url, + api_key=region.litellm_api_key + ) + + try: + # Get LiteLLM key info + litellm_data = await litellm_service.get_key_info(private_ai_key.litellm_token) + info = litellm_data.get("info", {}) + logger.info(f"LiteLLM key info: {info}") + + # Combine database key info with LiteLLM info + key_data = private_ai_key.to_dict() + key_data.update(info) + + return PrivateAIKeyDetail.model_validate(key_data) + except Exception as e: + logger.error(f"Failed to get Private AI Key details: {str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to get Private AI Key details: {str(e)}" + ) + def _get_key_if_allowed(key_id: int, current_user: DBUser, user_role: UserRole, db: Session) -> DBPrivateAIKey: # First try to find the key private_ai_key = db.query(DBPrivateAIKey).filter( @@ -571,3 +654,56 @@ async def update_budget_period( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update budget period: {str(e)}" ) + +@router.put("/{key_id}/extend-token-life") +async def extend_token_life( + key_id: int, + duration_update: TokenDurationUpdate, + current_user = Depends(get_current_user_from_auth), + user_role: UserRole = Depends(get_role_min_team_admin), + db: Session = Depends(get_db) +): + """ + Extend the life of a private AI key. + + This endpoint will: + 1. Verify the user has access to the key + 2. Update the key's duration in LiteLLM + 3. Return the updated key information + + Required parameters: + - **duration**: The amount of time to add to the key's life (e.g. "30d" for 30 days, "1y" for 1 year) + + Note: You must be authenticated to use this endpoint. + Only the owner of the key or an admin can update it. + """ + private_ai_key = _get_key_if_allowed(key_id, current_user, user_role, db) + + # Get the region + region = db.query(DBRegion).filter(DBRegion.id == private_ai_key.region_id).first() + if not region: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Region not found" + ) + + litellm_service = LiteLLMService( + api_url=region.litellm_api_url, + api_key=region.litellm_api_key + ) + + try: + # Update key duration in LiteLLM + await litellm_service.update_key_duration( + litellm_token=private_ai_key.litellm_token, + duration=duration_update.duration + ) + + # Get updated key information + key_data = await litellm_service.get_key_info(private_ai_key.litellm_token) + return key_data + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to extend token life: {str(e)}" + ) diff --git a/app/api/products.py b/app/api/products.py new file mode 100644 index 0000000..68475f8 --- /dev/null +++ b/app/api/products.py @@ -0,0 +1,161 @@ +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session +from typing import List, Optional +from datetime import datetime, UTC + +from app.db.database import get_db +from app.db.models import DBProduct, DBTeamProduct, DBTeam +from app.core.security import check_system_admin, get_current_user_from_auth, get_role_min_team_admin +from app.schemas.models import Product, ProductCreate, ProductUpdate + +router = APIRouter( + tags=["products"] +) + +@router.post("", response_model=Product, status_code=status.HTTP_201_CREATED, dependencies=[Depends(check_system_admin)]) +@router.post("/", response_model=Product, status_code=status.HTTP_201_CREATED, dependencies=[Depends(check_system_admin)]) +async def create_product( + product: ProductCreate, + db: Session = Depends(get_db) +): + """ + Create a new product. Only accessible by system admin users. + """ + # Check if product ID already exists + existing_product = db.query(DBProduct).filter(DBProduct.id == product.id).first() + if existing_product: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Product with this ID already exists" + ) + + # Create the product with all fields + db_product = DBProduct( + id=product.id, + name=product.name, + user_count=product.user_count, + keys_per_user=product.keys_per_user, + total_key_count=product.total_key_count, + service_key_count=product.service_key_count, + max_budget_per_key=product.max_budget_per_key, + rpm_per_key=product.rpm_per_key, + vector_db_count=product.vector_db_count, + vector_db_storage=product.vector_db_storage, + renewal_period_days=product.renewal_period_days, + active=product.active, + created_at=datetime.now(UTC) + ) + + db.add(db_product) + db.commit() + db.refresh(db_product) + + return db_product + +@router.get("", response_model=List[Product], dependencies=[Depends(get_role_min_team_admin)]) +@router.get("/", response_model=List[Product], dependencies=[Depends(get_role_min_team_admin)]) +async def list_products( + team_id: Optional[int] = None, + db: Session = Depends(get_db), + current_user = Depends(get_current_user_from_auth) +): + """ + List all products. Only accessible by team admin users or higher privileges. + If team_id is provided, only returns products associated with that team. + Team admins can only view products for their own team. + """ + # If team_id is provided, verify the user has access to that team + if team_id is not None: + # First check if the team exists + team = db.query(DBTeam).filter(DBTeam.id == team_id).first() + if not team: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Team not found" + ) + + # System admins can view any team's products + if not current_user.is_admin: + # Team admins can only view their own team's products + if current_user.team_id != team_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You can only view products for your own team" + ) + + # Get products associated with the team + return db.query(DBProduct).join(DBTeamProduct).filter(DBTeamProduct.team_id == team_id).all() + + # If no team_id provided, return all products + return db.query(DBProduct).all() + +@router.get("/{product_id}", response_model=Product, dependencies=[Depends(get_role_min_team_admin)]) +async def get_product( + product_id: str, + db: Session = Depends(get_db) +): + """ + Get a specific product by ID. Only accessible by team admin users or higher privileges. + """ + product = db.query(DBProduct).filter(DBProduct.id == product_id).first() + if not product: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Product not found" + ) + return product + +@router.put("/{product_id}", response_model=Product, dependencies=[Depends(check_system_admin)]) +async def update_product( + product_id: str, + product_update: ProductUpdate, + db: Session = Depends(get_db) +): + """ + Update a product. Only accessible by system admin users. + """ + product = db.query(DBProduct).filter(DBProduct.id == product_id).first() + if not product: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Product not found" + ) + + # Update the product with all provided fields + update_data = product_update.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(product, key, value) + + product.updated_at = datetime.now(UTC) + db.commit() + db.refresh(product) + + return product + +@router.delete("/{product_id}", dependencies=[Depends(check_system_admin)]) +async def delete_product( + product_id: str, + db: Session = Depends(get_db) +): + """ + Delete a product. Only accessible by system admin users. + """ + product = db.query(DBProduct).filter(DBProduct.id == product_id).first() + if not product: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Product not found" + ) + + # Check if the product is associated with any teams + team_association = db.query(DBTeamProduct).filter(DBTeamProduct.product_id == product_id).first() + if team_association: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cannot delete product that is associated with one or more teams" + ) + + db.delete(product) + db.commit() + + return {"message": "Product deleted successfully"} \ No newline at end of file diff --git a/app/api/regions.py b/app/api/regions.py index 19f03f9..d96bde6 100644 --- a/app/api/regions.py +++ b/app/api/regions.py @@ -7,7 +7,9 @@ from app.schemas.models import Region, RegionCreate, RegionResponse, User, RegionUpdate from app.db.models import DBRegion, DBPrivateAIKey -router = APIRouter() +router = APIRouter( + tags=["regions"] +) @router.post("", response_model=Region) @router.post("/", response_model=Region) diff --git a/app/api/teams.py b/app/api/teams.py index 7fda28a..75c1632 100644 --- a/app/api/teams.py +++ b/app/api/teams.py @@ -11,7 +11,9 @@ TeamWithUsers ) -router = APIRouter() +router = APIRouter( + tags=["teams"] +) @router.post("", response_model=Team, status_code=status.HTTP_201_CREATED) @router.post("/", response_model=Team, status_code=status.HTTP_201_CREATED) diff --git a/app/api/users.py b/app/api/users.py index c8208c4..c78ca3f 100644 --- a/app/api/users.py +++ b/app/api/users.py @@ -1,14 +1,17 @@ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session from typing import List, get_args - +from app.core.config import settings +from app.core.resource_limits import check_team_user_limit from app.db.database import get_db from app.schemas.models import User, UserUpdate, UserCreate, TeamOperation, UserRoleUpdate from app.db.models import DBUser, DBTeam from app.core.security import get_password_hash, check_system_admin, get_current_user_from_auth, UserRole, get_role_min_team_admin from datetime import datetime, UTC -router = APIRouter() +router = APIRouter( + tags=["users"] +) @router.get("/search", response_model=List[User], dependencies=[Depends(check_system_admin)]) async def search_users( @@ -70,6 +73,9 @@ async def create_user( detail="Not authorized to perform this action" ) + if settings.ENABLE_LIMITS and user.team_id is not None: + check_team_user_limit(db, user.team_id) + # Validate role if provided if user.role and user.role not in get_args(UserRole): raise HTTPException( diff --git a/app/core/config.py b/app/core/config.py index 260d849..8864b85 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -21,6 +21,17 @@ class Settings(BaseSettings): ALLOWED_HOSTS: list[str] = ["*"] # In production, restrict this PUBLIC_PATHS: list[str] = ["/health", "/docs", "/openapi.json", "/metrics"] + AWS_ACCESS_KEY_ID: str = "AKIATEST" + AWS_SECRET_ACCESS_KEY: str = "sk-string" + SES_SENDER_EMAIL: str = "info@example.com" + PASSWORDLESS_SIGN_IN: str = "true" + ENV_SUFFIX: str = os.getenv("ENV_SUFFIX", "local") + DYNAMODB_REGION: str = "eu-west-1" + SES_REGION: str = "eu-west-1" + ENABLE_LIMITS: bool = os.getenv("ENABLE_LIMITS", "false") == "true" + STRIPE_SECRET_KEY: str = os.getenv("STRIPE_SECRET_KEY", "sk_test_string") + WEBHOOK_SIG: str = os.getenv("WEBHOOK_SIG", "whsec_test_1234567890") + model_config = ConfigDict(env_file=".env") def model_post_init(self, values): @@ -28,4 +39,4 @@ def model_post_init(self, values): lagoon_routes = os.getenv("LAGOON_ROUTES", "").split(",") self.CORS_ORIGINS.extend([route.strip() for route in lagoon_routes if route.strip()]) -settings = Settings() \ No newline at end of file +settings = Settings() diff --git a/app/core/resource_limits.py b/app/core/resource_limits.py new file mode 100644 index 0000000..58f3ea1 --- /dev/null +++ b/app/core/resource_limits.py @@ -0,0 +1,199 @@ +from sqlalchemy.orm import Session +from sqlalchemy import func, or_ +from app.db.models import DBTeam, DBUser, DBPrivateAIKey, DBTeamProduct, DBProduct +from fastapi import HTTPException, status +from typing import Optional +from datetime import datetime, UTC +import logging + +logger = logging.getLogger(__name__) + +# Default limits across all customers and products +DEFAULT_USER_COUNT = 1 +DEFAULT_KEYS_PER_USER = 1 +DEFAULT_TOTAL_KEYS = 2 +DEFAULT_SERVICE_KEYS = 1 +DEFAULT_VECTOR_DB_COUNT = 1 +DEFAULT_KEY_DURATION = 30 +DEFAULT_MAX_SPEND = 20.0 +DEFAULT_RPM_PER_KEY = 500 + +def check_team_user_limit(db: Session, team_id: int) -> None: + """ + Check if adding a user would exceed the team's product limits. + Raises HTTPException if the limit would be exceeded. + + Args: + db: Database session + team_id: ID of the team to check + """ + # Get current user count and max allowed users in a single query + result = db.query( + func.count(DBUser.id).label('current_user_count'), + func.coalesce(func.max(DBProduct.user_count), DEFAULT_USER_COUNT).label('max_users') + ).select_from(DBUser).filter( + DBUser.team_id == team_id + ).outerjoin( + DBTeamProduct, + DBTeamProduct.team_id == team_id + ).outerjoin( + DBProduct, + DBProduct.id == DBTeamProduct.product_id + ).first() + + if not result: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Team not found") + + if result.current_user_count >= result.max_users: + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, + detail=f"Team has reached the maximum user limit of {result.max_users} users" + ) + +def check_key_limits(db: Session, team_id: int, owner_id: Optional[int] = None) -> None: + """ + Check if creating a new LLM token would exceed the team's or user's key limits. + Raises HTTPException if any limit would be exceeded. + + Args: + db: Database session + team_id: ID of the team to check + owner_id: Optional ID of the user who will own the key + """ + # Get all limits and current counts in a single query + result = db.query( + func.coalesce(func.max(DBProduct.total_key_count), DEFAULT_TOTAL_KEYS).label('max_total_keys'), + func.coalesce(func.max(DBProduct.keys_per_user), DEFAULT_KEYS_PER_USER).label('max_keys_per_user'), + func.coalesce(func.max(DBProduct.service_key_count), DEFAULT_SERVICE_KEYS).label('max_service_keys'), + func.count(DBPrivateAIKey.id).filter( + DBPrivateAIKey.litellm_token.isnot(None) + ).label('current_team_keys'), + func.count(DBPrivateAIKey.id).filter( + DBPrivateAIKey.owner_id == owner_id, + DBPrivateAIKey.litellm_token.isnot(None) + ).label('current_user_keys') if owner_id else None, + func.count(DBPrivateAIKey.id).filter( + DBPrivateAIKey.owner_id.is_(None), + DBPrivateAIKey.litellm_token.isnot(None) + ).label('current_service_keys') + ).select_from(DBTeam).filter( # Have to use Teams table as the base because not every team has a product + DBTeam.id == team_id + ).outerjoin( + DBTeamProduct, + DBTeamProduct.team_id == DBTeam.id + ).outerjoin( + DBProduct, + DBProduct.id == DBTeamProduct.product_id + ).outerjoin( + DBPrivateAIKey, + or_( + DBPrivateAIKey.team_id == DBTeam.id, + DBPrivateAIKey.owner_id.in_( + db.query(DBUser.id).filter(DBUser.team_id == DBTeam.id) + ) + ) + ).first() + + if not result: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Team not found") + + if result.current_team_keys >= result.max_total_keys: + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, + detail=f"Team has reached the maximum LLM token limit of {result.max_total_keys} tokens" + ) + + if owner_id is not None and result.current_user_keys >= result.max_keys_per_user: + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, + detail=f"User has reached the maximum LLM token limit of {result.max_keys_per_user} tokens" + ) + + if owner_id is None and result.current_service_keys >= result.max_service_keys: + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, + detail=f"Team has reached the maximum service LLM token limit of {result.max_service_keys} tokens" + ) + +def check_vector_db_limits(db: Session, team_id: int) -> None: + """ + Check if creating a new vector DB would exceed the team's vector DB limits. + Raises HTTPException if the limit would be exceeded. + + Args: + db: Database session + team_id: ID of the team to check + """ + # Get vector DB limits and current count in a single query + result = db.query( + func.coalesce(func.max(DBProduct.vector_db_count), DEFAULT_VECTOR_DB_COUNT).label('max_vector_db_count'), + func.count(DBPrivateAIKey.id).filter( + DBPrivateAIKey.database_name.isnot(None) + ).label('current_vector_db_count') + ).select_from(DBTeam).filter( + DBTeam.id == team_id + ).outerjoin( + DBTeamProduct, + DBTeamProduct.team_id == DBTeam.id + ).outerjoin( + DBProduct, + DBProduct.id == DBTeamProduct.product_id + ).outerjoin( + DBPrivateAIKey, + or_( + DBPrivateAIKey.team_id == DBTeam.id, + DBPrivateAIKey.owner_id.in_( + db.query(DBUser.id).filter(DBUser.team_id == DBTeam.id) + ) + ) + ).first() + + if not result: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Team not found") + + if result.current_vector_db_count >= result.max_vector_db_count: + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, + detail=f"Team has reached the maximum vector DB limit of {result.max_vector_db_count} databases" + ) + +def get_token_restrictions(db: Session, team_id: int) -> tuple[int, float, int]: + """ + Get the token restrictions for a team. + """ + # Get all token restrictions in a single query + result = db.query( + func.coalesce(func.max(DBProduct.renewal_period_days), DEFAULT_KEY_DURATION).label('max_key_duration'), + func.coalesce(func.max(DBProduct.max_budget_per_key), DEFAULT_MAX_SPEND).label('max_max_spend'), + func.coalesce(func.max(DBProduct.rpm_per_key), DEFAULT_RPM_PER_KEY).label('max_rpm_limit'), + DBTeam.created_at, + DBTeam.last_payment + ).select_from(DBTeam).filter( + DBTeam.id == team_id + ).outerjoin( + DBTeamProduct, + DBTeamProduct.team_id == DBTeam.id + ).outerjoin( + DBProduct, + DBProduct.id == DBTeamProduct.product_id + ).group_by( + DBTeam.created_at, + DBTeam.last_payment + ).first() + + if not result: + logger.error(f"Team not found for team_id: {team_id}") + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Team not found") + + if result.last_payment is None: + days_left_in_period = result.max_key_duration + else: + days_left_in_period = result.max_key_duration - ( + datetime.now(UTC) - max(result.created_at.replace(tzinfo=UTC), result.last_payment.replace(tzinfo=UTC)) + ).days + + return days_left_in_period, result.max_max_spend, result.max_rpm_limit + +def get_team_limits(db: Session, team_id: int): + # TODO: Go through all products, and create a master list of the limits on all fields for this team. + pass \ No newline at end of file diff --git a/app/core/worker.py b/app/core/worker.py new file mode 100644 index 0000000..89252c1 --- /dev/null +++ b/app/core/worker.py @@ -0,0 +1,190 @@ +from datetime import datetime, UTC +from sqlalchemy.orm import Session +from app.db.models import DBTeam, DBProduct, DBTeamProduct, DBPrivateAIKey, DBUser +from app.services.litellm import LiteLLMService +import logging +from collections import defaultdict +from app.core.resource_limits import get_token_restrictions +from app.services.stripe import ( + get_product_id_from_session, + get_product_id_from_subscription, + KNOWN_EVENTS, + SUBSCRIPTION_SUCCESS_EVENTS, + SESSION_FAILURE_EVENTS, + SUBSCRIPTION_FAILURE_EVENTS, + INVOICE_FAILURE_EVENTS, + INVOICE_SUCCESS_EVENTS +) + +logger = logging.getLogger(__name__) + +async def handle_stripe_event_background(event, db: Session): + """ + Background task to handle Stripe webhook events. + This runs in a separate thread to avoid blocking the webhook response. + """ + try: + event_type = event.type + if not event_type in KNOWN_EVENTS: + logger.info(f"Unknown event type: {event_type}") + return + event_object = event.data.object + customer_id = event_object.customer + if not customer_id: + logger.warning(f"No customer ID found in event, cannot complete processing") + return + # Success Events + if event_type in SUBSCRIPTION_SUCCESS_EVENTS: + # new subscription + product_id = await get_product_id_from_subscription(event_object.id) + start_date = datetime.fromtimestamp(event_object.start_date, tz=UTC) + await apply_product_for_team(db, customer_id, product_id, start_date) + elif event_type in INVOICE_SUCCESS_EVENTS: + # subscription renewed + subscription = event_object.parent.subscription_details.subscription + product_id = await get_product_id_from_subscription(subscription) + start_date = datetime.fromtimestamp(event_object.period_start, tz=UTC) + await apply_product_for_team(db, customer_id, product_id, start_date) + # Failure Events + elif event_type in SESSION_FAILURE_EVENTS: + product_id = await get_product_id_from_session(event_object.id) + await remove_product_from_team(db, customer_id, product_id) + elif event_type in SUBSCRIPTION_FAILURE_EVENTS: + product_id = await get_product_id_from_subscription(event_object.id) + await remove_product_from_team(db, customer_id, product_id) + elif event_type in INVOICE_FAILURE_EVENTS: + # We assume that the invoice is related to a subscription + subscription = event_object.parent.subscription_details.subscription + product_id = await get_product_id_from_subscription(subscription) + await remove_product_from_team(db, customer_id, product_id) + except Exception as e: + logger.error(f"Error in background event handler: {str(e)}") + + +async def apply_product_for_team(db: Session, customer_id: str, product_id: str, start_date: datetime): + """ + Apply a product to a team and update their last payment date. + Also extends all team keys and sets their max budgets via LiteLLM service. + + Args: + db: Database session + customer_id: Stripe customer ID + product_id: Product ID from the database + + Returns: + bool: True if update was successful, False otherwise + """ + logger.info(f"Applying product {product_id} to team {customer_id}") + try: + # Find the team and product + team = db.query(DBTeam).filter(DBTeam.stripe_customer_id == customer_id).first() + product = db.query(DBProduct).filter(DBProduct.id == product_id).first() + + if not team: + logger.error(f"Team not found for customer ID: {customer_id}") + return + if not product: + logger.error(f"Product not found for ID: {product_id}") + return + + # Update the last payment date + team.last_payment = start_date + + # Check if the product is already active for the team + existing_association = db.query(DBTeamProduct).filter( + DBTeamProduct.team_id == team.id, + DBTeamProduct.product_id == product.id + ).first() + + # Only create new association if it doesn't exist + if not existing_association: + team_product = DBTeamProduct( + team_id=team.id, + product_id=product.id + ) + db.add(team_product) + db.commit() # Commit the product association + + days_left_in_period, max_max_spend, max_rpm_limit = get_token_restrictions(db, team.id) + + # Get all keys for the team with their regions + team_users = db.query(DBUser).filter(DBUser.team_id == team.id).all() + team_user_ids = [user.id for user in team_users] + # Return keys owned by users in the team OR owned by the team + team_keys = db.query(DBPrivateAIKey).filter( + (DBPrivateAIKey.owner_id.in_(team_user_ids)) | + (DBPrivateAIKey.team_id == team.id) + ).all() + + # Group keys by region + keys_by_region = defaultdict(list) + for key in team_keys: + if not key.litellm_token: + logger.warning(f"Key {key.id} has no LiteLLM token, skipping") + continue + if not key.region: + logger.warning(f"Key {key.id} has no region, skipping") + continue + keys_by_region[key.region].append(key) + + # Update keys for each region + for region, keys in keys_by_region.items(): + # Initialize LiteLLM service for this region + litellm_service = LiteLLMService( + api_url=region.litellm_api_url, + api_key=region.litellm_api_key + ) + + # Update each key's duration and budget via LiteLLM + for key in keys: + try: + await litellm_service.set_key_restrictions( + litellm_token=key.litellm_token, + duration=f"{days_left_in_period}d", + budget_duration=f"{days_left_in_period}d", + budget_amount=max_max_spend, + rpm_limit=max_rpm_limit + ) + except Exception as e: + logger.error(f"Failed to update key {key.id} via LiteLLM: {str(e)}") + # Continue with other keys even if one fails + continue + + db.commit() + + except Exception as e: + db.rollback() + logger.error(f"Error applying product to team: {str(e)}") + raise e + +async def remove_product_from_team(db: Session, customer_id: str, product_id: str): + logger.info(f"Removing product {product_id} from team {customer_id}") + try: + # Find the team and product + team = db.query(DBTeam).filter(DBTeam.stripe_customer_id == customer_id).first() + product = db.query(DBProduct).filter(DBProduct.id == product_id).first() + + if not team: + logger.error(f"Team not found for customer ID: {customer_id}") + return + if not product: + logger.error(f"Product not found for ID: {product_id}") + return + # Check if the product is already active for the team + existing_association = db.query(DBTeamProduct).filter( + DBTeamProduct.team_id == team.id, + DBTeamProduct.product_id == product.id + ).first() + if not existing_association: + logger.error(f"Product {product_id} not found for team {customer_id}") + return + # Remove the product association + db.delete(existing_association) + + # TODO: Send notification + # TODO: Expire keys if applicable + db.commit() + except Exception as e: + db.rollback() + logger.error(f"Error removing product from team: {str(e)}") + raise e diff --git a/app/db/models.py b/app/db/models.py index e95fb86..a59cdda 100644 --- a/app/db/models.py +++ b/app/db/models.py @@ -1,10 +1,30 @@ -from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, DateTime, JSON +from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, DateTime, JSON, Float, Table from sqlalchemy.orm import relationship, declarative_base from datetime import datetime, UTC from sqlalchemy.sql import func Base = declarative_base() +class DBTeamProduct(Base): + """ + Association table for team-product relationship. + This model is required to implement a many-to-many relationship between teams and products. + It allows: + - Teams to subscribe to multiple products + - Products to be used by multiple teams + - Tracking when products were added to teams + - Maintaining referential integrity between teams and products + """ + __tablename__ = "team_products" + + team_id = Column(Integer, ForeignKey('teams.id', ondelete='CASCADE'), primary_key=True, nullable=False) + product_id = Column(String, ForeignKey('products.id', ondelete='CASCADE'), primary_key=True, nullable=False) + created_at = Column(DateTime(timezone=True), default=func.now(), nullable=False) + updated_at = Column(DateTime(timezone=True), onupdate=func.now(), nullable=True) + + team = relationship("DBTeam", back_populates="active_products") + product = relationship("DBProduct", back_populates="teams") + class DBRegion(Base): __tablename__ = "regions" @@ -62,9 +82,12 @@ class DBTeam(Base): is_active = Column(Boolean, default=True) created_at = Column(DateTime(timezone=True), default=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + stripe_customer_id = Column(String, nullable=True, unique=True, index=True) + last_payment = Column(DateTime(timezone=True), nullable=True) users = relationship("DBUser", back_populates="team") private_ai_keys = relationship("DBPrivateAIKey", back_populates="team") + active_products = relationship("DBTeamProduct", back_populates="team") class DBPrivateAIKey(Base): __tablename__ = "ai_tokens" @@ -118,4 +141,34 @@ class DBAuditLog(Base): user_agent = Column(String, nullable=True) request_source = Column(String, nullable=True) # Values: 'frontend', 'api', or None - user = relationship("DBUser", back_populates="audit_logs") \ No newline at end of file + user = relationship("DBUser", back_populates="audit_logs") + +class DBSystemSecret(Base): + __tablename__ = "system_secrets" + + id = Column(Integer, primary_key=True, index=True) + key = Column(String, unique=True, index=True, nullable=False) + value = Column(String, nullable=False) + description = Column(String, nullable=True) + created_at = Column(DateTime(timezone=True), default=func.now()) + updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + +class DBProduct(Base): + __tablename__ = "products" + + id = Column(String, primary_key=True, index=True) + name = Column(String, nullable=False) + user_count = Column(Integer, default=0) + keys_per_user = Column(Integer, default=0) + total_key_count = Column(Integer, default=0) + service_key_count = Column(Integer, default=0) + max_budget_per_key = Column(Float, default=0.0) + rpm_per_key = Column(Integer, default=0) + vector_db_count = Column(Integer, default=0) + vector_db_storage = Column(Integer, default=0) + renewal_period_days = Column(Integer, default=30) + active = Column(Boolean, default=True) + created_at = Column(DateTime(timezone=True), default=func.now()) + updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + + teams = relationship("DBTeamProduct", back_populates="product") \ No newline at end of file diff --git a/app/main.py b/app/main.py index c514898..df2029a 100644 --- a/app/main.py +++ b/app/main.py @@ -5,6 +5,13 @@ from fastapi.openapi.docs import get_swagger_ui_html from fastapi.openapi.utils import get_openapi from prometheus_fastapi_instrumentator import Instrumentator, metrics +from app.api import auth, private_ai_keys, users, regions, audit, teams, billing, products +from app.core.config import settings +from app.db.database import get_db +from app.middleware.audit import AuditLogMiddleware +from app.middleware.prometheus import PrometheusMiddleware +from app.middleware.auth import AuthMiddleware + import os import logging @@ -22,13 +29,6 @@ async def dispatch(self, request, call_next): request.scope["scheme"] = "https" return await call_next(request) -from app.api import auth, private_ai_keys, users, regions, audit, teams -from app.core.config import settings -from app.db.database import get_db -from app.middleware.audit import AuditLogMiddleware -from app.middleware.prometheus import PrometheusMiddleware -from app.middleware.auth import AuthMiddleware - app = FastAPI( title="Private AI Keys as a Service", description=""" @@ -132,6 +132,8 @@ async def health_check(): app.include_router(regions.router, prefix="/regions", tags=["regions"]) app.include_router(audit.router, prefix="/audit", tags=["audit"]) app.include_router(teams.router, prefix="/teams", tags=["teams"]) +app.include_router(billing.router, prefix="/billing", tags=["billing"]) +app.include_router(products.router, prefix="/products", tags=["products"]) @app.get("/", include_in_schema=False) async def custom_swagger_ui_html(): diff --git a/app/migrations/versions/20250516_090727_891e02a9ce6e_add_system_settings.py b/app/migrations/versions/20250516_090727_891e02a9ce6e_add_system_settings.py new file mode 100644 index 0000000..875def7 --- /dev/null +++ b/app/migrations/versions/20250516_090727_891e02a9ce6e_add_system_settings.py @@ -0,0 +1,42 @@ +"""Add system settings + +Revision ID: 891e02a9ce6e +Revises: 2bb3f48b650d +Create Date: 2025-05-16 09:07:27.786841+00:00 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '891e02a9ce6e' +down_revision: Union[str, None] = '2bb3f48b650d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('system_secrets', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('key', sa.String(), nullable=False), + sa.Column('value', sa.String(), nullable=False), + sa.Column('description', sa.String(), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_system_secrets_id'), 'system_secrets', ['id'], unique=False) + op.create_index(op.f('ix_system_secrets_key'), 'system_secrets', ['key'], unique=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_system_secrets_key'), table_name='system_secrets') + op.drop_index(op.f('ix_system_secrets_id'), table_name='system_secrets') + op.drop_table('system_secrets') + # ### end Alembic commands ### \ No newline at end of file diff --git a/app/migrations/versions/20250516_113604_1069b3617025_manage_products.py b/app/migrations/versions/20250516_113604_1069b3617025_manage_products.py new file mode 100644 index 0000000..18be89e --- /dev/null +++ b/app/migrations/versions/20250516_113604_1069b3617025_manage_products.py @@ -0,0 +1,43 @@ +"""manage products + +Revision ID: 1069b3617025 +Revises: 891e02a9ce6e +Create Date: 2025-05-16 11:36:04.879027+00:00 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = '1069b3617025' +down_revision: Union[str, None] = '891e02a9ce6e' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('products', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.Column('stripe_lookup_key', sa.String(), nullable=False), + sa.Column('active', sa.Boolean(), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_products_id'), 'products', ['id'], unique=False) + op.create_index(op.f('ix_products_stripe_lookup_key'), 'products', ['stripe_lookup_key'], unique=True) + op.add_column('teams', sa.Column('stripe_customer_id', sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('teams', 'stripe_customer_id') + op.drop_index(op.f('ix_products_stripe_lookup_key'), table_name='products') + op.drop_index(op.f('ix_products_id'), table_name='products') + op.drop_table('products') + # ### end Alembic commands ### \ No newline at end of file diff --git a/app/migrations/versions/20250520_112722_c2f3f1999f62_team_product_relationship.py b/app/migrations/versions/20250520_112722_c2f3f1999f62_team_product_relationship.py new file mode 100644 index 0000000..f1f1956 --- /dev/null +++ b/app/migrations/versions/20250520_112722_c2f3f1999f62_team_product_relationship.py @@ -0,0 +1,73 @@ +"""team product relationship + +Revision ID: c2f3f1999f62 +Revises: 1069b3617025 +Create Date: 2025-05-20 11:27:22.451900+00:00 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = 'c2f3f1999f62' +down_revision: Union[str, None] = '1069b3617025' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('products', sa.Column('user_count', sa.Integer(), nullable=True)) + op.add_column('products', sa.Column('keys_per_user', sa.Integer(), nullable=True)) + op.add_column('products', sa.Column('total_key_count', sa.Integer(), nullable=True)) + op.add_column('products', sa.Column('service_key_count', sa.Integer(), nullable=True)) + op.add_column('products', sa.Column('max_budget_per_key', sa.Float(), nullable=True)) + op.add_column('products', sa.Column('rpm_per_key', sa.Integer(), nullable=True)) + op.add_column('products', sa.Column('vector_db_count', sa.Integer(), nullable=True)) + op.add_column('products', sa.Column('vector_db_storage', sa.Integer(), nullable=True)) + op.add_column('products', sa.Column('renewal_period_days', sa.Integer(), nullable=True)) + op.alter_column('products', 'id', + existing_type=sa.INTEGER(), + type_=sa.String(), + existing_nullable=False) + op.drop_index('ix_products_stripe_lookup_key', table_name='products') + op.drop_column('products', 'stripe_lookup_key') + + op.create_table('team_products', + sa.Column('team_id', sa.Integer(), nullable=False), + sa.Column('product_id', sa.String(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint(['product_id'], ['products.id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['team_id'], ['teams.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('team_id', 'product_id') + ) + op.add_column('teams', sa.Column('last_payment', sa.DateTime(timezone=True), nullable=True)) + op.create_index(op.f('ix_teams_stripe_customer_id'), 'teams', ['stripe_customer_id'], unique=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_teams_stripe_customer_id'), table_name='teams') + op.drop_column('teams', 'last_payment') + op.add_column('products', sa.Column('stripe_lookup_key', sa.VARCHAR(), autoincrement=False, nullable=False)) + op.create_index('ix_products_stripe_lookup_key', 'products', ['stripe_lookup_key'], unique=True) + op.drop_table('team_products') + op.alter_column('products', 'id', + existing_type=sa.String(), + type_=sa.INTEGER(), + existing_nullable=False) + op.drop_column('products', 'renewal_period_days') + op.drop_column('products', 'vector_db_storage') + op.drop_column('products', 'vector_db_count') + op.drop_column('products', 'rpm_per_key') + op.drop_column('products', 'max_budget_per_key') + op.drop_column('products', 'service_key_count') + op.drop_column('products', 'total_key_count') + op.drop_column('products', 'keys_per_user') + op.drop_column('products', 'user_count') + # ### end Alembic commands ### \ No newline at end of file diff --git a/app/schemas/models.py b/app/schemas/models.py index b243c11..a7b9b72 100644 --- a/app/schemas/models.py +++ b/app/schemas/models.py @@ -147,9 +147,31 @@ class PrivateAIKey(PrivateAIKeyBase): team_id: Optional[int] = None model_config = ConfigDict(from_attributes=True) +class PrivateAIKeyDetail(PrivateAIKey): + spend: Optional[float] = None + key_name: Optional[str] = None + key_alias: Optional[str] = None + soft_budget_cooldown: Optional[bool] = None + models: Optional[List[str]] = None + max_parallel_requests: Optional[int] = None + tpm_limit: Optional[int] = None + rpm_limit: Optional[int] = None + max_budget: Optional[float] = None + budget_duration: Optional[str] = None + budget_reset_at: Optional[datetime] = None + expires: Optional[datetime] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + metadata: Optional[dict] = None + model_config = ConfigDict(from_attributes=True) + class BudgetPeriodUpdate(BaseModel): budget_duration: str +class TokenDurationUpdate(BaseModel): + """Schema for updating a token's duration""" + duration: str # e.g. "30d" for 30 days, "1y" for 1 year + class PrivateAIKeySpend(BaseModel): spend: float expires: datetime @@ -246,4 +268,47 @@ class UserRoleUpdate(BaseModel): class SignInData(BaseModel): username: EmailStr - verification_code: str \ No newline at end of file + verification_code: str + +class CheckoutSessionCreate(BaseModel): + price_lookup_token: str + +class ProductBase(BaseModel): + name: str + id: str # This is the Stripe product ID, format should be prod_XXX + user_count: Optional[int] = 1 + keys_per_user: Optional[int] = 1 + total_key_count: Optional[int] = 6 + service_key_count: Optional[int] = 5 + max_budget_per_key: Optional[float] = 20.0 + rpm_per_key: Optional[int] = 500 + vector_db_count: Optional[int] = 1 + vector_db_storage: Optional[int] = 50 # Not used yet, should be a number in GiB + renewal_period_days: int = 30 + active: bool = True + +class ProductCreate(ProductBase): + pass + +class ProductUpdate(BaseModel): + name: Optional[str] = None + user_count: Optional[int] = None + keys_per_user: Optional[int] = None + total_key_count: Optional[int] = None + service_key_count: Optional[int] = None + max_budget_per_key: Optional[float] = None + rpm_per_key: Optional[int] = None + vector_db_count: Optional[int] = None + vector_db_storage: Optional[int] = None + renewal_period_days: Optional[int] = None + active: Optional[bool] = None + model_config = ConfigDict(from_attributes=True) + +class Product(ProductBase): + created_at: datetime + updated_at: Optional[datetime] = None + model_config = ConfigDict(from_attributes=True) + +class PricingTableSession(BaseModel): + client_secret: str + model_config = ConfigDict(from_attributes=True) \ No newline at end of file diff --git a/app/services/litellm.py b/app/services/litellm.py index a31a9fa..4780b08 100644 --- a/app/services/litellm.py +++ b/app/services/litellm.py @@ -1,7 +1,9 @@ import requests from fastapi import HTTPException, status -import os import logging +from app.core.resource_limits import DEFAULT_KEY_DURATION, DEFAULT_MAX_SPEND, DEFAULT_RPM_PER_KEY +from app.core.config import settings +from typing import Optional logger = logging.getLogger(__name__) @@ -15,12 +17,11 @@ def __init__(self, api_url: str, api_key: str): if not self.master_key: raise ValueError("LiteLLM API key is required") - async def create_key(self, email: str, name: str, user_id: int, team_id: str) -> str: + async def create_key(self, email: str, name: str, user_id: int, team_id: str, duration: str = f"{DEFAULT_KEY_DURATION}d", max_budget: float = DEFAULT_MAX_SPEND, rpm_limit: int = DEFAULT_RPM_PER_KEY) -> str: """Create a new API key for LiteLLM""" try: logger.info(f"Creating new LiteLLM API key for email: {email}, name: {name}, user_id: {user_id}, team_id: {team_id}") request_data = { - "duration": "8760h", # Set token duration to 1 year (365 days * 24 hours) "models": ["all-team-models"], # Allow access to all models "aliases": {}, "config": {}, @@ -39,6 +40,15 @@ async def create_key(self, email: str, name: str, user_id: int, team_id: str) -> request_data["key_alias"] = key_alias request_data["metadata"] = metadata request_data["team_id"] = team_id + + if settings.ENABLE_LIMITS: + request_data["duration"] = duration + request_data["budget_duration"] = duration + request_data["max_budget"] = max_budget + request_data["rpm_limit"] = rpm_limit + else: + request_data["duration"] = "365d" + if user_id is not None: request_data["user_id"] = str(user_id) @@ -114,10 +124,41 @@ async def get_key_info(self, litellm_token: str) -> dict: detail=f"Failed to get LiteLLM key information: {error_msg}" ) - async def update_budget(self, litellm_token: str, budget_duration: str): + async def update_budget(self, litellm_token: str, budget_duration: str, budget_amount: Optional[float] = None): """Update the budget for a LiteLLM API key""" try: # Update budget period in LiteLLM + request_data = { + "key": litellm_token, + "budget_duration": budget_duration + } + if budget_amount: + request_data["max_budget"] = budget_amount + + response = requests.post( + f"{self.api_url}/key/update", + headers={ + "Authorization": f"Bearer {self.master_key}" + }, + json=request_data + ) + response.raise_for_status() + except requests.exceptions.RequestException as e: + error_msg = str(e) + if hasattr(e, 'response') and e.response is not None: + try: + error_details = e.response.json() + error_msg = f"Status {e.response.status_code}: {error_details}" + except ValueError: + error_msg = f"Status {e.response.status_code}: {e.response.text}" + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to update LiteLLM budget: {error_msg}" + ) + + async def update_key_duration(self, litellm_token: str, duration: str): + """Update the duration for a LiteLLM API key""" + try: response = requests.post( f"{self.api_url}/key/update", headers={ @@ -125,7 +166,7 @@ async def update_budget(self, litellm_token: str, budget_duration: str): }, json={ "key": litellm_token, - "budget_duration": budget_duration + "duration": duration } ) response.raise_for_status() @@ -139,5 +180,35 @@ async def update_budget(self, litellm_token: str, budget_duration: str): error_msg = f"Status {e.response.status_code}: {e.response.text}" raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to update LiteLLM budget: {error_msg}" + detail=f"Failed to update LiteLLM key duration: {error_msg}" + ) + + async def set_key_restrictions(self, litellm_token: str, duration: str, budget_amount: float, rpm_limit: int, budget_duration: Optional[str] = None): + """Set the restrictions for a LiteLLM API key""" + try: + response = requests.post( + f"{self.api_url}/key/update", + headers={ + "Authorization": f"Bearer {self.master_key}" + }, + json={ + "key": litellm_token, + "duration": duration, + "budget_duration": budget_duration, + "max_budget": budget_amount, + "rpm_limit": rpm_limit + } + ) + response.raise_for_status() + except requests.exceptions.RequestException as e: + error_msg = str(e) + if hasattr(e, 'response') and e.response is not None: + try: + error_details = e.response.json() + error_msg = f"Status {e.response.status_code}: {error_details}" + except ValueError: + error_msg = f"Status {e.response.status_code}: {e.response.text}" + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to set LiteLLM key restrictions: {error_msg}" ) diff --git a/app/services/stripe.py b/app/services/stripe.py new file mode 100644 index 0000000..8c36df8 --- /dev/null +++ b/app/services/stripe.py @@ -0,0 +1,257 @@ +import stripe +import os +import logging +from urllib.parse import urljoin +from fastapi import HTTPException, status +from sqlalchemy.orm import Session +from app.db.models import DBTeam, DBSystemSecret + +# Configure logger +logger = logging.getLogger(__name__) + +# Initialize Stripe +stripe.api_key = os.getenv("STRIPE_SECRET_KEY") + +# Full list of possible events: https://docs.stripe.com/api/events/types +INVOICE_SUCCESS_EVENTS = ["invoice.paid"] # Renewal +SUBSCRIPTION_SUCCESS_EVENTS = ["customer.subscription.resumed", "customer.subscription.created"] # New subscription +SESSION_FAILURE_EVENTS = ["checkout.session.async_payment_failed", "checkout.session.expired"] # Checkout failure +SUBSCRIPTION_FAILURE_EVENTS = ["customer.subscription.deleted", "customer.subscription.paused"] # Subscription failure +INVOICE_FAILURE_EVENTS = ["invoice.payment_failed"] # Invoice failure + +SUCCESS_EVENTS = INVOICE_SUCCESS_EVENTS + SUBSCRIPTION_SUCCESS_EVENTS +FAILURE_EVENTS = SESSION_FAILURE_EVENTS + SUBSCRIPTION_FAILURE_EVENTS + INVOICE_FAILURE_EVENTS +KNOWN_EVENTS = SUCCESS_EVENTS + FAILURE_EVENTS + +def decode_stripe_event( payload: bytes, signature: str, webhook_secret: str) -> stripe.Event: + """ + Decode Stripe webhook events. + + Args: + payload: The raw request body + signature: The Stripe signature header + webhook_secret: The webhook signing secret + + Returns: + stripe.Event: The Stripe event + """ + try: + event = stripe.Webhook.construct_event( + payload, signature, webhook_secret + ) + logger.info(f"Decoded event of type: {event.type}") + return event + + # If the signature doesn't match, assume bad intent + except stripe.error.SignatureVerificationError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Not found" + ) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid payload" + ) + except Exception as e: + logger.error(f"Error handling Stripe event: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error processing webhook" + ) + +async def create_portal_session( + stripe_customer_id: str, + return_url: str +) -> str: + """ + Create a Stripe Customer Portal session for team subscription management. + + Args: + stripe_customer_id: The Stripe customer ID to create the portal session for + frontend_url: The frontend URL for return redirect + + Returns: + str: The portal session URL + """ + try: + # Create the portal session + portal_session = stripe.billing_portal.Session.create( + customer=stripe_customer_id, + return_url=return_url + ) + + return portal_session.url + except Exception as e: + logger.error(f"Error creating portal session: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error creating portal session" + ) + +async def setup_stripe_webhook(webhook_key: str, webhook_route: str, db: Session) -> None: + """ + Set up the Stripe webhook endpoint if it doesn't exist and store its signing secret. + + Args: + webhook_key: The key to store the webhook secret under + db: Database session + """ + try: + # Check if we already have a webhook secret stored + existing_secret = db.query(DBSystemSecret).filter( + DBSystemSecret.key == webhook_key + ).first() + + if existing_secret: + return + + # Get the base URL from environment + base_url = os.getenv("BACKEND_URL", "http://localhost:8800") + webhook_url = urljoin(base_url, webhook_route) + + # List existing webhook endpoints + endpoints = stripe.WebhookEndpoint.list() + + # Check if we already have an endpoint for this URL + existing_endpoint = None + for endpoint in endpoints.data: + if endpoint.url == webhook_url: + existing_endpoint = endpoint + break + + if existing_endpoint: + # For existing endpoints, we need to create a new one to get the secret + # First delete the old endpoint + stripe.WebhookEndpoint.delete(existing_endpoint.id) + logger.info(f"Deleted existing webhook endpoint: {existing_endpoint.id}") + + # Create new webhook endpoint + endpoint = stripe.WebhookEndpoint.create( + url=webhook_url, + enabled_events=KNOWN_EVENTS + ) + + # Store the signing secret + secret = DBSystemSecret( + key="stripe_webhook_secret", + value=endpoint.secret, + description="Stripe webhook signing secret for handling events" + ) + db.add(secret) + db.commit() + + except Exception as e: + logger.error(f"Error setting up Stripe webhook: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error setting up Stripe webhook, {str(e)}" + ) + +async def create_stripe_customer( + team: DBTeam +) -> str: + """ + Create a Stripe customer for a team. + + Args: + team: The team to create a Stripe customer for + + Returns: + str: The Stripe customer ID + + Raises: + HTTPException: If error creating customer + """ + try: + # Check if team already has a Stripe customer + if team.stripe_customer_id: + return team.stripe_customer_id + + # Create Stripe customer + customer = stripe.Customer.create( + email=team.admin_email, + name=team.name, + metadata={ + "team_id": team.id, + "team_name": team.name + } + ) + + return customer.id + + except Exception as e: + logger.error(f"Error creating Stripe customer: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error creating Stripe customer" + ) + +async def get_product_id_from_subscription(subscription_id: str) -> str: + """ + Get the Stripe product ID for the team's subscription. + + Args: + subscription_id: The Stripe subscription ID + + Returns: + str: The Stripe product ID + """ + # Get the list of subscription items + subscription_items = stripe.SubscriptionItem.list( + subscription=subscription_id, + expand=['data.price.product'] + ) + + if not subscription_items.data: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="No items found in subscription" + ) + + return subscription_items.data[0].price.product.id + +async def get_product_id_from_session(session_id: str) -> str: + """ + Get the Stripe product ID for the team's subscription from a checkout session. + + Args: + session_id: The Stripe checkout session ID + """ + line_items = stripe.checkout.Session.list_line_items(session_id) + return line_items.data[0].price.product + +async def get_customer_from_pi(payment_intent: str) -> str: + """ + Get the Stripe customer ID from a payment intent. + """ + payment_intent = stripe.PaymentIntent.retrieve(payment_intent) + logger.info(f"Payment intent is:\n{payment_intent}") + return payment_intent.customer + +async def get_pricing_table_secret(customer_id: str) -> str: + """ + Create a Stripe Customer Session client secret for a customer. + + Args: + customer_id: The Stripe customer ID to create the session for + + Returns: + str: The customer session client secret + """ + try: + # Create the customer session + session = stripe.CustomerSession.create( + customer=customer_id, + components={ + "pricing_table": {"enabled": True} + } + ) + + return session.client_secret + except Exception as e: + logger.error(f"Error creating customer session: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error creating customer session" + ) diff --git a/docs/design/StripeFlow.diagram b/docs/design/StripeFlow.diagram new file mode 100644 index 0000000..2006c62 --- /dev/null +++ b/docs/design/StripeFlow.diagram @@ -0,0 +1,22 @@ +actor #green:0.5 Lauren +actor #blue Customer +participant #green amazee.ai +rparticipant #red Stripe + +Lauren -> Stripe: Create Products & Prices +Lauren -> amazee.ai: createPricingOptions + +Customer -> amazee.ai: chooseProduct +amazee.ai -> Stripe: createCustomer +amazee.ai <-- Stripe: customerID +amazee.ai -> Stripe: createCheckoutSession +amazee.ai <-- Stripe: redirectURL +Customer <-- amazee.ai: redirectURL +Customer -> Stripe: makePayment +Stripe -> amazee.ai: paymentSucceeded +activate amazee.ai +amazee.ai -> amazee.ai: extendKey +amazee.ai -> amazee.ai: setPaymentDate +deactivateafter amazee.ai +Customer <-- amazee.ai: Success + diff --git a/frontend/src/app/admin/products/page.tsx b/frontend/src/app/admin/products/page.tsx new file mode 100644 index 0000000..2399499 --- /dev/null +++ b/frontend/src/app/admin/products/page.tsx @@ -0,0 +1,473 @@ +'use client'; + +import { useState } from 'react'; +import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'; +import { + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from "@/components/ui/table"; +import { Button } from "@/components/ui/button"; +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, + DialogTrigger, +} from "@/components/ui/dialog"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { useToast } from '@/hooks/use-toast'; +import { get, post, put, del } from '@/utils/api'; + +interface Product { + id: string; + name: string; + user_count: number; + keys_per_user: number; + total_key_count: number; + service_key_count: number; + max_budget_per_key: number; + rpm_per_key: number; + vector_db_count: number; + vector_db_storage: number; + renewal_period_days: number; + active: boolean; + created_at: string; +} + +export default function ProductsPage() { + const { toast } = useToast(); + const queryClient = useQueryClient(); + const [isCreateDialogOpen, setIsCreateDialogOpen] = useState(false); + const [isEditDialogOpen, setIsEditDialogOpen] = useState(false); + const [selectedProduct, setSelectedProduct] = useState(null); + const [formData, setFormData] = useState>({}); + + // Update form data + const updateFormData = (newData: Partial) => { + setFormData(newData); + }; + + // Queries + const { data: products = [] } = useQuery({ + queryKey: ['products'], + queryFn: async () => { + const response = await get('/products'); + return response.json(); + }, + }); + + // Mutations + const createProductMutation = useMutation({ + mutationFn: async (productData: Partial) => { + const response = await post('/products', productData); + return response.json(); + }, + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: ['products'] }); + setIsCreateDialogOpen(false); + setFormData({}); + toast({ + title: "Success", + description: "Product created successfully" + }); + }, + onError: (error: Error) => { + toast({ + variant: "destructive", + title: "Error", + description: error.message + }); + }, + }); + + const updateProductMutation = useMutation({ + mutationFn: async ({ id, data }: { id: string; data: Partial }) => { + const response = await put(`/products/${id}`, data); + return response.json(); + }, + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: ['products'] }); + setIsEditDialogOpen(false); + setSelectedProduct(null); + setFormData({}); + toast({ + title: "Success", + description: "Product updated successfully" + }); + }, + onError: (error: Error) => { + toast({ + variant: "destructive", + title: "Error", + description: error.message + }); + }, + }); + + const deleteProductMutation = useMutation({ + mutationFn: async (id: string) => { + await del(`/products/${id}`); + }, + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: ['products'] }); + toast({ + title: "Success", + description: "Product deleted successfully" + }); + }, + onError: (error: Error) => { + toast({ + variant: "destructive", + title: "Error", + description: error.message + }); + }, + }); + + const handleCreate = () => { + createProductMutation.mutate(formData); + }; + + const handleUpdate = () => { + if (!selectedProduct) return; + updateProductMutation.mutate({ id: selectedProduct.id, data: formData }); + }; + + const handleDelete = (id: string) => { + if (!confirm('Are you sure you want to delete this product?')) return; + deleteProductMutation.mutate(id); + }; + + return ( +
+
+

Product Management

+ + + + + + + Create New Product + +
+
+ + updateFormData({ ...formData, id: e.target.value })} + /> +
+
+ + updateFormData({ ...formData, name: e.target.value })} + /> +
+
+ + updateFormData({ ...formData, user_count: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, keys_per_user: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, total_key_count: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, service_key_count: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, max_budget_per_key: parseFloat(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, rpm_per_key: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, vector_db_count: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, vector_db_storage: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, renewal_period_days: parseInt(e.target.value) })} + /> +
+
+
+ updateFormData({ ...formData, active: e.target.checked })} + className="h-4 w-4 rounded border-gray-300" + /> + +
+
+
+
+ +
+
+
+
+ + + + + ID + Name + User Count + Keys/User + Total Keys + Service Keys + Budget/Key + RPM/Key + Vector DBs + Storage (GiB) + Renewal (Days) + Status + Created + Actions + + + + {products.map((product) => ( + + {product.id} + {product.name} + {product.user_count} + {product.keys_per_user} + {product.total_key_count} + {product.service_key_count} + ${product.max_budget_per_key.toFixed(2)} + {product.rpm_per_key} + {product.vector_db_count} + {product.vector_db_storage} + {product.renewal_period_days} + + + {product.active ? 'Active' : 'Inactive'} + + + {new Date(product.created_at).toLocaleDateString()} + +
+ + +
+
+
+ ))} +
+
+ + + + + Edit Product + +
+
+ + +
+
+ + updateFormData({ ...formData, name: e.target.value })} + /> +
+
+ + updateFormData({ ...formData, user_count: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, keys_per_user: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, total_key_count: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, service_key_count: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, max_budget_per_key: parseFloat(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, rpm_per_key: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, vector_db_count: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, vector_db_storage: parseInt(e.target.value) })} + /> +
+
+ + updateFormData({ ...formData, renewal_period_days: parseInt(e.target.value) })} + /> +
+
+
+ updateFormData({ ...formData, active: e.target.checked })} + className="h-4 w-4 rounded border-gray-300" + /> + +
+
+
+
+ +
+
+
+
+ ); +} \ No newline at end of file diff --git a/frontend/src/app/admin/teams/page.tsx b/frontend/src/app/admin/teams/page.tsx index e0d30da..78b42ff 100644 --- a/frontend/src/app/admin/teams/page.tsx +++ b/frontend/src/app/admin/teams/page.tsx @@ -55,6 +55,22 @@ interface TeamUser { role: string; } +interface Product { + id: string; + name: string; + user_count: number; + keys_per_user: number; + total_key_count: number; + service_key_count: number; + max_budget_per_key: number; + rpm_per_key: number; + vector_db_count: number; + vector_db_storage: number; + renewal_period_days: number; + active: boolean; + created_at: string; +} + interface Team { id: string; name: string; @@ -65,6 +81,7 @@ interface Team { created_at: string; updated_at: string; users?: TeamUser[]; + products?: Product[]; } interface User { @@ -120,6 +137,17 @@ export default function TeamsPage() { enabled: !!expandedTeamId, }); + // Get products for expanded team + const { data: teamProducts = [], isLoading: isLoadingTeamProducts } = useQuery({ + queryKey: ['team-products', expandedTeamId], + queryFn: async () => { + if (!expandedTeamId) return []; + const response = await get(`/products?team_id=${expandedTeamId}`); + return response.json(); + }, + enabled: !!expandedTeamId, + }); + // Search users query const searchUsersMutation = useMutation({ mutationFn: async (query: string) => { @@ -540,6 +568,7 @@ export default function TeamsPage() { Team Details Users + Products @@ -659,6 +688,60 @@ export default function TeamsPage() { )} + +
+ {isLoadingTeamProducts ? ( +
+ +
+ ) : teamProducts.length > 0 ? ( +
+ + + + Name + User Count + Keys/User + Total Keys + Service Keys + Budget/Key + RPM/Key + Vector DBs + Storage (GiB) + Renewal (Days) + Status + + + + {teamProducts.map((product) => ( + + {product.name} + {product.user_count} + {product.keys_per_user} + {product.total_key_count} + {product.service_key_count} + ${product.max_budget_per_key.toFixed(2)} + {product.rpm_per_key} + {product.vector_db_count} + {product.vector_db_storage} + {product.renewal_period_days} + + + {product.active ? "Active" : "Inactive"} + + + + ))} + +
+
+ ) : ( +
+

No products associated with this team.

+
+ )} +
+
) : ( diff --git a/frontend/src/app/team-admin/pricing/page.tsx b/frontend/src/app/team-admin/pricing/page.tsx new file mode 100644 index 0000000..d79a293 --- /dev/null +++ b/frontend/src/app/team-admin/pricing/page.tsx @@ -0,0 +1,86 @@ +'use client'; + +import { useEffect, useState } from 'react'; +import { useAuth } from '@/hooks/use-auth'; +import { get, post } from '@/utils/api'; +import Script from 'next/script'; + +declare module 'react' { + interface HTMLAttributes extends AriaAttributes, DOMAttributes { + 'pricing-table-id'?: string; + 'publishable-key'?: string; + 'customer-session-client-secret'?: string; + } +} + +declare module 'react/jsx-runtime' { + interface Element { + 'stripe-pricing-table': HTMLElement; + } +} + +declare global { + interface HTMLElementTagNameMap { + 'stripe-pricing-table': HTMLElement; + } +} + +export default function PricingPage() { + const { user } = useAuth(); + const [clientSecret, setClientSecret] = useState(null); + const [error, setError] = useState(null); + + useEffect(() => { + const fetchSessionToken = async () => { + try { + if (!user?.team_id) return; + const response = await get(`/billing/teams/${user.team_id}/pricing-table-session`); + const data = await response.json(); + setClientSecret(data.client_secret); + } catch (err) { + setError('Failed to load pricing table. Please try again later.'); + console.error('Error fetching pricing table session:', err); + } + }; + + fetchSessionToken(); + }, [user?.team_id]); + + const handleManageSubscription = async () => { + try { + const response = await post(`/billing/teams/${user?.team_id}/portal`, {}); + if (response.redirected) { + window.location.href = response.url; + } + } catch (error) { + console.error('Error accessing portal:', error); + } + }; + + if (error) { + return
{error}
; + } + + return ( +
+
+

Subscription Plans

+ +
+