Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 120 additions & 9 deletions app/core/worker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from datetime import datetime, UTC, timedelta
from sqlalchemy.orm import Session
from app.db.models import DBTeam, DBProduct, DBTeamProduct, DBPrivateAIKey, DBUser, DBRegion
Expand All @@ -17,7 +18,7 @@
INVOICE_SUCCESS_EVENTS
)
from prometheus_client import Gauge, Counter, Summary
from typing import Dict, List
from typing import Dict, List, Optional
from app.core.security import create_access_token
from app.core.config import settings
from urllib.parse import urljoin
Expand Down Expand Up @@ -95,6 +96,8 @@ def get_team_keys_by_region(db: Session, team_id: int) -> Dict[DBRegion, List[DB
logger.info(f"Found {len(team_keys)} keys in {len(keys_by_region)} regions for team {team_id}")
return keys_by_region



async def handle_stripe_event_background(event, db: Session):
"""
Background task to handle Stripe webhook events.
Expand Down Expand Up @@ -247,19 +250,28 @@ async def remove_product_from_team(db: Session, customer_id: str, product_id: st
logger.error(f"Error removing product from team: {str(e)}")
raise e

async def monitor_team_keys(team: DBTeam, keys_by_region: Dict[DBRegion, List[DBPrivateAIKey]], expire_keys: bool) -> float:
async def monitor_team_keys(
team: DBTeam,
keys_by_region: Dict[DBRegion, List[DBPrivateAIKey]],
expire_keys: bool,
renewal_period_days: Optional[int] = None,
max_budget_amount: Optional[float] = None
) -> float:
"""
Monitor spend for all keys in a team across different regions.
Monitor spend for all keys in a team across different regions and optionally update keys after renewal period.

Args:
db: Database session
team: The team to monitor keys for
keys_by_region: Dictionary mapping regions to lists of keys
expire_keys: Whether to expire keys (set duration to 0)
renewal_period_days: Optional renewal period in days. If provided, will check for and update keys renewed within the last hour.
max_budget_amount: Optional maximum budget amount. If provided, will update the budget amount for the keys.

Returns:
float: Total spend across all keys for the team
"""
team_total = 0
current_time = datetime.now(UTC)

# Monitor keys for each region
for region, keys in keys_by_region.items():
Expand All @@ -281,6 +293,86 @@ async def monitor_team_keys(team: DBTeam, keys_by_region: Dict[DBRegion, List[DB
budget = info.get("max_budget", 0) or 0.0
key_alias = info.get("key_alias", f"key-{key.id}") # Fallback to key-{id} if no alias

# Check for renewal period update if renewal_period_days is provided
if renewal_period_days is not None:
# Check current values and only update if they don't match the parameters
current_budget_duration = info.get("budget_duration")
current_max_budget = info.get("max_budget")

needs_update = False
data = {"litellm_token": key.litellm_token}

# Check if budget_duration needs updating
expected_budget_duration = f"{renewal_period_days}d"
if current_budget_duration != expected_budget_duration:
data["budget_duration"] = expected_budget_duration
needs_update = True
logger.info(f"Key {key.id} budget_duration will be updated from '{current_budget_duration}' to '{expected_budget_duration}'")

# Check if budget_amount needs updating
if max_budget_amount is not None and current_max_budget != max_budget_amount:
data["budget_amount"] = max_budget_amount
needs_update = True
logger.info(f"Key {key.id} budget_amount will be updated from {current_max_budget} to {max_budget_amount}")

# Only check reset timestamp if we need to update
if needs_update:
budget_reset_at_str = info.get("budget_reset_at")
if budget_reset_at_str:
try:
# Parse the budget_reset_at timestamp
budget_reset_at = datetime.fromisoformat(budget_reset_at_str.replace('Z', '+00:00'))
if budget_reset_at.tzinfo is None:
budget_reset_at = budget_reset_at.replace(tzinfo=UTC)
logger.info(f"Key {key.id} budget_reset_at_str: {budget_reset_at_str}, budget_reset_at: {budget_reset_at}")

# Check if budget was reset recently using heuristics
# budget_reset_at represents when the next reset will occur
current_spend = info.get("spend", 0) or 0.0
current_budget_duration = info.get("budget_duration")

should_update = False
update_reason = ""

# Heuristic 1: Check if (now + current_budget_duration) is within an hour of budget_reset_at
if current_budget_duration is not None:
try:
# Parse current budget duration (e.g., "30d" -> 30 days)
duration_match = re.match(r'(\d+)d', current_budget_duration)
if duration_match:
duration_days = int(duration_match.group(1))
expected_reset_time = current_time + timedelta(days=duration_days)
hours_diff = abs((expected_reset_time - budget_reset_at).total_seconds() / 3600)

if hours_diff <= 1.0:
should_update = True
update_reason = f"reset time alignment (within {hours_diff:.2f} hours)"
except (ValueError, AttributeError):
logger.warning(f"Key {key.id} has invalid budget_duration format: {current_budget_duration}")
else:
logger.debug(f"Key {key.id} has no budget_duration set, skipping reset time alignment heuristic")
should_update = True
update_reason = "no budget_duration set, forcing update"

# Heuristic 2: Update if amount spent is $0.00 (indicating fresh reset)
if current_spend == 0.0:
should_update = True
update_reason = "zero spend (fresh reset)"

if should_update:
logger.info(f"Key {key.id} budget update triggered: {update_reason}, updating budget settings")
await litellm_service.update_budget(**data)
logger.info(f"Updated key {key.id} budget settings")
else:
logger.debug(f"Key {key.id} budget update not triggered, skipping update")
except ValueError:
logger.warning(f"Key {key.id} has invalid budget_reset_at timestamp: {budget_reset_at_str}")
else:
logger.warning(f"Key {key.id} has no budget_reset_at timestamp, forcing update")
await litellm_service.update_budget(**data)
else:
logger.info(f"Key {key.id} budget settings already match the expected values, no update needed")

# Set the key duration to 0 days to end its usability.
if expire_keys:
await litellm_service.update_key_duration(key.litellm_token, "0d")
Expand Down Expand Up @@ -332,15 +424,18 @@ async def monitor_teams(db: Session):
"""
logger.info("Monitoring teams")
try:
# Initialize SES service
ses_service = SESService()

# Get all teams
teams = db.query(DBTeam).all()
current_time = datetime.now(UTC)

# Track current active team labels
current_team_labels = set()
try:
# Initialize SES service
ses_service = SESService()
except Exception as e:
logger.error(f"Error initializing SES service: {str(e)}")
pass

logger.info(f"Found {len(teams)} teams to track")
for team in teams:
Expand Down Expand Up @@ -436,8 +531,24 @@ async def monitor_teams(db: Session):
if not has_products and days_remaining <= 0 and should_send_notifications:
expire_keys = True

# Monitor keys and get total spend
team_total = await monitor_team_keys(team, keys_by_region, expire_keys)
# Determine if we should check for renewal period updates
renewal_period_days = None
max_budget_amount = None
if has_products and team.last_payment:
# Get the product with the longest renewal period
active_products = db.query(DBTeamProduct).filter(
DBTeamProduct.team_id == team.id
).all()
product_ids = [tp.product_id for tp in active_products]
products = db.query(DBProduct).filter(DBProduct.id.in_(product_ids)).all()

if products:
max_renewal_product = max(products, key=lambda product: product.renewal_period_days)
renewal_period_days = max_renewal_product.renewal_period_days
max_budget_amount = max(products, key=lambda product: product.max_budget_per_key).max_budget_per_key

# Monitor keys and get total spend (includes renewal period updates if applicable)
team_total = await monitor_team_keys(team, keys_by_region, expire_keys, renewal_period_days, max_budget_amount)

# Set the total spend metric for the team (always emit metrics)
team_total_spend.labels(
Expand Down
92 changes: 35 additions & 57 deletions frontend/src/app/admin/private-ai-keys/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,16 @@ interface User {
created_at: string;
}



export default function PrivateAIKeysPage() {
const { toast } = useToast();
const queryClient = useQueryClient();
const [searchTerm, setSearchTerm] = useState('');
const [isUserSearchOpen, setIsUserSearchOpen] = useState(false);
const [selectedUser, setSelectedUser] = useState<User | null>(null);
const [loadedSpendKeys, setLoadedSpendKeys] = useState<Set<number>>(new Set());
const [isCreateDialogOpen, setIsCreateDialogOpen] = useState(false);
const [selectedUser, setSelectedUser] = useState<User | null>(null);
const [isUserSearchOpen, setIsUserSearchOpen] = useState(false);
const [searchTerm, setSearchTerm] = useState('');

const debouncedSearchTerm = useDebounce(searchTerm, 300);

// Queries
const { data: privateAIKeys = [], isLoading: isLoadingPrivateAIKeys } = useQuery<PrivateAIKey[]>({
queryKey: ['private-ai-keys', selectedUser?.id],
queryFn: async () => {
Expand All @@ -56,15 +53,24 @@ export default function PrivateAIKeysPage() {
const response = await get(url);
const data = await response.json();
return data;
},
refetchInterval: 30000, // Refetch every 30 seconds to detect new keys
refetchIntervalInBackground: true, // Continue polling even when tab is not active
}
});

// Use shared hook for data fetching
const { teamDetails, teamMembers, spendMap, regions } = usePrivateAIKeysData(privateAIKeys, loadedSpendKeys);
// Use shared hook for data fetching (only for team details and regions)
const { teamDetails, teamMembers, regions } = usePrivateAIKeysData(privateAIKeys, new Set());

// Query to get all users for displaying emails
// Search users
const { data: users = [], isLoading: isSearching } = useQuery<User[]>({
queryKey: ['users', debouncedSearchTerm],
queryFn: async () => {
if (!debouncedSearchTerm) return [];
const response = await get(`/users?search=${debouncedSearchTerm}`);
return response.json();
},
enabled: debouncedSearchTerm.length > 0,
});

// Get all users for the dropdown
const { data: usersMap = {} } = useQuery<Record<number, User>>({
queryKey: ['users-map'],
queryFn: async () => {
Expand All @@ -77,50 +83,22 @@ export default function PrivateAIKeysPage() {
},
});

const { data: users = [], isLoading: isLoadingUsers, isFetching: isFetchingUsers } = useQuery<User[], Error, User[]>({
queryKey: ['users', debouncedSearchTerm],
queryFn: async () => {
if (!debouncedSearchTerm) return [];
await new Promise(resolve => setTimeout(resolve, 100)); // Small delay to ensure loading state shows
const response = await get(`/users/search?email=${encodeURIComponent(debouncedSearchTerm)}`);
const data = await response.json();
return data;
},
enabled: isUserSearchOpen && !!debouncedSearchTerm,
gcTime: 60000,
staleTime: 30000,
refetchOnMount: false,
refetchOnWindowFocus: false,
});

// Show loading state immediately when search term changes
const isSearching = searchTerm.length > 0 && (
isLoadingUsers ||
isFetchingUsers ||
debouncedSearchTerm !== searchTerm
);

const handleSearchChange = (value: string) => {
setSearchTerm(value);
// Prefetch the query if we have a value
if (value) {
queryClient.prefetchQuery({
queryKey: ['users', value],
queryFn: async () => {
const response = await get(`/users/search?email=${encodeURIComponent(value)}`);
const data = await response.json();
return data;
},
});
}
};

// Mutations
// Create key mutation
const createKeyMutation = useMutation({
mutationFn: async (data: { name: string; region_id: number; owner_id?: number; team_id?: number; key_type: 'full' | 'llm' | 'vector' }) => {
const endpoint = data.key_type === 'full' ? 'private-ai-keys' :
data.key_type === 'llm' ? 'private-ai-keys/token' :
'private-ai-keys/vector-db';
mutationFn: async (data: {
name: string
region_id: number
key_type: 'full' | 'llm' | 'vector'
owner_id?: number
team_id?: number
}) => {
const endpoint = data.key_type === 'full' ? '/private-ai-keys' :
data.key_type === 'llm' ? '/private-ai-keys/token' :
'/private-ai-keys/vector-db';
const response = await post(endpoint, data);
return response.json();
},
Expand All @@ -142,9 +120,11 @@ export default function PrivateAIKeysPage() {
},
});

// Delete key mutation
const deletePrivateAIKeyMutation = useMutation({
mutationFn: async (keyId: number) => {
await del(`/private-ai-keys/${keyId}`);
const response = await del(`/private-ai-keys/${keyId}`);
return response.json();
},
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['private-ai-keys'] });
Expand Down Expand Up @@ -172,8 +152,8 @@ export default function PrivateAIKeysPage() {
return response.json();
},
onSuccess: (data, variables) => {
// Update the spend information for this specific key
queryClient.setQueryData(['private-ai-key-spend', variables.keyId], data);
// Invalidate the specific key's spend query to refresh the data
queryClient.invalidateQueries({ queryKey: ['private-ai-key-spend', variables.keyId] });
toast({
title: 'Success',
description: 'Budget period updated successfully',
Expand Down Expand Up @@ -284,8 +264,6 @@ export default function PrivateAIKeysPage() {
isDeleting={deletePrivateAIKeyMutation.isPending}
allowModification={true}
showOwner={true}
spendMap={spendMap}
onLoadSpend={(keyId) => setLoadedSpendKeys(prev => new Set([...prev, keyId]))}
onUpdateBudget={(keyId, budgetDuration) => {
updateBudgetPeriodMutation.mutate({ keyId, budgetDuration });
}}
Expand Down
20 changes: 5 additions & 15 deletions frontend/src/app/private-ai-keys/page.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
'use client';

import { useState, useCallback } from 'react';
import { useState } from 'react';
import { useMutation, useQueryClient, useQuery } from '@tanstack/react-query';
import { Card, CardContent } from '@/components/ui/card';
import { useToast } from '@/hooks/use-toast';
Expand Down Expand Up @@ -32,7 +32,6 @@ export default function DashboardPage() {
const queryClient = useQueryClient();
const { user } = useAuth();
const [isCreateDialogOpen, setIsCreateDialogOpen] = useState(false);
const [loadedSpendKeys, setLoadedSpendKeys] = useState<Set<number>>(new Set());

// Fetch private AI keys using React Query
const { data: privateAIKeys = [] } = useQuery<PrivateAIKey[]>({
Expand All @@ -48,13 +47,8 @@ export default function DashboardPage() {
refetchOnWindowFocus: false, // Prevent unnecessary refetches
});

// Use shared hook for data fetching
const { teamDetails, teamMembers, spendMap, regions } = usePrivateAIKeysData(privateAIKeys, loadedSpendKeys);

// Load spend for a key
const loadSpend = useCallback(async (keyId: number) => {
setLoadedSpendKeys(prev => new Set([...prev, keyId]));
}, []);
// Use shared hook for data fetching (only for team details and regions)
const { teamDetails, teamMembers, regions } = usePrivateAIKeysData(privateAIKeys, new Set());

// Update budget period mutation
const updateBudgetMutation = useMutation({
Expand All @@ -63,10 +57,8 @@ export default function DashboardPage() {
return response.json();
},
onSuccess: (_, { keyId }) => {
// Refresh spend data
setLoadedSpendKeys(prev => new Set([...prev, keyId]));
// Invalidate spend queries
queryClient.invalidateQueries({ queryKey: ['private-ai-keys-spend'] });
// Invalidate the specific key's spend query to refresh the data
queryClient.invalidateQueries({ queryKey: ['private-ai-key-spend', keyId] });
toast({
title: 'Success',
description: 'Budget period updated successfully',
Expand Down Expand Up @@ -174,8 +166,6 @@ export default function DashboardPage() {
isLoading={createKeyMutation.isPending}
showOwner={true}
allowModification={false}
spendMap={spendMap}
onLoadSpend={loadSpend}
onUpdateBudget={(keyId, budgetDuration) => updateBudgetMutation.mutate({ keyId, budgetDuration })}
isDeleting={deleteKeyMutation.isPending}
isUpdatingBudget={updateBudgetMutation.isPending}
Expand Down
Loading