Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions app/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,14 @@ async def get_sign_in_data(
return None
return None

def create_and_set_access_token(response: Response, user_email: str) -> Token:
def create_and_set_access_token(response: Response, user_email: str, user: Optional[DBUser] = None) -> Token:
"""
Create an access token for the user and set it as a cookie.

Args:
response: The FastAPI response object to set the cookie on
user_email: The email of the user to create the token for
user: The user object to check if they are a system administrator

Returns:
Token: The created access token
Expand All @@ -116,13 +117,17 @@ def create_and_set_access_token(response: Response, user_email: str) -> Token:
# Get cookie domain from LAGOON_ROUTES
cookie_domain = get_cookie_domain()

# Set cookie expiration based on user role
# System administrators get 8 hours (28800 seconds), regular users get 30 minutes (1800 seconds)
cookie_expiration = 28800 if user and user.is_admin else 1800

# Prepare cookie settings
cookie_settings = {
"key": "access_token",
"value": access_token,
"httponly": True,
"max_age": 1800,
"expires": 1800,
"max_age": cookie_expiration,
"expires": cookie_expiration,
"samesite": 'none',
"secure": True,
"path": '/',
Expand Down Expand Up @@ -176,7 +181,7 @@ async def login(
)

auth_logger.info(f"Successful login for user: {login_data.username}")
return create_and_set_access_token(response, user.email)
return create_and_set_access_token(response, user.email, user)

@router.post("/logout")
async def logout(response: Response):
Expand Down Expand Up @@ -476,7 +481,7 @@ async def sign_in(
auth_logger.info(f"Successfully created new user and team for: {sign_in_data.username}")

auth_logger.info(f"Successful sign-in for user: {sign_in_data.username}")
return create_and_set_access_token(response, user.email)
return create_and_set_access_token(response, user.email, user)

# API Token routes (as apposed to AI Token routes)
def generate_api_token() -> str:
Expand Down Expand Up @@ -623,7 +628,7 @@ async def validate_jwt(

# Token is valid, create new access token
auth_logger.info(f"Successfully validated JWT for user: {user.email}")
return create_and_set_access_token(response, user.email)
return create_and_set_access_token(response, user.email, user)

except JWTError as e:
if isinstance(e, jwt.ExpiredSignatureError):
Expand Down
2 changes: 1 addition & 1 deletion app/api/private_ai_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ async def create_llm_token(
email=owner_email,
name=private_ai_key.name,
user_id=owner_id,
team_id=f"{region.name.replace(' ', '_')}_{litellm_team}",
team_id=LiteLLMService.format_team_id(region.name, litellm_team),
duration=f"{days_left_in_period}d",
max_budget=max_max_spend,
rpm_limit=max_rpm_limit
Expand Down
206 changes: 202 additions & 4 deletions app/api/teams.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from typing import List
from typing import List, Optional
from datetime import datetime, UTC
import logging

from app.db.database import get_db
from app.db.models import DBTeam, DBTeamProduct, DBUser
from app.db.models import DBTeam, DBTeamProduct, DBUser, DBPrivateAIKey, DBRegion
from app.core.security import check_system_admin, check_specific_team_admin, get_current_user_from_auth
from app.schemas.models import (
Team, TeamCreate, TeamUpdate,
TeamWithUsers
TeamWithUsers, TeamMergeRequest, TeamMergeResponse
)
from app.core.resource_limits import DEFAULT_KEY_DURATION, DEFAULT_MAX_SPEND, DEFAULT_RPM_PER_KEY
from app.services.litellm import LiteLLMService
from app.services.ses import SESService
from app.core.worker import get_team_keys_by_region, generate_pricing_url, get_team_admin_email
from app.api.private_ai_keys import delete_private_ai_key


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -201,7 +203,6 @@ async def extend_team_trial(
except Exception as e:
logger.error(f"Failed to update key {key.id} via LiteLLM: {str(e)}")
# Continue with other keys even if one fails
continue

# Send trial extension email
try:
Expand All @@ -219,3 +220,200 @@ async def extend_team_trial(
# Don't fail the request if email fails

return {"message": "Team trial extended successfully"}

def _check_key_name_conflicts(team1_keys: List[DBPrivateAIKey], team2_keys: List[DBPrivateAIKey]) -> List[str]:
"""Return list of conflicting key names between two teams"""
team1_names = {key.name for key in team1_keys if key.name}
team2_names = {key.name for key in team2_keys if key.name}
return list(team1_names.intersection(team2_names))

async def _resolve_key_conflicts(
conflicts: List[str],
strategy: str,
team2_keys: List[DBPrivateAIKey],
rename_suffix: str,
db: Session = None,
current_user = None
) -> List[DBPrivateAIKey]:
"""Apply conflict resolution strategy to team2 keys"""
if strategy == "delete":
# Remove conflicting keys from team2 and delete them from database
keys_to_delete = [key for key in team2_keys if key.name in conflicts]
remaining_keys = [key for key in team2_keys if key.name not in conflicts]

# Delete conflicting keys from database if db session provided
if db and current_user:
for key in keys_to_delete:
try:
await delete_private_ai_key(
key_id=key.id,
current_user=current_user,
user_role="system_admin", # System admin context for merge operations
db=db
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any scenario where these get deleted but then for whatever reason the save/commit call fails when we actually do the migration and then the keys are gone?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is, but I don't have a good way to rollback in that case since this delete is fully destructive including tearing down vector DBs if they are part of the key. Given that the keys being deleted is the expected behaviour I'm OK with the risk.

except Exception as e:
logger.error(f"Failed to delete key {key.id}: {str(e)}")
# Continue with other keys even if one fails

return remaining_keys
elif strategy == "rename":
# Rename conflicting keys in team2
suffix = rename_suffix
for key in team2_keys:
if key.name in conflicts:
key.name = f"{key.name}{suffix}"
return team2_keys
elif strategy == "cancel":
# Return original keys unchanged
return team2_keys
else:
raise ValueError(f"Unknown conflict resolution strategy: {strategy}")

@router.post("/{target_team_id}/merge", dependencies=[Depends(check_system_admin)])
async def merge_teams(
target_team_id: int,
merge_request: TeamMergeRequest,
db: Session = Depends(get_db),
current_user: DBUser = Depends(get_current_user_from_auth)
):
"""
Merge source team into target team. Only accessible by system administrators.

This endpoint will:
1. Validate both teams exist
2. Check if source team has active product associations (fails if it does)
3. Check for key name conflicts
4. Apply conflict resolution strategy
5. Migrate users and keys
6. Update LiteLLM key associations
7. Delete the source team
"""
try:
# Validate teams exist
target_team = db.query(DBTeam).filter(DBTeam.id == target_team_id).first()
if not target_team:
raise HTTPException(status_code=404, detail="Target team not found")

source_team = db.query(DBTeam).filter(DBTeam.id == merge_request.source_team_id).first()
if not source_team:
raise HTTPException(status_code=404, detail="Source team not found")

# Prevent merging a team into itself
if source_team.id == target_team.id:
raise HTTPException(
status_code=400,
detail="Cannot merge a team into itself"
)

# Check if source team has active product associations first
source_products = db.query(DBTeamProduct).filter(DBTeamProduct.team_id == source_team.id).all()
if source_products:
product_names = [product.product_id for product in source_products]
raise HTTPException(
status_code=400,
detail=f"Cannot merge team '{source_team.name}' - it has active product associations: {', '.join(product_names)}. Please remove product associations before merging."
)

# Get team keys and users (only if no product associations found)
source_keys = db.query(DBPrivateAIKey).filter(DBPrivateAIKey.team_id == source_team.id).all()
target_keys = db.query(DBPrivateAIKey).filter(DBPrivateAIKey.team_id == target_team.id).all()
source_users = db.query(DBUser).filter(DBUser.team_id == source_team.id).all()

# Check for conflicts
conflicts = _check_key_name_conflicts(target_keys, source_keys)

# Apply conflict resolution strategy
if conflicts:
if merge_request.conflict_resolution_strategy == "cancel":
return TeamMergeResponse(
success=False,
message=f"Merge cancelled due to {len(conflicts)} key name conflicts",
conflicts_resolved=conflicts,
keys_migrated=0,
users_migrated=0
)

source_keys = await _resolve_key_conflicts(
conflicts,
merge_request.conflict_resolution_strategy,
source_keys,
merge_request.rename_suffix if merge_request.rename_suffix is not None else f"_team{source_team.id}",
db,
current_user
)

# Store team names before deletion
source_team_name = source_team.name
target_team_name = target_team.name

# Migrate users from source team to target team
users_migrated = 0
for user in source_users:
if user.team_id != target_team.id:
user.team_id = target_team.id
users_migrated += 1

# Migrate keys from source team to target team
keys_migrated = 0
for key in source_keys:
if key.team_id != target_team.id:
key.team_id = target_team.id
keys_migrated += 1

# Flush changes to ensure they're persisted in the current transaction
db.flush()

# Update LiteLLM key associations
# Create a map of keys by region to avoid unnecessary DB queries
keys_by_region = {}
for key in source_keys:
if key.region_id not in keys_by_region:
keys_by_region[key.region_id] = []
keys_by_region[key.region_id].append(key)

# Update LiteLLM key associations for each region
for region_id, region_keys in keys_by_region.items():
# Get region info
region = db.query(DBRegion).filter(
DBRegion.id == region_id,
DBRegion.is_active == True
).first()

# Initialize LiteLLM service for this region
litellm_service = LiteLLMService(
api_url=region.litellm_api_url,
api_key=region.litellm_api_key
)

# Update team association for each key in this region
for key in region_keys:
try:
await litellm_service.update_key_team_association(
key.litellm_token,
LiteLLMService.format_team_id(region.name, target_team.id)
)
except Exception as e:
logger.error(f"Failed to update LiteLLM key {key.id}: {str(e)}")

# Delete source team
db.delete(source_team)
db.commit()

return TeamMergeResponse(
success=True,
message=f"Successfully merged team '{source_team_name}' into '{target_team_name}'",
conflicts_resolved=conflicts if conflicts else None,
keys_migrated=keys_migrated,
users_migrated=users_migrated
)

except HTTPException:
# Re-raise HTTP exceptions as-is
raise
except Exception as e:
db.rollback()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this actually work if commits have been made in the _migrate_keys_to_team and _migrate_users_to_team functions?

logger.error(f"Error during team merge: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Team merge failed: {str(e)}"
)
12 changes: 12 additions & 0 deletions app/schemas/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,18 @@ class TeamSummary(BaseModel):
class TeamOperation(BaseModel):
team_id: int

class TeamMergeRequest(BaseModel):
source_team_id: int
conflict_resolution_strategy: Literal["delete", "rename", "cancel"]
rename_suffix: Optional[str] = None # For rename strategy

class TeamMergeResponse(BaseModel):
success: bool
message: str
conflicts_resolved: Optional[List[str]] = None
keys_migrated: int
users_migrated: int

class UserRoleUpdate(BaseModel):
role: str
model_config = ConfigDict(from_attributes=True)
Expand Down
32 changes: 32 additions & 0 deletions app/services/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ def __init__(self, api_url: str, api_key: str):
if not self.master_key:
raise ValueError("LiteLLM API key is required")

@staticmethod
def format_team_id(region_name: str, team_id: int) -> str:
"""Generate the correctly formatted team_id for LiteLLM"""
return f"{region_name.replace(' ', '_')}_{team_id}"

async def create_key(self, email: str, name: str, user_id: int, team_id: str, duration: str = f"{DEFAULT_KEY_DURATION}d", max_budget: float = DEFAULT_MAX_SPEND, rpm_limit: int = DEFAULT_RPM_PER_KEY) -> str:
"""Create a new API key for LiteLLM"""
try:
Expand Down Expand Up @@ -228,3 +233,30 @@ async def set_key_restrictions(self, litellm_token: str, duration: str, budget_a
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to set LiteLLM key restrictions: {error_msg}"
)

async def update_key_team_association(self, litellm_token: str, new_team_id: str):
"""Update the team association for a LiteLLM API key"""
try:
response = requests.post(
f"{self.api_url}/key/update",
headers={
"Authorization": f"Bearer {self.master_key}"
},
json={
"key": litellm_token,
"team_id": new_team_id
}
)
response.raise_for_status()
except requests.exceptions.RequestException as e:
error_msg = str(e)
if hasattr(e, 'response') and e.response is not None:
try:
error_details = e.response.json()
error_msg = f"Status {e.response.status_code}: {error_details}"
except ValueError:
error_msg = f"Status {e.response.status_code}: {e.response.text}"
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update LiteLLM key team association: {error_msg}"
)
Loading