Skip to content

Commit 972d61b

Browse files
authored
Merge pull request #118 from amazeeio/dev
Add Merge Teams Functionality
2 parents 708c562 + 37fc6fe commit 972d61b

File tree

11 files changed

+1508
-97
lines changed

11 files changed

+1508
-97
lines changed

app/api/auth.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,14 @@ async def get_sign_in_data(
9797
return None
9898
return None
9999

100-
def create_and_set_access_token(response: Response, user_email: str) -> Token:
100+
def create_and_set_access_token(response: Response, user_email: str, user: Optional[DBUser] = None) -> Token:
101101
"""
102102
Create an access token for the user and set it as a cookie.
103103
104104
Args:
105105
response: The FastAPI response object to set the cookie on
106106
user_email: The email of the user to create the token for
107+
user: The user object to check if they are a system administrator
107108
108109
Returns:
109110
Token: The created access token
@@ -116,13 +117,17 @@ def create_and_set_access_token(response: Response, user_email: str) -> Token:
116117
# Get cookie domain from LAGOON_ROUTES
117118
cookie_domain = get_cookie_domain()
118119

120+
# Set cookie expiration based on user role
121+
# System administrators get 8 hours (28800 seconds), regular users get 30 minutes (1800 seconds)
122+
cookie_expiration = 28800 if user and user.is_admin else 1800
123+
119124
# Prepare cookie settings
120125
cookie_settings = {
121126
"key": "access_token",
122127
"value": access_token,
123128
"httponly": True,
124-
"max_age": 1800,
125-
"expires": 1800,
129+
"max_age": cookie_expiration,
130+
"expires": cookie_expiration,
126131
"samesite": 'none',
127132
"secure": True,
128133
"path": '/',
@@ -176,7 +181,7 @@ async def login(
176181
)
177182

178183
auth_logger.info(f"Successful login for user: {login_data.username}")
179-
return create_and_set_access_token(response, user.email)
184+
return create_and_set_access_token(response, user.email, user)
180185

181186
@router.post("/logout")
182187
async def logout(response: Response):
@@ -476,7 +481,7 @@ async def sign_in(
476481
auth_logger.info(f"Successfully created new user and team for: {sign_in_data.username}")
477482

478483
auth_logger.info(f"Successful sign-in for user: {sign_in_data.username}")
479-
return create_and_set_access_token(response, user.email)
484+
return create_and_set_access_token(response, user.email, user)
480485

481486
# API Token routes (as apposed to AI Token routes)
482487
def generate_api_token() -> str:
@@ -623,7 +628,7 @@ async def validate_jwt(
623628

624629
# Token is valid, create new access token
625630
auth_logger.info(f"Successfully validated JWT for user: {user.email}")
626-
return create_and_set_access_token(response, user.email)
631+
return create_and_set_access_token(response, user.email, user)
627632

628633
except JWTError as e:
629634
if isinstance(e, jwt.ExpiredSignatureError):

app/api/private_ai_keys.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ async def create_llm_token(
373373
email=owner_email,
374374
name=private_ai_key.name,
375375
user_id=owner_id,
376-
team_id=f"{region.name.replace(' ', '_')}_{litellm_team}",
376+
team_id=LiteLLMService.format_team_id(region.name, litellm_team),
377377
duration=f"{days_left_in_period}d",
378378
max_budget=max_max_spend,
379379
rpm_limit=max_rpm_limit

app/api/teams.py

Lines changed: 202 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
from fastapi import APIRouter, Depends, HTTPException, status
22
from sqlalchemy.orm import Session
3-
from typing import List
3+
from typing import List, Optional
44
from datetime import datetime, UTC
55
import logging
66

77
from app.db.database import get_db
8-
from app.db.models import DBTeam, DBTeamProduct, DBUser
8+
from app.db.models import DBTeam, DBTeamProduct, DBUser, DBPrivateAIKey, DBRegion
99
from app.core.security import check_system_admin, check_specific_team_admin, get_current_user_from_auth
1010
from app.schemas.models import (
1111
Team, TeamCreate, TeamUpdate,
12-
TeamWithUsers
12+
TeamWithUsers, TeamMergeRequest, TeamMergeResponse
1313
)
1414
from app.core.resource_limits import DEFAULT_KEY_DURATION, DEFAULT_MAX_SPEND, DEFAULT_RPM_PER_KEY
1515
from app.services.litellm import LiteLLMService
1616
from app.services.ses import SESService
1717
from app.core.worker import get_team_keys_by_region, generate_pricing_url, get_team_admin_email
18+
from app.api.private_ai_keys import delete_private_ai_key
19+
1820

1921
logger = logging.getLogger(__name__)
2022

@@ -201,7 +203,6 @@ async def extend_team_trial(
201203
except Exception as e:
202204
logger.error(f"Failed to update key {key.id} via LiteLLM: {str(e)}")
203205
# Continue with other keys even if one fails
204-
continue
205206

206207
# Send trial extension email
207208
try:
@@ -219,3 +220,200 @@ async def extend_team_trial(
219220
# Don't fail the request if email fails
220221

221222
return {"message": "Team trial extended successfully"}
223+
224+
def _check_key_name_conflicts(team1_keys: List[DBPrivateAIKey], team2_keys: List[DBPrivateAIKey]) -> List[str]:
225+
"""Return list of conflicting key names between two teams"""
226+
team1_names = {key.name for key in team1_keys if key.name}
227+
team2_names = {key.name for key in team2_keys if key.name}
228+
return list(team1_names.intersection(team2_names))
229+
230+
async def _resolve_key_conflicts(
231+
conflicts: List[str],
232+
strategy: str,
233+
team2_keys: List[DBPrivateAIKey],
234+
rename_suffix: str,
235+
db: Session = None,
236+
current_user = None
237+
) -> List[DBPrivateAIKey]:
238+
"""Apply conflict resolution strategy to team2 keys"""
239+
if strategy == "delete":
240+
# Remove conflicting keys from team2 and delete them from database
241+
keys_to_delete = [key for key in team2_keys if key.name in conflicts]
242+
remaining_keys = [key for key in team2_keys if key.name not in conflicts]
243+
244+
# Delete conflicting keys from database if db session provided
245+
if db and current_user:
246+
for key in keys_to_delete:
247+
try:
248+
await delete_private_ai_key(
249+
key_id=key.id,
250+
current_user=current_user,
251+
user_role="system_admin", # System admin context for merge operations
252+
db=db
253+
)
254+
except Exception as e:
255+
logger.error(f"Failed to delete key {key.id}: {str(e)}")
256+
# Continue with other keys even if one fails
257+
258+
return remaining_keys
259+
elif strategy == "rename":
260+
# Rename conflicting keys in team2
261+
suffix = rename_suffix
262+
for key in team2_keys:
263+
if key.name in conflicts:
264+
key.name = f"{key.name}{suffix}"
265+
return team2_keys
266+
elif strategy == "cancel":
267+
# Return original keys unchanged
268+
return team2_keys
269+
else:
270+
raise ValueError(f"Unknown conflict resolution strategy: {strategy}")
271+
272+
@router.post("/{target_team_id}/merge", dependencies=[Depends(check_system_admin)])
273+
async def merge_teams(
274+
target_team_id: int,
275+
merge_request: TeamMergeRequest,
276+
db: Session = Depends(get_db),
277+
current_user: DBUser = Depends(get_current_user_from_auth)
278+
):
279+
"""
280+
Merge source team into target team. Only accessible by system administrators.
281+
282+
This endpoint will:
283+
1. Validate both teams exist
284+
2. Check if source team has active product associations (fails if it does)
285+
3. Check for key name conflicts
286+
4. Apply conflict resolution strategy
287+
5. Migrate users and keys
288+
6. Update LiteLLM key associations
289+
7. Delete the source team
290+
"""
291+
try:
292+
# Validate teams exist
293+
target_team = db.query(DBTeam).filter(DBTeam.id == target_team_id).first()
294+
if not target_team:
295+
raise HTTPException(status_code=404, detail="Target team not found")
296+
297+
source_team = db.query(DBTeam).filter(DBTeam.id == merge_request.source_team_id).first()
298+
if not source_team:
299+
raise HTTPException(status_code=404, detail="Source team not found")
300+
301+
# Prevent merging a team into itself
302+
if source_team.id == target_team.id:
303+
raise HTTPException(
304+
status_code=400,
305+
detail="Cannot merge a team into itself"
306+
)
307+
308+
# Check if source team has active product associations first
309+
source_products = db.query(DBTeamProduct).filter(DBTeamProduct.team_id == source_team.id).all()
310+
if source_products:
311+
product_names = [product.product_id for product in source_products]
312+
raise HTTPException(
313+
status_code=400,
314+
detail=f"Cannot merge team '{source_team.name}' - it has active product associations: {', '.join(product_names)}. Please remove product associations before merging."
315+
)
316+
317+
# Get team keys and users (only if no product associations found)
318+
source_keys = db.query(DBPrivateAIKey).filter(DBPrivateAIKey.team_id == source_team.id).all()
319+
target_keys = db.query(DBPrivateAIKey).filter(DBPrivateAIKey.team_id == target_team.id).all()
320+
source_users = db.query(DBUser).filter(DBUser.team_id == source_team.id).all()
321+
322+
# Check for conflicts
323+
conflicts = _check_key_name_conflicts(target_keys, source_keys)
324+
325+
# Apply conflict resolution strategy
326+
if conflicts:
327+
if merge_request.conflict_resolution_strategy == "cancel":
328+
return TeamMergeResponse(
329+
success=False,
330+
message=f"Merge cancelled due to {len(conflicts)} key name conflicts",
331+
conflicts_resolved=conflicts,
332+
keys_migrated=0,
333+
users_migrated=0
334+
)
335+
336+
source_keys = await _resolve_key_conflicts(
337+
conflicts,
338+
merge_request.conflict_resolution_strategy,
339+
source_keys,
340+
merge_request.rename_suffix if merge_request.rename_suffix is not None else f"_team{source_team.id}",
341+
db,
342+
current_user
343+
)
344+
345+
# Store team names before deletion
346+
source_team_name = source_team.name
347+
target_team_name = target_team.name
348+
349+
# Migrate users from source team to target team
350+
users_migrated = 0
351+
for user in source_users:
352+
if user.team_id != target_team.id:
353+
user.team_id = target_team.id
354+
users_migrated += 1
355+
356+
# Migrate keys from source team to target team
357+
keys_migrated = 0
358+
for key in source_keys:
359+
if key.team_id != target_team.id:
360+
key.team_id = target_team.id
361+
keys_migrated += 1
362+
363+
# Flush changes to ensure they're persisted in the current transaction
364+
db.flush()
365+
366+
# Update LiteLLM key associations
367+
# Create a map of keys by region to avoid unnecessary DB queries
368+
keys_by_region = {}
369+
for key in source_keys:
370+
if key.region_id not in keys_by_region:
371+
keys_by_region[key.region_id] = []
372+
keys_by_region[key.region_id].append(key)
373+
374+
# Update LiteLLM key associations for each region
375+
for region_id, region_keys in keys_by_region.items():
376+
# Get region info
377+
region = db.query(DBRegion).filter(
378+
DBRegion.id == region_id,
379+
DBRegion.is_active == True
380+
).first()
381+
382+
# Initialize LiteLLM service for this region
383+
litellm_service = LiteLLMService(
384+
api_url=region.litellm_api_url,
385+
api_key=region.litellm_api_key
386+
)
387+
388+
# Update team association for each key in this region
389+
for key in region_keys:
390+
try:
391+
await litellm_service.update_key_team_association(
392+
key.litellm_token,
393+
LiteLLMService.format_team_id(region.name, target_team.id)
394+
)
395+
except Exception as e:
396+
logger.error(f"Failed to update LiteLLM key {key.id}: {str(e)}")
397+
398+
# Delete source team
399+
db.delete(source_team)
400+
db.commit()
401+
402+
return TeamMergeResponse(
403+
success=True,
404+
message=f"Successfully merged team '{source_team_name}' into '{target_team_name}'",
405+
conflicts_resolved=conflicts if conflicts else None,
406+
keys_migrated=keys_migrated,
407+
users_migrated=users_migrated
408+
)
409+
410+
except HTTPException:
411+
# Re-raise HTTP exceptions as-is
412+
raise
413+
except Exception as e:
414+
db.rollback()
415+
logger.error(f"Error during team merge: {str(e)}")
416+
raise HTTPException(
417+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
418+
detail=f"Team merge failed: {str(e)}"
419+
)

app/schemas/models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,18 @@ class TeamSummary(BaseModel):
274274
class TeamOperation(BaseModel):
275275
team_id: int
276276

277+
class TeamMergeRequest(BaseModel):
278+
source_team_id: int
279+
conflict_resolution_strategy: Literal["delete", "rename", "cancel"]
280+
rename_suffix: Optional[str] = None # For rename strategy
281+
282+
class TeamMergeResponse(BaseModel):
283+
success: bool
284+
message: str
285+
conflicts_resolved: Optional[List[str]] = None
286+
keys_migrated: int
287+
users_migrated: int
288+
277289
class UserRoleUpdate(BaseModel):
278290
role: str
279291
model_config = ConfigDict(from_attributes=True)

app/services/litellm.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ def __init__(self, api_url: str, api_key: str):
1717
if not self.master_key:
1818
raise ValueError("LiteLLM API key is required")
1919

20+
@staticmethod
21+
def format_team_id(region_name: str, team_id: int) -> str:
22+
"""Generate the correctly formatted team_id for LiteLLM"""
23+
return f"{region_name.replace(' ', '_')}_{team_id}"
24+
2025
async def create_key(self, email: str, name: str, user_id: int, team_id: str, duration: str = f"{DEFAULT_KEY_DURATION}d", max_budget: float = DEFAULT_MAX_SPEND, rpm_limit: int = DEFAULT_RPM_PER_KEY) -> str:
2126
"""Create a new API key for LiteLLM"""
2227
try:
@@ -228,3 +233,30 @@ async def set_key_restrictions(self, litellm_token: str, duration: str, budget_a
228233
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
229234
detail=f"Failed to set LiteLLM key restrictions: {error_msg}"
230235
)
236+
237+
async def update_key_team_association(self, litellm_token: str, new_team_id: str):
238+
"""Update the team association for a LiteLLM API key"""
239+
try:
240+
response = requests.post(
241+
f"{self.api_url}/key/update",
242+
headers={
243+
"Authorization": f"Bearer {self.master_key}"
244+
},
245+
json={
246+
"key": litellm_token,
247+
"team_id": new_team_id
248+
}
249+
)
250+
response.raise_for_status()
251+
except requests.exceptions.RequestException as e:
252+
error_msg = str(e)
253+
if hasattr(e, 'response') and e.response is not None:
254+
try:
255+
error_details = e.response.json()
256+
error_msg = f"Status {e.response.status_code}: {error_details}"
257+
except ValueError:
258+
error_msg = f"Status {e.response.status_code}: {e.response.text}"
259+
raise HTTPException(
260+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
261+
detail=f"Failed to update LiteLLM key team association: {error_msg}"
262+
)

0 commit comments

Comments
 (0)