diff --git a/app/api/billing.py b/app/api/billing.py index 981d4ed..f32f6b5 100644 --- a/app/api/billing.py +++ b/app/api/billing.py @@ -4,7 +4,7 @@ import os from datetime import datetime, UTC from app.db.database import get_db -from app.core.security import check_specific_team_admin, check_system_admin +from app.core.security import get_role_min_specific_team_admin, get_role_min_system_admin from app.db.models import DBTeam, DBSystemSecret, DBProduct, DBTeamProduct from app.schemas.models import PricingTableSession, SubscriptionCreate, SubscriptionResponse from app.services.stripe import ( @@ -91,7 +91,7 @@ async def handle_events( detail="Error processing webhook" ) -@router.post("/teams/{team_id}/portal", dependencies=[Depends(check_specific_team_admin)]) +@router.post("/teams/{team_id}/portal", dependencies=[Depends(get_role_min_specific_team_admin)]) async def get_portal( team_id: int, db: Session = Depends(get_db) @@ -135,7 +135,7 @@ async def get_portal( detail="Error creating portal session" ) -@router.get("/teams/{team_id}/pricing-table-session", dependencies=[Depends(check_specific_team_admin)], response_model=PricingTableSession) +@router.get("/teams/{team_id}/pricing-table-session", dependencies=[Depends(get_role_min_specific_team_admin)], response_model=PricingTableSession) async def get_pricing_table_session( team_id: int, db: Session = Depends(get_db) @@ -178,7 +178,7 @@ async def get_pricing_table_session( detail="Error creating customer session" ) -@router.post("/teams/{team_id}/subscriptions", dependencies=[Depends(check_system_admin)], response_model=SubscriptionResponse, status_code=status.HTTP_201_CREATED) +@router.post("/teams/{team_id}/subscriptions", dependencies=[Depends(get_role_min_system_admin)], response_model=SubscriptionResponse, status_code=status.HTTP_201_CREATED) async def create_team_subscription( team_id: int, subscription_data: SubscriptionCreate, diff --git a/app/api/pricing_tables.py b/app/api/pricing_tables.py index d77d30b..be8449b 100644 --- a/app/api/pricing_tables.py +++ b/app/api/pricing_tables.py @@ -5,7 +5,7 @@ from app.db.database import get_db from app.db.models import DBTeam, DBPricingTable -from app.core.security import check_system_admin, get_role_min_team_admin, get_current_user_from_auth +from app.core.security import get_role_min_system_admin, get_role_min_team_admin, get_current_user_from_auth from app.schemas.models import PricingTableCreate, PricingTableResponse, PricingTablesResponse from app.core.config import settings @@ -16,8 +16,8 @@ tags=["pricing-tables"] ) -@router.post("", response_model=PricingTableResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(check_system_admin)]) -@router.post("/", response_model=PricingTableResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(check_system_admin)]) +@router.post("", response_model=PricingTableResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(get_role_min_system_admin)]) +@router.post("/", response_model=PricingTableResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(get_role_min_system_admin)]) async def create_pricing_table( pricing_table: PricingTableCreate, db: Session = Depends(get_db) @@ -113,8 +113,8 @@ async def get_pricing_table( updated_at=pricing_table.updated_at or pricing_table.created_at ) -@router.delete("", dependencies=[Depends(check_system_admin)]) -@router.delete("/", dependencies=[Depends(check_system_admin)]) +@router.delete("", dependencies=[Depends(get_role_min_system_admin)]) +@router.delete("/", dependencies=[Depends(get_role_min_system_admin)]) async def delete_pricing_table( table_type: str, db: Session = Depends(get_db) @@ -146,7 +146,7 @@ async def delete_pricing_table( return {"message": f"Pricing table of type '{table_type}' deleted successfully"} -@router.get("/list", response_model=PricingTablesResponse, dependencies=[Depends(check_system_admin)]) +@router.get("/list", response_model=PricingTablesResponse, dependencies=[Depends(get_role_min_system_admin)]) async def get_all_pricing_tables( db: Session = Depends(get_db) ): diff --git a/app/api/private_ai_keys.py b/app/api/private_ai_keys.py index 447d749..bf6eec3 100644 --- a/app/api/private_ai_keys.py +++ b/app/api/private_ai_keys.py @@ -13,9 +13,22 @@ from app.db.postgres import PostgresManager from app.db.models import DBPrivateAIKey, DBRegion, DBUser, DBTeam from app.services.litellm import LiteLLMService -from app.core.security import get_current_user_from_auth, get_role_min_key_creator, get_role_min_team_admin, UserRole, check_system_admin +from app.core.security import ( + get_current_user_from_auth, + get_role_min_team_admin, + get_private_ai_access, + get_role_min_system_admin +) +from app.core.roles import UserRole from app.core.config import settings -from app.core.resource_limits import check_key_limits, check_vector_db_limits, get_token_restrictions, DEFAULT_KEY_DURATION, DEFAULT_MAX_SPEND, DEFAULT_RPM_PER_KEY +from app.core.resource_limits import ( + check_key_limits, + check_vector_db_limits, + get_token_restrictions, + DEFAULT_KEY_DURATION, + DEFAULT_MAX_SPEND, + DEFAULT_RPM_PER_KEY +) router = APIRouter( tags=["private-ai-keys"] @@ -68,7 +81,7 @@ def _validate_permissions_and_get_ownership_info( async def create_vector_db( vector_db: VectorDBCreate, current_user = Depends(get_current_user_from_auth), - user_role: UserRole = Depends(get_role_min_key_creator), + user_role: UserRole = Depends(get_private_ai_access), db: Session = Depends(get_db), store_result: bool = True ): @@ -164,7 +177,7 @@ async def create_vector_db( async def create_private_ai_key( private_ai_key: PrivateAIKeyCreate, current_user = Depends(get_current_user_from_auth), - user_role: UserRole = Depends(get_role_min_key_creator), + user_role: UserRole = Depends(get_private_ai_access), db: Session = Depends(get_db) ): """ @@ -279,7 +292,7 @@ async def create_private_ai_key( async def create_llm_token( private_ai_key: PrivateAIKeyCreate, current_user = Depends(get_current_user_from_auth), - user_role: UserRole = Depends(get_role_min_key_creator), + user_role: UserRole = Depends(get_private_ai_access), db: Session = Depends(get_db), store_result: bool = True ): @@ -464,7 +477,7 @@ async def list_private_ai_keys( private_ai_keys = query.all() return [key.to_dict() for key in private_ai_keys] -@router.get("/{key_id}", response_model=PrivateAIKeyDetail, dependencies=[Depends(check_system_admin)]) +@router.get("/{key_id}", response_model=PrivateAIKeyDetail, dependencies=[Depends(get_role_min_system_admin)]) async def get_private_ai_key( key_id: int, current_user = Depends(get_current_user_from_auth), @@ -571,7 +584,7 @@ def _get_key_if_allowed(key_id: int, current_user: DBUser, user_role: UserRole, async def delete_private_ai_key( key_id: int, current_user = Depends(get_current_user_from_auth), - user_role: UserRole = Depends(get_role_min_key_creator), + user_role: UserRole = Depends(get_private_ai_access), db: Session = Depends(get_db) ): private_ai_key = _get_key_if_allowed(key_id, current_user, user_role, db) diff --git a/app/api/products.py b/app/api/products.py index 68475f8..25e7512 100644 --- a/app/api/products.py +++ b/app/api/products.py @@ -5,15 +5,15 @@ from app.db.database import get_db from app.db.models import DBProduct, DBTeamProduct, DBTeam -from app.core.security import check_system_admin, get_current_user_from_auth, get_role_min_team_admin +from app.core.security import get_role_min_system_admin, get_current_user_from_auth, get_role_min_team_admin from app.schemas.models import Product, ProductCreate, ProductUpdate router = APIRouter( tags=["products"] ) -@router.post("", response_model=Product, status_code=status.HTTP_201_CREATED, dependencies=[Depends(check_system_admin)]) -@router.post("/", response_model=Product, status_code=status.HTTP_201_CREATED, dependencies=[Depends(check_system_admin)]) +@router.post("", response_model=Product, status_code=status.HTTP_201_CREATED, dependencies=[Depends(get_role_min_system_admin)]) +@router.post("/", response_model=Product, status_code=status.HTTP_201_CREATED, dependencies=[Depends(get_role_min_system_admin)]) async def create_product( product: ProductCreate, db: Session = Depends(get_db) @@ -105,7 +105,7 @@ async def get_product( ) return product -@router.put("/{product_id}", response_model=Product, dependencies=[Depends(check_system_admin)]) +@router.put("/{product_id}", response_model=Product, dependencies=[Depends(get_role_min_system_admin)]) async def update_product( product_id: str, product_update: ProductUpdate, @@ -132,7 +132,7 @@ async def update_product( return product -@router.delete("/{product_id}", dependencies=[Depends(check_system_admin)]) +@router.delete("/{product_id}", dependencies=[Depends(get_role_min_system_admin)]) async def delete_product( product_id: str, db: Session = Depends(get_db) diff --git a/app/api/regions.py b/app/api/regions.py index f2d40d4..6c462e6 100644 --- a/app/api/regions.py +++ b/app/api/regions.py @@ -9,7 +9,7 @@ from app.api.auth import get_current_user_from_auth from app.schemas.models import Region, RegionCreate, RegionResponse, User, RegionUpdate, TeamSummary from app.db.models import DBRegion, DBPrivateAIKey, DBTeamRegion, DBTeam -from app.core.security import check_system_admin +from app.core.security import get_role_min_system_admin logger = logging.getLogger(__name__) @@ -89,8 +89,8 @@ async def validate_database_connection(host: str, port: int, user: str, password detail=f"Database connection validation failed: {str(e)}" ) -@router.post("", response_model=Region, dependencies=[Depends(check_system_admin)]) -@router.post("/", response_model=Region, dependencies=[Depends(check_system_admin)]) +@router.post("", response_model=Region, dependencies=[Depends(get_role_min_system_admin)]) +@router.post("/", response_model=Region, dependencies=[Depends(get_role_min_system_admin)]) async def create_region( region: RegionCreate, db: Session = Depends(get_db) @@ -158,13 +158,13 @@ async def list_regions( return non_dedicated_regions + team_dedicated_regions -@router.get("/admin", response_model=List[Region], dependencies=[Depends(check_system_admin)]) +@router.get("/admin", response_model=List[Region], dependencies=[Depends(get_role_min_system_admin)]) async def list_admin_regions( db: Session = Depends(get_db) ): return db.query(DBRegion).all() -@router.get("/{region_id}", response_model=RegionResponse, dependencies=[Depends(check_system_admin)]) +@router.get("/{region_id}", response_model=RegionResponse, dependencies=[Depends(get_role_min_system_admin)]) async def get_region( region_id: int, db: Session = Depends(get_db) @@ -177,7 +177,7 @@ async def get_region( ) return region -@router.delete("/{region_id}", dependencies=[Depends(check_system_admin)]) +@router.delete("/{region_id}", dependencies=[Depends(get_role_min_system_admin)]) async def delete_region( region_id: int, db: Session = Depends(get_db) @@ -209,7 +209,7 @@ async def delete_region( ) return {"message": "Region deleted successfully"} -@router.put("/{region_id}", response_model=Region, dependencies=[Depends(check_system_admin)]) +@router.put("/{region_id}", response_model=Region, dependencies=[Depends(get_role_min_system_admin)]) async def update_region( region_id: int, region: RegionUpdate, @@ -251,7 +251,7 @@ async def update_region( ) return db_region -@router.post("/{region_id}/teams/{team_id}", dependencies=[Depends(check_system_admin)]) +@router.post("/{region_id}/teams/{team_id}", dependencies=[Depends(get_role_min_system_admin)]) async def associate_team_with_region( region_id: int, team_id: int, @@ -311,7 +311,7 @@ async def associate_team_with_region( return {"message": "Team associated with region successfully"} -@router.delete("/{region_id}/teams/{team_id}", dependencies=[Depends(check_system_admin)]) +@router.delete("/{region_id}/teams/{team_id}", dependencies=[Depends(get_role_min_system_admin)]) async def disassociate_team_from_region( region_id: int, team_id: int, @@ -345,7 +345,7 @@ async def disassociate_team_from_region( return {"message": "Team disassociated from region successfully"} -@router.get("/{region_id}/teams", response_model=List[TeamSummary], dependencies=[Depends(check_system_admin)]) +@router.get("/{region_id}/teams", response_model=List[TeamSummary], dependencies=[Depends(get_role_min_system_admin)]) async def list_teams_for_region( region_id: int, db: Session = Depends(get_db) diff --git a/app/api/teams.py b/app/api/teams.py index 2a5591f..c181f4b 100644 --- a/app/api/teams.py +++ b/app/api/teams.py @@ -6,8 +6,8 @@ import logging from app.db.database import get_db -from app.db.models import DBTeam, DBTeamProduct, DBUser, DBPrivateAIKey, DBRegion, DBTeamRegion -from app.core.security import check_system_admin, check_specific_team_admin, get_current_user_from_auth +from app.db.models import DBTeam, DBTeamProduct, DBUser, DBPrivateAIKey, DBRegion, DBTeamRegion, DBProduct +from app.core.security import get_role_min_system_admin, get_role_min_specific_team_admin, get_current_user_from_auth, check_sales_or_higher from app.schemas.models import ( Team, TeamCreate, TeamUpdate, TeamWithUsers, TeamMergeRequest, TeamMergeResponse @@ -17,6 +17,7 @@ from app.services.ses import SESService from app.core.worker import get_team_keys_by_region, generate_pricing_url, get_team_admin_email from app.api.private_ai_keys import delete_private_ai_key +from app.schemas.models import SalesTeamsResponse, SalesProduct, SalesTeam logger = logging.getLogger(__name__) @@ -66,8 +67,8 @@ async def register_team( return db_team -@router.get("", response_model=List[Team], dependencies=[Depends(check_system_admin)]) -@router.get("/", response_model=List[Team], dependencies=[Depends(check_system_admin)]) +@router.get("", response_model=List[Team], dependencies=[Depends(get_role_min_system_admin)]) +@router.get("/", response_model=List[Team], dependencies=[Depends(get_role_min_system_admin)]) async def list_teams( db: Session = Depends(get_db) ): @@ -76,7 +77,7 @@ async def list_teams( """ return db.query(DBTeam).all() -@router.get("/{team_id}", response_model=TeamWithUsers, dependencies=[Depends(check_specific_team_admin)]) +@router.get("/{team_id}", response_model=TeamWithUsers, dependencies=[Depends(get_role_min_specific_team_admin)]) async def get_team( team_id: int, db: Session = Depends(get_db) @@ -92,7 +93,7 @@ async def get_team( # Convert directly to TeamWithUsers model return TeamWithUsers.model_validate(db_team) -@router.put("/{team_id}", response_model=Team, dependencies=[Depends(check_specific_team_admin)]) +@router.put("/{team_id}", response_model=Team, dependencies=[Depends(get_role_min_specific_team_admin)]) async def update_team( team_id: int, team_update: TeamUpdate, @@ -143,7 +144,7 @@ async def update_team( return db_team -@router.delete("/{team_id}", dependencies=[Depends(check_system_admin)]) +@router.delete("/{team_id}", dependencies=[Depends(get_role_min_system_admin)]) async def delete_team( team_id: int, db: Session = Depends(get_db) @@ -166,7 +167,7 @@ async def delete_team( return {"message": "Team deleted successfully"} -@router.post("/{team_id}/extend-trial", dependencies=[Depends(check_system_admin)]) +@router.post("/{team_id}/extend-trial", dependencies=[Depends(get_role_min_system_admin)]) async def extend_team_trial( team_id: int, db: Session = Depends(get_db) @@ -230,6 +231,146 @@ async def extend_team_trial( return {"message": "Team trial extended successfully"} +@router.get("/sales/list-teams", response_model=SalesTeamsResponse, dependencies=[Depends(check_sales_or_higher)]) +async def list_teams_for_sales( + db: Session = Depends(get_db) +): + """ + Get consolidated team information for sales dashboard. + Returns all teams with their products, regions, spend data, and trial status. + Accessible by system admin and sales users. + """ + try: + # Track unreachable endpoints for logging at the end (use set to avoid duplicates) + unreachable_endpoints = set() + + # Pre-fetch all regions once to avoid repeated queries + all_regions = db.query(DBRegion).filter(DBRegion.is_active == True).all() + regions_map = {r.id: r for r in all_regions} + + # Pre-create LiteLLM services for each region to avoid re-instantiation + litellm_services = {} + for region in all_regions: + litellm_services[region.id] = LiteLLMService( + api_url=region.litellm_api_url, + api_key=region.litellm_api_key + ) + + # Get all teams with their basic information + teams = db.query(DBTeam).all() + + sales_teams = [] + + for team in teams: + # Get team products + team_products = db.query(DBTeamProduct).join(DBProduct).filter( + DBTeamProduct.team_id == team.id, + DBProduct.active == True + ).all() + + products = [ + SalesProduct( + id=team_product.product.id, + name=team_product.product.name, + active=team_product.product.active + ) + for team_product in team_products + ] + + # Get team AI keys (both team-owned and user-owned) and calculate total spend + team_users = db.query(DBUser).filter(DBUser.team_id == team.id).all() + team_user_ids = [user.id for user in team_users] + + team_keys = db.query(DBPrivateAIKey).filter( + (DBPrivateAIKey.team_id == team.id) | # Team-owned keys + (DBPrivateAIKey.owner_id.in_(team_user_ids)) # User-owned keys by team members + ).all() + + # Calculate total spend from all AI keys and build regions list as we go + total_spend = 0.0 + regions_set = set() + + for key in team_keys: + if key.litellm_token and key.region_id in regions_map: + try: + # Use pre-fetched region info and pre-created LiteLLM service + region = regions_map[key.region_id] + litellm_service = litellm_services[region.id] + + # Add region name to our set + regions_set.add(region.name) + + # Get spend data from LiteLLM + key_data = await litellm_service.get_key_info(key.litellm_token) + key_spend = key_data.get("info", {}).get("spend", 0.0) + total_spend += float(key_spend) + except Exception as e: + # Track unreachable endpoint for logging at the end (only once per region) + region = regions_map[key.region_id] + endpoint_info = f"Region: {region.name}" + unreachable_endpoints.add(endpoint_info) + + # Convert set to list for the response + regions = list(regions_set) + + # Calculate trial status + trial_status = _calculate_trial_status(team, products) + + sales_team = SalesTeam( + id=team.id, + name=team.name, + admin_email=team.admin_email, + created_at=team.created_at, + last_payment=team.last_payment, + is_always_free=team.is_always_free, + products=products, + regions=regions, + total_spend=round(total_spend, 4), + trial_status=trial_status + ) + + sales_teams.append(sales_team) + + # Log all unreachable endpoints at the end + if unreachable_endpoints: + logger.warning(f"Unreachable LiteLLM endpoints encountered: {len(unreachable_endpoints)} unique endpoints") + for endpoint in unreachable_endpoints: + logger.warning(f" - {endpoint}") + + return SalesTeamsResponse(teams=sales_teams) + + except Exception as e: + logger.error(f"Error in list_teams_for_sales: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to retrieve sales data: {str(e)}" + ) + +def _calculate_trial_status(team: DBTeam, products: List[SalesProduct]) -> str: + """ + Calculate trial status based on team creation, last payment, and active products. + """ + if team.is_always_free: + return "Always Free" + + if len(products) > 0: + return "Active Product" + + # Calculate days until expiry + trial_period_days = 30 + if team.last_payment: + days_since_last_payment = (datetime.now(UTC) - team.last_payment.replace(tzinfo=UTC)).days + days_remaining = trial_period_days - days_since_last_payment + else: + days_since_creation = (datetime.now(UTC) - team.created_at.replace(tzinfo=UTC)).days + days_remaining = trial_period_days - days_since_creation + + if days_remaining <= 0: + return "Expired" + else: + # Always show days remaining for active trials + return f"{days_remaining} days left" + def _check_key_name_conflicts(team1_keys: List[DBPrivateAIKey], team2_keys: List[DBPrivateAIKey]) -> List[str]: """Return list of conflicting key names between two teams""" team1_names = {key.name for key in team1_keys if key.name} @@ -278,7 +419,7 @@ async def _resolve_key_conflicts( else: raise ValueError(f"Unknown conflict resolution strategy: {strategy}") -@router.post("/{target_team_id}/merge", dependencies=[Depends(check_system_admin)]) +@router.post("/{target_team_id}/merge", dependencies=[Depends(get_role_min_system_admin)]) async def merge_teams( target_team_id: int, merge_request: TeamMergeRequest, diff --git a/app/api/users.py b/app/api/users.py index c78ca3f..6b60aea 100644 --- a/app/api/users.py +++ b/app/api/users.py @@ -1,19 +1,20 @@ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session -from typing import List, get_args +from typing import List from app.core.config import settings from app.core.resource_limits import check_team_user_limit from app.db.database import get_db from app.schemas.models import User, UserUpdate, UserCreate, TeamOperation, UserRoleUpdate from app.db.models import DBUser, DBTeam -from app.core.security import get_password_hash, check_system_admin, get_current_user_from_auth, UserRole, get_role_min_team_admin +from app.core.security import get_password_hash, get_role_min_system_admin, get_current_user_from_auth, get_role_min_team_admin +from app.core.roles import UserRole from datetime import datetime, UTC router = APIRouter( tags=["users"] ) -@router.get("/search", response_model=List[User], dependencies=[Depends(check_system_admin)]) +@router.get("/search", response_model=List[User], dependencies=[Depends(get_role_min_system_admin)]) async def search_users( email: str, db: Session = Depends(get_db) @@ -77,10 +78,10 @@ async def create_user( check_team_user_limit(db, user.team_id) # Validate role if provided - if user.role and user.role not in get_args(UserRole): + if user.role and user.role not in UserRole.get_all_roles(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid role. Must be one of: {', '.join(get_args(UserRole))}" + detail=f"Invalid role. Must be one of: {', '.join(UserRole.get_all_roles())}" ) # Default to the lowest permissions for a user in a team @@ -198,7 +199,7 @@ async def add_user_to_team( db.refresh(db_user) return db_user -@router.post("/{user_id}/remove-from-team", response_model=User, dependencies=[Depends(check_system_admin)]) +@router.post("/{user_id}/remove-from-team", response_model=User, dependencies=[Depends(get_role_min_system_admin)]) async def remove_user_from_team( user_id: int, current_user: DBUser = Depends(get_current_user_from_auth), @@ -224,7 +225,7 @@ async def remove_user_from_team( db.refresh(db_user) return db_user -@router.delete("/{user_id}", dependencies=[Depends(check_system_admin)]) +@router.delete("/{user_id}", dependencies=[Depends(get_role_min_system_admin)]) async def delete_user( user_id: int, current_user: DBUser = Depends(get_current_user_from_auth), @@ -256,10 +257,10 @@ async def update_user_role( Update a user's role. Accessible by admin users or team admins for their team members. """ # Validate role - if role_update.role not in get_args(UserRole): + if role_update.role not in UserRole.get_all_roles(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid role. Must be one of: {', '.join(get_args(UserRole))}" + detail=f"Invalid role. Must be one of: {', '.join(UserRole.get_all_roles())}" ) # Get the user to update diff --git a/app/core/rbac.py b/app/core/rbac.py new file mode 100644 index 0000000..618b92e --- /dev/null +++ b/app/core/rbac.py @@ -0,0 +1,108 @@ +from fastapi import Depends, HTTPException, status +from typing import List, Union, Set, Callable +from app.core.roles import UserRole, UserType +from app.db.models import DBUser +import logging + +class RBACDependency: + """Base class for role-based access control dependencies""" + logger = logging.getLogger(__name__) + + def __init__(self, allowed_roles: List[str], require_team_membership: bool = False): + self.allowed_roles = set(allowed_roles) + self.require_team_membership = require_team_membership + + def __call__(self, current_user: DBUser) -> str: + return self.check_access(current_user) + + def check_access(self, user: DBUser) -> str: + """Check if user has access and return their effective role""" + # Validate user type constraints + if self._validate_user_type_constraints(user): + self.logger.info(f"User {user.id} has invalid user type constraints") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to perform this action" + ) + + # Check role permissions + effective_role = self._get_effective_role(user) + if effective_role not in self.allowed_roles: + self.logger.info(f"User {user.id} has invalid role {effective_role}") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to perform this action" + ) + + # Check team membership if required (but allow system admins to bypass this) + if self.require_team_membership and not user.team_id and not user.is_admin: + self.logger.info(f"User {user.id} is not a team member") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to perform this action" + ) + + return effective_role + + def _validate_user_type_constraints(self, user: DBUser) -> bool: + """Validate that user type matches role constraints""" + # System admins (is_admin=True) cannot be team members + if user.is_admin and user.team_id is not None: + return True + + # Get the effective role for validation + effective_role = self._get_effective_role(user) + + # System users (team_id is None) cannot have team roles + if user.team_id is None and effective_role in UserRole.get_team_roles(): + return True + + # Team users (team_id is not None) cannot have system roles + if user.team_id is not None and effective_role in UserRole.get_system_roles(): + return True + + return False + + def _get_effective_role(self, user: DBUser) -> str: + """Get the effective role for a user""" + if user.is_admin: + return UserRole.SYSTEM_ADMIN + return user.role or UserRole.USER + +# Pre-defined dependency functions for common use cases +def require_system_admin(): + """Require system admin role""" + return RBACDependency([UserRole.SYSTEM_ADMIN]) + +def require_team_admin(): + """Require team admin role or system admin""" + return RBACDependency(UserRole.ADMIN_ROLES, require_team_membership=True) + +def require_key_creator_or_higher(): + """Require key creator role or higher (team context)""" + return RBACDependency(UserRole.KEY_MANAGEMENT_ROLES, require_team_membership=True) + +def require_private_ai_access(): + """Require access to private AI operations - allows system users or team key creators""" + return RBACDependency(UserRole.KEY_MANAGEMENT_ROLES + [UserRole.USER], require_team_membership=False) + +def require_read_only_or_higher(): + """Require read only role or higher (team context)""" + return RBACDependency(UserRole.READ_ACCESS_ROLES, require_team_membership=True) + +def require_sales_or_higher(): + """Require sales role or higher (system context)""" + return RBACDependency(UserRole.SYSTEM_ACCESS_ROLES) + +def require_any_role(): + """Allow any authenticated user""" + return RBACDependency(UserRole.get_all_roles()) + +# Custom role dependency creator +def require_roles(*roles: str): + """Create a dependency that requires specific roles""" + return RBACDependency(list(roles)) + +def require_roles_with_team(*roles: str): + """Create a dependency that requires specific roles and team membership""" + return RBACDependency(list(roles), require_team_membership=True) diff --git a/app/core/roles.py b/app/core/roles.py new file mode 100644 index 0000000..6e38210 --- /dev/null +++ b/app/core/roles.py @@ -0,0 +1,53 @@ + +from typing import List, Set, Literal +from enum import Enum + +class UserType(Enum): + SYSTEM = "system" + TEAM = "team" + +class UserRole: + # System roles + SYSTEM_ADMIN = "system_admin" + USER = "user" # Default system user + SALES = "sales" + + # Team roles + TEAM_ADMIN = "admin" + KEY_CREATOR = "key_creator" + READ_ONLY = "read_only" + + # Legacy support - these MUST match existing string values exactly + ADMIN = TEAM_ADMIN # "admin" + DEFAULT = USER # "user" + + # Role combinations for better readability + ADMIN_ROLES = [TEAM_ADMIN, SYSTEM_ADMIN] + KEY_MANAGEMENT_ROLES = [KEY_CREATOR] + ADMIN_ROLES + READ_ACCESS_ROLES = [READ_ONLY] + KEY_MANAGEMENT_ROLES + SYSTEM_ACCESS_ROLES = [SYSTEM_ADMIN, SALES] + + @staticmethod + def get_system_roles() -> List[str]: + """Get all valid system user roles""" + return [UserRole.SYSTEM_ADMIN, UserRole.USER, UserRole.SALES] + + @staticmethod + def get_team_roles() -> List[str]: + """Get all valid team user roles""" + return [UserRole.TEAM_ADMIN, UserRole.KEY_CREATOR, UserRole.READ_ONLY] + + @staticmethod + def get_all_roles() -> List[str]: + """Get all valid roles (backwards compatible)""" + return UserRole.get_system_roles() + UserRole.get_team_roles() + + @staticmethod + def is_system_role(role: str) -> bool: + """Check if role is valid for system users""" + return role in UserRole.get_system_roles() + + @staticmethod + def is_team_role(role: str) -> bool: + """Check if role is valid for team users""" + return role in UserRole.get_team_roles() diff --git a/app/core/security.py b/app/core/security.py index 43583ad..5a37861 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta, UTC -from typing import Optional, Literal, Dict +from typing import Optional from jose import JWTError, jwt from passlib.context import CryptContext from fastapi import Depends, HTTPException, status, Cookie, Header, Request @@ -9,6 +9,13 @@ from app.db.database import get_db from sqlalchemy.orm import Session from app.db.models import DBUser, DBAPIToken +from app.core.rbac import ( + require_system_admin, + require_team_admin, + require_key_creator_or_higher, + require_sales_or_higher, + require_private_ai_access, +) logger = logging.getLogger(__name__) @@ -18,18 +25,6 @@ # Custom bearer scheme bearer_scheme = HTTPBearer(auto_error=False) -# Define valid user roles as a Literal type -UserRole = Literal["admin", "key_creator", "read_only", "user", "system_admin"] - -# Define a hierarchy for roles -user_role_hierarchy: Dict[UserRole, int] = { - "admin": 0, - "user": 1, - "key_creator": 2, - "read_only": 3, - "system_admin": 4, -} - def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify a password against its hash.""" return pwd_context.verify(plain_password, hashed_password) @@ -137,35 +132,40 @@ async def get_current_user_from_auth( headers={"WWW-Authenticate": "Bearer"}, ) -async def check_system_admin(current_user: DBUser = Depends(get_current_user_from_auth)): +async def get_role_min_system_admin(current_user: DBUser = Depends(get_current_user_from_auth)): """Check if the current user is a system admin.""" - if not current_user.is_admin: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Not authorized to perform this action" - ) - -def get_user_role(minimum_role: UserRole, current_user: DBUser): - if current_user.is_admin: - return "system_admin" - elif user_role_hierarchy[current_user.role] > user_role_hierarchy[minimum_role]: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Not authorized to perform this action" - ) - return current_user.role + dependency = require_system_admin() + return dependency.check_access(current_user) async def get_role_min_team_admin(current_user: DBUser = Depends(get_current_user_from_auth)): - return get_user_role("admin", current_user) + """Require team admin role or higher.""" + dependency = require_team_admin() + return dependency.check_access(current_user) + +async def get_role_min_specific_team_admin(current_user: DBUser = Depends(get_current_user_from_auth), team_id: int = None): + """Check if user is admin of specific team.""" + dependency = require_team_admin() + role = dependency.check_access(current_user) -async def check_specific_team_admin(current_user: DBUser = Depends(get_current_user_from_auth), team_id: int = None): - get_user_role("admin", current_user) - # system administrators will fail the team check + # Additional team-specific check if not current_user.is_admin and not current_user.team_id == team_id: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to perform this action" ) + return role async def get_role_min_key_creator(current_user: DBUser = Depends(get_current_user_from_auth)): - return get_user_role("key_creator", current_user) + """Require key creator role or higher.""" + dependency = require_key_creator_or_higher() + return dependency.check_access(current_user) + +async def get_private_ai_access(current_user: DBUser = Depends(get_current_user_from_auth)): + """Require access to private AI operations - allows system users or team key creators.""" + dependency = require_private_ai_access() + return dependency.check_access(current_user) + +async def check_sales_or_higher(current_user: DBUser = Depends(get_current_user_from_auth)): + """Check if the current user is a sales user or system admin.""" + dependency = require_sales_or_higher() + return dependency.check_access(current_user) diff --git a/app/schemas/models.py b/app/schemas/models.py index 2f95186..a74651d 100644 --- a/app/schemas/models.py +++ b/app/schemas/models.py @@ -362,4 +362,28 @@ class SubscriptionResponse(BaseModel): product_id: str team_id: int created_at: datetime + model_config = ConfigDict(from_attributes=True) + +# Sales Dashboard schemas +class SalesProduct(BaseModel): + id: str + name: str + active: bool + model_config = ConfigDict(from_attributes=True) + +class SalesTeam(BaseModel): + id: int + name: str + admin_email: str + created_at: datetime + last_payment: Optional[datetime] = None + is_always_free: bool + products: List[SalesProduct] + regions: List[str] + total_spend: float + trial_status: str + model_config = ConfigDict(from_attributes=True) + +class SalesTeamsResponse(BaseModel): + teams: List[SalesTeam] model_config = ConfigDict(from_attributes=True) \ No newline at end of file diff --git a/frontend/src/app/admin/layout.tsx b/frontend/src/app/admin/layout.tsx index 5a355cf..544a090 100644 --- a/frontend/src/app/admin/layout.tsx +++ b/frontend/src/app/admin/layout.tsx @@ -1,11 +1,15 @@ 'use client'; import { ReactNode } from 'react'; +import { usePathname } from 'next/navigation'; export default function AdminLayout({ children }: { children: ReactNode }) { + const pathname = usePathname(); + const isSalesDashboard = pathname === '/admin/sales-dashboard'; + return (
-
+
{children}
diff --git a/frontend/src/app/admin/page.tsx b/frontend/src/app/admin/page.tsx index 2231ede..2ae7507 100644 --- a/frontend/src/app/admin/page.tsx +++ b/frontend/src/app/admin/page.tsx @@ -4,7 +4,7 @@ import { useEffect } from 'react'; import { useRouter } from 'next/navigation'; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'; import { Button } from '@/components/ui/button'; -import { Users, Globe, Key } from 'lucide-react'; +import { Users, Globe, Key, DollarSign } from 'lucide-react'; export default function AdminPage() { const router = useRouter(); @@ -54,6 +54,19 @@ export default function AdminPage() { + + router.push('/admin/sales-dashboard')}> + + + + Sales Dashboard + + Monitor team performance and revenue metrics + + + + +
); } \ No newline at end of file diff --git a/frontend/src/app/admin/sales-dashboard/page.tsx b/frontend/src/app/admin/sales-dashboard/page.tsx new file mode 100644 index 0000000..269e796 --- /dev/null +++ b/frontend/src/app/admin/sales-dashboard/page.tsx @@ -0,0 +1,843 @@ +'use client'; + +import { useState, useMemo, useCallback } from 'react'; +import { useQuery } from '@tanstack/react-query'; +import { + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, + TablePagination, +} from '@/components/ui/table'; + +import { Loader2, ChevronUp, ChevronDown, ChevronsUpDown, DollarSign, Calendar, Users, Globe, Package, Plus, X } from 'lucide-react'; +import { get } from '@/utils/api'; +import { Badge } from '@/components/ui/badge'; +import { Button } from '@/components/ui/button'; +import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select'; +import { Input } from '@/components/ui/input'; + +interface Team { + id: number; + name: string; + admin_email: string; + created_at: string; + last_payment?: string; + is_always_free: boolean; + products: Product[]; + regions: string[]; + total_spend: number; + trial_status: string; +} + +interface Product { + id: string; + name: string; + active: boolean; +} + + + +type SortField = 'admin_email' | 'name' | 'created_at' | 'last_payment' | 'products' | 'trial_status' | 'regions' | 'total_spend' | null; +type SortDirection = 'asc' | 'desc'; + +interface Filter { + id: string; + column: string; + value: string; + operator: 'contains' | 'equals' | 'starts_with' | 'ends_with'; +} + +export default function SalesDashboardPage() { + + // Filter and sort state + const [filters, setFilters] = useState([]); + const [sortField, setSortField] = useState(null); + const [sortDirection, setSortDirection] = useState('asc'); + + + // Queries + const { data: teams = [], isLoading: isLoadingTeams } = useQuery({ + queryKey: ['sales-teams'], + queryFn: async () => { + const response = await get('/teams/sales/list-teams'); + const data = await response.json(); + return data.teams; + }, + }); + + + + + // Calculate trial time remaining + const getTrialTimeRemaining = useCallback((team: Team): string => { + return team.trial_status; + }, []); + + // Get unique regions for a team + const getTeamRegions = useCallback((team: Team): string[] => { + return team.regions; + }, []); + + // Get total spend for a team + const getTeamTotalSpend = useCallback((team: Team): number => { + return team.total_spend; + }, []); + + // Get all team spend values for calculating min/max + const allTeamSpends = useMemo(() => { + return teams.map(team => ({ + teamId: team.id, + spend: getTeamTotalSpend(team) + })).filter(item => item.spend > 0); // Only non-zero values for min calculation + }, [teams, getTeamTotalSpend]); + + // Calculate min and max spend values + const spendStats = useMemo(() => { + if (allTeamSpends.length === 0) { + return { minSpend: 0, maxSpend: 0 }; + } + + const spends = allTeamSpends.map(item => item.spend); + return { + minSpend: Math.min(...spends), + maxSpend: Math.max(...spends) + }; + }, [allTeamSpends]); + + + // Get color for spend value based on gradient rules + const getSpendColor = useCallback((spend: number): string => { + if (spend === 0) { + return '#6b7280'; // Dark grey for $0.00 + } + + if (spend === spendStats.maxSpend) { + return '#166534'; // Dark green for highest value + } + + if (spend === spendStats.minSpend) { + return '#000000'; // Black for lowest non-zero value + } + + // Gradient for values between min and max + if (spendStats.maxSpend === spendStats.minSpend) { + return '#166534'; // If all values are the same, use dark green + } + + const ratio = (spend - spendStats.minSpend) / (spendStats.maxSpend - spendStats.minSpend); + const red = Math.round(22 + (ratio * 0)); // Start from dark green (22, 101, 52) + const green = Math.round(101 + (ratio * 53)); // End at darker green (22, 154, 52) + const blue = Math.round(52 + (ratio * 0)); + + return `rgb(${red}, ${green}, ${blue})`; + }, [spendStats]); + + // Add a new filter + const addFilter = () => { + const newFilter: Filter = { + id: Date.now().toString(), + column: 'admin_email', + value: '', + operator: 'contains' + }; + setFilters([...filters, newFilter]); + }; + + // Update filter when column changes to set appropriate default value + const handleColumnChange = (filterId: string, column: string) => { + const columnInfo = filterColumns.find(col => col.value === column); + let defaultValue = ''; + let defaultOperator: Filter['operator'] = 'contains'; + + if (columnInfo?.type === 'select') { + const options = getFilterOptions(column); + if (options.length > 0) { + defaultValue = options[0].value; + } + // Regions use 'contains' since teams can have multiple regions + if (column === 'regions') { + defaultOperator = 'contains'; + } else { + defaultOperator = 'equals'; + } + } else if (columnInfo?.type === 'number') { + defaultOperator = 'equals'; + } + + updateFilter(filterId, { column, value: defaultValue, operator: defaultOperator }); + }; + + // Remove a filter + const removeFilter = (filterId: string) => { + setFilters(filters.filter(f => f.id !== filterId)); + }; + + // Update a filter + const updateFilter = (filterId: string, updates: Partial) => { + setFilters(filters.map(f => + f.id === filterId ? { ...f, ...updates } : f + )); + }; + + // Clear all filters + const clearAllFilters = () => { + setFilters([]); + setSortField(null); + setSortDirection('asc'); + }; + + + // Available filter columns + const filterColumns = [ + { value: 'admin_email', label: 'Team Email', type: 'text' }, + { value: 'name', label: 'Team Name', type: 'text' }, + { value: 'products', label: 'Products', type: 'select' }, + { value: 'trial_status', label: 'Trial Status', type: 'select' }, + { value: 'regions', label: 'Regions', type: 'select' }, + ]; + + // Filter operators + const filterOperators = [ + { value: 'contains', label: 'Contains' }, + { value: 'equals', label: 'Equals' }, + { value: 'starts_with', label: 'Starts with' }, + { value: 'ends_with', label: 'Ends with' }, + ]; + + // Get operators for specific column types + const getOperatorsForColumn = (column: string) => { + const columnInfo = filterColumns.find(col => col.value === column); + + if (columnInfo?.type === 'select') { + // Regions can use 'contains' since teams can have multiple regions + if (column === 'regions') { + return [ + { value: 'equals', label: 'Equals' }, + { value: 'contains', label: 'Contains' } + ]; + } + return [{ value: 'equals', label: 'Equals' }]; + } + + if (columnInfo?.type === 'number') { + return [ + { value: 'equals', label: 'Equals' }, + { value: 'contains', label: 'Contains' } + ]; + } + + // Text fields get all operators + return filterOperators; + }; + + // Get options for select-type filters + const getFilterOptions = (column: string) => { + switch (column) { + case 'trial_status': + return [ + { value: 'Active Product', label: 'Active Product' }, + { value: 'Always Free', label: 'Always Free' }, + { value: 'In Progress', label: 'In Progress' }, + { value: 'Expired', label: 'Expired' }, + ]; + case 'regions': + // Get unique regions from all teams + const allRegions = new Set(); + teams.forEach(team => { + const teamRegions = getTeamRegions(team); + teamRegions.forEach(region => allRegions.add(region)); + }); + return [ + { value: 'No Region', label: 'No Region' }, + ...Array.from(allRegions).sort().map(region => ({ + value: region, + label: region + })) + ]; + case 'products': + const allProducts = new Set(); + teams.forEach(team => { + const teamProducts = team.products || []; + teamProducts.forEach(product => allProducts.add(product.name)); + }); + return [ + { value: 'No Product', label: 'No Product' }, + ...Array.from(allProducts).sort().map(productName => ({ + value: productName, + label: productName + })) + ]; + default: + return []; + } + }; + + // Get filter input component based on column type + const getFilterInput = (filter: Filter) => { + const column = filterColumns.find(col => col.value === filter.column); + + if (column?.type === 'select') { + const options = getFilterOptions(filter.column); + return ( + + ); + } + + if (column?.type === 'number') { + return ( + updateFilter(filter.id, { value: e.target.value })} + className="flex-1" + /> + ); + } + + // Default to text input + return ( + updateFilter(filter.id, { value: e.target.value })} + className="flex-1" + /> + ); + }; + + // Apply filters to teams + const applyFilters = useCallback((teams: Team[]) => { + if (filters.length === 0) return teams; + + return teams.filter(team => { + return filters.every(filter => { + let teamValue: string | number; + + switch (filter.column) { + case 'admin_email': + teamValue = team.admin_email.toLowerCase(); + break; + case 'name': + teamValue = team.name.toLowerCase(); + break; + case 'products': + const teamProducts = team.products || []; + teamValue = teamProducts.length > 0 ? teamProducts.map(p => p.name).join(', ') : 'No Product'; + break; + case 'trial_status': + const trialStatus = getTrialTimeRemaining(team); + if (trialStatus.includes('days left')) { + teamValue = 'In Progress'; + } else { + teamValue = trialStatus; + } + break; + case 'regions': + const teamRegions = getTeamRegions(team); + teamValue = teamRegions.length > 0 ? teamRegions.join(', ') : 'No Region'; + break; + default: + return true; + } + + const filterValue = filter.value.toLowerCase(); + + if (typeof teamValue === 'number') { + // Handle numeric comparisons + const numValue = parseFloat(filterValue); + if (isNaN(numValue)) return true; + return teamValue === numValue; + } + + // Handle string comparisons + switch (filter.operator) { + case 'contains': + return teamValue.toLowerCase().includes(filterValue); + case 'equals': + return teamValue.toLowerCase() === filterValue; + case 'starts_with': + return teamValue.toLowerCase().startsWith(filterValue); + case 'ends_with': + return teamValue.toLowerCase().endsWith(filterValue); + default: + return true; + } + }); + }); + }, [filters, getTrialTimeRemaining, getTeamRegions]); + + // Filtered and sorted teams + const filteredAndSortedTeams = useMemo(() => { + const filtered = applyFilters(teams); + + if (sortField) { + filtered.sort((a, b) => { + let aValue: string | number; + let bValue: string | number; + + switch (sortField) { + case 'admin_email': + aValue = a.admin_email.toLowerCase(); + bValue = b.admin_email.toLowerCase(); + break; + case 'name': + aValue = a.name.toLowerCase(); + bValue = b.name.toLowerCase(); + break; + case 'created_at': + aValue = new Date(a.created_at).getTime(); + bValue = new Date(b.created_at).getTime(); + break; + case 'last_payment': + aValue = a.last_payment ? new Date(a.last_payment).getTime() : 0; + bValue = b.last_payment ? new Date(b.last_payment).getTime() : 0; + break; + case 'products': + aValue = (a.products || []).length; + bValue = (b.products || []).length; + break; + case 'trial_status': + aValue = getTrialTimeRemaining(a); + bValue = getTrialTimeRemaining(b); + break; + case 'regions': + aValue = getTeamRegions(a).length; + bValue = getTeamRegions(b).length; + break; + case 'total_spend': + aValue = getTeamTotalSpend(a); + bValue = getTeamTotalSpend(b); + break; + default: + return 0; + } + + if (sortDirection === 'asc') { + return aValue < bValue ? -1 : aValue > bValue ? 1 : 0; + } else { + return aValue > bValue ? -1 : aValue < bValue ? 1 : 0; + } + }); + } + + return filtered; + }, [teams, sortField, sortDirection, getTrialTimeRemaining, getTeamRegions, getTeamTotalSpend, applyFilters]); + + const hasActiveFilters = filters.length > 0; + + // Manual pagination to ensure it updates with local state changes + const [currentPage, setCurrentPage] = useState(1); + const [pageSize, setPageSize] = useState(10); + + const totalItems = filteredAndSortedTeams.length; + const totalPages = Math.ceil(totalItems / pageSize); + + const startIndex = (currentPage - 1) * pageSize; + const endIndex = startIndex + pageSize; + const paginatedData = filteredAndSortedTeams.slice(startIndex, endIndex); + + const goToPage = (page: number) => setCurrentPage(page); + const changePageSize = (newPageSize: number) => { + setPageSize(newPageSize); + setCurrentPage(1); // Reset to first page when changing page size + }; + + // Handle sorting + const handleSort = (field: SortField) => { + if (sortField === field) { + setSortDirection(sortDirection === 'asc' ? 'desc' : 'asc'); + } else { + setSortField(field); + setSortDirection('asc'); + } + }; + + // Get sort icon + const getSortIcon = (field: SortField) => { + if (sortField !== field) { + return ; + } + return sortDirection === 'asc' ? : ; + }; + + // Format date + const formatDate = (dateString: string | undefined): string => { + if (!dateString) return 'Never'; + return new Date(dateString).toLocaleDateString(); + }; + + // Format currency + const formatCurrency = (amount: number): string => { + return new Intl.NumberFormat('en-US', { + style: 'currency', + currency: 'USD', + minimumFractionDigits: 4, + maximumFractionDigits: 4, + }).format(amount); + }; + + if (isLoadingTeams) { + return ( +
+ +
+ Loading teams... +
+
+ ); + } + + return ( +
+
+
+

Sales Dashboard

+

+ Monitor team performance, subscriptions, and revenue metrics +

+
+
+
+
+ {teams.length} +
+
Total Teams
+
+
+
+ + {/* Filters Section */} +
+
+

Filters

+
+ {hasActiveFilters && ( + + )} + +
+
+ + {filters.length > 0 && ( +
+ {filters.map((filter) => ( +
+ + + + + {getFilterInput(filter)} + + +
+ ))} +
+ )} + + {hasActiveFilters && ( +
+ Showing {filteredAndSortedTeams.length} of {teams.length} teams +
+ )} +
+ +
+
+ + + + handleSort('admin_email')} + > +
+ + Team Email + {getSortIcon('admin_email')} +
+
+ handleSort('name')} + > +
+ Team Name + {getSortIcon('name')} +
+
+ handleSort('created_at')} + > +
+ + Team Create Date + {getSortIcon('created_at')} +
+
+ handleSort('last_payment')} + > +
+ + Last Payment + {getSortIcon('last_payment')} +
+
+ handleSort('products')} + > +
+ + Products + {getSortIcon('products')} +
+
+ handleSort('trial_status')} + > +
+ + Trial Status + {getSortIcon('trial_status')} +
+
+ handleSort('regions')} + > +
+ + Regions + {getSortIcon('regions')} +
+
+ handleSort('total_spend')} + > +
+ + Total Spend + {getSortIcon('total_spend')} +
+
+ +
+
+ + {paginatedData.length === 0 ? ( + + + No teams found matching your filters. + + + ) : ( + paginatedData.map((team) => ( + + +
{team.admin_email}
+
+ +
{team.name}
+
+ ID: {team.id} +
+
+ + {formatDate(team.created_at)} + + + {formatDate(team.last_payment)} + + + {(() => { + const teamProducts = team.products || []; + return teamProducts.length > 0 ? ( +
+ {teamProducts.map((product) => ( + + {product.name} + + ))} +
+ ) : ( + No products + ); + })()} +
+ + {(() => { + const trialStatus = getTrialTimeRemaining(team); + let badgeVariant: "default" | "secondary" | "destructive" | "outline" = "outline"; + let customStyle = {}; + + if (trialStatus === 'Active Product') { + badgeVariant = "default"; + customStyle = { backgroundColor: '#166534', color: 'white' }; // dark green + } else if (trialStatus === 'Always Free') { + badgeVariant = "secondary"; + } else if (trialStatus === 'Expired') { + badgeVariant = "destructive"; + customStyle = { backgroundColor: '#991b1b', color: 'white' }; // dark red + } else if (trialStatus === 'In Progress') { + badgeVariant = "outline"; + customStyle = { backgroundColor: '#fef3c7', color: '#92400e' }; // amber/yellow + } else if (trialStatus.includes('days left')) { + // Extract the number of days + const daysMatch = trialStatus.match(/(\d+)/); + if (daysMatch) { + const days = parseInt(daysMatch[1], 10); + // Calculate color gradient: 30 days = green, 0 days = red + // Use a 30-day scale instead of 7-day for better gradient + const ratio = Math.max(0, Math.min(1, days / 30)); + + // Green to red gradient: green (34, 197, 94) to red (239, 68, 68) + const red = Math.round(34 + (205 * (1 - ratio))); // 34-239 range + const green = Math.round(197 - (129 * (1 - ratio))); // 197-68 range + const blue = Math.round(94 - (26 * (1 - ratio))); // 94-68 range + + customStyle = { + backgroundColor: `rgb(${red}, ${green}, ${blue})`, + color: 'white', + fontWeight: 'bold' + }; + } + } + + return ( + + {trialStatus} + + ); + })()} + + + {(() => { + const regions = getTeamRegions(team); + if (regions.length === 0) { + return No regions; + } + return ( +
+ {regions.map((region) => ( + + {region} + + ))} +
+ ); + })()} +
+ + {(() => { + const totalSpend = getTeamTotalSpend(team); + const spendColor = getSpendColor(totalSpend); + return ( +
+ {formatCurrency(totalSpend)} +
+ ); + })()} +
+ +
+ )) + )} +
+
+
+
+ + +
+ ); +} diff --git a/frontend/src/app/admin/users/page.tsx b/frontend/src/app/admin/users/page.tsx index e40422c..237e68c 100644 --- a/frontend/src/app/admin/users/page.tsx +++ b/frontend/src/app/admin/users/page.tsx @@ -81,6 +81,7 @@ const USER_ROLES = [ { value: 'admin', label: 'Admin' }, { value: 'key_creator', label: 'Key Creator' }, { value: 'read_only', label: 'Read Only' }, + { value: 'sales', label: 'Sales' }, ]; type SortField = 'email' | 'team_name' | 'role' | null; @@ -406,6 +407,15 @@ export default function UsersPage() { void fetchUsers(); }, [fetchUsers]); + // Update role when switching between system and team user types + useEffect(() => { + if (isSystemUser) { + setNewUserRole('admin'); // Default to admin for system users + } else { + setNewUserRole('read_only'); // Default to read_only for team users + } + }, [isSystemUser]); + if (isLoadingUsers) { return (
@@ -452,6 +462,9 @@ export default function UsersPage() { if (newUserTeamId) { userData.team_id = newUserTeamId; } + } else { + // For system users, also set the role + userData.role = newUserRole; } createUserMutation.mutate(userData); @@ -545,7 +558,7 @@ export default function UsersPage() {
- + + + + + + Admin + Sales + + +
+ )} + )} + + + + + {filters.length > 0 && ( +
+ {filters.map((filter) => ( +
+ + + + + {getFilterInput(filter)} + + +
+ ))} +
+ )} + + {hasActiveFilters && ( +
+ Showing {filteredAndSortedTeams.length} of {teams.length} teams +
+ )} + + +
+
+ + + + handleSort('admin_email')} + > +
+ + Team Email + {getSortIcon('admin_email')} +
+
+ handleSort('name')} + > +
+ Team Name + {getSortIcon('name')} +
+
+ handleSort('created_at')} + > +
+ + Team Create Date + {getSortIcon('created_at')} +
+
+ handleSort('last_payment')} + > +
+ + Last Payment + {getSortIcon('last_payment')} +
+
+ handleSort('products')} + > +
+ + Products + {getSortIcon('products')} +
+
+ handleSort('trial_status')} + > +
+ + Trial Status + {getSortIcon('trial_status')} +
+
+ handleSort('regions')} + > +
+ + Regions + {getSortIcon('regions')} +
+
+ handleSort('total_spend')} + > +
+ + Total Spend + {getSortIcon('total_spend')} +
+
+ +
+
+ + {paginatedData.length === 0 ? ( + + + No teams found matching your filters. + + + ) : ( + paginatedData.map((team) => ( + + +
{team.admin_email}
+
+ +
{team.name}
+
+ ID: {team.id} +
+
+ + {formatDate(team.created_at)} + + + {formatDate(team.last_payment)} + + + {(() => { + const teamProducts = team.products || []; + return teamProducts.length > 0 ? ( +
+ {teamProducts.map((product) => ( + + {product.name} + + ))} +
+ ) : ( + No products + ); + })()} +
+ + {(() => { + const trialStatus = getTrialTimeRemaining(team); + let badgeVariant: "default" | "secondary" | "destructive" | "outline" = "outline"; + let customStyle = {}; + + if (trialStatus === 'Active Product') { + badgeVariant = "default"; + customStyle = { backgroundColor: '#166534', color: 'white' }; // dark green + } else if (trialStatus === 'Always Free') { + badgeVariant = "secondary"; + } else if (trialStatus === 'Expired') { + badgeVariant = "destructive"; + customStyle = { backgroundColor: '#991b1b', color: 'white' }; // dark red + } else if (trialStatus === 'In Progress') { + badgeVariant = "outline"; + customStyle = { backgroundColor: '#fef3c7', color: '#92400e' }; // amber/yellow + } else if (trialStatus.includes('days left')) { + // Extract the number of days + const daysMatch = trialStatus.match(/(\d+)/); + if (daysMatch) { + const days = parseInt(daysMatch[1], 10); + // Calculate color gradient: 30 days = green, 0 days = red + // Use a 30-day scale instead of 7-day for better gradient + const ratio = Math.max(0, Math.min(1, days / 30)); + + // Green to red gradient: green (34, 197, 94) to red (239, 68, 68) + const red = Math.round(34 + (205 * (1 - ratio))); // 34-239 range + const green = Math.round(197 - (129 * (1 - ratio))); // 197-68 range + const blue = Math.round(94 - (26 * (1 - ratio))); // 94-68 range + + customStyle = { + backgroundColor: `rgb(${red}, ${green}, ${blue})`, + color: 'white', + fontWeight: 'bold' + }; + } + } + + return ( + + {trialStatus} + + ); + })()} + + + {(() => { + const regions = getTeamRegions(team); + if (regions.length === 0) { + return No regions; + } + return ( +
+ {regions.map((region) => ( + + {region} + + ))} +
+ ); + })()} +
+ + {(() => { + const totalSpend = getTeamTotalSpend(team); + const spendColor = getSpendColor(totalSpend); + return ( +
+ {formatCurrency(totalSpend)} +
+ ); + })()} +
+ +
+ )) + )} +
+
+
+
+ + + + ); +} diff --git a/frontend/src/components/auth/login-form.tsx b/frontend/src/components/auth/login-form.tsx index 937ed05..3b1d77a 100644 --- a/frontend/src/components/auth/login-form.tsx +++ b/frontend/src/components/auth/login-form.tsx @@ -80,7 +80,12 @@ export function LoginForm() { }); router.refresh(); - router.push('/private-ai-keys'); + // Redirect based on user role + if (profileData.role === 'sales') { + router.push('/sales'); + } else { + router.push('/private-ai-keys'); + } } catch (profileError) { console.error('Failed to fetch user profile:', profileError); setError('Successfully logged in but failed to fetch user profile. Please refresh the page.'); diff --git a/frontend/src/components/auth/passwordless-login-form.tsx b/frontend/src/components/auth/passwordless-login-form.tsx index 38d1e7a..5b5c03b 100644 --- a/frontend/src/components/auth/passwordless-login-form.tsx +++ b/frontend/src/components/auth/passwordless-login-form.tsx @@ -173,7 +173,12 @@ export function PasswordlessLoginForm({ onSwitchToPassword }: PasswordlessLoginF }); router.refresh(); - router.push('/private-ai-keys'); + // Redirect based on user role + if (profileData.role === 'sales') { + router.push('/sales'); + } else { + router.push('/private-ai-keys'); + } } catch (profileError) { console.error('Failed to fetch user profile:', profileError); setError('Successfully signed in but failed to fetch user profile. Please refresh the page.'); diff --git a/frontend/src/components/sidebar-layout.tsx b/frontend/src/components/sidebar-layout.tsx index 7f6df87..220f2b0 100644 --- a/frontend/src/components/sidebar-layout.tsx +++ b/frontend/src/components/sidebar-layout.tsx @@ -11,7 +11,8 @@ import { PanelLeftClose, PanelLeft, Users2, - Package + Package, + DollarSign } from 'lucide-react'; import { Sidebar, SidebarProvider } from '@/components/ui/sidebar'; import { NavUser } from '@/components/nav-user'; @@ -49,6 +50,7 @@ const navigation = [ { name: 'Products', href: '/admin/products', icon: }, { name: 'Private AI Keys', href: '/admin/private-ai-keys', icon: }, { name: 'Audit Logs', href: '/admin/audit-logs', icon: }, + { name: 'Sales Dashboard', href: '/admin/sales-dashboard', icon: }, ], }, { @@ -63,6 +65,10 @@ const navigation = [ }, ]; +const salesNavigation = [ + { name: 'Sales Dashboard', href: '/sales', icon: }, +]; + function NavMain({ navigation, pathname, collapsed }: { navigation: NavItem[]; pathname: string; collapsed: boolean }) { const [expandedItems, setExpandedItems] = useState(['/admin', '/team-admin']); @@ -190,12 +196,14 @@ export function SidebarLayout({ return <>{children}; } - // Filter out admin navigation for non-admin users - const filteredNavigation = navigation.filter((item) => { - if (item.name === 'Admin' && !user?.is_admin) return false; - if (item.name === 'Team Admin' && !isTeamAdmin(user)) return false; - return true; - }); + // Use sales navigation for sales users, otherwise filter regular navigation + const filteredNavigation = user?.role === 'sales' + ? salesNavigation + : navigation.filter((item) => { + if (item.name === 'Admin' && !user?.is_admin) return false; + if (item.name === 'Team Admin' && !isTeamAdmin(user)) return false; + return true; + }); return (
@@ -204,7 +212,7 @@ export function SidebarLayout({
{!collapsed && ( - +
- + {/* Main content */} -
+
{children}
diff --git a/scripts/add_test_data.py b/scripts/add_test_data.py new file mode 100644 index 0000000..b4041e7 --- /dev/null +++ b/scripts/add_test_data.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python3 + +import os +import sys +from datetime import datetime, timedelta, UTC +import random + +# Add the parent directory to the Python path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from sqlalchemy.orm import sessionmaker, Session +from app.db.database import engine +from app.db.models import DBTeam, DBUser, DBProduct, DBTeamProduct +from app.core.security import get_password_hash + +def create_test_data(): + """Create test data for teams, users, and products""" + + # Create database session + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + db = SessionLocal() + + try: + print("Creating test data...") + + # Check for existing test data on a case-by-case basis + existing_teams = {} + for team_name in [ + "Test Team 1", + "Test Team 2 - Always Free", + "Test Team 3 - With Product", + "Test Team 4 - With Payment History", + "Test Team 5 - No Products" + ]: + existing_team = db.query(DBTeam).filter(DBTeam.name == team_name).first() + if existing_team: + existing_teams[team_name] = existing_team + print(f"⚠️ {team_name} already exists (ID: {existing_team.id})") + else: + print(f"✅ {team_name} will be created") + + # Check if products exist, create one if none exist + existing_products = db.query(DBProduct).all() + if not existing_products: + print("No products found, creating a sample product...") + sample_product = DBProduct( + id="prod_test_sample", + name="Test Sample Product", + user_count=10, + keys_per_user=5, + total_key_count=50, + service_key_count=10, + max_budget_per_key=100.0, + rpm_per_key=1000, + vector_db_count=5, + vector_db_storage=500, + renewal_period_days=30, + active=True, + created_at=datetime.now(UTC) + ) + db.add(sample_product) + db.commit() + db.refresh(sample_product) + print(f"Created sample product: {sample_product.name} (ID: {sample_product.id})") + existing_products = [sample_product] + + # Pick a random product for teams that need products + selected_product = random.choice(existing_products) + print(f"Selected product for teams: {selected_product.name} (ID: {selected_product.id})") + + # 1. Team with one user, created 32 days ago + if "Test Team 1" not in existing_teams: + print("\n1. Creating team with one user (created 32 days ago)...") + team1 = DBTeam( + name="Test Team 1", + admin_email="admin1@test.com", + is_active=True, + is_always_free=False, + created_at=datetime.now(UTC) - timedelta(days=32) + ) + db.add(team1) + db.commit() + db.refresh(team1) + + user1 = DBUser( + email="user1@test.com", + hashed_password=get_password_hash("testpassword123"), + is_active=True, + is_admin=False, + role="user", + team_id=team1.id, + created_at=datetime.now(UTC) - timedelta(days=32) + ) + db.add(user1) + db.commit() + print(f" Created team: {team1.name} (ID: {team1.id})") + print(f" Created user: {user1.email} (ID: {user1.id})") + else: + team1 = existing_teams["Test Team 1"] + print(f"\n1. Team 1 already exists: {team1.name} (ID: {team1.id})") + + # 2. Team with one user, always_free=True, created 20 days ago + if "Test Team 2 - Always Free" not in existing_teams: + print("\n2. Creating team with one user, always_free=True (created 20 days ago)...") + team2 = DBTeam( + name="Test Team 2 - Always Free", + admin_email="admin2@test.com", + is_active=True, + is_always_free=True, + created_at=datetime.now(UTC) - timedelta(days=20) + ) + db.add(team2) + db.commit() + db.refresh(team2) + + user2 = DBUser( + email="user2@test.com", + hashed_password=get_password_hash("testpassword123"), + is_active=True, + is_admin=False, + role="user", + team_id=team2.id, + created_at=datetime.now(UTC) - timedelta(days=20) + ) + db.add(user2) + db.commit() + print(f" Created team: {team2.name} (ID: {team2.id}) - always_free: {team2.is_always_free}") + print(f" Created user: {user2.email} (ID: {user2.id})") + else: + team2 = existing_teams["Test Team 2 - Always Free"] + print(f"\n2. Team 2 already exists: {team2.name} (ID: {team2.id})") + + # 3. Team with one user and product association + if "Test Team 3 - With Product" not in existing_teams: + print("\n3. Creating team with one user and product association...") + team3 = DBTeam( + name="Test Team 3 - With Product", + admin_email="admin3@test.com", + is_active=True, + is_always_free=False, + created_at=datetime.now(UTC) + ) + db.add(team3) + db.commit() + db.refresh(team3) + + user3 = DBUser( + email="user3@test.com", + hashed_password=get_password_hash("testpassword123"), + is_active=True, + is_admin=False, + role="user", + team_id=team3.id, + created_at=datetime.now(UTC) + ) + db.add(user3) + db.commit() + + # Create team-product association + team_product = DBTeamProduct( + team_id=team3.id, + product_id=selected_product.id + ) + db.add(team_product) + db.commit() + + print(f" Created team: {team3.name} (ID: {team3.id})") + print(f" Created user: {user3.email} (ID: {user3.id})") + print(f" Associated with product: {selected_product.name} (ID: {selected_product.id})") + else: + team3 = existing_teams["Test Team 3 - With Product"] + print(f"\n3. Team 3 already exists: {team3.name} (ID: {team3.id})") + + # 4. Team with one user, created 40 days ago, with payment 35 days ago, and product association + if "Test Team 4 - With Payment History" not in existing_teams: + print("\n4. Creating team with one user, payment history, and product association...") + team4 = DBTeam( + name="Test Team 4 - With Payment History", + admin_email="admin4@test.com", + is_active=True, + is_always_free=False, + created_at=datetime.now(UTC) - timedelta(days=40), + last_payment=datetime.now(UTC) - timedelta(days=35) + ) + db.add(team4) + db.commit() + db.refresh(team4) + + user4 = DBUser( + email="user4@test.com", + hashed_password=get_password_hash("testpassword123"), + is_active=True, + is_admin=False, + role="user", + team_id=team4.id, + created_at=datetime.now(UTC) - timedelta(days=40) + ) + db.add(user4) + db.commit() + + # Create team-product association + team_product4 = DBTeamProduct( + team_id=team4.id, + product_id=selected_product.id + ) + db.add(team_product4) + db.commit() + + print(f" Created team: {team4.name} (ID: {team4.id})") + print(f" Created user: {user4.email} (ID: {user4.id})") + print(f" Payment made: {team4.last_payment.strftime('%Y-%m-%d')}") + print(f" Associated with product: {selected_product.name} (ID: {selected_product.id})") + else: + team4 = existing_teams["Test Team 4 - With Payment History"] + print(f"\n4. Team 4 already exists: {team4.name} (ID: {team4.id})") + + # 5. Team with one user, no products, created 20 days ago + if "Test Team 5 - No Products" not in existing_teams: + print("\n5. Creating team with one user, no products (created 20 days ago)...") + team5 = DBTeam( + name="Test Team 5 - No Products", + admin_email="admin5@test.com", + is_active=True, + is_always_free=False, + created_at=datetime.now(UTC) - timedelta(days=20) + ) + db.add(team5) + db.commit() + db.refresh(team5) + + user5 = DBUser( + email="user5@test.com", + hashed_password=get_password_hash("testpassword123"), + is_active=True, + is_admin=False, + role="user", + team_id=team5.id, + created_at=datetime.now(UTC) - timedelta(days=20) + ) + db.add(user5) + db.commit() + + print(f" Created team: {team5.name} (ID: {team5.id})") + print(f" Created user: {user5.email} (ID: {user5.id})") + print(f" No products associated") + else: + team5 = existing_teams["Test Team 5 - No Products"] + print(f"\n5. Team 5 already exists: {team5.name} (ID: {team5.id})") + + print("\n✅ Test data created successfully!") + print(f"\nSummary:") + print(f"- Team 1: {team1.name} (created 32 days ago)") + print(f"- Team 2: {team2.name} (always_free=True, created 20 days ago)") + print(f"- Team 3: {team3.name} (with product association)") + print(f"- Team 4: {team4.name} (payment history, product association, created 40 days ago)") + print(f"- Team 5: {team5.name} (no products, created 20 days ago)") + print(f"- Total users created: 5") + print(f"- Product used: {selected_product.name}") + + except Exception as e: + print(f"❌ Error creating test data: {str(e)}") + db.rollback() + raise + finally: + db.close() + +def main(): + """Main function to run the script""" + try: + create_test_data() + except Exception as e: + print(f"Script failed: {str(e)}") + sys.exit(1) + +if __name__ == "__main__": + main() diff --git a/tests/conftest.py b/tests/conftest.py index 9d8916a..40cc50f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -143,7 +143,7 @@ def test_team_user(db, test_team): hashed_password=get_password_hash("password123"), is_active=True, is_admin=False, - role="user", + role="key_creator", team_id=test_team.id, created_at=datetime.now(UTC) ) diff --git a/tests/test_rbac.py b/tests/test_rbac.py new file mode 100644 index 0000000..3287248 --- /dev/null +++ b/tests/test_rbac.py @@ -0,0 +1,235 @@ +import pytest +from fastapi import HTTPException +from app.core.rbac import RBACDependency, require_system_admin, require_team_admin, require_key_creator_or_higher, require_read_only_or_higher, require_sales_or_higher +from app.core.roles import UserRole +from app.db.models import DBUser + +class TestRBACDependency: + """Test RBAC dependency functionality""" + + def test_system_admin_access(self): + """ + Given a system admin user + When checking access to system admin endpoint + Then access should be granted and return system_admin role + """ + user = DBUser(id=1, email="admin@test.com", is_admin=True, team_id=None, role=None) + dependency = require_system_admin() + + result = dependency.check_access(user) + assert result == UserRole.SYSTEM_ADMIN + + def test_team_admin_access(self): + """ + Given a team admin user + When checking access to team admin endpoint + Then access should be granted and return admin role + """ + user = DBUser(id=1, email="teamadmin@test.com", is_admin=False, team_id=1, role="admin") + dependency = require_team_admin() + + result = dependency.check_access(user) + assert result == UserRole.TEAM_ADMIN + + def test_key_creator_access(self): + """ + Given a key creator user + When checking access to key creator endpoint + Then access should be granted and return key_creator role + """ + user = DBUser(id=1, email="keycreator@test.com", is_admin=False, team_id=1, role="key_creator") + dependency = require_key_creator_or_higher() + + result = dependency.check_access(user) + assert result == UserRole.KEY_CREATOR + + def test_read_only_access(self): + """ + Given a read only user + When checking access to read only endpoint + Then access should be granted and return read_only role + """ + user = DBUser(id=1, email="readonly@test.com", is_admin=False, team_id=1, role="read_only") + dependency = require_read_only_or_higher() + + result = dependency.check_access(user) + assert result == UserRole.READ_ONLY + + def test_system_user_cannot_be_team_member(self): + """ + Given a system admin user with team_id set + When checking access to any endpoint + Then access should be denied due to invalid user type + """ + user = DBUser(id=1, email="admin@test.com", is_admin=True, team_id=1, role=None) + dependency = require_system_admin() + + with pytest.raises(HTTPException) as exc_info: + dependency.check_access(user) + + assert exc_info.value.status_code == 403 + assert "Not authorized to perform this action" in str(exc_info.value.detail) + + def test_team_user_must_be_team_member(self): + """ + Given a team user without team_id + When checking access to team endpoint + Then access should be denied due to invalid user type + """ + user = DBUser(id=1, email="user@test.com", is_admin=False, team_id=None, role="admin") + dependency = require_team_admin() + + with pytest.raises(HTTPException) as exc_info: + dependency.check_access(user) + + assert exc_info.value.status_code == 403 + assert "Not authorized to perform this action" in str(exc_info.value.detail) + + def test_insufficient_permissions(self): + """ + Given a read only user + When checking access to team admin endpoint + Then access should be denied due to insufficient permissions + """ + user = DBUser(id=1, email="readonly@test.com", is_admin=False, team_id=1, role="read_only") + dependency = require_team_admin() + + with pytest.raises(HTTPException) as exc_info: + dependency.check_access(user) + + assert exc_info.value.status_code == 403 + assert "Not authorized to perform this action" in str(exc_info.value.detail) + + def test_team_membership_required(self): + """ + Given a system admin + When checking access to team endpoint + Then access should be granted because system admins can do anything + """ + user = DBUser(id=1, email="admin@test.com", is_admin=True, team_id=None, role=None) + dependency = require_team_admin() + + result = dependency.check_access(user) + assert result == UserRole.SYSTEM_ADMIN + + def test_sales_user_access(self): + """ + Given a sales user + When checking access to sales endpoint + Then access should be granted and return sales role + """ + user = DBUser(id=1, email="sales@test.com", is_admin=False, team_id=None, role="sales") + dependency = require_sales_or_higher() + + result = dependency.check_access(user) + assert result == UserRole.SALES + + def test_system_admin_access_sales_endpoint(self): + """ + Given a system admin user + When checking access to sales endpoint + Then access should be granted and return system_admin role + """ + user = DBUser(id=1, email="admin@test.com", is_admin=True, team_id=None, role=None) + dependency = require_sales_or_higher() + + result = dependency.check_access(user) + assert result == UserRole.SYSTEM_ADMIN + + def test_regular_user_denied_sales_access(self): + """ + Given a regular system user + When checking access to sales endpoint + Then access should be denied due to insufficient permissions + """ + user = DBUser(id=1, email="user@test.com", is_admin=False, team_id=None, role="user") + dependency = require_sales_or_higher() + + with pytest.raises(HTTPException) as exc_info: + dependency.check_access(user) + + assert exc_info.value.status_code == 403 + assert "Not authorized to perform this action" in str(exc_info.value.detail) + + def test_team_user_denied_sales_access(self): + """ + Given a team user + When checking access to sales endpoint + Then access should be denied due to invalid user type + """ + user = DBUser(id=1, email="teamuser@test.com", is_admin=False, team_id=1, role="admin") + dependency = require_sales_or_higher() + + with pytest.raises(HTTPException) as exc_info: + dependency.check_access(user) + + assert exc_info.value.status_code == 403 + assert "Not authorized to perform this action" in str(exc_info.value.detail) + +class TestUserRole: + """Test UserRole class functionality""" + + def test_system_roles(self): + """ + Given various roles + When checking if they are system roles + Then only system roles should return True + """ + assert UserRole.is_system_role("system_admin") + assert UserRole.is_system_role("user") + assert UserRole.is_system_role("sales") + assert not UserRole.is_system_role("admin") + assert not UserRole.is_system_role("key_creator") + assert not UserRole.is_system_role("read_only") + + def test_team_roles(self): + """ + Given various roles + When checking if they are team roles + Then only team roles should return True + """ + assert UserRole.is_team_role("admin") + assert UserRole.is_team_role("key_creator") + assert UserRole.is_team_role("read_only") + assert not UserRole.is_team_role("system_admin") + assert not UserRole.is_team_role("user") + assert not UserRole.is_team_role("sales") + + def test_get_system_roles(self): + """ + Given the UserRole class + When getting system roles + Then it should return all valid system roles + """ + roles = UserRole.get_system_roles() + assert "system_admin" in roles + assert "user" in roles + assert "sales" in roles + assert len(roles) == 3 + + def test_get_team_roles(self): + """ + Given the UserRole class + When getting team roles + Then it should return all valid team roles + """ + roles = UserRole.get_team_roles() + assert "admin" in roles + assert "key_creator" in roles + assert "read_only" in roles + assert len(roles) == 3 + + def test_get_all_roles(self): + """ + Given the UserRole class + When getting all roles + Then it should return both system and team roles + """ + roles = UserRole.get_all_roles() + assert len(roles) == 6 + assert "system_admin" in roles + assert "user" in roles + assert "sales" in roles + assert "admin" in roles + assert "key_creator" in roles + assert "read_only" in roles diff --git a/tests/test_sales_api.py b/tests/test_sales_api.py new file mode 100644 index 0000000..ed29291 --- /dev/null +++ b/tests/test_sales_api.py @@ -0,0 +1,537 @@ +""" +Tests for the consolidated sales API endpoint. +""" +import pytest +from datetime import datetime, UTC, timedelta +from unittest.mock import patch, AsyncMock +from sqlalchemy.orm import Session +from app.db.models import DBTeam, DBProduct, DBTeamProduct, DBPrivateAIKey, DBRegion, DBUser + + +@pytest.fixture +def test_ai_key(db: Session, test_team: DBTeam, test_region: DBRegion) -> DBPrivateAIKey: + """Create a test AI key.""" + ai_key = DBPrivateAIKey( + database_name="test-db", + name="Test Key", + database_host="test-host", + database_username="test-user", + database_password="test-pass", + litellm_token="test-token", + litellm_api_url="https://test-litellm.com", + owner_id=None, + team_id=test_team.id, + region_id=test_region.id + ) + db.add(ai_key) + db.commit() + db.refresh(ai_key) + return ai_key + + +@pytest.fixture +def test_always_free_team(db: Session) -> DBTeam: + """Create a test always-free team.""" + team = DBTeam( + name="Always Free Team", + admin_email="free@test.com", + is_active=True, + is_always_free=True, + created_at=datetime.now(UTC) - timedelta(days=45), # 45 days ago + last_payment=None + ) + db.add(team) + db.commit() + db.refresh(team) + return team + + +@pytest.fixture +def test_paid_team(db: Session) -> DBTeam: + """Create a test team with payment history.""" + team = DBTeam( + name="Paid Team", + admin_email="paid@test.com", + is_active=True, + is_always_free=False, + created_at=datetime.now(UTC) - timedelta(days=60), # 60 days ago + last_payment=datetime.now(UTC) - timedelta(days=5) # 5 days ago + ) + db.add(team) + db.commit() + db.refresh(team) + return team + + +@pytest.fixture +def mock_litellm_response(): + """Mock LiteLLM API response.""" + return { + "info": { + "spend": 25.50, + "expires": "2024-12-31T23:59:59Z", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-02T00:00:00Z", + "max_budget": 100.0, + "budget_duration": "monthly", + "budget_reset_at": "2024-02-01T00:00:00Z" + } + } + + +@pytest.fixture +def test_user_owned_ai_key(db: Session, test_team_user: DBUser, test_region: DBRegion) -> DBPrivateAIKey: + """Create a test AI key owned by a team member.""" + ai_key = DBPrivateAIKey( + database_name="user-db", + name="User Key", + database_host="user-host", + database_username="user-user", + database_password="user-pass", + litellm_token="user-token", + litellm_api_url="https://user-litellm.com", + owner_id=test_team_user.id, + team_id=None, # Not team-owned, user-owned + region_id=test_region.id + ) + db.add(ai_key) + db.commit() + db.refresh(ai_key) + return ai_key + + +def test_list_teams_for_sales_requires_admin(client, test_team): + """Test that only system admins can access the sales endpoint.""" + response = client.get("/teams/sales/list-teams") + assert response.status_code == 401 # Unauthorized + + +def test_list_teams_for_sales_success(client, admin_token, test_team, test_product, + test_region, test_ai_key, mock_litellm_response, db): + """Test successful retrieval of sales data.""" + # Create team-product association + team_product = DBTeamProduct( + team_id=test_team.id, + product_id=test_product.id + ) + db.add(team_product) + db.commit() + + # Mock LiteLLM service + with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info: + mock_get_info.return_value = mock_litellm_response + + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) + + assert response.status_code == 200 + data = response.json() + assert "teams" in data + assert len(data["teams"]) == 1 + + team_data = data["teams"][0] + assert team_data["id"] == test_team.id + assert team_data["name"] == test_team.name + assert team_data["admin_email"] == test_team.admin_email + assert team_data["is_always_free"] == False + assert len(team_data["products"]) == 1 + assert team_data["products"][0]["id"] == test_product.id + assert team_data["products"][0]["name"] == test_product.name + assert team_data["products"][0]["active"] == True + assert len(team_data["regions"]) == 1 + assert team_data["regions"][0] == test_region.name + assert team_data["total_spend"] == 25.50 + assert team_data["trial_status"] == "Active Product" + + +def test_always_free_team_trial_status(client, admin_token, test_always_free_team, + test_product, test_region, test_ai_key, db): + """Test that always-free teams show correct trial status.""" + # Create team-product association + team_product = DBTeamProduct( + team_id=test_always_free_team.id, + product_id=test_product.id + ) + db.add(team_product) + db.commit() + + # Mock LiteLLM service + with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info: + mock_get_info.return_value = {"info": {"spend": 0.0}} + + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) + + assert response.status_code == 200 + data = response.json() + team_data = data["teams"][0] + assert team_data["trial_status"] == "Always Free" + + +def test_paid_team_trial_status(client, admin_token, test_paid_team, + test_product, test_region, test_ai_key, db): + """Test that teams with payment history show correct trial status.""" + # Create team-product association + team_product = DBTeamProduct( + team_id=test_paid_team.id, + product_id=test_product.id + ) + db.add(team_product) + db.commit() + + # Mock LiteLLM service + with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info: + mock_get_info.return_value = {"info": {"spend": 0.0}} + + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) + + assert response.status_code == 200 + data = response.json() + assert "teams" in data + team_data = data["teams"][0] + assert team_data["trial_status"] == "Active Product" + + +def test_team_without_products(client, admin_token, test_team, test_region, test_ai_key, db): + """Test team without any products shows trial status based on creation date.""" + # Mock LiteLLM service + with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info: + mock_get_info.return_value = {"info": {"spend": 0.0}} + + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) + + assert response.status_code == 200 + data = response.json() + assert "teams" in data + team_data = data["teams"][0] + # Team without products should show trial status based on creation date + assert team_data["trial_status"] == "30 days left" + assert len(team_data["products"]) == 0 + + +def test_team_with_multiple_ai_keys(client, admin_token, test_team, test_product, + test_region, test_ai_key, db): + """Test team with multiple AI keys aggregates spend correctly.""" + # Create second AI key + ai_key2 = DBPrivateAIKey( + database_name="test-db-2", + name="Test Key 2", + database_host="test-host", + database_username="test-user", + database_password="test-pass", + litellm_token="test-token-2", + litellm_api_url="https://test-litellm.com", + owner_id=None, + team_id=test_team.id, + region_id=test_region.id + ) + db.add(ai_key2) + db.commit() + + # Create team-product association + team_product = DBTeamProduct( + team_id=test_team.id, + product_id=test_product.id + ) + db.add(team_product) + db.commit() + + # Mock LiteLLM service with different spend values + with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info: + mock_get_info.side_effect = [ + {"info": {"spend": 25.50}}, + {"info": {"spend": 15.25}} + ] + + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) + + assert response.status_code == 200 + data = response.json() + team_data = data["teams"][0] + assert team_data["total_spend"] == 40.75 # 25.50 + 15.25 + + +def test_litellm_service_error_handling(client, admin_token, test_team, test_product, + test_region, test_ai_key, db): + """Test that LiteLLM service errors don't break the entire response.""" + # Create team-product association + team_product = DBTeamProduct( + team_id=test_team.id, + product_id=test_product.id + ) + db.add(team_product) + db.commit() + + # Mock LiteLLM service to raise an exception + with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info: + mock_get_info.side_effect = Exception("LiteLLM service error") + + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) + + # Should still succeed but with 0 spend + assert response.status_code == 200 + data = response.json() + team_data = data["teams"][0] + assert team_data["total_spend"] == 0.0 + + +def test_expired_trial_status(client, admin_token, test_team, test_region, test_ai_key, db): + """Test that expired trials show correct status.""" + # Update team to be older than 30 days (no products, so should show expired) + test_team.created_at = datetime.now(UTC) - timedelta(days=35) + db.commit() + + # Mock LiteLLM service + with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info: + mock_get_info.return_value = {"info": {"spend": 0.0}} + + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) + + assert response.status_code == 200 + data = response.json() + team_data = data["teams"][0] + assert team_data["trial_status"] == "Expired" + + +def test_list_teams_for_sales_includes_user_owned_keys(client, admin_token, test_team, test_product, + test_region, test_ai_key, test_team_user, + test_user_owned_ai_key, mock_litellm_response, db): + """Test that sales data includes both team-owned and user-owned AI keys.""" + # Create team-product association + team_product = DBTeamProduct( + team_id=test_team.id, + product_id=test_product.id + ) + db.add(team_product) + db.commit() + + # Mock LiteLLM service to return different spend for each key + with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info: + # Return different spend values for team key vs user key + mock_get_info.side_effect = [ + {"info": {"spend": 25.50}}, # Team key spend + {"info": {"spend": 15.25}} # User key spend + ] + + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) + + assert response.status_code == 200 + data = response.json() + assert "teams" in data + assert len(data["teams"]) == 1 + + team_data = data["teams"][0] + assert team_data["id"] == test_team.id + + # Should include both keys in total spend (25.50 + 15.25 = 40.75) + assert team_data["total_spend"] == 40.75 + + # Should include region from both keys + assert len(team_data["regions"]) == 1 + assert team_data["regions"][0] == test_region.name + + +def test_team_with_15_days_remaining(client, admin_token, test_team, test_region, test_ai_key, db): + """Test that teams with 15 days remaining show correct format.""" + # Update team to be 15 days old + test_team.created_at = datetime.now(UTC) - timedelta(days=15) + db.commit() + + # Mock LiteLLM service + with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info: + mock_get_info.return_value = {"info": {"spend": 0.0}} + + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) + + assert response.status_code == 200 + data = response.json() + team_data = data["teams"][0] + # Should show exact days remaining, not "In Progress" + assert team_data["trial_status"] == "15 days left" + + +def test_team_with_7_days_remaining(client, admin_token, test_team, test_region, test_ai_key, db): + """Test that teams with 7 days remaining show correct format.""" + # Update team to be 23 days old (7 days remaining) + test_team.created_at = datetime.now(UTC) - timedelta(days=23) + db.commit() + + # Mock LiteLLM service + with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info: + mock_get_info.return_value = {"info": {"spend": 0.0}} + + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) + + assert response.status_code == 200 + data = response.json() + team_data = data["teams"][0] + # Should show exact days remaining + assert team_data["trial_status"] == "7 days left" + + +def test_team_with_last_payment_days_calculation(client, admin_token, test_team, test_region, test_ai_key, db): + """Test that teams with last_payment calculate days remaining from payment date.""" + # Set last_payment to 10 days ago (so 20 days remaining from payment) + test_team.last_payment = datetime.now(UTC) - timedelta(days=10) + db.commit() + + # Mock LiteLLM service + with patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) as mock_get_info: + mock_get_info.return_value = {"info": {"spend": 0.0}} + + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) + + assert response.status_code == 200 + data = response.json() + team_data = data["teams"][0] + # Should calculate from last_payment date (30 - 10 = 20 days remaining) + assert team_data["trial_status"] == "20 days left" + + +@patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) +def test_list_teams_for_sales_unreachable_endpoints_logging(mock_get_info, client, admin_token, test_team, test_product, + test_region, test_ai_key, db): + """ + Given a team with AI keys in a region where LiteLLM endpoint is unreachable + When calling list_teams_for_sales + Then the function should default spend to 0 and still return successful response + """ + # Create team-product association + team_product = DBTeamProduct( + team_id=test_team.id, + product_id=test_product.id + ) + db.add(team_product) + + # Create a second AI key in the same region to test multiple keys + ai_key2 = DBPrivateAIKey( + database_name="test-db-2", + name="Test Key 2", + database_host="test-host", + database_username="test-user", + database_password="test-pass", + litellm_token="test-token-2", + litellm_api_url="https://test-litellm.com", + owner_id=None, + team_id=test_team.id, + region_id=test_region.id + ) + db.add(ai_key2) + db.commit() + + # Mock LiteLLM service to raise an exception (simulating unreachable endpoint) + mock_get_info.side_effect = Exception("Connection timeout") + + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) + + # Should still succeed despite LiteLLM errors + assert response.status_code == 200 + data = response.json() + team_data = data["teams"][0] + + # Spend should default to 0 for failed litellm calls + assert team_data["total_spend"] == 0.0 + + +@patch('app.services.litellm.LiteLLMService.get_key_info', new_callable=AsyncMock) +def test_list_teams_for_sales_multiple_unreachable_regions(mock_get_info, client, admin_token, test_team, test_product, + test_region, test_ai_key, db): + """ + Given a team with AI keys in multiple regions where some endpoints are unreachable + When calling list_teams_for_sales + Then the function should default spend to 0 and include all regions in response + """ + # Create a second region + region2 = DBRegion( + name="Test Region 2", + postgres_host="test-host-2", + postgres_port=5432, + postgres_admin_user="test-user-2", + postgres_admin_password="test-pass-2", + litellm_api_url="https://test-litellm-2.com", + litellm_api_key="test-key-2", + is_active=True + ) + db.add(region2) + db.commit() + + # Create team-product association + team_product = DBTeamProduct( + team_id=test_team.id, + product_id=test_product.id + ) + db.add(team_product) + + # Create AI key in second region + ai_key2 = DBPrivateAIKey( + database_name="test-db-2", + name="Test Key 2", + database_host="test-host-2", + database_username="test-user-2", + database_password="test-pass-2", + litellm_token="test-token-2", + litellm_api_url="https://test-litellm-2.com", + owner_id=None, + team_id=test_team.id, + region_id=region2.id + ) + db.add(ai_key2) + db.commit() + + # Mock LiteLLM service to raise different exceptions for different regions + mock_get_info.side_effect = [ + Exception("Connection timeout"), # First region fails + Exception("Service unavailable") # Second region fails + ] + + response = client.get( + "/teams/sales/list-teams", + headers={"Authorization": f"Bearer {admin_token}"} + ) + + # Should still succeed despite LiteLLM errors + assert response.status_code == 200 + data = response.json() + team_data = data["teams"][0] + + # Spend should default to 0 for failed litellm calls + assert team_data["total_spend"] == 0.0 + + # Should include both regions in the response + assert len(team_data["regions"]) == 2 + assert test_region.name in team_data["regions"] + assert region2.name in team_data["regions"] diff --git a/tests/test_security.py b/tests/test_security.py new file mode 100644 index 0000000..21178df --- /dev/null +++ b/tests/test_security.py @@ -0,0 +1,91 @@ +import pytest +from unittest.mock import AsyncMock, patch +from fastapi import HTTPException +from app.core.security import check_sales_or_higher, get_role_min_system_admin +from app.core.roles import UserRole +from app.db.models import DBUser + + +class TestSecurityFunctions: + """Test security function functionality""" + + @pytest.mark.asyncio + async def test_check_sales_or_higher_sales_user(self): + """ + Given a sales user + When calling check_sales_or_higher + Then it should return sales role + """ + user = DBUser(id=1, email="sales@test.com", is_admin=False, team_id=None, role="sales") + + result = await check_sales_or_higher(current_user=user) + assert result == UserRole.SALES + + @pytest.mark.asyncio + async def test_check_sales_or_higher_system_admin(self): + """ + Given a system admin user + When calling check_sales_or_higher + Then it should return system_admin role + """ + user = DBUser(id=1, email="admin@test.com", is_admin=True, team_id=None, role=None) + + result = await check_sales_or_higher(current_user=user) + assert result == UserRole.SYSTEM_ADMIN + + @pytest.mark.asyncio + async def test_check_sales_or_higher_regular_user_denied(self): + """ + Given a regular system user + When calling check_sales_or_higher + Then it should raise HTTPException with 403 status + """ + user = DBUser(id=1, email="user@test.com", is_admin=False, team_id=None, role="user") + + with pytest.raises(HTTPException) as exc_info: + await check_sales_or_higher(current_user=user) + + assert exc_info.value.status_code == 403 + assert "Not authorized to perform this action" in str(exc_info.value.detail) + + @pytest.mark.asyncio + async def test_check_sales_or_higher_team_user_denied(self): + """ + Given a team user + When calling check_sales_or_higher + Then it should raise HTTPException with 403 status + """ + user = DBUser(id=1, email="teamuser@test.com", is_admin=False, team_id=1, role="admin") + + with pytest.raises(HTTPException) as exc_info: + await check_sales_or_higher(current_user=user) + + assert exc_info.value.status_code == 403 + assert "Not authorized to perform this action" in str(exc_info.value.detail) + + @pytest.mark.asyncio + async def test_check_system_admin_system_admin(self): + """ + Given a system admin user + When calling check_system_admin + Then it should return system_admin role + """ + user = DBUser(id=1, email="admin@test.com", is_admin=True, team_id=None, role=None) + + result = await get_role_min_system_admin(current_user=user) + assert result == UserRole.SYSTEM_ADMIN + + @pytest.mark.asyncio + async def test_check_system_admin_regular_user_denied(self): + """ + Given a regular system user + When calling check_system_admin + Then it should raise HTTPException with 403 status + """ + user = DBUser(id=1, email="user@test.com", is_admin=False, team_id=None, role="user") + + with pytest.raises(HTTPException) as exc_info: + await get_role_min_system_admin(current_user=user) + + assert exc_info.value.status_code == 403 + assert "Not authorized to perform this action" in str(exc_info.value.detail)