Skip to content

Commit c204d41

Browse files
committed
Refactor Security Role Checks
- Replaced direct role checks with new role management functions across multiple API endpoints to enhance consistency and maintainability. - Updated dependencies in billing.py, pricing_tables.py, private_ai_keys.py, products.py, regions.py, teams.py, users.py to utilize get_role_min_* functions. - Streamlined role validation logic in rbac.py and roles.py for improved clarity and organization. - Removed deprecated role assignment validation methods from tests, ensuring alignment with the new role management structure.
1 parent 34ddeb8 commit c204d41

File tree

12 files changed

+58
-151
lines changed

12 files changed

+58
-151
lines changed

app/api/billing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
from datetime import datetime, UTC
66
from app.db.database import get_db
7-
from app.core.security import check_specific_team_admin, check_system_admin
7+
from app.core.security import get_role_min_specific_team_admin, get_role_min_system_admin
88
from app.db.models import DBTeam, DBSystemSecret, DBProduct, DBTeamProduct
99
from app.schemas.models import PricingTableSession, SubscriptionCreate, SubscriptionResponse
1010
from app.services.stripe import (
@@ -91,7 +91,7 @@ async def handle_events(
9191
detail="Error processing webhook"
9292
)
9393

94-
@router.post("/teams/{team_id}/portal", dependencies=[Depends(check_specific_team_admin)])
94+
@router.post("/teams/{team_id}/portal", dependencies=[Depends(get_role_min_specific_team_admin)])
9595
async def get_portal(
9696
team_id: int,
9797
db: Session = Depends(get_db)
@@ -135,7 +135,7 @@ async def get_portal(
135135
detail="Error creating portal session"
136136
)
137137

138-
@router.get("/teams/{team_id}/pricing-table-session", dependencies=[Depends(check_specific_team_admin)], response_model=PricingTableSession)
138+
@router.get("/teams/{team_id}/pricing-table-session", dependencies=[Depends(get_role_min_specific_team_admin)], response_model=PricingTableSession)
139139
async def get_pricing_table_session(
140140
team_id: int,
141141
db: Session = Depends(get_db)
@@ -178,7 +178,7 @@ async def get_pricing_table_session(
178178
detail="Error creating customer session"
179179
)
180180

181-
@router.post("/teams/{team_id}/subscriptions", dependencies=[Depends(check_system_admin)], response_model=SubscriptionResponse, status_code=status.HTTP_201_CREATED)
181+
@router.post("/teams/{team_id}/subscriptions", dependencies=[Depends(get_role_min_system_admin)], response_model=SubscriptionResponse, status_code=status.HTTP_201_CREATED)
182182
async def create_team_subscription(
183183
team_id: int,
184184
subscription_data: SubscriptionCreate,

app/api/pricing_tables.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from app.db.database import get_db
77
from app.db.models import DBTeam, DBPricingTable
8-
from app.core.security import check_system_admin, get_role_min_team_admin, get_current_user_from_auth
8+
from app.core.security import get_role_min_system_admin, get_role_min_team_admin, get_current_user_from_auth
99
from app.schemas.models import PricingTableCreate, PricingTableResponse, PricingTablesResponse
1010
from app.core.config import settings
1111

@@ -16,8 +16,8 @@
1616
tags=["pricing-tables"]
1717
)
1818

19-
@router.post("", response_model=PricingTableResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(check_system_admin)])
20-
@router.post("/", response_model=PricingTableResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(check_system_admin)])
19+
@router.post("", response_model=PricingTableResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(get_role_min_system_admin)])
20+
@router.post("/", response_model=PricingTableResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(get_role_min_system_admin)])
2121
async def create_pricing_table(
2222
pricing_table: PricingTableCreate,
2323
db: Session = Depends(get_db)
@@ -113,8 +113,8 @@ async def get_pricing_table(
113113
updated_at=pricing_table.updated_at or pricing_table.created_at
114114
)
115115

116-
@router.delete("", dependencies=[Depends(check_system_admin)])
117-
@router.delete("/", dependencies=[Depends(check_system_admin)])
116+
@router.delete("", dependencies=[Depends(get_role_min_system_admin)])
117+
@router.delete("/", dependencies=[Depends(get_role_min_system_admin)])
118118
async def delete_pricing_table(
119119
table_type: str,
120120
db: Session = Depends(get_db)
@@ -146,7 +146,7 @@ async def delete_pricing_table(
146146

147147
return {"message": f"Pricing table of type '{table_type}' deleted successfully"}
148148

149-
@router.get("/list", response_model=PricingTablesResponse, dependencies=[Depends(check_system_admin)])
149+
@router.get("/list", response_model=PricingTablesResponse, dependencies=[Depends(get_role_min_system_admin)])
150150
async def get_all_pricing_tables(
151151
db: Session = Depends(get_db)
152152
):

app/api/private_ai_keys.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
get_current_user_from_auth,
1818
get_role_min_team_admin,
1919
get_private_ai_access,
20-
check_system_admin
20+
get_role_min_system_admin
2121
)
2222
from app.core.roles import UserRole
2323
from app.core.config import settings
@@ -477,7 +477,7 @@ async def list_private_ai_keys(
477477
private_ai_keys = query.all()
478478
return [key.to_dict() for key in private_ai_keys]
479479

480-
@router.get("/{key_id}", response_model=PrivateAIKeyDetail, dependencies=[Depends(check_system_admin)])
480+
@router.get("/{key_id}", response_model=PrivateAIKeyDetail, dependencies=[Depends(get_role_min_system_admin)])
481481
async def get_private_ai_key(
482482
key_id: int,
483483
current_user = Depends(get_current_user_from_auth),

app/api/products.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55

66
from app.db.database import get_db
77
from app.db.models import DBProduct, DBTeamProduct, DBTeam
8-
from app.core.security import check_system_admin, get_current_user_from_auth, get_role_min_team_admin
8+
from app.core.security import get_role_min_system_admin, get_current_user_from_auth, get_role_min_team_admin
99
from app.schemas.models import Product, ProductCreate, ProductUpdate
1010

1111
router = APIRouter(
1212
tags=["products"]
1313
)
1414

15-
@router.post("", response_model=Product, status_code=status.HTTP_201_CREATED, dependencies=[Depends(check_system_admin)])
16-
@router.post("/", response_model=Product, status_code=status.HTTP_201_CREATED, dependencies=[Depends(check_system_admin)])
15+
@router.post("", response_model=Product, status_code=status.HTTP_201_CREATED, dependencies=[Depends(get_role_min_system_admin)])
16+
@router.post("/", response_model=Product, status_code=status.HTTP_201_CREATED, dependencies=[Depends(get_role_min_system_admin)])
1717
async def create_product(
1818
product: ProductCreate,
1919
db: Session = Depends(get_db)
@@ -105,7 +105,7 @@ async def get_product(
105105
)
106106
return product
107107

108-
@router.put("/{product_id}", response_model=Product, dependencies=[Depends(check_system_admin)])
108+
@router.put("/{product_id}", response_model=Product, dependencies=[Depends(get_role_min_system_admin)])
109109
async def update_product(
110110
product_id: str,
111111
product_update: ProductUpdate,
@@ -132,7 +132,7 @@ async def update_product(
132132

133133
return product
134134

135-
@router.delete("/{product_id}", dependencies=[Depends(check_system_admin)])
135+
@router.delete("/{product_id}", dependencies=[Depends(get_role_min_system_admin)])
136136
async def delete_product(
137137
product_id: str,
138138
db: Session = Depends(get_db)

app/api/regions.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from app.api.auth import get_current_user_from_auth
1010
from app.schemas.models import Region, RegionCreate, RegionResponse, User, RegionUpdate, TeamSummary
1111
from app.db.models import DBRegion, DBPrivateAIKey, DBTeamRegion, DBTeam
12-
from app.core.security import check_system_admin
12+
from app.core.security import get_role_min_system_admin
1313

1414
logger = logging.getLogger(__name__)
1515

@@ -89,8 +89,8 @@ async def validate_database_connection(host: str, port: int, user: str, password
8989
detail=f"Database connection validation failed: {str(e)}"
9090
)
9191

92-
@router.post("", response_model=Region, dependencies=[Depends(check_system_admin)])
93-
@router.post("/", response_model=Region, dependencies=[Depends(check_system_admin)])
92+
@router.post("", response_model=Region, dependencies=[Depends(get_role_min_system_admin)])
93+
@router.post("/", response_model=Region, dependencies=[Depends(get_role_min_system_admin)])
9494
async def create_region(
9595
region: RegionCreate,
9696
db: Session = Depends(get_db)
@@ -158,13 +158,13 @@ async def list_regions(
158158

159159
return non_dedicated_regions + team_dedicated_regions
160160

161-
@router.get("/admin", response_model=List[Region], dependencies=[Depends(check_system_admin)])
161+
@router.get("/admin", response_model=List[Region], dependencies=[Depends(get_role_min_system_admin)])
162162
async def list_admin_regions(
163163
db: Session = Depends(get_db)
164164
):
165165
return db.query(DBRegion).all()
166166

167-
@router.get("/{region_id}", response_model=RegionResponse, dependencies=[Depends(check_system_admin)])
167+
@router.get("/{region_id}", response_model=RegionResponse, dependencies=[Depends(get_role_min_system_admin)])
168168
async def get_region(
169169
region_id: int,
170170
db: Session = Depends(get_db)
@@ -177,7 +177,7 @@ async def get_region(
177177
)
178178
return region
179179

180-
@router.delete("/{region_id}", dependencies=[Depends(check_system_admin)])
180+
@router.delete("/{region_id}", dependencies=[Depends(get_role_min_system_admin)])
181181
async def delete_region(
182182
region_id: int,
183183
db: Session = Depends(get_db)
@@ -209,7 +209,7 @@ async def delete_region(
209209
)
210210
return {"message": "Region deleted successfully"}
211211

212-
@router.put("/{region_id}", response_model=Region, dependencies=[Depends(check_system_admin)])
212+
@router.put("/{region_id}", response_model=Region, dependencies=[Depends(get_role_min_system_admin)])
213213
async def update_region(
214214
region_id: int,
215215
region: RegionUpdate,
@@ -251,7 +251,7 @@ async def update_region(
251251
)
252252
return db_region
253253

254-
@router.post("/{region_id}/teams/{team_id}", dependencies=[Depends(check_system_admin)])
254+
@router.post("/{region_id}/teams/{team_id}", dependencies=[Depends(get_role_min_system_admin)])
255255
async def associate_team_with_region(
256256
region_id: int,
257257
team_id: int,
@@ -311,7 +311,7 @@ async def associate_team_with_region(
311311

312312
return {"message": "Team associated with region successfully"}
313313

314-
@router.delete("/{region_id}/teams/{team_id}", dependencies=[Depends(check_system_admin)])
314+
@router.delete("/{region_id}/teams/{team_id}", dependencies=[Depends(get_role_min_system_admin)])
315315
async def disassociate_team_from_region(
316316
region_id: int,
317317
team_id: int,
@@ -345,7 +345,7 @@ async def disassociate_team_from_region(
345345

346346
return {"message": "Team disassociated from region successfully"}
347347

348-
@router.get("/{region_id}/teams", response_model=List[TeamSummary], dependencies=[Depends(check_system_admin)])
348+
@router.get("/{region_id}/teams", response_model=List[TeamSummary], dependencies=[Depends(get_role_min_system_admin)])
349349
async def list_teams_for_region(
350350
region_id: int,
351351
db: Session = Depends(get_db)

app/api/teams.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from app.db.database import get_db
99
from app.db.models import DBTeam, DBTeamProduct, DBUser, DBPrivateAIKey, DBRegion, DBTeamRegion, DBProduct
10-
from app.core.security import check_system_admin, check_specific_team_admin, get_current_user_from_auth, check_sales_or_higher
10+
from app.core.security import get_role_min_system_admin, get_role_min_specific_team_admin, get_current_user_from_auth, check_sales_or_higher
1111
from app.schemas.models import (
1212
Team, TeamCreate, TeamUpdate,
1313
TeamWithUsers, TeamMergeRequest, TeamMergeResponse
@@ -67,8 +67,8 @@ async def register_team(
6767

6868
return db_team
6969

70-
@router.get("", response_model=List[Team], dependencies=[Depends(check_system_admin)])
71-
@router.get("/", response_model=List[Team], dependencies=[Depends(check_system_admin)])
70+
@router.get("", response_model=List[Team], dependencies=[Depends(get_role_min_system_admin)])
71+
@router.get("/", response_model=List[Team], dependencies=[Depends(get_role_min_system_admin)])
7272
async def list_teams(
7373
db: Session = Depends(get_db)
7474
):
@@ -77,7 +77,7 @@ async def list_teams(
7777
"""
7878
return db.query(DBTeam).all()
7979

80-
@router.get("/{team_id}", response_model=TeamWithUsers, dependencies=[Depends(check_specific_team_admin)])
80+
@router.get("/{team_id}", response_model=TeamWithUsers, dependencies=[Depends(get_role_min_specific_team_admin)])
8181
async def get_team(
8282
team_id: int,
8383
db: Session = Depends(get_db)
@@ -93,7 +93,7 @@ async def get_team(
9393
# Convert directly to TeamWithUsers model
9494
return TeamWithUsers.model_validate(db_team)
9595

96-
@router.put("/{team_id}", response_model=Team, dependencies=[Depends(check_specific_team_admin)])
96+
@router.put("/{team_id}", response_model=Team, dependencies=[Depends(get_role_min_specific_team_admin)])
9797
async def update_team(
9898
team_id: int,
9999
team_update: TeamUpdate,
@@ -144,7 +144,7 @@ async def update_team(
144144

145145
return db_team
146146

147-
@router.delete("/{team_id}", dependencies=[Depends(check_system_admin)])
147+
@router.delete("/{team_id}", dependencies=[Depends(get_role_min_system_admin)])
148148
async def delete_team(
149149
team_id: int,
150150
db: Session = Depends(get_db)
@@ -167,7 +167,7 @@ async def delete_team(
167167

168168
return {"message": "Team deleted successfully"}
169169

170-
@router.post("/{team_id}/extend-trial", dependencies=[Depends(check_system_admin)])
170+
@router.post("/{team_id}/extend-trial", dependencies=[Depends(get_role_min_system_admin)])
171171
async def extend_team_trial(
172172
team_id: int,
173173
db: Session = Depends(get_db)
@@ -408,7 +408,7 @@ async def _resolve_key_conflicts(
408408
else:
409409
raise ValueError(f"Unknown conflict resolution strategy: {strategy}")
410410

411-
@router.post("/{target_team_id}/merge", dependencies=[Depends(check_system_admin)])
411+
@router.post("/{target_team_id}/merge", dependencies=[Depends(get_role_min_system_admin)])
412412
async def merge_teams(
413413
target_team_id: int,
414414
merge_request: TeamMergeRequest,

app/api/users.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
from app.db.database import get_db
77
from app.schemas.models import User, UserUpdate, UserCreate, TeamOperation, UserRoleUpdate
88
from app.db.models import DBUser, DBTeam
9-
from app.core.security import get_password_hash, check_system_admin, get_current_user_from_auth, get_role_min_team_admin
9+
from app.core.security import get_password_hash, get_role_min_system_admin, get_current_user_from_auth, get_role_min_team_admin
1010
from app.core.roles import UserRole
1111
from datetime import datetime, UTC
1212

1313
router = APIRouter(
1414
tags=["users"]
1515
)
1616

17-
@router.get("/search", response_model=List[User], dependencies=[Depends(check_system_admin)])
17+
@router.get("/search", response_model=List[User], dependencies=[Depends(get_role_min_system_admin)])
1818
async def search_users(
1919
email: str,
2020
db: Session = Depends(get_db)
@@ -199,7 +199,7 @@ async def add_user_to_team(
199199
db.refresh(db_user)
200200
return db_user
201201

202-
@router.post("/{user_id}/remove-from-team", response_model=User, dependencies=[Depends(check_system_admin)])
202+
@router.post("/{user_id}/remove-from-team", response_model=User, dependencies=[Depends(get_role_min_system_admin)])
203203
async def remove_user_from_team(
204204
user_id: int,
205205
current_user: DBUser = Depends(get_current_user_from_auth),
@@ -225,7 +225,7 @@ async def remove_user_from_team(
225225
db.refresh(db_user)
226226
return db_user
227227

228-
@router.delete("/{user_id}", dependencies=[Depends(check_system_admin)])
228+
@router.delete("/{user_id}", dependencies=[Depends(get_role_min_system_admin)])
229229
async def delete_user(
230230
user_id: int,
231231
current_user: DBUser = Depends(get_current_user_from_auth),

app/core/rbac.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@ def _validate_user_type_constraints(self, user: DBUser) -> bool:
5454
effective_role = self._get_effective_role(user)
5555

5656
# System users (team_id is None) cannot have team roles
57-
if user.team_id is None and effective_role in ["admin", "key_creator", "read_only"]:
57+
if user.team_id is None and effective_role in UserRole.get_team_roles():
5858
return True
5959

6060
# Team users (team_id is not None) cannot have system roles
61-
if user.team_id is not None and effective_role in ["system_admin", "user", "sales"]:
61+
if user.team_id is not None and effective_role in UserRole.get_system_roles():
6262
return True
6363

6464
return False
@@ -76,23 +76,23 @@ def require_system_admin():
7676

7777
def require_team_admin():
7878
"""Require team admin role or system admin"""
79-
return RBACDependency([UserRole.TEAM_ADMIN, UserRole.SYSTEM_ADMIN], require_team_membership=True)
79+
return RBACDependency(UserRole.ADMIN_ROLES, require_team_membership=True)
8080

8181
def require_key_creator_or_higher():
8282
"""Require key creator role or higher (team context)"""
83-
return RBACDependency([UserRole.TEAM_ADMIN, UserRole.KEY_CREATOR, UserRole.SYSTEM_ADMIN], require_team_membership=True)
83+
return RBACDependency(UserRole.KEY_MANAGEMENT_ROLES, require_team_membership=True)
8484

8585
def require_private_ai_access():
8686
"""Require access to private AI operations - allows system users or team key creators"""
87-
return RBACDependency([UserRole.TEAM_ADMIN, UserRole.KEY_CREATOR, UserRole.SYSTEM_ADMIN, UserRole.USER], require_team_membership=False)
87+
return RBACDependency(UserRole.KEY_MANAGEMENT_ROLES + [UserRole.USER], require_team_membership=False)
8888

8989
def require_read_only_or_higher():
9090
"""Require read only role or higher (team context)"""
91-
return RBACDependency([UserRole.TEAM_ADMIN, UserRole.KEY_CREATOR, UserRole.READ_ONLY, UserRole.SYSTEM_ADMIN], require_team_membership=True)
91+
return RBACDependency(UserRole.READ_ACCESS_ROLES, require_team_membership=True)
9292

9393
def require_sales_or_higher():
9494
"""Require sales role or higher (system context)"""
95-
return RBACDependency([UserRole.SYSTEM_ADMIN, UserRole.SALES])
95+
return RBACDependency(UserRole.SYSTEM_ACCESS_ROLES)
9696

9797
def require_any_role():
9898
"""Allow any authenticated user"""

0 commit comments

Comments
 (0)