Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions app/api/billing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from datetime import datetime, UTC
from app.db.database import get_db
from app.core.security import check_specific_team_admin, check_system_admin
from app.core.security import get_role_min_specific_team_admin, get_role_min_system_admin
from app.db.models import DBTeam, DBSystemSecret, DBProduct, DBTeamProduct
from app.schemas.models import PricingTableSession, SubscriptionCreate, SubscriptionResponse
from app.services.stripe import (
Expand Down Expand Up @@ -91,7 +91,7 @@ async def handle_events(
detail="Error processing webhook"
)

@router.post("/teams/{team_id}/portal", dependencies=[Depends(check_specific_team_admin)])
@router.post("/teams/{team_id}/portal", dependencies=[Depends(get_role_min_specific_team_admin)])
async def get_portal(
team_id: int,
db: Session = Depends(get_db)
Expand Down Expand Up @@ -135,7 +135,7 @@ async def get_portal(
detail="Error creating portal session"
)

@router.get("/teams/{team_id}/pricing-table-session", dependencies=[Depends(check_specific_team_admin)], response_model=PricingTableSession)
@router.get("/teams/{team_id}/pricing-table-session", dependencies=[Depends(get_role_min_specific_team_admin)], response_model=PricingTableSession)
async def get_pricing_table_session(
team_id: int,
db: Session = Depends(get_db)
Expand Down Expand Up @@ -178,7 +178,7 @@ async def get_pricing_table_session(
detail="Error creating customer session"
)

@router.post("/teams/{team_id}/subscriptions", dependencies=[Depends(check_system_admin)], response_model=SubscriptionResponse, status_code=status.HTTP_201_CREATED)
@router.post("/teams/{team_id}/subscriptions", dependencies=[Depends(get_role_min_system_admin)], response_model=SubscriptionResponse, status_code=status.HTTP_201_CREATED)
async def create_team_subscription(
team_id: int,
subscription_data: SubscriptionCreate,
Expand Down
12 changes: 6 additions & 6 deletions app/api/pricing_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from app.db.database import get_db
from app.db.models import DBTeam, DBPricingTable
from app.core.security import check_system_admin, get_role_min_team_admin, get_current_user_from_auth
from app.core.security import get_role_min_system_admin, get_role_min_team_admin, get_current_user_from_auth
from app.schemas.models import PricingTableCreate, PricingTableResponse, PricingTablesResponse
from app.core.config import settings

Expand All @@ -16,8 +16,8 @@
tags=["pricing-tables"]
)

@router.post("", response_model=PricingTableResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(check_system_admin)])
@router.post("/", response_model=PricingTableResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(check_system_admin)])
@router.post("", response_model=PricingTableResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(get_role_min_system_admin)])
@router.post("/", response_model=PricingTableResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(get_role_min_system_admin)])
async def create_pricing_table(
pricing_table: PricingTableCreate,
db: Session = Depends(get_db)
Expand Down Expand Up @@ -113,8 +113,8 @@ async def get_pricing_table(
updated_at=pricing_table.updated_at or pricing_table.created_at
)

@router.delete("", dependencies=[Depends(check_system_admin)])
@router.delete("/", dependencies=[Depends(check_system_admin)])
@router.delete("", dependencies=[Depends(get_role_min_system_admin)])
@router.delete("/", dependencies=[Depends(get_role_min_system_admin)])
async def delete_pricing_table(
table_type: str,
db: Session = Depends(get_db)
Expand Down Expand Up @@ -146,7 +146,7 @@ async def delete_pricing_table(

return {"message": f"Pricing table of type '{table_type}' deleted successfully"}

@router.get("/list", response_model=PricingTablesResponse, dependencies=[Depends(check_system_admin)])
@router.get("/list", response_model=PricingTablesResponse, dependencies=[Depends(get_role_min_system_admin)])
async def get_all_pricing_tables(
db: Session = Depends(get_db)
):
Expand Down
27 changes: 20 additions & 7 deletions app/api/private_ai_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,22 @@
from app.db.postgres import PostgresManager
from app.db.models import DBPrivateAIKey, DBRegion, DBUser, DBTeam
from app.services.litellm import LiteLLMService
from app.core.security import get_current_user_from_auth, get_role_min_key_creator, get_role_min_team_admin, UserRole, check_system_admin
from app.core.security import (
get_current_user_from_auth,
get_role_min_team_admin,
get_private_ai_access,
get_role_min_system_admin
)
from app.core.roles import UserRole
from app.core.config import settings
from app.core.resource_limits import check_key_limits, check_vector_db_limits, get_token_restrictions, DEFAULT_KEY_DURATION, DEFAULT_MAX_SPEND, DEFAULT_RPM_PER_KEY
from app.core.resource_limits import (
check_key_limits,
check_vector_db_limits,
get_token_restrictions,
DEFAULT_KEY_DURATION,
DEFAULT_MAX_SPEND,
DEFAULT_RPM_PER_KEY
)

router = APIRouter(
tags=["private-ai-keys"]
Expand Down Expand Up @@ -68,7 +81,7 @@ def _validate_permissions_and_get_ownership_info(
async def create_vector_db(
vector_db: VectorDBCreate,
current_user = Depends(get_current_user_from_auth),
user_role: UserRole = Depends(get_role_min_key_creator),
user_role: UserRole = Depends(get_private_ai_access),
db: Session = Depends(get_db),
store_result: bool = True
):
Expand Down Expand Up @@ -164,7 +177,7 @@ async def create_vector_db(
async def create_private_ai_key(
private_ai_key: PrivateAIKeyCreate,
current_user = Depends(get_current_user_from_auth),
user_role: UserRole = Depends(get_role_min_key_creator),
user_role: UserRole = Depends(get_private_ai_access),
db: Session = Depends(get_db)
):
"""
Expand Down Expand Up @@ -279,7 +292,7 @@ async def create_private_ai_key(
async def create_llm_token(
private_ai_key: PrivateAIKeyCreate,
current_user = Depends(get_current_user_from_auth),
user_role: UserRole = Depends(get_role_min_key_creator),
user_role: UserRole = Depends(get_private_ai_access),
db: Session = Depends(get_db),
store_result: bool = True
):
Expand Down Expand Up @@ -464,7 +477,7 @@ async def list_private_ai_keys(
private_ai_keys = query.all()
return [key.to_dict() for key in private_ai_keys]

@router.get("/{key_id}", response_model=PrivateAIKeyDetail, dependencies=[Depends(check_system_admin)])
@router.get("/{key_id}", response_model=PrivateAIKeyDetail, dependencies=[Depends(get_role_min_system_admin)])
async def get_private_ai_key(
key_id: int,
current_user = Depends(get_current_user_from_auth),
Expand Down Expand Up @@ -571,7 +584,7 @@ def _get_key_if_allowed(key_id: int, current_user: DBUser, user_role: UserRole,
async def delete_private_ai_key(
key_id: int,
current_user = Depends(get_current_user_from_auth),
user_role: UserRole = Depends(get_role_min_key_creator),
user_role: UserRole = Depends(get_private_ai_access),
db: Session = Depends(get_db)
):
private_ai_key = _get_key_if_allowed(key_id, current_user, user_role, db)
Expand Down
10 changes: 5 additions & 5 deletions app/api/products.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@

from app.db.database import get_db
from app.db.models import DBProduct, DBTeamProduct, DBTeam
from app.core.security import check_system_admin, get_current_user_from_auth, get_role_min_team_admin
from app.core.security import get_role_min_system_admin, get_current_user_from_auth, get_role_min_team_admin
from app.schemas.models import Product, ProductCreate, ProductUpdate

router = APIRouter(
tags=["products"]
)

@router.post("", response_model=Product, status_code=status.HTTP_201_CREATED, dependencies=[Depends(check_system_admin)])
@router.post("/", response_model=Product, status_code=status.HTTP_201_CREATED, dependencies=[Depends(check_system_admin)])
@router.post("", response_model=Product, status_code=status.HTTP_201_CREATED, dependencies=[Depends(get_role_min_system_admin)])
@router.post("/", response_model=Product, status_code=status.HTTP_201_CREATED, dependencies=[Depends(get_role_min_system_admin)])
async def create_product(
product: ProductCreate,
db: Session = Depends(get_db)
Expand Down Expand Up @@ -105,7 +105,7 @@ async def get_product(
)
return product

@router.put("/{product_id}", response_model=Product, dependencies=[Depends(check_system_admin)])
@router.put("/{product_id}", response_model=Product, dependencies=[Depends(get_role_min_system_admin)])
async def update_product(
product_id: str,
product_update: ProductUpdate,
Expand All @@ -132,7 +132,7 @@ async def update_product(

return product

@router.delete("/{product_id}", dependencies=[Depends(check_system_admin)])
@router.delete("/{product_id}", dependencies=[Depends(get_role_min_system_admin)])
async def delete_product(
product_id: str,
db: Session = Depends(get_db)
Expand Down
20 changes: 10 additions & 10 deletions app/api/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from app.api.auth import get_current_user_from_auth
from app.schemas.models import Region, RegionCreate, RegionResponse, User, RegionUpdate, TeamSummary
from app.db.models import DBRegion, DBPrivateAIKey, DBTeamRegion, DBTeam
from app.core.security import check_system_admin
from app.core.security import get_role_min_system_admin

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -89,8 +89,8 @@ async def validate_database_connection(host: str, port: int, user: str, password
detail=f"Database connection validation failed: {str(e)}"
)

@router.post("", response_model=Region, dependencies=[Depends(check_system_admin)])
@router.post("/", response_model=Region, dependencies=[Depends(check_system_admin)])
@router.post("", response_model=Region, dependencies=[Depends(get_role_min_system_admin)])
@router.post("/", response_model=Region, dependencies=[Depends(get_role_min_system_admin)])
async def create_region(
region: RegionCreate,
db: Session = Depends(get_db)
Expand Down Expand Up @@ -158,13 +158,13 @@ async def list_regions(

return non_dedicated_regions + team_dedicated_regions

@router.get("/admin", response_model=List[Region], dependencies=[Depends(check_system_admin)])
@router.get("/admin", response_model=List[Region], dependencies=[Depends(get_role_min_system_admin)])
async def list_admin_regions(
db: Session = Depends(get_db)
):
return db.query(DBRegion).all()

@router.get("/{region_id}", response_model=RegionResponse, dependencies=[Depends(check_system_admin)])
@router.get("/{region_id}", response_model=RegionResponse, dependencies=[Depends(get_role_min_system_admin)])
async def get_region(
region_id: int,
db: Session = Depends(get_db)
Expand All @@ -177,7 +177,7 @@ async def get_region(
)
return region

@router.delete("/{region_id}", dependencies=[Depends(check_system_admin)])
@router.delete("/{region_id}", dependencies=[Depends(get_role_min_system_admin)])
async def delete_region(
region_id: int,
db: Session = Depends(get_db)
Expand Down Expand Up @@ -209,7 +209,7 @@ async def delete_region(
)
return {"message": "Region deleted successfully"}

@router.put("/{region_id}", response_model=Region, dependencies=[Depends(check_system_admin)])
@router.put("/{region_id}", response_model=Region, dependencies=[Depends(get_role_min_system_admin)])
async def update_region(
region_id: int,
region: RegionUpdate,
Expand Down Expand Up @@ -251,7 +251,7 @@ async def update_region(
)
return db_region

@router.post("/{region_id}/teams/{team_id}", dependencies=[Depends(check_system_admin)])
@router.post("/{region_id}/teams/{team_id}", dependencies=[Depends(get_role_min_system_admin)])
async def associate_team_with_region(
region_id: int,
team_id: int,
Expand Down Expand Up @@ -311,7 +311,7 @@ async def associate_team_with_region(

return {"message": "Team associated with region successfully"}

@router.delete("/{region_id}/teams/{team_id}", dependencies=[Depends(check_system_admin)])
@router.delete("/{region_id}/teams/{team_id}", dependencies=[Depends(get_role_min_system_admin)])
async def disassociate_team_from_region(
region_id: int,
team_id: int,
Expand Down Expand Up @@ -345,7 +345,7 @@ async def disassociate_team_from_region(

return {"message": "Team disassociated from region successfully"}

@router.get("/{region_id}/teams", response_model=List[TeamSummary], dependencies=[Depends(check_system_admin)])
@router.get("/{region_id}/teams", response_model=List[TeamSummary], dependencies=[Depends(get_role_min_system_admin)])
async def list_teams_for_region(
region_id: int,
db: Session = Depends(get_db)
Expand Down
Loading