Skip to content

Commit d9f4430

Browse files
committed
Refactor Security and User Role Management
- Organized imports in private_ai_keys.py and users.py for better readability. - Updated role validation logic to utilize UserRole.get_all_roles() for consistency across user role checks. - Removed unused role definitions and hierarchy from security.py, streamlining the codebase. - Enhanced role management by centralizing role-related functionality in the roles module.
1 parent 92139d1 commit d9f4430

File tree

3 files changed

+23
-35
lines changed

3 files changed

+23
-35
lines changed

app/api/private_ai_keys.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,22 @@
1313
from app.db.postgres import PostgresManager
1414
from app.db.models import DBPrivateAIKey, DBRegion, DBUser, DBTeam
1515
from app.services.litellm import LiteLLMService
16-
from app.core.security import get_current_user_from_auth, get_role_min_key_creator, get_role_min_team_admin, get_private_ai_access, UserRole, check_system_admin
16+
from app.core.security import (
17+
get_current_user_from_auth,
18+
get_role_min_team_admin,
19+
get_private_ai_access,
20+
check_system_admin
21+
)
22+
from app.core.roles import UserRole
1723
from app.core.config import settings
18-
from app.core.resource_limits import check_key_limits, check_vector_db_limits, get_token_restrictions, DEFAULT_KEY_DURATION, DEFAULT_MAX_SPEND, DEFAULT_RPM_PER_KEY
24+
from app.core.resource_limits import (
25+
check_key_limits,
26+
check_vector_db_limits,
27+
get_token_restrictions,
28+
DEFAULT_KEY_DURATION,
29+
DEFAULT_MAX_SPEND,
30+
DEFAULT_RPM_PER_KEY
31+
)
1932

2033
router = APIRouter(
2134
tags=["private-ai-keys"]

app/api/users.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from fastapi import APIRouter, Depends, HTTPException, status
22
from sqlalchemy.orm import Session
3-
from typing import List, get_args
3+
from typing import List
44
from app.core.config import settings
55
from app.core.resource_limits import check_team_user_limit
66
from app.db.database import get_db
77
from app.schemas.models import User, UserUpdate, UserCreate, TeamOperation, UserRoleUpdate
88
from app.db.models import DBUser, DBTeam
9-
from app.core.security import get_password_hash, check_system_admin, get_current_user_from_auth, UserRole, get_role_min_team_admin
9+
from app.core.security import get_password_hash, check_system_admin, get_current_user_from_auth, get_role_min_team_admin
10+
from app.core.roles import UserRole
1011
from datetime import datetime, UTC
1112

1213
router = APIRouter(
@@ -77,10 +78,10 @@ async def create_user(
7778
check_team_user_limit(db, user.team_id)
7879

7980
# Validate role if provided
80-
if user.role and user.role not in get_args(UserRole):
81+
if user.role and user.role not in UserRole.get_all_roles():
8182
raise HTTPException(
8283
status_code=status.HTTP_400_BAD_REQUEST,
83-
detail=f"Invalid role. Must be one of: {', '.join(get_args(UserRole))}"
84+
detail=f"Invalid role. Must be one of: {', '.join(UserRole.get_all_roles())}"
8485
)
8586

8687
# Default to the lowest permissions for a user in a team
@@ -256,10 +257,10 @@ async def update_user_role(
256257
Update a user's role. Accessible by admin users or team admins for their team members.
257258
"""
258259
# Validate role
259-
if role_update.role not in get_args(UserRole):
260+
if role_update.role not in UserRole.get_all_roles():
260261
raise HTTPException(
261262
status_code=status.HTTP_400_BAD_REQUEST,
262-
detail=f"Invalid role. Must be one of: {', '.join(get_args(UserRole))}"
263+
detail=f"Invalid role. Must be one of: {', '.join(UserRole.get_all_roles())}"
263264
)
264265

265266
# Get the user to update

app/core/security.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime, timedelta, UTC
2-
from typing import Optional, Literal, Dict
2+
from typing import Optional
33
from jose import JWTError, jwt
44
from passlib.context import CryptContext
55
from fastapi import Depends, HTTPException, status, Cookie, Header, Request
@@ -9,14 +9,11 @@
99
from app.db.database import get_db
1010
from sqlalchemy.orm import Session
1111
from app.db.models import DBUser, DBAPIToken
12-
from app.core.roles import UserRole as RoleClass
1312
from app.core.rbac import (
1413
require_system_admin,
1514
require_team_admin,
1615
require_key_creator_or_higher,
17-
require_read_only_or_higher,
1816
require_sales_or_higher,
19-
require_any_role
2017
)
2118

2219
logger = logging.getLogger(__name__)
@@ -27,19 +24,6 @@
2724
# Custom bearer scheme
2825
bearer_scheme = HTTPBearer(auto_error=False)
2926

30-
# Define valid user roles as a Literal type - MUST match existing values exactly
31-
UserRole = Literal["admin", "key_creator", "read_only", "user", "system_admin", "sales"]
32-
33-
# Define a hierarchy for roles - updated to include new roles
34-
user_role_hierarchy: Dict[UserRole, int] = {
35-
"system_admin": 0,
36-
"admin": 1,
37-
"user": 2,
38-
"key_creator": 3,
39-
"read_only": 4,
40-
"sales": 5,
41-
}
42-
4327
def verify_password(plain_password: str, hashed_password: str) -> bool:
4428
"""Verify a password against its hash."""
4529
return pwd_context.verify(plain_password, hashed_password)
@@ -152,16 +136,6 @@ async def check_system_admin(current_user: DBUser = Depends(get_current_user_fro
152136
dependency = require_system_admin()
153137
return dependency.check_access(current_user)
154138

155-
def get_user_role(minimum_role: UserRole, current_user: DBUser):
156-
if current_user.is_admin:
157-
return "system_admin"
158-
elif user_role_hierarchy[current_user.role] > user_role_hierarchy[minimum_role]:
159-
raise HTTPException(
160-
status_code=status.HTTP_403_FORBIDDEN,
161-
detail="Not authorized to perform this action"
162-
)
163-
return current_user.role
164-
165139
async def get_role_min_team_admin(current_user: DBUser = Depends(get_current_user_from_auth)):
166140
"""Require team admin role or higher."""
167141
dependency = require_team_admin()

0 commit comments

Comments
 (0)