diff --git a/app/api/billing.py b/app/api/billing.py
index 981d4ed..f32f6b5 100644
--- a/app/api/billing.py
+++ b/app/api/billing.py
@@ -4,7 +4,7 @@
import os
from datetime import datetime, UTC
from app.db.database import get_db
-from app.core.security import check_specific_team_admin, check_system_admin
+from app.core.security import get_role_min_specific_team_admin, get_role_min_system_admin
from app.db.models import DBTeam, DBSystemSecret, DBProduct, DBTeamProduct
from app.schemas.models import PricingTableSession, SubscriptionCreate, SubscriptionResponse
from app.services.stripe import (
@@ -91,7 +91,7 @@ async def handle_events(
detail="Error processing webhook"
)
-@router.post("/teams/{team_id}/portal", dependencies=[Depends(check_specific_team_admin)])
+@router.post("/teams/{team_id}/portal", dependencies=[Depends(get_role_min_specific_team_admin)])
async def get_portal(
team_id: int,
db: Session = Depends(get_db)
@@ -135,7 +135,7 @@ async def get_portal(
detail="Error creating portal session"
)
-@router.get("/teams/{team_id}/pricing-table-session", dependencies=[Depends(check_specific_team_admin)], response_model=PricingTableSession)
+@router.get("/teams/{team_id}/pricing-table-session", dependencies=[Depends(get_role_min_specific_team_admin)], response_model=PricingTableSession)
async def get_pricing_table_session(
team_id: int,
db: Session = Depends(get_db)
@@ -178,7 +178,7 @@ async def get_pricing_table_session(
detail="Error creating customer session"
)
-@router.post("/teams/{team_id}/subscriptions", dependencies=[Depends(check_system_admin)], response_model=SubscriptionResponse, status_code=status.HTTP_201_CREATED)
+@router.post("/teams/{team_id}/subscriptions", dependencies=[Depends(get_role_min_system_admin)], response_model=SubscriptionResponse, status_code=status.HTTP_201_CREATED)
async def create_team_subscription(
team_id: int,
subscription_data: SubscriptionCreate,
diff --git a/app/api/pricing_tables.py b/app/api/pricing_tables.py
index d77d30b..be8449b 100644
--- a/app/api/pricing_tables.py
+++ b/app/api/pricing_tables.py
@@ -5,7 +5,7 @@
from app.db.database import get_db
from app.db.models import DBTeam, DBPricingTable
-from app.core.security import check_system_admin, get_role_min_team_admin, get_current_user_from_auth
+from app.core.security import get_role_min_system_admin, get_role_min_team_admin, get_current_user_from_auth
from app.schemas.models import PricingTableCreate, PricingTableResponse, PricingTablesResponse
from app.core.config import settings
@@ -16,8 +16,8 @@
tags=["pricing-tables"]
)
-@router.post("", response_model=PricingTableResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(check_system_admin)])
-@router.post("/", response_model=PricingTableResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(check_system_admin)])
+@router.post("", response_model=PricingTableResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(get_role_min_system_admin)])
+@router.post("/", response_model=PricingTableResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(get_role_min_system_admin)])
async def create_pricing_table(
pricing_table: PricingTableCreate,
db: Session = Depends(get_db)
@@ -113,8 +113,8 @@ async def get_pricing_table(
updated_at=pricing_table.updated_at or pricing_table.created_at
)
-@router.delete("", dependencies=[Depends(check_system_admin)])
-@router.delete("/", dependencies=[Depends(check_system_admin)])
+@router.delete("", dependencies=[Depends(get_role_min_system_admin)])
+@router.delete("/", dependencies=[Depends(get_role_min_system_admin)])
async def delete_pricing_table(
table_type: str,
db: Session = Depends(get_db)
@@ -146,7 +146,7 @@ async def delete_pricing_table(
return {"message": f"Pricing table of type '{table_type}' deleted successfully"}
-@router.get("/list", response_model=PricingTablesResponse, dependencies=[Depends(check_system_admin)])
+@router.get("/list", response_model=PricingTablesResponse, dependencies=[Depends(get_role_min_system_admin)])
async def get_all_pricing_tables(
db: Session = Depends(get_db)
):
diff --git a/app/api/private_ai_keys.py b/app/api/private_ai_keys.py
index 447d749..bf6eec3 100644
--- a/app/api/private_ai_keys.py
+++ b/app/api/private_ai_keys.py
@@ -13,9 +13,22 @@
from app.db.postgres import PostgresManager
from app.db.models import DBPrivateAIKey, DBRegion, DBUser, DBTeam
from app.services.litellm import LiteLLMService
-from app.core.security import get_current_user_from_auth, get_role_min_key_creator, get_role_min_team_admin, UserRole, check_system_admin
+from app.core.security import (
+ get_current_user_from_auth,
+ get_role_min_team_admin,
+ get_private_ai_access,
+ get_role_min_system_admin
+)
+from app.core.roles import UserRole
from app.core.config import settings
-from app.core.resource_limits import check_key_limits, check_vector_db_limits, get_token_restrictions, DEFAULT_KEY_DURATION, DEFAULT_MAX_SPEND, DEFAULT_RPM_PER_KEY
+from app.core.resource_limits import (
+ check_key_limits,
+ check_vector_db_limits,
+ get_token_restrictions,
+ DEFAULT_KEY_DURATION,
+ DEFAULT_MAX_SPEND,
+ DEFAULT_RPM_PER_KEY
+)
router = APIRouter(
tags=["private-ai-keys"]
@@ -68,7 +81,7 @@ def _validate_permissions_and_get_ownership_info(
async def create_vector_db(
vector_db: VectorDBCreate,
current_user = Depends(get_current_user_from_auth),
- user_role: UserRole = Depends(get_role_min_key_creator),
+ user_role: UserRole = Depends(get_private_ai_access),
db: Session = Depends(get_db),
store_result: bool = True
):
@@ -164,7 +177,7 @@ async def create_vector_db(
async def create_private_ai_key(
private_ai_key: PrivateAIKeyCreate,
current_user = Depends(get_current_user_from_auth),
- user_role: UserRole = Depends(get_role_min_key_creator),
+ user_role: UserRole = Depends(get_private_ai_access),
db: Session = Depends(get_db)
):
"""
@@ -279,7 +292,7 @@ async def create_private_ai_key(
async def create_llm_token(
private_ai_key: PrivateAIKeyCreate,
current_user = Depends(get_current_user_from_auth),
- user_role: UserRole = Depends(get_role_min_key_creator),
+ user_role: UserRole = Depends(get_private_ai_access),
db: Session = Depends(get_db),
store_result: bool = True
):
@@ -464,7 +477,7 @@ async def list_private_ai_keys(
private_ai_keys = query.all()
return [key.to_dict() for key in private_ai_keys]
-@router.get("/{key_id}", response_model=PrivateAIKeyDetail, dependencies=[Depends(check_system_admin)])
+@router.get("/{key_id}", response_model=PrivateAIKeyDetail, dependencies=[Depends(get_role_min_system_admin)])
async def get_private_ai_key(
key_id: int,
current_user = Depends(get_current_user_from_auth),
@@ -571,7 +584,7 @@ def _get_key_if_allowed(key_id: int, current_user: DBUser, user_role: UserRole,
async def delete_private_ai_key(
key_id: int,
current_user = Depends(get_current_user_from_auth),
- user_role: UserRole = Depends(get_role_min_key_creator),
+ user_role: UserRole = Depends(get_private_ai_access),
db: Session = Depends(get_db)
):
private_ai_key = _get_key_if_allowed(key_id, current_user, user_role, db)
diff --git a/app/api/products.py b/app/api/products.py
index 68475f8..25e7512 100644
--- a/app/api/products.py
+++ b/app/api/products.py
@@ -5,15 +5,15 @@
from app.db.database import get_db
from app.db.models import DBProduct, DBTeamProduct, DBTeam
-from app.core.security import check_system_admin, get_current_user_from_auth, get_role_min_team_admin
+from app.core.security import get_role_min_system_admin, get_current_user_from_auth, get_role_min_team_admin
from app.schemas.models import Product, ProductCreate, ProductUpdate
router = APIRouter(
tags=["products"]
)
-@router.post("", response_model=Product, status_code=status.HTTP_201_CREATED, dependencies=[Depends(check_system_admin)])
-@router.post("/", response_model=Product, status_code=status.HTTP_201_CREATED, dependencies=[Depends(check_system_admin)])
+@router.post("", response_model=Product, status_code=status.HTTP_201_CREATED, dependencies=[Depends(get_role_min_system_admin)])
+@router.post("/", response_model=Product, status_code=status.HTTP_201_CREATED, dependencies=[Depends(get_role_min_system_admin)])
async def create_product(
product: ProductCreate,
db: Session = Depends(get_db)
@@ -105,7 +105,7 @@ async def get_product(
)
return product
-@router.put("/{product_id}", response_model=Product, dependencies=[Depends(check_system_admin)])
+@router.put("/{product_id}", response_model=Product, dependencies=[Depends(get_role_min_system_admin)])
async def update_product(
product_id: str,
product_update: ProductUpdate,
@@ -132,7 +132,7 @@ async def update_product(
return product
-@router.delete("/{product_id}", dependencies=[Depends(check_system_admin)])
+@router.delete("/{product_id}", dependencies=[Depends(get_role_min_system_admin)])
async def delete_product(
product_id: str,
db: Session = Depends(get_db)
diff --git a/app/api/regions.py b/app/api/regions.py
index f2d40d4..6c462e6 100644
--- a/app/api/regions.py
+++ b/app/api/regions.py
@@ -9,7 +9,7 @@
from app.api.auth import get_current_user_from_auth
from app.schemas.models import Region, RegionCreate, RegionResponse, User, RegionUpdate, TeamSummary
from app.db.models import DBRegion, DBPrivateAIKey, DBTeamRegion, DBTeam
-from app.core.security import check_system_admin
+from app.core.security import get_role_min_system_admin
logger = logging.getLogger(__name__)
@@ -89,8 +89,8 @@ async def validate_database_connection(host: str, port: int, user: str, password
detail=f"Database connection validation failed: {str(e)}"
)
-@router.post("", response_model=Region, dependencies=[Depends(check_system_admin)])
-@router.post("/", response_model=Region, dependencies=[Depends(check_system_admin)])
+@router.post("", response_model=Region, dependencies=[Depends(get_role_min_system_admin)])
+@router.post("/", response_model=Region, dependencies=[Depends(get_role_min_system_admin)])
async def create_region(
region: RegionCreate,
db: Session = Depends(get_db)
@@ -158,13 +158,13 @@ async def list_regions(
return non_dedicated_regions + team_dedicated_regions
-@router.get("/admin", response_model=List[Region], dependencies=[Depends(check_system_admin)])
+@router.get("/admin", response_model=List[Region], dependencies=[Depends(get_role_min_system_admin)])
async def list_admin_regions(
db: Session = Depends(get_db)
):
return db.query(DBRegion).all()
-@router.get("/{region_id}", response_model=RegionResponse, dependencies=[Depends(check_system_admin)])
+@router.get("/{region_id}", response_model=RegionResponse, dependencies=[Depends(get_role_min_system_admin)])
async def get_region(
region_id: int,
db: Session = Depends(get_db)
@@ -177,7 +177,7 @@ async def get_region(
)
return region
-@router.delete("/{region_id}", dependencies=[Depends(check_system_admin)])
+@router.delete("/{region_id}", dependencies=[Depends(get_role_min_system_admin)])
async def delete_region(
region_id: int,
db: Session = Depends(get_db)
@@ -209,7 +209,7 @@ async def delete_region(
)
return {"message": "Region deleted successfully"}
-@router.put("/{region_id}", response_model=Region, dependencies=[Depends(check_system_admin)])
+@router.put("/{region_id}", response_model=Region, dependencies=[Depends(get_role_min_system_admin)])
async def update_region(
region_id: int,
region: RegionUpdate,
@@ -251,7 +251,7 @@ async def update_region(
)
return db_region
-@router.post("/{region_id}/teams/{team_id}", dependencies=[Depends(check_system_admin)])
+@router.post("/{region_id}/teams/{team_id}", dependencies=[Depends(get_role_min_system_admin)])
async def associate_team_with_region(
region_id: int,
team_id: int,
@@ -311,7 +311,7 @@ async def associate_team_with_region(
return {"message": "Team associated with region successfully"}
-@router.delete("/{region_id}/teams/{team_id}", dependencies=[Depends(check_system_admin)])
+@router.delete("/{region_id}/teams/{team_id}", dependencies=[Depends(get_role_min_system_admin)])
async def disassociate_team_from_region(
region_id: int,
team_id: int,
@@ -345,7 +345,7 @@ async def disassociate_team_from_region(
return {"message": "Team disassociated from region successfully"}
-@router.get("/{region_id}/teams", response_model=List[TeamSummary], dependencies=[Depends(check_system_admin)])
+@router.get("/{region_id}/teams", response_model=List[TeamSummary], dependencies=[Depends(get_role_min_system_admin)])
async def list_teams_for_region(
region_id: int,
db: Session = Depends(get_db)
diff --git a/app/api/teams.py b/app/api/teams.py
index 2a5591f..c181f4b 100644
--- a/app/api/teams.py
+++ b/app/api/teams.py
@@ -6,8 +6,8 @@
import logging
from app.db.database import get_db
-from app.db.models import DBTeam, DBTeamProduct, DBUser, DBPrivateAIKey, DBRegion, DBTeamRegion
-from app.core.security import check_system_admin, check_specific_team_admin, get_current_user_from_auth
+from app.db.models import DBTeam, DBTeamProduct, DBUser, DBPrivateAIKey, DBRegion, DBTeamRegion, DBProduct
+from app.core.security import get_role_min_system_admin, get_role_min_specific_team_admin, get_current_user_from_auth, check_sales_or_higher
from app.schemas.models import (
Team, TeamCreate, TeamUpdate,
TeamWithUsers, TeamMergeRequest, TeamMergeResponse
@@ -17,6 +17,7 @@
from app.services.ses import SESService
from app.core.worker import get_team_keys_by_region, generate_pricing_url, get_team_admin_email
from app.api.private_ai_keys import delete_private_ai_key
+from app.schemas.models import SalesTeamsResponse, SalesProduct, SalesTeam
logger = logging.getLogger(__name__)
@@ -66,8 +67,8 @@ async def register_team(
return db_team
-@router.get("", response_model=List[Team], dependencies=[Depends(check_system_admin)])
-@router.get("/", response_model=List[Team], dependencies=[Depends(check_system_admin)])
+@router.get("", response_model=List[Team], dependencies=[Depends(get_role_min_system_admin)])
+@router.get("/", response_model=List[Team], dependencies=[Depends(get_role_min_system_admin)])
async def list_teams(
db: Session = Depends(get_db)
):
@@ -76,7 +77,7 @@ async def list_teams(
"""
return db.query(DBTeam).all()
-@router.get("/{team_id}", response_model=TeamWithUsers, dependencies=[Depends(check_specific_team_admin)])
+@router.get("/{team_id}", response_model=TeamWithUsers, dependencies=[Depends(get_role_min_specific_team_admin)])
async def get_team(
team_id: int,
db: Session = Depends(get_db)
@@ -92,7 +93,7 @@ async def get_team(
# Convert directly to TeamWithUsers model
return TeamWithUsers.model_validate(db_team)
-@router.put("/{team_id}", response_model=Team, dependencies=[Depends(check_specific_team_admin)])
+@router.put("/{team_id}", response_model=Team, dependencies=[Depends(get_role_min_specific_team_admin)])
async def update_team(
team_id: int,
team_update: TeamUpdate,
@@ -143,7 +144,7 @@ async def update_team(
return db_team
-@router.delete("/{team_id}", dependencies=[Depends(check_system_admin)])
+@router.delete("/{team_id}", dependencies=[Depends(get_role_min_system_admin)])
async def delete_team(
team_id: int,
db: Session = Depends(get_db)
@@ -166,7 +167,7 @@ async def delete_team(
return {"message": "Team deleted successfully"}
-@router.post("/{team_id}/extend-trial", dependencies=[Depends(check_system_admin)])
+@router.post("/{team_id}/extend-trial", dependencies=[Depends(get_role_min_system_admin)])
async def extend_team_trial(
team_id: int,
db: Session = Depends(get_db)
@@ -230,6 +231,146 @@ async def extend_team_trial(
return {"message": "Team trial extended successfully"}
+@router.get("/sales/list-teams", response_model=SalesTeamsResponse, dependencies=[Depends(check_sales_or_higher)])
+async def list_teams_for_sales(
+ db: Session = Depends(get_db)
+):
+ """
+ Get consolidated team information for sales dashboard.
+ Returns all teams with their products, regions, spend data, and trial status.
+ Accessible by system admin and sales users.
+ """
+ try:
+ # Track unreachable endpoints for logging at the end (use set to avoid duplicates)
+ unreachable_endpoints = set()
+
+ # Pre-fetch all regions once to avoid repeated queries
+ all_regions = db.query(DBRegion).filter(DBRegion.is_active == True).all()
+ regions_map = {r.id: r for r in all_regions}
+
+ # Pre-create LiteLLM services for each region to avoid re-instantiation
+ litellm_services = {}
+ for region in all_regions:
+ litellm_services[region.id] = LiteLLMService(
+ api_url=region.litellm_api_url,
+ api_key=region.litellm_api_key
+ )
+
+ # Get all teams with their basic information
+ teams = db.query(DBTeam).all()
+
+ sales_teams = []
+
+ for team in teams:
+ # Get team products
+ team_products = db.query(DBTeamProduct).join(DBProduct).filter(
+ DBTeamProduct.team_id == team.id,
+ DBProduct.active == True
+ ).all()
+
+ products = [
+ SalesProduct(
+ id=team_product.product.id,
+ name=team_product.product.name,
+ active=team_product.product.active
+ )
+ for team_product in team_products
+ ]
+
+ # Get team AI keys (both team-owned and user-owned) and calculate total spend
+ team_users = db.query(DBUser).filter(DBUser.team_id == team.id).all()
+ team_user_ids = [user.id for user in team_users]
+
+ team_keys = db.query(DBPrivateAIKey).filter(
+ (DBPrivateAIKey.team_id == team.id) | # Team-owned keys
+ (DBPrivateAIKey.owner_id.in_(team_user_ids)) # User-owned keys by team members
+ ).all()
+
+ # Calculate total spend from all AI keys and build regions list as we go
+ total_spend = 0.0
+ regions_set = set()
+
+ for key in team_keys:
+ if key.litellm_token and key.region_id in regions_map:
+ try:
+ # Use pre-fetched region info and pre-created LiteLLM service
+ region = regions_map[key.region_id]
+ litellm_service = litellm_services[region.id]
+
+ # Add region name to our set
+ regions_set.add(region.name)
+
+ # Get spend data from LiteLLM
+ key_data = await litellm_service.get_key_info(key.litellm_token)
+ key_spend = key_data.get("info", {}).get("spend", 0.0)
+ total_spend += float(key_spend)
+ except Exception as e:
+ # Track unreachable endpoint for logging at the end (only once per region)
+ region = regions_map[key.region_id]
+ endpoint_info = f"Region: {region.name}"
+ unreachable_endpoints.add(endpoint_info)
+
+ # Convert set to list for the response
+ regions = list(regions_set)
+
+ # Calculate trial status
+ trial_status = _calculate_trial_status(team, products)
+
+ sales_team = SalesTeam(
+ id=team.id,
+ name=team.name,
+ admin_email=team.admin_email,
+ created_at=team.created_at,
+ last_payment=team.last_payment,
+ is_always_free=team.is_always_free,
+ products=products,
+ regions=regions,
+ total_spend=round(total_spend, 4),
+ trial_status=trial_status
+ )
+
+ sales_teams.append(sales_team)
+
+ # Log all unreachable endpoints at the end
+ if unreachable_endpoints:
+ logger.warning(f"Unreachable LiteLLM endpoints encountered: {len(unreachable_endpoints)} unique endpoints")
+ for endpoint in unreachable_endpoints:
+ logger.warning(f" - {endpoint}")
+
+ return SalesTeamsResponse(teams=sales_teams)
+
+ except Exception as e:
+ logger.error(f"Error in list_teams_for_sales: {str(e)}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to retrieve sales data: {str(e)}"
+ )
+
+def _calculate_trial_status(team: DBTeam, products: List[SalesProduct]) -> str:
+ """
+ Calculate trial status based on team creation, last payment, and active products.
+ """
+ if team.is_always_free:
+ return "Always Free"
+
+ if len(products) > 0:
+ return "Active Product"
+
+ # Calculate days until expiry
+ trial_period_days = 30
+ if team.last_payment:
+ days_since_last_payment = (datetime.now(UTC) - team.last_payment.replace(tzinfo=UTC)).days
+ days_remaining = trial_period_days - days_since_last_payment
+ else:
+ days_since_creation = (datetime.now(UTC) - team.created_at.replace(tzinfo=UTC)).days
+ days_remaining = trial_period_days - days_since_creation
+
+ if days_remaining <= 0:
+ return "Expired"
+ else:
+ # Always show days remaining for active trials
+ return f"{days_remaining} days left"
+
def _check_key_name_conflicts(team1_keys: List[DBPrivateAIKey], team2_keys: List[DBPrivateAIKey]) -> List[str]:
"""Return list of conflicting key names between two teams"""
team1_names = {key.name for key in team1_keys if key.name}
@@ -278,7 +419,7 @@ async def _resolve_key_conflicts(
else:
raise ValueError(f"Unknown conflict resolution strategy: {strategy}")
-@router.post("/{target_team_id}/merge", dependencies=[Depends(check_system_admin)])
+@router.post("/{target_team_id}/merge", dependencies=[Depends(get_role_min_system_admin)])
async def merge_teams(
target_team_id: int,
merge_request: TeamMergeRequest,
diff --git a/app/api/users.py b/app/api/users.py
index c78ca3f..6b60aea 100644
--- a/app/api/users.py
+++ b/app/api/users.py
@@ -1,19 +1,20 @@
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
-from typing import List, get_args
+from typing import List
from app.core.config import settings
from app.core.resource_limits import check_team_user_limit
from app.db.database import get_db
from app.schemas.models import User, UserUpdate, UserCreate, TeamOperation, UserRoleUpdate
from app.db.models import DBUser, DBTeam
-from app.core.security import get_password_hash, check_system_admin, get_current_user_from_auth, UserRole, get_role_min_team_admin
+from app.core.security import get_password_hash, get_role_min_system_admin, get_current_user_from_auth, get_role_min_team_admin
+from app.core.roles import UserRole
from datetime import datetime, UTC
router = APIRouter(
tags=["users"]
)
-@router.get("/search", response_model=List[User], dependencies=[Depends(check_system_admin)])
+@router.get("/search", response_model=List[User], dependencies=[Depends(get_role_min_system_admin)])
async def search_users(
email: str,
db: Session = Depends(get_db)
@@ -77,10 +78,10 @@ async def create_user(
check_team_user_limit(db, user.team_id)
# Validate role if provided
- if user.role and user.role not in get_args(UserRole):
+ if user.role and user.role not in UserRole.get_all_roles():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
- detail=f"Invalid role. Must be one of: {', '.join(get_args(UserRole))}"
+ detail=f"Invalid role. Must be one of: {', '.join(UserRole.get_all_roles())}"
)
# Default to the lowest permissions for a user in a team
@@ -198,7 +199,7 @@ async def add_user_to_team(
db.refresh(db_user)
return db_user
-@router.post("/{user_id}/remove-from-team", response_model=User, dependencies=[Depends(check_system_admin)])
+@router.post("/{user_id}/remove-from-team", response_model=User, dependencies=[Depends(get_role_min_system_admin)])
async def remove_user_from_team(
user_id: int,
current_user: DBUser = Depends(get_current_user_from_auth),
@@ -224,7 +225,7 @@ async def remove_user_from_team(
db.refresh(db_user)
return db_user
-@router.delete("/{user_id}", dependencies=[Depends(check_system_admin)])
+@router.delete("/{user_id}", dependencies=[Depends(get_role_min_system_admin)])
async def delete_user(
user_id: int,
current_user: DBUser = Depends(get_current_user_from_auth),
@@ -256,10 +257,10 @@ async def update_user_role(
Update a user's role. Accessible by admin users or team admins for their team members.
"""
# Validate role
- if role_update.role not in get_args(UserRole):
+ if role_update.role not in UserRole.get_all_roles():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
- detail=f"Invalid role. Must be one of: {', '.join(get_args(UserRole))}"
+ detail=f"Invalid role. Must be one of: {', '.join(UserRole.get_all_roles())}"
)
# Get the user to update
diff --git a/app/core/rbac.py b/app/core/rbac.py
new file mode 100644
index 0000000..618b92e
--- /dev/null
+++ b/app/core/rbac.py
@@ -0,0 +1,108 @@
+from fastapi import Depends, HTTPException, status
+from typing import List, Union, Set, Callable
+from app.core.roles import UserRole, UserType
+from app.db.models import DBUser
+import logging
+
+class RBACDependency:
+ """Base class for role-based access control dependencies"""
+ logger = logging.getLogger(__name__)
+
+ def __init__(self, allowed_roles: List[str], require_team_membership: bool = False):
+ self.allowed_roles = set(allowed_roles)
+ self.require_team_membership = require_team_membership
+
+ def __call__(self, current_user: DBUser) -> str:
+ return self.check_access(current_user)
+
+ def check_access(self, user: DBUser) -> str:
+ """Check if user has access and return their effective role"""
+ # Validate user type constraints
+ if self._validate_user_type_constraints(user):
+ self.logger.info(f"User {user.id} has invalid user type constraints")
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Not authorized to perform this action"
+ )
+
+ # Check role permissions
+ effective_role = self._get_effective_role(user)
+ if effective_role not in self.allowed_roles:
+ self.logger.info(f"User {user.id} has invalid role {effective_role}")
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Not authorized to perform this action"
+ )
+
+ # Check team membership if required (but allow system admins to bypass this)
+ if self.require_team_membership and not user.team_id and not user.is_admin:
+ self.logger.info(f"User {user.id} is not a team member")
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Not authorized to perform this action"
+ )
+
+ return effective_role
+
+ def _validate_user_type_constraints(self, user: DBUser) -> bool:
+ """Validate that user type matches role constraints"""
+ # System admins (is_admin=True) cannot be team members
+ if user.is_admin and user.team_id is not None:
+ return True
+
+ # Get the effective role for validation
+ effective_role = self._get_effective_role(user)
+
+ # System users (team_id is None) cannot have team roles
+ if user.team_id is None and effective_role in UserRole.get_team_roles():
+ return True
+
+ # Team users (team_id is not None) cannot have system roles
+ if user.team_id is not None and effective_role in UserRole.get_system_roles():
+ return True
+
+ return False
+
+ def _get_effective_role(self, user: DBUser) -> str:
+ """Get the effective role for a user"""
+ if user.is_admin:
+ return UserRole.SYSTEM_ADMIN
+ return user.role or UserRole.USER
+
+# Pre-defined dependency functions for common use cases
+def require_system_admin():
+ """Require system admin role"""
+ return RBACDependency([UserRole.SYSTEM_ADMIN])
+
+def require_team_admin():
+ """Require team admin role or system admin"""
+ return RBACDependency(UserRole.ADMIN_ROLES, require_team_membership=True)
+
+def require_key_creator_or_higher():
+ """Require key creator role or higher (team context)"""
+ return RBACDependency(UserRole.KEY_MANAGEMENT_ROLES, require_team_membership=True)
+
+def require_private_ai_access():
+ """Require access to private AI operations - allows system users or team key creators"""
+ return RBACDependency(UserRole.KEY_MANAGEMENT_ROLES + [UserRole.USER], require_team_membership=False)
+
+def require_read_only_or_higher():
+ """Require read only role or higher (team context)"""
+ return RBACDependency(UserRole.READ_ACCESS_ROLES, require_team_membership=True)
+
+def require_sales_or_higher():
+ """Require sales role or higher (system context)"""
+ return RBACDependency(UserRole.SYSTEM_ACCESS_ROLES)
+
+def require_any_role():
+ """Allow any authenticated user"""
+ return RBACDependency(UserRole.get_all_roles())
+
+# Custom role dependency creator
+def require_roles(*roles: str):
+ """Create a dependency that requires specific roles"""
+ return RBACDependency(list(roles))
+
+def require_roles_with_team(*roles: str):
+ """Create a dependency that requires specific roles and team membership"""
+ return RBACDependency(list(roles), require_team_membership=True)
diff --git a/app/core/roles.py b/app/core/roles.py
new file mode 100644
index 0000000..6e38210
--- /dev/null
+++ b/app/core/roles.py
@@ -0,0 +1,53 @@
+
+from typing import List, Set, Literal
+from enum import Enum
+
+class UserType(Enum):
+ SYSTEM = "system"
+ TEAM = "team"
+
+class UserRole:
+ # System roles
+ SYSTEM_ADMIN = "system_admin"
+ USER = "user" # Default system user
+ SALES = "sales"
+
+ # Team roles
+ TEAM_ADMIN = "admin"
+ KEY_CREATOR = "key_creator"
+ READ_ONLY = "read_only"
+
+ # Legacy support - these MUST match existing string values exactly
+ ADMIN = TEAM_ADMIN # "admin"
+ DEFAULT = USER # "user"
+
+ # Role combinations for better readability
+ ADMIN_ROLES = [TEAM_ADMIN, SYSTEM_ADMIN]
+ KEY_MANAGEMENT_ROLES = [KEY_CREATOR] + ADMIN_ROLES
+ READ_ACCESS_ROLES = [READ_ONLY] + KEY_MANAGEMENT_ROLES
+ SYSTEM_ACCESS_ROLES = [SYSTEM_ADMIN, SALES]
+
+ @staticmethod
+ def get_system_roles() -> List[str]:
+ """Get all valid system user roles"""
+ return [UserRole.SYSTEM_ADMIN, UserRole.USER, UserRole.SALES]
+
+ @staticmethod
+ def get_team_roles() -> List[str]:
+ """Get all valid team user roles"""
+ return [UserRole.TEAM_ADMIN, UserRole.KEY_CREATOR, UserRole.READ_ONLY]
+
+ @staticmethod
+ def get_all_roles() -> List[str]:
+ """Get all valid roles (backwards compatible)"""
+ return UserRole.get_system_roles() + UserRole.get_team_roles()
+
+ @staticmethod
+ def is_system_role(role: str) -> bool:
+ """Check if role is valid for system users"""
+ return role in UserRole.get_system_roles()
+
+ @staticmethod
+ def is_team_role(role: str) -> bool:
+ """Check if role is valid for team users"""
+ return role in UserRole.get_team_roles()
diff --git a/app/core/security.py b/app/core/security.py
index 43583ad..5a37861 100644
--- a/app/core/security.py
+++ b/app/core/security.py
@@ -1,5 +1,5 @@
from datetime import datetime, timedelta, UTC
-from typing import Optional, Literal, Dict
+from typing import Optional
from jose import JWTError, jwt
from passlib.context import CryptContext
from fastapi import Depends, HTTPException, status, Cookie, Header, Request
@@ -9,6 +9,13 @@
from app.db.database import get_db
from sqlalchemy.orm import Session
from app.db.models import DBUser, DBAPIToken
+from app.core.rbac import (
+ require_system_admin,
+ require_team_admin,
+ require_key_creator_or_higher,
+ require_sales_or_higher,
+ require_private_ai_access,
+)
logger = logging.getLogger(__name__)
@@ -18,18 +25,6 @@
# Custom bearer scheme
bearer_scheme = HTTPBearer(auto_error=False)
-# Define valid user roles as a Literal type
-UserRole = Literal["admin", "key_creator", "read_only", "user", "system_admin"]
-
-# Define a hierarchy for roles
-user_role_hierarchy: Dict[UserRole, int] = {
- "admin": 0,
- "user": 1,
- "key_creator": 2,
- "read_only": 3,
- "system_admin": 4,
-}
-
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash."""
return pwd_context.verify(plain_password, hashed_password)
@@ -137,35 +132,40 @@ async def get_current_user_from_auth(
headers={"WWW-Authenticate": "Bearer"},
)
-async def check_system_admin(current_user: DBUser = Depends(get_current_user_from_auth)):
+async def get_role_min_system_admin(current_user: DBUser = Depends(get_current_user_from_auth)):
"""Check if the current user is a system admin."""
- if not current_user.is_admin:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="Not authorized to perform this action"
- )
-
-def get_user_role(minimum_role: UserRole, current_user: DBUser):
- if current_user.is_admin:
- return "system_admin"
- elif user_role_hierarchy[current_user.role] > user_role_hierarchy[minimum_role]:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="Not authorized to perform this action"
- )
- return current_user.role
+ dependency = require_system_admin()
+ return dependency.check_access(current_user)
async def get_role_min_team_admin(current_user: DBUser = Depends(get_current_user_from_auth)):
- return get_user_role("admin", current_user)
+ """Require team admin role or higher."""
+ dependency = require_team_admin()
+ return dependency.check_access(current_user)
+
+async def get_role_min_specific_team_admin(current_user: DBUser = Depends(get_current_user_from_auth), team_id: int = None):
+ """Check if user is admin of specific team."""
+ dependency = require_team_admin()
+ role = dependency.check_access(current_user)
-async def check_specific_team_admin(current_user: DBUser = Depends(get_current_user_from_auth), team_id: int = None):
- get_user_role("admin", current_user)
- # system administrators will fail the team check
+ # Additional team-specific check
if not current_user.is_admin and not current_user.team_id == team_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authorized to perform this action"
)
+ return role
async def get_role_min_key_creator(current_user: DBUser = Depends(get_current_user_from_auth)):
- return get_user_role("key_creator", current_user)
+ """Require key creator role or higher."""
+ dependency = require_key_creator_or_higher()
+ return dependency.check_access(current_user)
+
+async def get_private_ai_access(current_user: DBUser = Depends(get_current_user_from_auth)):
+ """Require access to private AI operations - allows system users or team key creators."""
+ dependency = require_private_ai_access()
+ return dependency.check_access(current_user)
+
+async def check_sales_or_higher(current_user: DBUser = Depends(get_current_user_from_auth)):
+ """Check if the current user is a sales user or system admin."""
+ dependency = require_sales_or_higher()
+ return dependency.check_access(current_user)
diff --git a/app/schemas/models.py b/app/schemas/models.py
index 2f95186..a74651d 100644
--- a/app/schemas/models.py
+++ b/app/schemas/models.py
@@ -362,4 +362,28 @@ class SubscriptionResponse(BaseModel):
product_id: str
team_id: int
created_at: datetime
+ model_config = ConfigDict(from_attributes=True)
+
+# Sales Dashboard schemas
+class SalesProduct(BaseModel):
+ id: str
+ name: str
+ active: bool
+ model_config = ConfigDict(from_attributes=True)
+
+class SalesTeam(BaseModel):
+ id: int
+ name: str
+ admin_email: str
+ created_at: datetime
+ last_payment: Optional[datetime] = None
+ is_always_free: bool
+ products: List[SalesProduct]
+ regions: List[str]
+ total_spend: float
+ trial_status: str
+ model_config = ConfigDict(from_attributes=True)
+
+class SalesTeamsResponse(BaseModel):
+ teams: List[SalesTeam]
model_config = ConfigDict(from_attributes=True)
\ No newline at end of file
diff --git a/frontend/src/app/admin/layout.tsx b/frontend/src/app/admin/layout.tsx
index 5a355cf..544a090 100644
--- a/frontend/src/app/admin/layout.tsx
+++ b/frontend/src/app/admin/layout.tsx
@@ -1,11 +1,15 @@
'use client';
import { ReactNode } from 'react';
+import { usePathname } from 'next/navigation';
export default function AdminLayout({ children }: { children: ReactNode }) {
+ const pathname = usePathname();
+ const isSalesDashboard = pathname === '/admin/sales-dashboard';
+
return (
-
diff --git a/frontend/src/app/admin/page.tsx b/frontend/src/app/admin/page.tsx
index 2231ede..2ae7507 100644
--- a/frontend/src/app/admin/page.tsx
+++ b/frontend/src/app/admin/page.tsx
@@ -4,7 +4,7 @@ import { useEffect } from 'react';
import { useRouter } from 'next/navigation';
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
import { Button } from '@/components/ui/button';
-import { Users, Globe, Key } from 'lucide-react';
+import { Users, Globe, Key, DollarSign } from 'lucide-react';
export default function AdminPage() {
const router = useRouter();
@@ -54,6 +54,19 @@ export default function AdminPage() {
View Keys
+
+
router.push('/admin/sales-dashboard')}>
+
+
+
+ Sales Dashboard
+
+ Monitor team performance and revenue metrics
+
+
+ View Dashboard
+
+
);
}
\ No newline at end of file
diff --git a/frontend/src/app/admin/sales-dashboard/page.tsx b/frontend/src/app/admin/sales-dashboard/page.tsx
new file mode 100644
index 0000000..269e796
--- /dev/null
+++ b/frontend/src/app/admin/sales-dashboard/page.tsx
@@ -0,0 +1,843 @@
+'use client';
+
+import { useState, useMemo, useCallback } from 'react';
+import { useQuery } from '@tanstack/react-query';
+import {
+ Table,
+ TableBody,
+ TableCell,
+ TableHead,
+ TableHeader,
+ TableRow,
+ TablePagination,
+} from '@/components/ui/table';
+
+import { Loader2, ChevronUp, ChevronDown, ChevronsUpDown, DollarSign, Calendar, Users, Globe, Package, Plus, X } from 'lucide-react';
+import { get } from '@/utils/api';
+import { Badge } from '@/components/ui/badge';
+import { Button } from '@/components/ui/button';
+import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select';
+import { Input } from '@/components/ui/input';
+
+interface Team {
+ id: number;
+ name: string;
+ admin_email: string;
+ created_at: string;
+ last_payment?: string;
+ is_always_free: boolean;
+ products: Product[];
+ regions: string[];
+ total_spend: number;
+ trial_status: string;
+}
+
+interface Product {
+ id: string;
+ name: string;
+ active: boolean;
+}
+
+
+
+type SortField = 'admin_email' | 'name' | 'created_at' | 'last_payment' | 'products' | 'trial_status' | 'regions' | 'total_spend' | null;
+type SortDirection = 'asc' | 'desc';
+
+interface Filter {
+ id: string;
+ column: string;
+ value: string;
+ operator: 'contains' | 'equals' | 'starts_with' | 'ends_with';
+}
+
+export default function SalesDashboardPage() {
+
+ // Filter and sort state
+ const [filters, setFilters] = useState([]);
+ const [sortField, setSortField] = useState(null);
+ const [sortDirection, setSortDirection] = useState('asc');
+
+
+ // Queries
+ const { data: teams = [], isLoading: isLoadingTeams } = useQuery({
+ queryKey: ['sales-teams'],
+ queryFn: async () => {
+ const response = await get('/teams/sales/list-teams');
+ const data = await response.json();
+ return data.teams;
+ },
+ });
+
+
+
+
+ // Calculate trial time remaining
+ const getTrialTimeRemaining = useCallback((team: Team): string => {
+ return team.trial_status;
+ }, []);
+
+ // Get unique regions for a team
+ const getTeamRegions = useCallback((team: Team): string[] => {
+ return team.regions;
+ }, []);
+
+ // Get total spend for a team
+ const getTeamTotalSpend = useCallback((team: Team): number => {
+ return team.total_spend;
+ }, []);
+
+ // Get all team spend values for calculating min/max
+ const allTeamSpends = useMemo(() => {
+ return teams.map(team => ({
+ teamId: team.id,
+ spend: getTeamTotalSpend(team)
+ })).filter(item => item.spend > 0); // Only non-zero values for min calculation
+ }, [teams, getTeamTotalSpend]);
+
+ // Calculate min and max spend values
+ const spendStats = useMemo(() => {
+ if (allTeamSpends.length === 0) {
+ return { minSpend: 0, maxSpend: 0 };
+ }
+
+ const spends = allTeamSpends.map(item => item.spend);
+ return {
+ minSpend: Math.min(...spends),
+ maxSpend: Math.max(...spends)
+ };
+ }, [allTeamSpends]);
+
+
+ // Get color for spend value based on gradient rules
+ const getSpendColor = useCallback((spend: number): string => {
+ if (spend === 0) {
+ return '#6b7280'; // Dark grey for $0.00
+ }
+
+ if (spend === spendStats.maxSpend) {
+ return '#166534'; // Dark green for highest value
+ }
+
+ if (spend === spendStats.minSpend) {
+ return '#000000'; // Black for lowest non-zero value
+ }
+
+ // Gradient for values between min and max
+ if (spendStats.maxSpend === spendStats.minSpend) {
+ return '#166534'; // If all values are the same, use dark green
+ }
+
+ const ratio = (spend - spendStats.minSpend) / (spendStats.maxSpend - spendStats.minSpend);
+ const red = Math.round(22 + (ratio * 0)); // Start from dark green (22, 101, 52)
+ const green = Math.round(101 + (ratio * 53)); // End at darker green (22, 154, 52)
+ const blue = Math.round(52 + (ratio * 0));
+
+ return `rgb(${red}, ${green}, ${blue})`;
+ }, [spendStats]);
+
+ // Add a new filter
+ const addFilter = () => {
+ const newFilter: Filter = {
+ id: Date.now().toString(),
+ column: 'admin_email',
+ value: '',
+ operator: 'contains'
+ };
+ setFilters([...filters, newFilter]);
+ };
+
+ // Update filter when column changes to set appropriate default value
+ const handleColumnChange = (filterId: string, column: string) => {
+ const columnInfo = filterColumns.find(col => col.value === column);
+ let defaultValue = '';
+ let defaultOperator: Filter['operator'] = 'contains';
+
+ if (columnInfo?.type === 'select') {
+ const options = getFilterOptions(column);
+ if (options.length > 0) {
+ defaultValue = options[0].value;
+ }
+ // Regions use 'contains' since teams can have multiple regions
+ if (column === 'regions') {
+ defaultOperator = 'contains';
+ } else {
+ defaultOperator = 'equals';
+ }
+ } else if (columnInfo?.type === 'number') {
+ defaultOperator = 'equals';
+ }
+
+ updateFilter(filterId, { column, value: defaultValue, operator: defaultOperator });
+ };
+
+ // Remove a filter
+ const removeFilter = (filterId: string) => {
+ setFilters(filters.filter(f => f.id !== filterId));
+ };
+
+ // Update a filter
+ const updateFilter = (filterId: string, updates: Partial) => {
+ setFilters(filters.map(f =>
+ f.id === filterId ? { ...f, ...updates } : f
+ ));
+ };
+
+ // Clear all filters
+ const clearAllFilters = () => {
+ setFilters([]);
+ setSortField(null);
+ setSortDirection('asc');
+ };
+
+
+ // Available filter columns
+ const filterColumns = [
+ { value: 'admin_email', label: 'Team Email', type: 'text' },
+ { value: 'name', label: 'Team Name', type: 'text' },
+ { value: 'products', label: 'Products', type: 'select' },
+ { value: 'trial_status', label: 'Trial Status', type: 'select' },
+ { value: 'regions', label: 'Regions', type: 'select' },
+ ];
+
+ // Filter operators
+ const filterOperators = [
+ { value: 'contains', label: 'Contains' },
+ { value: 'equals', label: 'Equals' },
+ { value: 'starts_with', label: 'Starts with' },
+ { value: 'ends_with', label: 'Ends with' },
+ ];
+
+ // Get operators for specific column types
+ const getOperatorsForColumn = (column: string) => {
+ const columnInfo = filterColumns.find(col => col.value === column);
+
+ if (columnInfo?.type === 'select') {
+ // Regions can use 'contains' since teams can have multiple regions
+ if (column === 'regions') {
+ return [
+ { value: 'equals', label: 'Equals' },
+ { value: 'contains', label: 'Contains' }
+ ];
+ }
+ return [{ value: 'equals', label: 'Equals' }];
+ }
+
+ if (columnInfo?.type === 'number') {
+ return [
+ { value: 'equals', label: 'Equals' },
+ { value: 'contains', label: 'Contains' }
+ ];
+ }
+
+ // Text fields get all operators
+ return filterOperators;
+ };
+
+ // Get options for select-type filters
+ const getFilterOptions = (column: string) => {
+ switch (column) {
+ case 'trial_status':
+ return [
+ { value: 'Active Product', label: 'Active Product' },
+ { value: 'Always Free', label: 'Always Free' },
+ { value: 'In Progress', label: 'In Progress' },
+ { value: 'Expired', label: 'Expired' },
+ ];
+ case 'regions':
+ // Get unique regions from all teams
+ const allRegions = new Set();
+ teams.forEach(team => {
+ const teamRegions = getTeamRegions(team);
+ teamRegions.forEach(region => allRegions.add(region));
+ });
+ return [
+ { value: 'No Region', label: 'No Region' },
+ ...Array.from(allRegions).sort().map(region => ({
+ value: region,
+ label: region
+ }))
+ ];
+ case 'products':
+ const allProducts = new Set();
+ teams.forEach(team => {
+ const teamProducts = team.products || [];
+ teamProducts.forEach(product => allProducts.add(product.name));
+ });
+ return [
+ { value: 'No Product', label: 'No Product' },
+ ...Array.from(allProducts).sort().map(productName => ({
+ value: productName,
+ label: productName
+ }))
+ ];
+ default:
+ return [];
+ }
+ };
+
+ // Get filter input component based on column type
+ const getFilterInput = (filter: Filter) => {
+ const column = filterColumns.find(col => col.value === filter.column);
+
+ if (column?.type === 'select') {
+ const options = getFilterOptions(filter.column);
+ return (
+ updateFilter(filter.id, { value })}
+ >
+
+
+
+
+ {options.map((option) => (
+
+ {option.label}
+
+ ))}
+
+
+ );
+ }
+
+ if (column?.type === 'number') {
+ return (
+ updateFilter(filter.id, { value: e.target.value })}
+ className="flex-1"
+ />
+ );
+ }
+
+ // Default to text input
+ return (
+ updateFilter(filter.id, { value: e.target.value })}
+ className="flex-1"
+ />
+ );
+ };
+
+ // Apply filters to teams
+ const applyFilters = useCallback((teams: Team[]) => {
+ if (filters.length === 0) return teams;
+
+ return teams.filter(team => {
+ return filters.every(filter => {
+ let teamValue: string | number;
+
+ switch (filter.column) {
+ case 'admin_email':
+ teamValue = team.admin_email.toLowerCase();
+ break;
+ case 'name':
+ teamValue = team.name.toLowerCase();
+ break;
+ case 'products':
+ const teamProducts = team.products || [];
+ teamValue = teamProducts.length > 0 ? teamProducts.map(p => p.name).join(', ') : 'No Product';
+ break;
+ case 'trial_status':
+ const trialStatus = getTrialTimeRemaining(team);
+ if (trialStatus.includes('days left')) {
+ teamValue = 'In Progress';
+ } else {
+ teamValue = trialStatus;
+ }
+ break;
+ case 'regions':
+ const teamRegions = getTeamRegions(team);
+ teamValue = teamRegions.length > 0 ? teamRegions.join(', ') : 'No Region';
+ break;
+ default:
+ return true;
+ }
+
+ const filterValue = filter.value.toLowerCase();
+
+ if (typeof teamValue === 'number') {
+ // Handle numeric comparisons
+ const numValue = parseFloat(filterValue);
+ if (isNaN(numValue)) return true;
+ return teamValue === numValue;
+ }
+
+ // Handle string comparisons
+ switch (filter.operator) {
+ case 'contains':
+ return teamValue.toLowerCase().includes(filterValue);
+ case 'equals':
+ return teamValue.toLowerCase() === filterValue;
+ case 'starts_with':
+ return teamValue.toLowerCase().startsWith(filterValue);
+ case 'ends_with':
+ return teamValue.toLowerCase().endsWith(filterValue);
+ default:
+ return true;
+ }
+ });
+ });
+ }, [filters, getTrialTimeRemaining, getTeamRegions]);
+
+ // Filtered and sorted teams
+ const filteredAndSortedTeams = useMemo(() => {
+ const filtered = applyFilters(teams);
+
+ if (sortField) {
+ filtered.sort((a, b) => {
+ let aValue: string | number;
+ let bValue: string | number;
+
+ switch (sortField) {
+ case 'admin_email':
+ aValue = a.admin_email.toLowerCase();
+ bValue = b.admin_email.toLowerCase();
+ break;
+ case 'name':
+ aValue = a.name.toLowerCase();
+ bValue = b.name.toLowerCase();
+ break;
+ case 'created_at':
+ aValue = new Date(a.created_at).getTime();
+ bValue = new Date(b.created_at).getTime();
+ break;
+ case 'last_payment':
+ aValue = a.last_payment ? new Date(a.last_payment).getTime() : 0;
+ bValue = b.last_payment ? new Date(b.last_payment).getTime() : 0;
+ break;
+ case 'products':
+ aValue = (a.products || []).length;
+ bValue = (b.products || []).length;
+ break;
+ case 'trial_status':
+ aValue = getTrialTimeRemaining(a);
+ bValue = getTrialTimeRemaining(b);
+ break;
+ case 'regions':
+ aValue = getTeamRegions(a).length;
+ bValue = getTeamRegions(b).length;
+ break;
+ case 'total_spend':
+ aValue = getTeamTotalSpend(a);
+ bValue = getTeamTotalSpend(b);
+ break;
+ default:
+ return 0;
+ }
+
+ if (sortDirection === 'asc') {
+ return aValue < bValue ? -1 : aValue > bValue ? 1 : 0;
+ } else {
+ return aValue > bValue ? -1 : aValue < bValue ? 1 : 0;
+ }
+ });
+ }
+
+ return filtered;
+ }, [teams, sortField, sortDirection, getTrialTimeRemaining, getTeamRegions, getTeamTotalSpend, applyFilters]);
+
+ const hasActiveFilters = filters.length > 0;
+
+ // Manual pagination to ensure it updates with local state changes
+ const [currentPage, setCurrentPage] = useState(1);
+ const [pageSize, setPageSize] = useState(10);
+
+ const totalItems = filteredAndSortedTeams.length;
+ const totalPages = Math.ceil(totalItems / pageSize);
+
+ const startIndex = (currentPage - 1) * pageSize;
+ const endIndex = startIndex + pageSize;
+ const paginatedData = filteredAndSortedTeams.slice(startIndex, endIndex);
+
+ const goToPage = (page: number) => setCurrentPage(page);
+ const changePageSize = (newPageSize: number) => {
+ setPageSize(newPageSize);
+ setCurrentPage(1); // Reset to first page when changing page size
+ };
+
+ // Handle sorting
+ const handleSort = (field: SortField) => {
+ if (sortField === field) {
+ setSortDirection(sortDirection === 'asc' ? 'desc' : 'asc');
+ } else {
+ setSortField(field);
+ setSortDirection('asc');
+ }
+ };
+
+ // Get sort icon
+ const getSortIcon = (field: SortField) => {
+ if (sortField !== field) {
+ return ;
+ }
+ return sortDirection === 'asc' ? : ;
+ };
+
+ // Format date
+ const formatDate = (dateString: string | undefined): string => {
+ if (!dateString) return 'Never';
+ return new Date(dateString).toLocaleDateString();
+ };
+
+ // Format currency
+ const formatCurrency = (amount: number): string => {
+ return new Intl.NumberFormat('en-US', {
+ style: 'currency',
+ currency: 'USD',
+ minimumFractionDigits: 4,
+ maximumFractionDigits: 4,
+ }).format(amount);
+ };
+
+ if (isLoadingTeams) {
+ return (
+
+
+
+ Loading teams...
+
+
+ );
+ }
+
+ return (
+
+
+
+
Sales Dashboard
+
+ Monitor team performance, subscriptions, and revenue metrics
+
+
+
+
+
+ {teams.length}
+
+
Total Teams
+
+
+
+
+ {/* Filters Section */}
+
+
+
Filters
+
+ {hasActiveFilters && (
+
+ Clear All
+
+ )}
+
+
+ Add Filter
+
+
+
+
+ {filters.length > 0 && (
+
+ {filters.map((filter) => (
+
+ handleColumnChange(filter.id, value)}
+ >
+
+
+
+
+ {filterColumns.map((column) => (
+
+ {column.label}
+
+ ))}
+
+
+
+ updateFilter(filter.id, { operator: value as Filter['operator'] })}
+ >
+
+
+
+
+ {getOperatorsForColumn(filter.column).map((op) => (
+
+ {op.label}
+
+ ))}
+
+
+
+ {getFilterInput(filter)}
+
+ removeFilter(filter.id)}
+ className="text-muted-foreground hover:text-destructive"
+ >
+
+
+
+ ))}
+
+ )}
+
+ {hasActiveFilters && (
+
+ Showing {filteredAndSortedTeams.length} of {teams.length} teams
+
+ )}
+
+
+
+
+
+
+
+ handleSort('admin_email')}
+ >
+
+
+ Team Email
+ {getSortIcon('admin_email')}
+
+
+ handleSort('name')}
+ >
+
+ Team Name
+ {getSortIcon('name')}
+
+
+ handleSort('created_at')}
+ >
+
+
+ Team Create Date
+ {getSortIcon('created_at')}
+
+
+ handleSort('last_payment')}
+ >
+
+
+ Last Payment
+ {getSortIcon('last_payment')}
+
+
+ handleSort('products')}
+ >
+
+
+ Products
+ {getSortIcon('products')}
+
+
+ handleSort('trial_status')}
+ >
+
+
+ Trial Status
+ {getSortIcon('trial_status')}
+
+
+ handleSort('regions')}
+ >
+
+
+ Regions
+ {getSortIcon('regions')}
+
+
+ handleSort('total_spend')}
+ >
+
+
+ Total Spend
+ {getSortIcon('total_spend')}
+
+
+
+
+
+
+ {paginatedData.length === 0 ? (
+
+
+ No teams found matching your filters.
+
+
+ ) : (
+ paginatedData.map((team) => (
+
+
+ {team.admin_email}
+
+
+ {team.name}
+
+ ID: {team.id}
+
+
+
+ {formatDate(team.created_at)}
+
+
+ {formatDate(team.last_payment)}
+
+
+ {(() => {
+ const teamProducts = team.products || [];
+ return teamProducts.length > 0 ? (
+
+ {teamProducts.map((product) => (
+
+ {product.name}
+
+ ))}
+
+ ) : (
+ No products
+ );
+ })()}
+
+
+ {(() => {
+ const trialStatus = getTrialTimeRemaining(team);
+ let badgeVariant: "default" | "secondary" | "destructive" | "outline" = "outline";
+ let customStyle = {};
+
+ if (trialStatus === 'Active Product') {
+ badgeVariant = "default";
+ customStyle = { backgroundColor: '#166534', color: 'white' }; // dark green
+ } else if (trialStatus === 'Always Free') {
+ badgeVariant = "secondary";
+ } else if (trialStatus === 'Expired') {
+ badgeVariant = "destructive";
+ customStyle = { backgroundColor: '#991b1b', color: 'white' }; // dark red
+ } else if (trialStatus === 'In Progress') {
+ badgeVariant = "outline";
+ customStyle = { backgroundColor: '#fef3c7', color: '#92400e' }; // amber/yellow
+ } else if (trialStatus.includes('days left')) {
+ // Extract the number of days
+ const daysMatch = trialStatus.match(/(\d+)/);
+ if (daysMatch) {
+ const days = parseInt(daysMatch[1], 10);
+ // Calculate color gradient: 30 days = green, 0 days = red
+ // Use a 30-day scale instead of 7-day for better gradient
+ const ratio = Math.max(0, Math.min(1, days / 30));
+
+ // Green to red gradient: green (34, 197, 94) to red (239, 68, 68)
+ const red = Math.round(34 + (205 * (1 - ratio))); // 34-239 range
+ const green = Math.round(197 - (129 * (1 - ratio))); // 197-68 range
+ const blue = Math.round(94 - (26 * (1 - ratio))); // 94-68 range
+
+ customStyle = {
+ backgroundColor: `rgb(${red}, ${green}, ${blue})`,
+ color: 'white',
+ fontWeight: 'bold'
+ };
+ }
+ }
+
+ return (
+
+ {trialStatus}
+
+ );
+ })()}
+
+
+ {(() => {
+ const regions = getTeamRegions(team);
+ if (regions.length === 0) {
+ return No regions ;
+ }
+ return (
+
+ {regions.map((region) => (
+
+ {region}
+
+ ))}
+
+ );
+ })()}
+
+
+ {(() => {
+ const totalSpend = getTeamTotalSpend(team);
+ const spendColor = getSpendColor(totalSpend);
+ return (
+
+ {formatCurrency(totalSpend)}
+
+ );
+ })()}
+
+
+
+ ))
+ )}
+
+
+
+
+
+
+
+ );
+}
diff --git a/frontend/src/app/admin/users/page.tsx b/frontend/src/app/admin/users/page.tsx
index e40422c..237e68c 100644
--- a/frontend/src/app/admin/users/page.tsx
+++ b/frontend/src/app/admin/users/page.tsx
@@ -81,6 +81,7 @@ const USER_ROLES = [
{ value: 'admin', label: 'Admin' },
{ value: 'key_creator', label: 'Key Creator' },
{ value: 'read_only', label: 'Read Only' },
+ { value: 'sales', label: 'Sales' },
];
type SortField = 'email' | 'team_name' | 'role' | null;
@@ -406,6 +407,15 @@ export default function UsersPage() {
void fetchUsers();
}, [fetchUsers]);
+ // Update role when switching between system and team user types
+ useEffect(() => {
+ if (isSystemUser) {
+ setNewUserRole('admin'); // Default to admin for system users
+ } else {
+ setNewUserRole('read_only'); // Default to read_only for team users
+ }
+ }, [isSystemUser]);
+
if (isLoadingUsers) {
return (
@@ -452,6 +462,9 @@ export default function UsersPage() {
if (newUserTeamId) {
userData.team_id = newUserTeamId;
}
+ } else {
+ // For system users, also set the role
+ userData.role = newUserRole;
}
createUserMutation.mutate(userData);
@@ -545,7 +558,7 @@ export default function UsersPage() {
- Role
+ Team Role
- {USER_ROLES.map((role) => (
+ {USER_ROLES.filter(role => role.value !== 'sales').map((role) => (
{role.label}
@@ -564,6 +577,23 @@ export default function UsersPage() {
>
)}
+ {isSystemUser && (
+
+ System Role
+
+
+
+
+
+ Admin
+ Sales
+
+
+
+ )}
{
if (user) {
- router.replace('/private-ai-keys');
+ // Redirect sales users to their dashboard
+ if (user.role === 'sales') {
+ router.replace('/sales');
+ } else {
+ router.replace('/private-ai-keys');
+ }
} else {
router.replace('/auth/login');
}
diff --git a/frontend/src/app/private-ai-keys/page.tsx b/frontend/src/app/private-ai-keys/page.tsx
index c8fe9ba..c9f37d5 100644
--- a/frontend/src/app/private-ai-keys/page.tsx
+++ b/frontend/src/app/private-ai-keys/page.tsx
@@ -99,11 +99,20 @@ export default function DashboardPage() {
// Create key mutation
const createKeyMutation = useMutation({
- mutationFn: async ({ region_id, name, key_type }: { region_id: number, name: string, key_type: 'full' | 'llm' | 'vector' }) => {
+ mutationFn: async ({ region_id, name, key_type, owner_id, team_id }: {
+ region_id: number,
+ name: string,
+ key_type: 'full' | 'llm' | 'vector',
+ owner_id?: number,
+ team_id?: number
+ }) => {
const endpoint = key_type === 'full' ? '/private-ai-keys' :
key_type === 'llm' ? '/private-ai-keys/token' :
'/private-ai-keys/vector-db';
- const response = await post(endpoint, { region_id, name });
+ const payload: { region_id: number; name: string; owner_id?: number; team_id?: number } = { region_id, name };
+ if (owner_id) payload.owner_id = owner_id;
+ if (team_id) payload.team_id = team_id;
+ const response = await post(endpoint, payload);
const data = await response.json();
return data;
},
@@ -136,7 +145,9 @@ export default function DashboardPage() {
createKeyMutation.mutate({
region_id: data.region_id,
name: data.name,
- key_type: data.key_type
+ key_type: data.key_type,
+ owner_id: data.owner_id,
+ team_id: data.team_id
});
};
diff --git a/frontend/src/app/sales/page.tsx b/frontend/src/app/sales/page.tsx
new file mode 100644
index 0000000..ea4f8ba
--- /dev/null
+++ b/frontend/src/app/sales/page.tsx
@@ -0,0 +1,843 @@
+'use client';
+
+import { useState, useMemo, useCallback } from 'react';
+import { useQuery } from '@tanstack/react-query';
+import {
+ Table,
+ TableBody,
+ TableCell,
+ TableHead,
+ TableHeader,
+ TableRow,
+ TablePagination,
+} from '@/components/ui/table';
+
+import { Loader2, ChevronUp, ChevronDown, ChevronsUpDown, DollarSign, Calendar, Users, Globe, Package, Plus, X } from 'lucide-react';
+import { get } from '@/utils/api';
+import { Badge } from '@/components/ui/badge';
+import { Button } from '@/components/ui/button';
+import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select';
+import { Input } from '@/components/ui/input';
+
+interface Team {
+ id: number;
+ name: string;
+ admin_email: string;
+ created_at: string;
+ last_payment?: string;
+ is_always_free: boolean;
+ products: Product[];
+ regions: string[];
+ total_spend: number;
+ trial_status: string;
+}
+
+interface Product {
+ id: string;
+ name: string;
+ active: boolean;
+}
+
+
+
+type SortField = 'admin_email' | 'name' | 'created_at' | 'last_payment' | 'products' | 'trial_status' | 'regions' | 'total_spend' | null;
+type SortDirection = 'asc' | 'desc';
+
+interface Filter {
+ id: string;
+ column: string;
+ value: string;
+ operator: 'contains' | 'equals' | 'starts_with' | 'ends_with';
+}
+
+export default function SalesPage() {
+
+ // Filter and sort state
+ const [filters, setFilters] = useState([]);
+ const [sortField, setSortField] = useState(null);
+ const [sortDirection, setSortDirection] = useState('asc');
+
+
+ // Queries
+ const { data: teams = [], isLoading: isLoadingTeams } = useQuery({
+ queryKey: ['sales-teams'],
+ queryFn: async () => {
+ const response = await get('/teams/sales/list-teams');
+ const data = await response.json();
+ return data.teams;
+ },
+ });
+
+
+
+
+ // Calculate trial time remaining
+ const getTrialTimeRemaining = useCallback((team: Team): string => {
+ return team.trial_status;
+ }, []);
+
+ // Get unique regions for a team
+ const getTeamRegions = useCallback((team: Team): string[] => {
+ return team.regions;
+ }, []);
+
+ // Get total spend for a team
+ const getTeamTotalSpend = useCallback((team: Team): number => {
+ return team.total_spend;
+ }, []);
+
+ // Get all team spend values for calculating min/max
+ const allTeamSpends = useMemo(() => {
+ return teams.map(team => ({
+ teamId: team.id,
+ spend: getTeamTotalSpend(team)
+ })).filter(item => item.spend > 0); // Only non-zero values for min calculation
+ }, [teams, getTeamTotalSpend]);
+
+ // Calculate min and max spend values
+ const spendStats = useMemo(() => {
+ if (allTeamSpends.length === 0) {
+ return { minSpend: 0, maxSpend: 0 };
+ }
+
+ const spends = allTeamSpends.map(item => item.spend);
+ return {
+ minSpend: Math.min(...spends),
+ maxSpend: Math.max(...spends)
+ };
+ }, [allTeamSpends]);
+
+
+ // Get color for spend value based on gradient rules
+ const getSpendColor = useCallback((spend: number): string => {
+ if (spend === 0) {
+ return '#6b7280'; // Dark grey for $0.00
+ }
+
+ if (spend === spendStats.maxSpend) {
+ return '#166534'; // Dark green for highest value
+ }
+
+ if (spend === spendStats.minSpend) {
+ return '#000000'; // Black for lowest non-zero value
+ }
+
+ // Gradient for values between min and max
+ if (spendStats.maxSpend === spendStats.minSpend) {
+ return '#166534'; // If all values are the same, use dark green
+ }
+
+ const ratio = (spend - spendStats.minSpend) / (spendStats.maxSpend - spendStats.minSpend);
+ const red = Math.round(22 + (ratio * 0)); // Start from dark green (22, 101, 52)
+ const green = Math.round(101 + (ratio * 53)); // End at darker green (22, 154, 52)
+ const blue = Math.round(52 + (ratio * 0));
+
+ return `rgb(${red}, ${green}, ${blue})`;
+ }, [spendStats]);
+
+ // Add a new filter
+ const addFilter = () => {
+ const newFilter: Filter = {
+ id: Date.now().toString(),
+ column: 'admin_email',
+ value: '',
+ operator: 'contains'
+ };
+ setFilters([...filters, newFilter]);
+ };
+
+ // Update filter when column changes to set appropriate default value
+ const handleColumnChange = (filterId: string, column: string) => {
+ const columnInfo = filterColumns.find(col => col.value === column);
+ let defaultValue = '';
+ let defaultOperator: Filter['operator'] = 'contains';
+
+ if (columnInfo?.type === 'select') {
+ const options = getFilterOptions(column);
+ if (options.length > 0) {
+ defaultValue = options[0].value;
+ }
+ // Regions use 'contains' since teams can have multiple regions
+ if (column === 'regions') {
+ defaultOperator = 'contains';
+ } else {
+ defaultOperator = 'equals';
+ }
+ } else if (columnInfo?.type === 'number') {
+ defaultOperator = 'equals';
+ }
+
+ updateFilter(filterId, { column, value: defaultValue, operator: defaultOperator });
+ };
+
+ // Remove a filter
+ const removeFilter = (filterId: string) => {
+ setFilters(filters.filter(f => f.id !== filterId));
+ };
+
+ // Update a filter
+ const updateFilter = (filterId: string, updates: Partial) => {
+ setFilters(filters.map(f =>
+ f.id === filterId ? { ...f, ...updates } : f
+ ));
+ };
+
+ // Clear all filters
+ const clearAllFilters = () => {
+ setFilters([]);
+ setSortField(null);
+ setSortDirection('asc');
+ };
+
+
+ // Available filter columns
+ const filterColumns = [
+ { value: 'admin_email', label: 'Team Email', type: 'text' },
+ { value: 'name', label: 'Team Name', type: 'text' },
+ { value: 'products', label: 'Products', type: 'select' },
+ { value: 'trial_status', label: 'Trial Status', type: 'select' },
+ { value: 'regions', label: 'Regions', type: 'select' },
+ ];
+
+ // Filter operators
+ const filterOperators = [
+ { value: 'contains', label: 'Contains' },
+ { value: 'equals', label: 'Equals' },
+ { value: 'starts_with', label: 'Starts with' },
+ { value: 'ends_with', label: 'Ends with' },
+ ];
+
+ // Get operators for specific column types
+ const getOperatorsForColumn = (column: string) => {
+ const columnInfo = filterColumns.find(col => col.value === column);
+
+ if (columnInfo?.type === 'select') {
+ // Regions can use 'contains' since teams can have multiple regions
+ if (column === 'regions') {
+ return [
+ { value: 'equals', label: 'Equals' },
+ { value: 'contains', label: 'Contains' }
+ ];
+ }
+ return [{ value: 'equals', label: 'Equals' }];
+ }
+
+ if (columnInfo?.type === 'number') {
+ return [
+ { value: 'equals', label: 'Equals' },
+ { value: 'contains', label: 'Contains' }
+ ];
+ }
+
+ // Text fields get all operators
+ return filterOperators;
+ };
+
+ // Get options for select-type filters
+ const getFilterOptions = (column: string) => {
+ switch (column) {
+ case 'trial_status':
+ return [
+ { value: 'Active Product', label: 'Active Product' },
+ { value: 'Always Free', label: 'Always Free' },
+ { value: 'In Progress', label: 'In Progress' },
+ { value: 'Expired', label: 'Expired' },
+ ];
+ case 'regions':
+ // Get unique regions from all teams
+ const allRegions = new Set();
+ teams.forEach(team => {
+ const teamRegions = getTeamRegions(team);
+ teamRegions.forEach(region => allRegions.add(region));
+ });
+ return [
+ { value: 'No Region', label: 'No Region' },
+ ...Array.from(allRegions).sort().map(region => ({
+ value: region,
+ label: region
+ }))
+ ];
+ case 'products':
+ const allProducts = new Set();
+ teams.forEach(team => {
+ const teamProducts = team.products || [];
+ teamProducts.forEach(product => allProducts.add(product.name));
+ });
+ return [
+ { value: 'No Product', label: 'No Product' },
+ ...Array.from(allProducts).sort().map(productName => ({
+ value: productName,
+ label: productName
+ }))
+ ];
+ default:
+ return [];
+ }
+ };
+
+ // Get filter input component based on column type
+ const getFilterInput = (filter: Filter) => {
+ const column = filterColumns.find(col => col.value === filter.column);
+
+ if (column?.type === 'select') {
+ const options = getFilterOptions(filter.column);
+ return (
+ updateFilter(filter.id, { value })}
+ >
+
+
+
+
+ {options.map((option) => (
+
+ {option.label}
+
+ ))}
+
+
+ );
+ }
+
+ if (column?.type === 'number') {
+ return (
+ updateFilter(filter.id, { value: e.target.value })}
+ className="flex-1"
+ />
+ );
+ }
+
+ // Default to text input
+ return (
+ updateFilter(filter.id, { value: e.target.value })}
+ className="flex-1"
+ />
+ );
+ };
+
+ // Apply filters to teams
+ const applyFilters = useCallback((teams: Team[]) => {
+ if (filters.length === 0) return teams;
+
+ return teams.filter(team => {
+ return filters.every(filter => {
+ let teamValue: string | number;
+
+ switch (filter.column) {
+ case 'admin_email':
+ teamValue = team.admin_email.toLowerCase();
+ break;
+ case 'name':
+ teamValue = team.name.toLowerCase();
+ break;
+ case 'products':
+ const teamProducts = team.products || [];
+ teamValue = teamProducts.length > 0 ? teamProducts.map(p => p.name).join(', ') : 'No Product';
+ break;
+ case 'trial_status':
+ const trialStatus = getTrialTimeRemaining(team);
+ if (trialStatus.includes('days left')) {
+ teamValue = 'In Progress';
+ } else {
+ teamValue = trialStatus;
+ }
+ break;
+ case 'regions':
+ const teamRegions = getTeamRegions(team);
+ teamValue = teamRegions.length > 0 ? teamRegions.join(', ') : 'No Region';
+ break;
+ default:
+ return true;
+ }
+
+ const filterValue = filter.value.toLowerCase();
+
+ if (typeof teamValue === 'number') {
+ // Handle numeric comparisons
+ const numValue = parseFloat(filterValue);
+ if (isNaN(numValue)) return true;
+ return teamValue === numValue;
+ }
+
+ // Handle string comparisons
+ switch (filter.operator) {
+ case 'contains':
+ return teamValue.toLowerCase().includes(filterValue);
+ case 'equals':
+ return teamValue.toLowerCase() === filterValue;
+ case 'starts_with':
+ return teamValue.toLowerCase().startsWith(filterValue);
+ case 'ends_with':
+ return teamValue.toLowerCase().endsWith(filterValue);
+ default:
+ return true;
+ }
+ });
+ });
+ }, [filters, getTrialTimeRemaining, getTeamRegions]);
+
+ // Filtered and sorted teams
+ const filteredAndSortedTeams = useMemo(() => {
+ const filtered = applyFilters(teams);
+
+ if (sortField) {
+ filtered.sort((a, b) => {
+ let aValue: string | number;
+ let bValue: string | number;
+
+ switch (sortField) {
+ case 'admin_email':
+ aValue = a.admin_email.toLowerCase();
+ bValue = b.admin_email.toLowerCase();
+ break;
+ case 'name':
+ aValue = a.name.toLowerCase();
+ bValue = b.name.toLowerCase();
+ break;
+ case 'created_at':
+ aValue = new Date(a.created_at).getTime();
+ bValue = new Date(b.created_at).getTime();
+ break;
+ case 'last_payment':
+ aValue = a.last_payment ? new Date(a.last_payment).getTime() : 0;
+ bValue = b.last_payment ? new Date(b.last_payment).getTime() : 0;
+ break;
+ case 'products':
+ aValue = (a.products || []).length;
+ bValue = (b.products || []).length;
+ break;
+ case 'trial_status':
+ aValue = getTrialTimeRemaining(a);
+ bValue = getTrialTimeRemaining(b);
+ break;
+ case 'regions':
+ aValue = getTeamRegions(a).length;
+ bValue = getTeamRegions(b).length;
+ break;
+ case 'total_spend':
+ aValue = getTeamTotalSpend(a);
+ bValue = getTeamTotalSpend(b);
+ break;
+ default:
+ return 0;
+ }
+
+ if (sortDirection === 'asc') {
+ return aValue < bValue ? -1 : aValue > bValue ? 1 : 0;
+ } else {
+ return aValue > bValue ? -1 : aValue < bValue ? 1 : 0;
+ }
+ });
+ }
+
+ return filtered;
+ }, [teams, sortField, sortDirection, getTrialTimeRemaining, getTeamRegions, getTeamTotalSpend, applyFilters]);
+
+ const hasActiveFilters = filters.length > 0;
+
+ // Manual pagination to ensure it updates with local state changes
+ const [currentPage, setCurrentPage] = useState(1);
+ const [pageSize, setPageSize] = useState(10);
+
+ const totalItems = filteredAndSortedTeams.length;
+ const totalPages = Math.ceil(totalItems / pageSize);
+
+ const startIndex = (currentPage - 1) * pageSize;
+ const endIndex = startIndex + pageSize;
+ const paginatedData = filteredAndSortedTeams.slice(startIndex, endIndex);
+
+ const goToPage = (page: number) => setCurrentPage(page);
+ const changePageSize = (newPageSize: number) => {
+ setPageSize(newPageSize);
+ setCurrentPage(1); // Reset to first page when changing page size
+ };
+
+ // Handle sorting
+ const handleSort = (field: SortField) => {
+ if (sortField === field) {
+ setSortDirection(sortDirection === 'asc' ? 'desc' : 'asc');
+ } else {
+ setSortField(field);
+ setSortDirection('asc');
+ }
+ };
+
+ // Get sort icon
+ const getSortIcon = (field: SortField) => {
+ if (sortField !== field) {
+ return ;
+ }
+ return sortDirection === 'asc' ? : ;
+ };
+
+ // Format date
+ const formatDate = (dateString: string | undefined): string => {
+ if (!dateString) return 'Never';
+ return new Date(dateString).toLocaleDateString();
+ };
+
+ // Format currency
+ const formatCurrency = (amount: number): string => {
+ return new Intl.NumberFormat('en-US', {
+ style: 'currency',
+ currency: 'USD',
+ minimumFractionDigits: 4,
+ maximumFractionDigits: 4,
+ }).format(amount);
+ };
+
+ if (isLoadingTeams) {
+ return (
+
+
+
+ Loading teams...
+
+
+ );
+ }
+
+ return (
+
+
+
+
Sales Dashboard
+
+ Monitor team performance, subscriptions, and revenue metrics
+
+
+
+
+
+ {teams.length}
+
+
Total Teams
+
+
+
+
+ {/* Filters Section */}
+
+
+
Filters
+
+ {hasActiveFilters && (
+
+ Clear All
+
+ )}
+
+
+ Add Filter
+
+
+
+
+ {filters.length > 0 && (
+
+ {filters.map((filter) => (
+
+ handleColumnChange(filter.id, value)}
+ >
+
+
+
+
+ {filterColumns.map((column) => (
+
+ {column.label}
+
+ ))}
+
+
+
+ updateFilter(filter.id, { operator: value as Filter['operator'] })}
+ >
+
+
+
+
+ {getOperatorsForColumn(filter.column).map((op) => (
+
+ {op.label}
+
+ ))}
+
+
+
+ {getFilterInput(filter)}
+
+ removeFilter(filter.id)}
+ className="text-muted-foreground hover:text-destructive"
+ >
+
+
+
+ ))}
+
+ )}
+
+ {hasActiveFilters && (
+
+ Showing {filteredAndSortedTeams.length} of {teams.length} teams
+
+ )}
+
+
+
+
+
+
+
+ handleSort('admin_email')}
+ >
+
+
+ Team Email
+ {getSortIcon('admin_email')}
+
+
+ handleSort('name')}
+ >
+
+ Team Name
+ {getSortIcon('name')}
+
+
+ handleSort('created_at')}
+ >
+
+
+ Team Create Date
+ {getSortIcon('created_at')}
+
+
+ handleSort('last_payment')}
+ >
+
+
+ Last Payment
+ {getSortIcon('last_payment')}
+
+
+ handleSort('products')}
+ >
+
+
+ Products
+ {getSortIcon('products')}
+
+
+ handleSort('trial_status')}
+ >
+
+
+ Trial Status
+ {getSortIcon('trial_status')}
+
+
+ handleSort('regions')}
+ >
+
+
+ Regions
+ {getSortIcon('regions')}
+
+
+ handleSort('total_spend')}
+ >
+
+
+ Total Spend
+ {getSortIcon('total_spend')}
+
+
+
+
+
+
+ {paginatedData.length === 0 ? (
+
+
+ No teams found matching your filters.
+
+
+ ) : (
+ paginatedData.map((team) => (
+
+
+ {team.admin_email}
+
+
+ {team.name}
+
+ ID: {team.id}
+
+
+
+ {formatDate(team.created_at)}
+
+
+ {formatDate(team.last_payment)}
+
+
+ {(() => {
+ const teamProducts = team.products || [];
+ return teamProducts.length > 0 ? (
+
+ {teamProducts.map((product) => (
+
+ {product.name}
+
+ ))}
+
+ ) : (
+ No products
+ );
+ })()}
+
+
+ {(() => {
+ const trialStatus = getTrialTimeRemaining(team);
+ let badgeVariant: "default" | "secondary" | "destructive" | "outline" = "outline";
+ let customStyle = {};
+
+ if (trialStatus === 'Active Product') {
+ badgeVariant = "default";
+ customStyle = { backgroundColor: '#166534', color: 'white' }; // dark green
+ } else if (trialStatus === 'Always Free') {
+ badgeVariant = "secondary";
+ } else if (trialStatus === 'Expired') {
+ badgeVariant = "destructive";
+ customStyle = { backgroundColor: '#991b1b', color: 'white' }; // dark red
+ } else if (trialStatus === 'In Progress') {
+ badgeVariant = "outline";
+ customStyle = { backgroundColor: '#fef3c7', color: '#92400e' }; // amber/yellow
+ } else if (trialStatus.includes('days left')) {
+ // Extract the number of days
+ const daysMatch = trialStatus.match(/(\d+)/);
+ if (daysMatch) {
+ const days = parseInt(daysMatch[1], 10);
+ // Calculate color gradient: 30 days = green, 0 days = red
+ // Use a 30-day scale instead of 7-day for better gradient
+ const ratio = Math.max(0, Math.min(1, days / 30));
+
+ // Green to red gradient: green (34, 197, 94) to red (239, 68, 68)
+ const red = Math.round(34 + (205 * (1 - ratio))); // 34-239 range
+ const green = Math.round(197 - (129 * (1 - ratio))); // 197-68 range
+ const blue = Math.round(94 - (26 * (1 - ratio))); // 94-68 range
+
+ customStyle = {
+ backgroundColor: `rgb(${red}, ${green}, ${blue})`,
+ color: 'white',
+ fontWeight: 'bold'
+ };
+ }
+ }
+
+ return (
+
+ {trialStatus}
+
+ );
+ })()}
+
+
+ {(() => {
+ const regions = getTeamRegions(team);
+ if (regions.length === 0) {
+ return No regions ;
+ }
+ return (
+
+ {regions.map((region) => (
+
+ {region}
+
+ ))}
+
+ );
+ })()}
+
+
+ {(() => {
+ const totalSpend = getTeamTotalSpend(team);
+ const spendColor = getSpendColor(totalSpend);
+ return (
+
+ {formatCurrency(totalSpend)}
+
+ );
+ })()}
+
+
+
+ ))
+ )}
+
+
+
+
+
+
+
+ );
+}
diff --git a/frontend/src/components/auth/login-form.tsx b/frontend/src/components/auth/login-form.tsx
index 937ed05..3b1d77a 100644
--- a/frontend/src/components/auth/login-form.tsx
+++ b/frontend/src/components/auth/login-form.tsx
@@ -80,7 +80,12 @@ export function LoginForm() {
});
router.refresh();
- router.push('/private-ai-keys');
+ // Redirect based on user role
+ if (profileData.role === 'sales') {
+ router.push('/sales');
+ } else {
+ router.push('/private-ai-keys');
+ }
} catch (profileError) {
console.error('Failed to fetch user profile:', profileError);
setError('Successfully logged in but failed to fetch user profile. Please refresh the page.');
diff --git a/frontend/src/components/auth/passwordless-login-form.tsx b/frontend/src/components/auth/passwordless-login-form.tsx
index 38d1e7a..5b5c03b 100644
--- a/frontend/src/components/auth/passwordless-login-form.tsx
+++ b/frontend/src/components/auth/passwordless-login-form.tsx
@@ -173,7 +173,12 @@ export function PasswordlessLoginForm({ onSwitchToPassword }: PasswordlessLoginF
});
router.refresh();
- router.push('/private-ai-keys');
+ // Redirect based on user role
+ if (profileData.role === 'sales') {
+ router.push('/sales');
+ } else {
+ router.push('/private-ai-keys');
+ }
} catch (profileError) {
console.error('Failed to fetch user profile:', profileError);
setError('Successfully signed in but failed to fetch user profile. Please refresh the page.');
diff --git a/frontend/src/components/sidebar-layout.tsx b/frontend/src/components/sidebar-layout.tsx
index 7f6df87..220f2b0 100644
--- a/frontend/src/components/sidebar-layout.tsx
+++ b/frontend/src/components/sidebar-layout.tsx
@@ -11,7 +11,8 @@ import {
PanelLeftClose,
PanelLeft,
Users2,
- Package
+ Package,
+ DollarSign
} from 'lucide-react';
import { Sidebar, SidebarProvider } from '@/components/ui/sidebar';
import { NavUser } from '@/components/nav-user';
@@ -49,6 +50,7 @@ const navigation = [
{ name: 'Products', href: '/admin/products', icon: },
{ name: 'Private AI Keys', href: '/admin/private-ai-keys', icon: },
{ name: 'Audit Logs', href: '/admin/audit-logs', icon: },
+ { name: 'Sales Dashboard', href: '/admin/sales-dashboard', icon: },
],
},
{
@@ -63,6 +65,10 @@ const navigation = [
},
];
+const salesNavigation = [
+ { name: 'Sales Dashboard', href: '/sales', icon: },
+];
+
function NavMain({ navigation, pathname, collapsed }: { navigation: NavItem[]; pathname: string; collapsed: boolean }) {
const [expandedItems, setExpandedItems] = useState(['/admin', '/team-admin']);
@@ -190,12 +196,14 @@ export function SidebarLayout({
return <>{children}>;
}
- // Filter out admin navigation for non-admin users
- const filteredNavigation = navigation.filter((item) => {
- if (item.name === 'Admin' && !user?.is_admin) return false;
- if (item.name === 'Team Admin' && !isTeamAdmin(user)) return false;
- return true;
- });
+ // Use sales navigation for sales users, otherwise filter regular navigation
+ const filteredNavigation = user?.role === 'sales'
+ ? salesNavigation
+ : navigation.filter((item) => {
+ if (item.name === 'Admin' && !user?.is_admin) return false;
+ if (item.name === 'Team Admin' && !isTeamAdmin(user)) return false;
+ return true;
+ });
return (
@@ -204,7 +212,7 @@ export function SidebarLayout({
{!collapsed && (
-
+
-
+
{/* Main content */}
-
+
{children}
diff --git a/scripts/add_test_data.py b/scripts/add_test_data.py
new file mode 100644
index 0000000..b4041e7
--- /dev/null
+++ b/scripts/add_test_data.py
@@ -0,0 +1,276 @@
+#!/usr/bin/env python3
+
+import os
+import sys
+from datetime import datetime, timedelta, UTC
+import random
+
+# Add the parent directory to the Python path
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+
+from sqlalchemy.orm import sessionmaker, Session
+from app.db.database import engine
+from app.db.models import DBTeam, DBUser, DBProduct, DBTeamProduct
+from app.core.security import get_password_hash
+
+def create_test_data():
+ """Create test data for teams, users, and products"""
+
+ # Create database session
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
+ db = SessionLocal()
+
+ try:
+ print("Creating test data...")
+
+ # Check for existing test data on a case-by-case basis
+ existing_teams = {}
+ for team_name in [
+ "Test Team 1",
+ "Test Team 2 - Always Free",
+ "Test Team 3 - With Product",
+ "Test Team 4 - With Payment History",
+ "Test Team 5 - No Products"
+ ]:
+ existing_team = db.query(DBTeam).filter(DBTeam.name == team_name).first()
+ if existing_team:
+ existing_teams[team_name] = existing_team
+ print(f"⚠️ {team_name} already exists (ID: {existing_team.id})")
+ else:
+ print(f"✅ {team_name} will be created")
+
+ # Check if products exist, create one if none exist
+ existing_products = db.query(DBProduct).all()
+ if not existing_products:
+ print("No products found, creating a sample product...")
+ sample_product = DBProduct(
+ id="prod_test_sample",
+ name="Test Sample Product",
+ user_count=10,
+ keys_per_user=5,
+ total_key_count=50,
+ service_key_count=10,
+ max_budget_per_key=100.0,
+ rpm_per_key=1000,
+ vector_db_count=5,
+ vector_db_storage=500,
+ renewal_period_days=30,
+ active=True,
+ created_at=datetime.now(UTC)
+ )
+ db.add(sample_product)
+ db.commit()
+ db.refresh(sample_product)
+ print(f"Created sample product: {sample_product.name} (ID: {sample_product.id})")
+ existing_products = [sample_product]
+
+ # Pick a random product for teams that need products
+ selected_product = random.choice(existing_products)
+ print(f"Selected product for teams: {selected_product.name} (ID: {selected_product.id})")
+
+ # 1. Team with one user, created 32 days ago
+ if "Test Team 1" not in existing_teams:
+ print("\n1. Creating team with one user (created 32 days ago)...")
+ team1 = DBTeam(
+ name="Test Team 1",
+ admin_email="admin1@test.com",
+ is_active=True,
+ is_always_free=False,
+ created_at=datetime.now(UTC) - timedelta(days=32)
+ )
+ db.add(team1)
+ db.commit()
+ db.refresh(team1)
+
+ user1 = DBUser(
+ email="user1@test.com",
+ hashed_password=get_password_hash("testpassword123"),
+ is_active=True,
+ is_admin=False,
+ role="user",
+ team_id=team1.id,
+ created_at=datetime.now(UTC) - timedelta(days=32)
+ )
+ db.add(user1)
+ db.commit()
+ print(f" Created team: {team1.name} (ID: {team1.id})")
+ print(f" Created user: {user1.email} (ID: {user1.id})")
+ else:
+ team1 = existing_teams["Test Team 1"]
+ print(f"\n1. Team 1 already exists: {team1.name} (ID: {team1.id})")
+
+ # 2. Team with one user, always_free=True, created 20 days ago
+ if "Test Team 2 - Always Free" not in existing_teams:
+ print("\n2. Creating team with one user, always_free=True (created 20 days ago)...")
+ team2 = DBTeam(
+ name="Test Team 2 - Always Free",
+ admin_email="admin2@test.com",
+ is_active=True,
+ is_always_free=True,
+ created_at=datetime.now(UTC) - timedelta(days=20)
+ )
+ db.add(team2)
+ db.commit()
+ db.refresh(team2)
+
+ user2 = DBUser(
+ email="user2@test.com",
+ hashed_password=get_password_hash("testpassword123"),
+ is_active=True,
+ is_admin=False,
+ role="user",
+ team_id=team2.id,
+ created_at=datetime.now(UTC) - timedelta(days=20)
+ )
+ db.add(user2)
+ db.commit()
+ print(f" Created team: {team2.name} (ID: {team2.id}) - always_free: {team2.is_always_free}")
+ print(f" Created user: {user2.email} (ID: {user2.id})")
+ else:
+ team2 = existing_teams["Test Team 2 - Always Free"]
+ print(f"\n2. Team 2 already exists: {team2.name} (ID: {team2.id})")
+
+ # 3. Team with one user and product association
+ if "Test Team 3 - With Product" not in existing_teams:
+ print("\n3. Creating team with one user and product association...")
+ team3 = DBTeam(
+ name="Test Team 3 - With Product",
+ admin_email="admin3@test.com",
+ is_active=True,
+ is_always_free=False,
+ created_at=datetime.now(UTC)
+ )
+ db.add(team3)
+ db.commit()
+ db.refresh(team3)
+
+ user3 = DBUser(
+ email="user3@test.com",
+ hashed_password=get_password_hash("testpassword123"),
+ is_active=True,
+ is_admin=False,
+ role="user",
+ team_id=team3.id,
+ created_at=datetime.now(UTC)
+ )
+ db.add(user3)
+ db.commit()
+
+ # Create team-product association
+ team_product = DBTeamProduct(
+ team_id=team3.id,
+ product_id=selected_product.id
+ )
+ db.add(team_product)
+ db.commit()
+
+ print(f" Created team: {team3.name} (ID: {team3.id})")
+ print(f" Created user: {user3.email} (ID: {user3.id})")
+ print(f" Associated with product: {selected_product.name} (ID: {selected_product.id})")
+ else:
+ team3 = existing_teams["Test Team 3 - With Product"]
+ print(f"\n3. Team 3 already exists: {team3.name} (ID: {team3.id})")
+
+ # 4. Team with one user, created 40 days ago, with payment 35 days ago, and product association
+ if "Test Team 4 - With Payment History" not in existing_teams:
+ print("\n4. Creating team with one user, payment history, and product association...")
+ team4 = DBTeam(
+ name="Test Team 4 - With Payment History",
+ admin_email="admin4@test.com",
+ is_active=True,
+ is_always_free=False,
+ created_at=datetime.now(UTC) - timedelta(days=40),
+ last_payment=datetime.now(UTC) - timedelta(days=35)
+ )
+ db.add(team4)
+ db.commit()
+ db.refresh(team4)
+
+ user4 = DBUser(
+ email="user4@test.com",
+ hashed_password=get_password_hash("testpassword123"),
+ is_active=True,
+ is_admin=False,
+ role="user",
+ team_id=team4.id,
+ created_at=datetime.now(UTC) - timedelta(days=40)
+ )
+ db.add(user4)
+ db.commit()
+
+ # Create team-product association
+ team_product4 = DBTeamProduct(
+ team_id=team4.id,
+ product_id=selected_product.id
+ )
+ db.add(team_product4)
+ db.commit()
+
+ print(f" Created team: {team4.name} (ID: {team4.id})")
+ print(f" Created user: {user4.email} (ID: {user4.id})")
+ print(f" Payment made: {team4.last_payment.strftime('%Y-%m-%d')}")
+ print(f" Associated with product: {selected_product.name} (ID: {selected_product.id})")
+ else:
+ team4 = existing_teams["Test Team 4 - With Payment History"]
+ print(f"\n4. Team 4 already exists: {team4.name} (ID: {team4.id})")
+
+ # 5. Team with one user, no products, created 20 days ago
+ if "Test Team 5 - No Products" not in existing_teams:
+ print("\n5. Creating team with one user, no products (created 20 days ago)...")
+ team5 = DBTeam(
+ name="Test Team 5 - No Products",
+ admin_email="admin5@test.com",
+ is_active=True,
+ is_always_free=False,
+ created_at=datetime.now(UTC) - timedelta(days=20)
+ )
+ db.add(team5)
+ db.commit()
+ db.refresh(team5)
+
+ user5 = DBUser(
+ email="user5@test.com",
+ hashed_password=get_password_hash("testpassword123"),
+ is_active=True,
+ is_admin=False,
+ role="user",
+ team_id=team5.id,
+ created_at=datetime.now(UTC) - timedelta(days=20)
+ )
+ db.add(user5)
+ db.commit()
+
+ print(f" Created team: {team5.name} (ID: {team5.id})")
+ print(f" Created user: {user5.email} (ID: {user5.id})")
+ print(f" No products associated")
+ else:
+ team5 = existing_teams["Test Team 5 - No Products"]
+ print(f"\n5. Team 5 already exists: {team5.name} (ID: {team5.id})")
+
+ print("\n✅ Test data created successfully!")
+ print(f"\nSummary:")
+ print(f"- Team 1: {team1.name} (created 32 days ago)")
+ print(f"- Team 2: {team2.name} (always_free=True, created 20 days ago)")
+ print(f"- Team 3: {team3.name} (with product association)")
+ print(f"- Team 4: {team4.name} (payment history, product association, created 40 days ago)")
+ print(f"- Team 5: {team5.name} (no products, created 20 days ago)")
+ print(f"- Total users created: 5")
+ print(f"- Product used: {selected_product.name}")
+
+ except Exception as e:
+ print(f"❌ Error creating test data: {str(e)}")
+ db.rollback()
+ raise
+ finally:
+ db.close()
+
+def main():
+ """Main function to run the script"""
+ try:
+ create_test_data()
+ except Exception as e:
+ print(f"Script failed: {str(e)}")
+ sys.exit(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/conftest.py b/tests/conftest.py
index 9d8916a..40cc50f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -143,7 +143,7 @@ def test_team_user(db, test_team):
hashed_password=get_password_hash("password123"),
is_active=True,
is_admin=False,
- role="user",
+ role="key_creator",
team_id=test_team.id,
created_at=datetime.now(UTC)
)
diff --git a/tests/test_rbac.py b/tests/test_rbac.py
new file mode 100644
index 0000000..3287248
--- /dev/null
+++ b/tests/test_rbac.py
@@ -0,0 +1,235 @@
+import pytest
+from fastapi import HTTPException
+from app.core.rbac import RBACDependency, require_system_admin, require_team_admin, require_key_creator_or_higher, require_read_only_or_higher, require_sales_or_higher
+from app.core.roles import UserRole
+from app.db.models import DBUser
+
+class TestRBACDependency:
+ """Test RBAC dependency functionality"""
+
+ def test_system_admin_access(self):
+ """
+ Given a system admin user
+ When checking access to system admin endpoint
+ Then access should be granted and return system_admin role
+ """
+ user = DBUser(id=1, email="admin@test.com", is_admin=True, team_id=None, role=None)
+ dependency = require_system_admin()
+
+ result = dependency.check_access(user)
+ assert result == UserRole.SYSTEM_ADMIN
+
+ def test_team_admin_access(self):
+ """
+ Given a team admin user
+ When checking access to team admin endpoint
+ Then access should be granted and return admin role
+ """
+ user = DBUser(id=1, email="teamadmin@test.com", is_admin=False, team_id=1, role="admin")
+ dependency = require_team_admin()
+
+ result = dependency.check_access(user)
+ assert result == UserRole.TEAM_ADMIN
+
+ def test_key_creator_access(self):
+ """
+ Given a key creator user
+ When checking access to key creator endpoint
+ Then access should be granted and return key_creator role
+ """
+ user = DBUser(id=1, email="keycreator@test.com", is_admin=False, team_id=1, role="key_creator")
+ dependency = require_key_creator_or_higher()
+
+ result = dependency.check_access(user)
+ assert result == UserRole.KEY_CREATOR
+
+ def test_read_only_access(self):
+ """
+ Given a read only user
+ When checking access to read only endpoint
+ Then access should be granted and return read_only role
+ """
+ user = DBUser(id=1, email="readonly@test.com", is_admin=False, team_id=1, role="read_only")
+ dependency = require_read_only_or_higher()
+
+ result = dependency.check_access(user)
+ assert result == UserRole.READ_ONLY
+
+ def test_system_user_cannot_be_team_member(self):
+ """
+ Given a system admin user with team_id set
+ When checking access to any endpoint
+ Then access should be denied due to invalid user type
+ """
+ user = DBUser(id=1, email="admin@test.com", is_admin=True, team_id=1, role=None)
+ dependency = require_system_admin()
+
+ with pytest.raises(HTTPException) as exc_info:
+ dependency.check_access(user)
+
+ assert exc_info.value.status_code == 403
+ assert "Not authorized to perform this action" in str(exc_info.value.detail)
+
+ def test_team_user_must_be_team_member(self):
+ """
+ Given a team user without team_id
+ When checking access to team endpoint
+ Then access should be denied due to invalid user type
+ """
+ user = DBUser(id=1, email="user@test.com", is_admin=False, team_id=None, role="admin")
+ dependency = require_team_admin()
+
+ with pytest.raises(HTTPException) as exc_info:
+ dependency.check_access(user)
+
+ assert exc_info.value.status_code == 403
+ assert "Not authorized to perform this action" in str(exc_info.value.detail)
+
+ def test_insufficient_permissions(self):
+ """
+ Given a read only user
+ When checking access to team admin endpoint
+ Then access should be denied due to insufficient permissions
+ """
+ user = DBUser(id=1, email="readonly@test.com", is_admin=False, team_id=1, role="read_only")
+ dependency = require_team_admin()
+
+ with pytest.raises(HTTPException) as exc_info:
+ dependency.check_access(user)
+
+ assert exc_info.value.status_code == 403
+ assert "Not authorized to perform this action" in str(exc_info.value.detail)
+
+ def test_team_membership_required(self):
+ """
+ Given a system admin
+ When checking access to team endpoint
+ Then access should be granted because system admins can do anything
+ """
+ user = DBUser(id=1, email="admin@test.com", is_admin=True, team_id=None, role=None)
+ dependency = require_team_admin()
+
+ result = dependency.check_access(user)
+ assert result == UserRole.SYSTEM_ADMIN
+
+ def test_sales_user_access(self):
+ """
+ Given a sales user
+ When checking access to sales endpoint
+ Then access should be granted and return sales role
+ """
+ user = DBUser(id=1, email="sales@test.com", is_admin=False, team_id=None, role="sales")
+ dependency = require_sales_or_higher()
+
+ result = dependency.check_access(user)
+ assert result == UserRole.SALES
+
+ def test_system_admin_access_sales_endpoint(self):
+ """
+ Given a system admin user
+ When checking access to sales endpoint
+ Then access should be granted and return system_admin role
+ """
+ user = DBUser(id=1, email="admin@test.com", is_admin=True, team_id=None, role=None)
+ dependency = require_sales_or_higher()
+
+ result = dependency.check_access(user)
+ assert result == UserRole.SYSTEM_ADMIN
+
+ def test_regular_user_denied_sales_access(self):
+ """
+ Given a regular system user
+ When checking access to sales endpoint
+ Then access should be denied due to insufficient permissions
+ """
+ user = DBUser(id=1, email="user@test.com", is_admin=False, team_id=None, role="user")
+ dependency = require_sales_or_higher()
+
+ with pytest.raises(HTTPException) as exc_info:
+ dependency.check_access(user)
+
+ assert exc_info.value.status_code == 403
+ assert "Not authorized to perform this action" in str(exc_info.value.detail)
+
+ def test_team_user_denied_sales_access(self):
+ """
+ Given a team user
+ When checking access to sales endpoint
+ Then access should be denied due to invalid user type
+ """
+ user = DBUser(id=1, email="teamuser@test.com", is_admin=False, team_id=1, role="admin")
+ dependency = require_sales_or_higher()
+
+ with pytest.raises(HTTPException) as exc_info:
+ dependency.check_access(user)
+
+ assert exc_info.value.status_code == 403
+ assert "Not authorized to perform this action" in str(exc_info.value.detail)
+
+class TestUserRole:
+ """Test UserRole class functionality"""
+
+ def test_system_roles(self):
+ """
+ Given various roles
+ When checking if they are system roles
+ Then only system roles should return True
+ """
+ assert UserRole.is_system_role("system_admin")
+ assert UserRole.is_system_role("user")
+ assert UserRole.is_system_role("sales")
+ assert not UserRole.is_system_role("admin")
+ assert not UserRole.is_system_role("key_creator")
+ assert not UserRole.is_system_role("read_only")
+
+ def test_team_roles(self):
+ """
+ Given various roles
+ When checking if they are team roles
+ Then only team roles should return True
+ """
+ assert UserRole.is_team_role("admin")
+ assert UserRole.is_team_role("key_creator")
+ assert UserRole.is_team_role("read_only")
+ assert not UserRole.is_team_role("system_admin")
+ assert not UserRole.is_team_role("user")
+ assert not UserRole.is_team_role("sales")
+
+ def test_get_system_roles(self):
+ """
+ Given the UserRole class
+ When getting system roles
+ Then it should return all valid system roles
+ """
+ roles = UserRole.get_system_roles()
+ assert "system_admin" in roles
+ assert "user" in roles
+ assert "sales" in roles
+ assert len(roles) == 3
+
+ def test_get_team_roles(self):
+ """
+ Given the UserRole class
+ When getting team roles
+ Then it should return all valid team roles
+ """
+ roles = UserRole.get_team_roles()
+ assert "admin" in roles
+ assert "key_creator" in roles
+ assert "read_only" in roles
+ assert len(roles) == 3
+
+ def test_get_all_roles(self):
+ """
+ Given the UserRole class
+ When getting all roles
+ Then it should return both system and team roles
+ """
+ roles = UserRole.get_all_roles()
+ assert len(roles) == 6
+ assert "system_admin" in roles
+ assert "user" in roles
+ assert "sales" in roles
+ assert "admin" in roles
+ assert "key_creator" in roles
+ assert "read_only" in roles
diff --git a/tests/test_sales_api.py b/tests/test_sales_api.py
new file mode 100644
index 0000000..ed29291
--- /dev/null
+++ b/tests/test_sales_api.py
@@ -0,0 +1,537 @@
+"""
+Tests for the consolidated sales API endpoint.
+"""
+import pytest
+from datetime import datetime, UTC, timedelta
+from unittest.mock import patch, AsyncMock
+from sqlalchemy.orm import Session
+from app.db.models import DBTeam, DBProduct, DBTeamProduct, DBPrivateAIKey, DBRegion, DBUser
+
+
+@pytest.fixture
+def test_ai_key(db: Session, test_team: DBTeam, test_region: DBRegion) -> DBPrivateAIKey:
+ """Create a test AI key."""
+ ai_key = DBPrivateAIKey(
+ database_name="test-db",
+ name="Test Key",
+ database_host="test-host",
+ database_username="test-user",
+ database_password="test-pass",
+ litellm_token="test-token",
+ litellm_api_url="https://test-litellm.com",
+ owner_id=None,
+ team_id=test_team.id,
+ region_id=test_region.id
+ )
+ db.add(ai_key)
+ db.commit()
+ db.refresh(ai_key)
+ return ai_key
+
+
+@pytest.fixture
+def test_always_free_team(db: Session) -> DBTeam:
+ """Create a test always-free team."""
+ team = DBTeam(
+ name="Always Free Team",
+ admin_email="free@test.com",
+ is_active=True,
+ is_always_free=True,
+ created_at=datetime.now(UTC) - timedelta(days=45), # 45 days ago
+ last_payment=None
+ )
+ db.add(team)
+ db.commit()
+ db.refresh(team)
+ return team
+
+
+@pytest.fixture
+def test_paid_team(db: Session) -> DBTeam:
+ """Create a test team with payment history."""
+ team = DBTeam(
+ name="Paid Team",
+ admin_email="paid@test.com",
+ is_active=True,
+ is_always_free=False,
+ created_at=datetime.now(UTC) - timedelta(days=60), # 60 days ago
+ last_payment=datetime.now(UTC) - timedelta(days=5) # 5 days ago
+ )
+ db.add(team)
+ db.commit()
+ db.refresh(team)
+ return team
+
+
+@pytest.fixture
+def mock_litellm_response():
+ """Mock LiteLLM API response."""
+ return {
+ "info": {
+ "spend": 25.50,
+ "expires": "2024-12-31T23:59:59Z",
+ "created_at": "2024-01-01T00:00:00Z",
+ "updated_at": "2024-01-02T00:00:00Z",
+ "max_budget": 100.0,
+ "budget_duration": "monthly",
+ "budget_reset_at": "2024-02-01T00:00:00Z"
+ }
+ }
+
+
+@pytest.fixture
+def test_user_owned_ai_key(db: Session, test_team_user: DBUser, test_region: DBRegion) -> DBPrivateAIKey:
+ """Create a test AI key owned by a team member."""
+ ai_key = DBPrivateAIKey(
+ database_name="user-db",
+ name="User Key",
+ database_host="user-host",
+ database_username="user-user",
+ database_password="user-pass",
+ litellm_token="user-token",
+ litellm_api_url="https://user-litellm.com",
+ owner_id=test_team_user.id,
+ team_id=None, # Not team-owned, user-owned
+ region_id=test_region.id
+ )
+ db.add(ai_key)
+ db.commit()
+ db.refresh(ai_key)
+ return ai_key
+
+
+def test_list_teams_for_sales_requires_admin(client, test_team):
+ """Test that only system admins can access the sales endpoint."""
+ response = client.get("/teams/sales/list-teams")
+ assert response.status_code == 401 # Unauthorized
+
+
+def test_list_teams_for_sales_success(client, admin_token, test_team, test_product,
+ test_region, test_ai_key, mock_litellm_response, db):
+ """Test successful retrieval of sales data."""
+ # Create team-product association
+ team_product = DBTeamProduct(
+ team_id=test_team.id,
+ product_id=test_product.id
+ )
+ db.add(team_product)
+ db.commit()
+
+ # Mock LiteLLM service
+ with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info:
+ mock_get_info.return_value = mock_litellm_response
+
+ response = client.get(
+ "/teams/sales/list-teams",
+ headers={"Authorization": f"Bearer {admin_token}"}
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+ assert "teams" in data
+ assert len(data["teams"]) == 1
+
+ team_data = data["teams"][0]
+ assert team_data["id"] == test_team.id
+ assert team_data["name"] == test_team.name
+ assert team_data["admin_email"] == test_team.admin_email
+ assert team_data["is_always_free"] == False
+ assert len(team_data["products"]) == 1
+ assert team_data["products"][0]["id"] == test_product.id
+ assert team_data["products"][0]["name"] == test_product.name
+ assert team_data["products"][0]["active"] == True
+ assert len(team_data["regions"]) == 1
+ assert team_data["regions"][0] == test_region.name
+ assert team_data["total_spend"] == 25.50
+ assert team_data["trial_status"] == "Active Product"
+
+
+def test_always_free_team_trial_status(client, admin_token, test_always_free_team,
+ test_product, test_region, test_ai_key, db):
+ """Test that always-free teams show correct trial status."""
+ # Create team-product association
+ team_product = DBTeamProduct(
+ team_id=test_always_free_team.id,
+ product_id=test_product.id
+ )
+ db.add(team_product)
+ db.commit()
+
+ # Mock LiteLLM service
+ with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info:
+ mock_get_info.return_value = {"info": {"spend": 0.0}}
+
+ response = client.get(
+ "/teams/sales/list-teams",
+ headers={"Authorization": f"Bearer {admin_token}"}
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+ team_data = data["teams"][0]
+ assert team_data["trial_status"] == "Always Free"
+
+
+def test_paid_team_trial_status(client, admin_token, test_paid_team,
+ test_product, test_region, test_ai_key, db):
+ """Test that teams with payment history show correct trial status."""
+ # Create team-product association
+ team_product = DBTeamProduct(
+ team_id=test_paid_team.id,
+ product_id=test_product.id
+ )
+ db.add(team_product)
+ db.commit()
+
+ # Mock LiteLLM service
+ with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info:
+ mock_get_info.return_value = {"info": {"spend": 0.0}}
+
+ response = client.get(
+ "/teams/sales/list-teams",
+ headers={"Authorization": f"Bearer {admin_token}"}
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+ assert "teams" in data
+ team_data = data["teams"][0]
+ assert team_data["trial_status"] == "Active Product"
+
+
+def test_team_without_products(client, admin_token, test_team, test_region, test_ai_key, db):
+ """Test team without any products shows trial status based on creation date."""
+ # Mock LiteLLM service
+ with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info:
+ mock_get_info.return_value = {"info": {"spend": 0.0}}
+
+ response = client.get(
+ "/teams/sales/list-teams",
+ headers={"Authorization": f"Bearer {admin_token}"}
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+ assert "teams" in data
+ team_data = data["teams"][0]
+ # Team without products should show trial status based on creation date
+ assert team_data["trial_status"] == "30 days left"
+ assert len(team_data["products"]) == 0
+
+
+def test_team_with_multiple_ai_keys(client, admin_token, test_team, test_product,
+ test_region, test_ai_key, db):
+ """Test team with multiple AI keys aggregates spend correctly."""
+ # Create second AI key
+ ai_key2 = DBPrivateAIKey(
+ database_name="test-db-2",
+ name="Test Key 2",
+ database_host="test-host",
+ database_username="test-user",
+ database_password="test-pass",
+ litellm_token="test-token-2",
+ litellm_api_url="https://test-litellm.com",
+ owner_id=None,
+ team_id=test_team.id,
+ region_id=test_region.id
+ )
+ db.add(ai_key2)
+ db.commit()
+
+ # Create team-product association
+ team_product = DBTeamProduct(
+ team_id=test_team.id,
+ product_id=test_product.id
+ )
+ db.add(team_product)
+ db.commit()
+
+ # Mock LiteLLM service with different spend values
+ with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info:
+ mock_get_info.side_effect = [
+ {"info": {"spend": 25.50}},
+ {"info": {"spend": 15.25}}
+ ]
+
+ response = client.get(
+ "/teams/sales/list-teams",
+ headers={"Authorization": f"Bearer {admin_token}"}
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+ team_data = data["teams"][0]
+ assert team_data["total_spend"] == 40.75 # 25.50 + 15.25
+
+
+def test_litellm_service_error_handling(client, admin_token, test_team, test_product,
+ test_region, test_ai_key, db):
+ """Test that LiteLLM service errors don't break the entire response."""
+ # Create team-product association
+ team_product = DBTeamProduct(
+ team_id=test_team.id,
+ product_id=test_product.id
+ )
+ db.add(team_product)
+ db.commit()
+
+ # Mock LiteLLM service to raise an exception
+ with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info:
+ mock_get_info.side_effect = Exception("LiteLLM service error")
+
+ response = client.get(
+ "/teams/sales/list-teams",
+ headers={"Authorization": f"Bearer {admin_token}"}
+ )
+
+ # Should still succeed but with 0 spend
+ assert response.status_code == 200
+ data = response.json()
+ team_data = data["teams"][0]
+ assert team_data["total_spend"] == 0.0
+
+
+def test_expired_trial_status(client, admin_token, test_team, test_region, test_ai_key, db):
+ """Test that expired trials show correct status."""
+ # Update team to be older than 30 days (no products, so should show expired)
+ test_team.created_at = datetime.now(UTC) - timedelta(days=35)
+ db.commit()
+
+ # Mock LiteLLM service
+ with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info:
+ mock_get_info.return_value = {"info": {"spend": 0.0}}
+
+ response = client.get(
+ "/teams/sales/list-teams",
+ headers={"Authorization": f"Bearer {admin_token}"}
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+ team_data = data["teams"][0]
+ assert team_data["trial_status"] == "Expired"
+
+
+def test_list_teams_for_sales_includes_user_owned_keys(client, admin_token, test_team, test_product,
+ test_region, test_ai_key, test_team_user,
+ test_user_owned_ai_key, mock_litellm_response, db):
+ """Test that sales data includes both team-owned and user-owned AI keys."""
+ # Create team-product association
+ team_product = DBTeamProduct(
+ team_id=test_team.id,
+ product_id=test_product.id
+ )
+ db.add(team_product)
+ db.commit()
+
+ # Mock LiteLLM service to return different spend for each key
+ with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info:
+ # Return different spend values for team key vs user key
+ mock_get_info.side_effect = [
+ {"info": {"spend": 25.50}}, # Team key spend
+ {"info": {"spend": 15.25}} # User key spend
+ ]
+
+ response = client.get(
+ "/teams/sales/list-teams",
+ headers={"Authorization": f"Bearer {admin_token}"}
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+ assert "teams" in data
+ assert len(data["teams"]) == 1
+
+ team_data = data["teams"][0]
+ assert team_data["id"] == test_team.id
+
+ # Should include both keys in total spend (25.50 + 15.25 = 40.75)
+ assert team_data["total_spend"] == 40.75
+
+ # Should include region from both keys
+ assert len(team_data["regions"]) == 1
+ assert team_data["regions"][0] == test_region.name
+
+
+def test_team_with_15_days_remaining(client, admin_token, test_team, test_region, test_ai_key, db):
+ """Test that teams with 15 days remaining show correct format."""
+ # Update team to be 15 days old
+ test_team.created_at = datetime.now(UTC) - timedelta(days=15)
+ db.commit()
+
+ # Mock LiteLLM service
+ with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info:
+ mock_get_info.return_value = {"info": {"spend": 0.0}}
+
+ response = client.get(
+ "/teams/sales/list-teams",
+ headers={"Authorization": f"Bearer {admin_token}"}
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+ team_data = data["teams"][0]
+ # Should show exact days remaining, not "In Progress"
+ assert team_data["trial_status"] == "15 days left"
+
+
+def test_team_with_7_days_remaining(client, admin_token, test_team, test_region, test_ai_key, db):
+ """Test that teams with 7 days remaining show correct format."""
+ # Update team to be 23 days old (7 days remaining)
+ test_team.created_at = datetime.now(UTC) - timedelta(days=23)
+ db.commit()
+
+ # Mock LiteLLM service
+ with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info:
+ mock_get_info.return_value = {"info": {"spend": 0.0}}
+
+ response = client.get(
+ "/teams/sales/list-teams",
+ headers={"Authorization": f"Bearer {admin_token}"}
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+ team_data = data["teams"][0]
+ # Should show exact days remaining
+ assert team_data["trial_status"] == "7 days left"
+
+
+def test_team_with_last_payment_days_calculation(client, admin_token, test_team, test_region, test_ai_key, db):
+ """Test that teams with last_payment calculate days remaining from payment date."""
+ # Set last_payment to 10 days ago (so 20 days remaining from payment)
+ test_team.last_payment = datetime.now(UTC) - timedelta(days=10)
+ db.commit()
+
+ # Mock LiteLLM service
+ with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info:
+ mock_get_info.return_value = {"info": {"spend": 0.0}}
+
+ response = client.get(
+ "/teams/sales/list-teams",
+ headers={"Authorization": f"Bearer {admin_token}"}
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+ team_data = data["teams"][0]
+ # Should calculate from last_payment date (30 - 10 = 20 days remaining)
+ assert team_data["trial_status"] == "20 days left"
+
+
+@patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock)
+def test_list_teams_for_sales_unreachable_endpoints_logging(mock_get_info, client, admin_token, test_team, test_product,
+ test_region, test_ai_key, db):
+ """
+ Given a team with AI keys in a region where LiteLLM endpoint is unreachable
+ When calling list_teams_for_sales
+ Then the function should default spend to 0 and still return successful response
+ """
+ # Create team-product association
+ team_product = DBTeamProduct(
+ team_id=test_team.id,
+ product_id=test_product.id
+ )
+ db.add(team_product)
+
+ # Create a second AI key in the same region to test multiple keys
+ ai_key2 = DBPrivateAIKey(
+ database_name="test-db-2",
+ name="Test Key 2",
+ database_host="test-host",
+ database_username="test-user",
+ database_password="test-pass",
+ litellm_token="test-token-2",
+ litellm_api_url="https://test-litellm.com",
+ owner_id=None,
+ team_id=test_team.id,
+ region_id=test_region.id
+ )
+ db.add(ai_key2)
+ db.commit()
+
+ # Mock LiteLLM service to raise an exception (simulating unreachable endpoint)
+ mock_get_info.side_effect = Exception("Connection timeout")
+
+ response = client.get(
+ "/teams/sales/list-teams",
+ headers={"Authorization": f"Bearer {admin_token}"}
+ )
+
+ # Should still succeed despite LiteLLM errors
+ assert response.status_code == 200
+ data = response.json()
+ team_data = data["teams"][0]
+
+ # Spend should default to 0 for failed litellm calls
+ assert team_data["total_spend"] == 0.0
+
+
+@patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock)
+def test_list_teams_for_sales_multiple_unreachable_regions(mock_get_info, client, admin_token, test_team, test_product,
+ test_region, test_ai_key, db):
+ """
+ Given a team with AI keys in multiple regions where some endpoints are unreachable
+ When calling list_teams_for_sales
+ Then the function should default spend to 0 and include all regions in response
+ """
+ # Create a second region
+ region2 = DBRegion(
+ name="Test Region 2",
+ postgres_host="test-host-2",
+ postgres_port=5432,
+ postgres_admin_user="test-user-2",
+ postgres_admin_password="test-pass-2",
+ litellm_api_url="https://test-litellm-2.com",
+ litellm_api_key="test-key-2",
+ is_active=True
+ )
+ db.add(region2)
+ db.commit()
+
+ # Create team-product association
+ team_product = DBTeamProduct(
+ team_id=test_team.id,
+ product_id=test_product.id
+ )
+ db.add(team_product)
+
+ # Create AI key in second region
+ ai_key2 = DBPrivateAIKey(
+ database_name="test-db-2",
+ name="Test Key 2",
+ database_host="test-host-2",
+ database_username="test-user-2",
+ database_password="test-pass-2",
+ litellm_token="test-token-2",
+ litellm_api_url="https://test-litellm-2.com",
+ owner_id=None,
+ team_id=test_team.id,
+ region_id=region2.id
+ )
+ db.add(ai_key2)
+ db.commit()
+
+ # Mock LiteLLM service to raise different exceptions for different regions
+ mock_get_info.side_effect = [
+ Exception("Connection timeout"), # First region fails
+ Exception("Service unavailable") # Second region fails
+ ]
+
+ response = client.get(
+ "/teams/sales/list-teams",
+ headers={"Authorization": f"Bearer {admin_token}"}
+ )
+
+ # Should still succeed despite LiteLLM errors
+ assert response.status_code == 200
+ data = response.json()
+ team_data = data["teams"][0]
+
+ # Spend should default to 0 for failed litellm calls
+ assert team_data["total_spend"] == 0.0
+
+ # Should include both regions in the response
+ assert len(team_data["regions"]) == 2
+ assert test_region.name in team_data["regions"]
+ assert region2.name in team_data["regions"]
diff --git a/tests/test_security.py b/tests/test_security.py
new file mode 100644
index 0000000..21178df
--- /dev/null
+++ b/tests/test_security.py
@@ -0,0 +1,91 @@
+import pytest
+from unittest.mock import AsyncMock, patch
+from fastapi import HTTPException
+from app.core.security import check_sales_or_higher, get_role_min_system_admin
+from app.core.roles import UserRole
+from app.db.models import DBUser
+
+
+class TestSecurityFunctions:
+ """Test security function functionality"""
+
+ @pytest.mark.asyncio
+ async def test_check_sales_or_higher_sales_user(self):
+ """
+ Given a sales user
+ When calling check_sales_or_higher
+ Then it should return sales role
+ """
+ user = DBUser(id=1, email="sales@test.com", is_admin=False, team_id=None, role="sales")
+
+ result = await check_sales_or_higher(current_user=user)
+ assert result == UserRole.SALES
+
+ @pytest.mark.asyncio
+ async def test_check_sales_or_higher_system_admin(self):
+ """
+ Given a system admin user
+ When calling check_sales_or_higher
+ Then it should return system_admin role
+ """
+ user = DBUser(id=1, email="admin@test.com", is_admin=True, team_id=None, role=None)
+
+ result = await check_sales_or_higher(current_user=user)
+ assert result == UserRole.SYSTEM_ADMIN
+
+ @pytest.mark.asyncio
+ async def test_check_sales_or_higher_regular_user_denied(self):
+ """
+ Given a regular system user
+ When calling check_sales_or_higher
+ Then it should raise HTTPException with 403 status
+ """
+ user = DBUser(id=1, email="user@test.com", is_admin=False, team_id=None, role="user")
+
+ with pytest.raises(HTTPException) as exc_info:
+ await check_sales_or_higher(current_user=user)
+
+ assert exc_info.value.status_code == 403
+ assert "Not authorized to perform this action" in str(exc_info.value.detail)
+
+ @pytest.mark.asyncio
+ async def test_check_sales_or_higher_team_user_denied(self):
+ """
+ Given a team user
+ When calling check_sales_or_higher
+ Then it should raise HTTPException with 403 status
+ """
+ user = DBUser(id=1, email="teamuser@test.com", is_admin=False, team_id=1, role="admin")
+
+ with pytest.raises(HTTPException) as exc_info:
+ await check_sales_or_higher(current_user=user)
+
+ assert exc_info.value.status_code == 403
+ assert "Not authorized to perform this action" in str(exc_info.value.detail)
+
+ @pytest.mark.asyncio
+ async def test_check_system_admin_system_admin(self):
+ """
+ Given a system admin user
+ When calling check_system_admin
+ Then it should return system_admin role
+ """
+ user = DBUser(id=1, email="admin@test.com", is_admin=True, team_id=None, role=None)
+
+ result = await get_role_min_system_admin(current_user=user)
+ assert result == UserRole.SYSTEM_ADMIN
+
+ @pytest.mark.asyncio
+ async def test_check_system_admin_regular_user_denied(self):
+ """
+ Given a regular system user
+ When calling check_system_admin
+ Then it should raise HTTPException with 403 status
+ """
+ user = DBUser(id=1, email="user@test.com", is_admin=False, team_id=None, role="user")
+
+ with pytest.raises(HTTPException) as exc_info:
+ await get_role_min_system_admin(current_user=user)
+
+ assert exc_info.value.status_code == 403
+ assert "Not authorized to perform this action" in str(exc_info.value.detail)