Skip to content

Commit cdb7577

Browse files
authored
Merge pull request #112 from amazeeio/dev
Budget Bug Update
2 parents a218b18 + c843d9c commit cdb7577

File tree

8 files changed

+932
-224
lines changed

8 files changed

+932
-224
lines changed

app/core/worker.py

Lines changed: 120 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from datetime import datetime, UTC, timedelta
23
from sqlalchemy.orm import Session
34
from app.db.models import DBTeam, DBProduct, DBTeamProduct, DBPrivateAIKey, DBUser, DBRegion
@@ -17,7 +18,7 @@
1718
INVOICE_SUCCESS_EVENTS
1819
)
1920
from prometheus_client import Gauge, Counter, Summary
20-
from typing import Dict, List
21+
from typing import Dict, List, Optional
2122
from app.core.security import create_access_token
2223
from app.core.config import settings
2324
from urllib.parse import urljoin
@@ -95,6 +96,8 @@ def get_team_keys_by_region(db: Session, team_id: int) -> Dict[DBRegion, List[DB
9596
logger.info(f"Found {len(team_keys)} keys in {len(keys_by_region)} regions for team {team_id}")
9697
return keys_by_region
9798

99+
100+
98101
async def handle_stripe_event_background(event, db: Session):
99102
"""
100103
Background task to handle Stripe webhook events.
@@ -247,19 +250,28 @@ async def remove_product_from_team(db: Session, customer_id: str, product_id: st
247250
logger.error(f"Error removing product from team: {str(e)}")
248251
raise e
249252

250-
async def monitor_team_keys(team: DBTeam, keys_by_region: Dict[DBRegion, List[DBPrivateAIKey]], expire_keys: bool) -> float:
253+
async def monitor_team_keys(
254+
team: DBTeam,
255+
keys_by_region: Dict[DBRegion, List[DBPrivateAIKey]],
256+
expire_keys: bool,
257+
renewal_period_days: Optional[int] = None,
258+
max_budget_amount: Optional[float] = None
259+
) -> float:
251260
"""
252-
Monitor spend for all keys in a team across different regions.
261+
Monitor spend for all keys in a team across different regions and optionally update keys after renewal period.
253262
254263
Args:
255-
db: Database session
256264
team: The team to monitor keys for
257265
keys_by_region: Dictionary mapping regions to lists of keys
266+
expire_keys: Whether to expire keys (set duration to 0)
267+
renewal_period_days: Optional renewal period in days. If provided, will check for and update keys renewed within the last hour.
268+
max_budget_amount: Optional maximum budget amount. If provided, will update the budget amount for the keys.
258269
259270
Returns:
260271
float: Total spend across all keys for the team
261272
"""
262273
team_total = 0
274+
current_time = datetime.now(UTC)
263275

264276
# Monitor keys for each region
265277
for region, keys in keys_by_region.items():
@@ -281,6 +293,86 @@ async def monitor_team_keys(team: DBTeam, keys_by_region: Dict[DBRegion, List[DB
281293
budget = info.get("max_budget", 0) or 0.0
282294
key_alias = info.get("key_alias", f"key-{key.id}") # Fallback to key-{id} if no alias
283295

296+
# Check for renewal period update if renewal_period_days is provided
297+
if renewal_period_days is not None:
298+
# Check current values and only update if they don't match the parameters
299+
current_budget_duration = info.get("budget_duration")
300+
current_max_budget = info.get("max_budget")
301+
302+
needs_update = False
303+
data = {"litellm_token": key.litellm_token}
304+
305+
# Check if budget_duration needs updating
306+
expected_budget_duration = f"{renewal_period_days}d"
307+
if current_budget_duration != expected_budget_duration:
308+
data["budget_duration"] = expected_budget_duration
309+
needs_update = True
310+
logger.info(f"Key {key.id} budget_duration will be updated from '{current_budget_duration}' to '{expected_budget_duration}'")
311+
312+
# Check if budget_amount needs updating
313+
if max_budget_amount is not None and current_max_budget != max_budget_amount:
314+
data["budget_amount"] = max_budget_amount
315+
needs_update = True
316+
logger.info(f"Key {key.id} budget_amount will be updated from {current_max_budget} to {max_budget_amount}")
317+
318+
# Only check reset timestamp if we need to update
319+
if needs_update:
320+
budget_reset_at_str = info.get("budget_reset_at")
321+
if budget_reset_at_str:
322+
try:
323+
# Parse the budget_reset_at timestamp
324+
budget_reset_at = datetime.fromisoformat(budget_reset_at_str.replace('Z', '+00:00'))
325+
if budget_reset_at.tzinfo is None:
326+
budget_reset_at = budget_reset_at.replace(tzinfo=UTC)
327+
logger.info(f"Key {key.id} budget_reset_at_str: {budget_reset_at_str}, budget_reset_at: {budget_reset_at}")
328+
329+
# Check if budget was reset recently using heuristics
330+
# budget_reset_at represents when the next reset will occur
331+
current_spend = info.get("spend", 0) or 0.0
332+
current_budget_duration = info.get("budget_duration")
333+
334+
should_update = False
335+
update_reason = ""
336+
337+
# Heuristic 1: Check if (now + current_budget_duration) is within an hour of budget_reset_at
338+
if current_budget_duration is not None:
339+
try:
340+
# Parse current budget duration (e.g., "30d" -> 30 days)
341+
duration_match = re.match(r'(\d+)d', current_budget_duration)
342+
if duration_match:
343+
duration_days = int(duration_match.group(1))
344+
expected_reset_time = current_time + timedelta(days=duration_days)
345+
hours_diff = abs((expected_reset_time - budget_reset_at).total_seconds() / 3600)
346+
347+
if hours_diff <= 1.0:
348+
should_update = True
349+
update_reason = f"reset time alignment (within {hours_diff:.2f} hours)"
350+
except (ValueError, AttributeError):
351+
logger.warning(f"Key {key.id} has invalid budget_duration format: {current_budget_duration}")
352+
else:
353+
logger.debug(f"Key {key.id} has no budget_duration set, skipping reset time alignment heuristic")
354+
should_update = True
355+
update_reason = "no budget_duration set, forcing update"
356+
357+
# Heuristic 2: Update if amount spent is $0.00 (indicating fresh reset)
358+
if current_spend == 0.0:
359+
should_update = True
360+
update_reason = "zero spend (fresh reset)"
361+
362+
if should_update:
363+
logger.info(f"Key {key.id} budget update triggered: {update_reason}, updating budget settings")
364+
await litellm_service.update_budget(**data)
365+
logger.info(f"Updated key {key.id} budget settings")
366+
else:
367+
logger.debug(f"Key {key.id} budget update not triggered, skipping update")
368+
except ValueError:
369+
logger.warning(f"Key {key.id} has invalid budget_reset_at timestamp: {budget_reset_at_str}")
370+
else:
371+
logger.warning(f"Key {key.id} has no budget_reset_at timestamp, forcing update")
372+
await litellm_service.update_budget(**data)
373+
else:
374+
logger.info(f"Key {key.id} budget settings already match the expected values, no update needed")
375+
284376
# Set the key duration to 0 days to end its usability.
285377
if expire_keys:
286378
await litellm_service.update_key_duration(key.litellm_token, "0d")
@@ -332,15 +424,18 @@ async def monitor_teams(db: Session):
332424
"""
333425
logger.info("Monitoring teams")
334426
try:
335-
# Initialize SES service
336-
ses_service = SESService()
337-
338427
# Get all teams
339428
teams = db.query(DBTeam).all()
340429
current_time = datetime.now(UTC)
341430

342431
# Track current active team labels
343432
current_team_labels = set()
433+
try:
434+
# Initialize SES service
435+
ses_service = SESService()
436+
except Exception as e:
437+
logger.error(f"Error initializing SES service: {str(e)}")
438+
pass
344439

345440
logger.info(f"Found {len(teams)} teams to track")
346441
for team in teams:
@@ -436,8 +531,24 @@ async def monitor_teams(db: Session):
436531
if not has_products and days_remaining <= 0 and should_send_notifications:
437532
expire_keys = True
438533

439-
# Monitor keys and get total spend
440-
team_total = await monitor_team_keys(team, keys_by_region, expire_keys)
534+
# Determine if we should check for renewal period updates
535+
renewal_period_days = None
536+
max_budget_amount = None
537+
if has_products and team.last_payment:
538+
# Get the product with the longest renewal period
539+
active_products = db.query(DBTeamProduct).filter(
540+
DBTeamProduct.team_id == team.id
541+
).all()
542+
product_ids = [tp.product_id for tp in active_products]
543+
products = db.query(DBProduct).filter(DBProduct.id.in_(product_ids)).all()
544+
545+
if products:
546+
max_renewal_product = max(products, key=lambda product: product.renewal_period_days)
547+
renewal_period_days = max_renewal_product.renewal_period_days
548+
max_budget_amount = max(products, key=lambda product: product.max_budget_per_key).max_budget_per_key
549+
550+
# Monitor keys and get total spend (includes renewal period updates if applicable)
551+
team_total = await monitor_team_keys(team, keys_by_region, expire_keys, renewal_period_days, max_budget_amount)
441552

442553
# Set the total spend metric for the team (always emit metrics)
443554
team_total_spend.labels(

frontend/src/app/admin/private-ai-keys/page.tsx

Lines changed: 35 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,16 @@ interface User {
3434
created_at: string;
3535
}
3636

37-
38-
3937
export default function PrivateAIKeysPage() {
4038
const { toast } = useToast();
4139
const queryClient = useQueryClient();
42-
const [searchTerm, setSearchTerm] = useState('');
43-
const [isUserSearchOpen, setIsUserSearchOpen] = useState(false);
44-
const [selectedUser, setSelectedUser] = useState<User | null>(null);
45-
const [loadedSpendKeys, setLoadedSpendKeys] = useState<Set<number>>(new Set());
4640
const [isCreateDialogOpen, setIsCreateDialogOpen] = useState(false);
41+
const [selectedUser, setSelectedUser] = useState<User | null>(null);
42+
const [isUserSearchOpen, setIsUserSearchOpen] = useState(false);
43+
const [searchTerm, setSearchTerm] = useState('');
44+
4745
const debouncedSearchTerm = useDebounce(searchTerm, 300);
4846

49-
// Queries
5047
const { data: privateAIKeys = [], isLoading: isLoadingPrivateAIKeys } = useQuery<PrivateAIKey[]>({
5148
queryKey: ['private-ai-keys', selectedUser?.id],
5249
queryFn: async () => {
@@ -56,15 +53,24 @@ export default function PrivateAIKeysPage() {
5653
const response = await get(url);
5754
const data = await response.json();
5855
return data;
59-
},
60-
refetchInterval: 30000, // Refetch every 30 seconds to detect new keys
61-
refetchIntervalInBackground: true, // Continue polling even when tab is not active
56+
}
6257
});
6358

64-
// Use shared hook for data fetching
65-
const { teamDetails, teamMembers, spendMap, regions } = usePrivateAIKeysData(privateAIKeys, loadedSpendKeys);
59+
// Use shared hook for data fetching (only for team details and regions)
60+
const { teamDetails, teamMembers, regions } = usePrivateAIKeysData(privateAIKeys, new Set());
6661

67-
// Query to get all users for displaying emails
62+
// Search users
63+
const { data: users = [], isLoading: isSearching } = useQuery<User[]>({
64+
queryKey: ['users', debouncedSearchTerm],
65+
queryFn: async () => {
66+
if (!debouncedSearchTerm) return [];
67+
const response = await get(`/users?search=${debouncedSearchTerm}`);
68+
return response.json();
69+
},
70+
enabled: debouncedSearchTerm.length > 0,
71+
});
72+
73+
// Get all users for the dropdown
6874
const { data: usersMap = {} } = useQuery<Record<number, User>>({
6975
queryKey: ['users-map'],
7076
queryFn: async () => {
@@ -77,50 +83,22 @@ export default function PrivateAIKeysPage() {
7783
},
7884
});
7985

80-
const { data: users = [], isLoading: isLoadingUsers, isFetching: isFetchingUsers } = useQuery<User[], Error, User[]>({
81-
queryKey: ['users', debouncedSearchTerm],
82-
queryFn: async () => {
83-
if (!debouncedSearchTerm) return [];
84-
await new Promise(resolve => setTimeout(resolve, 100)); // Small delay to ensure loading state shows
85-
const response = await get(`/users/search?email=${encodeURIComponent(debouncedSearchTerm)}`);
86-
const data = await response.json();
87-
return data;
88-
},
89-
enabled: isUserSearchOpen && !!debouncedSearchTerm,
90-
gcTime: 60000,
91-
staleTime: 30000,
92-
refetchOnMount: false,
93-
refetchOnWindowFocus: false,
94-
});
95-
96-
// Show loading state immediately when search term changes
97-
const isSearching = searchTerm.length > 0 && (
98-
isLoadingUsers ||
99-
isFetchingUsers ||
100-
debouncedSearchTerm !== searchTerm
101-
);
102-
10386
const handleSearchChange = (value: string) => {
10487
setSearchTerm(value);
105-
// Prefetch the query if we have a value
106-
if (value) {
107-
queryClient.prefetchQuery({
108-
queryKey: ['users', value],
109-
queryFn: async () => {
110-
const response = await get(`/users/search?email=${encodeURIComponent(value)}`);
111-
const data = await response.json();
112-
return data;
113-
},
114-
});
115-
}
11688
};
11789

118-
// Mutations
90+
// Create key mutation
11991
const createKeyMutation = useMutation({
120-
mutationFn: async (data: { name: string; region_id: number; owner_id?: number; team_id?: number; key_type: 'full' | 'llm' | 'vector' }) => {
121-
const endpoint = data.key_type === 'full' ? 'private-ai-keys' :
122-
data.key_type === 'llm' ? 'private-ai-keys/token' :
123-
'private-ai-keys/vector-db';
92+
mutationFn: async (data: {
93+
name: string
94+
region_id: number
95+
key_type: 'full' | 'llm' | 'vector'
96+
owner_id?: number
97+
team_id?: number
98+
}) => {
99+
const endpoint = data.key_type === 'full' ? '/private-ai-keys' :
100+
data.key_type === 'llm' ? '/private-ai-keys/token' :
101+
'/private-ai-keys/vector-db';
124102
const response = await post(endpoint, data);
125103
return response.json();
126104
},
@@ -142,9 +120,11 @@ export default function PrivateAIKeysPage() {
142120
},
143121
});
144122

123+
// Delete key mutation
145124
const deletePrivateAIKeyMutation = useMutation({
146125
mutationFn: async (keyId: number) => {
147-
await del(`/private-ai-keys/${keyId}`);
126+
const response = await del(`/private-ai-keys/${keyId}`);
127+
return response.json();
148128
},
149129
onSuccess: () => {
150130
queryClient.invalidateQueries({ queryKey: ['private-ai-keys'] });
@@ -172,8 +152,8 @@ export default function PrivateAIKeysPage() {
172152
return response.json();
173153
},
174154
onSuccess: (data, variables) => {
175-
// Update the spend information for this specific key
176-
queryClient.setQueryData(['private-ai-key-spend', variables.keyId], data);
155+
// Invalidate the specific key's spend query to refresh the data
156+
queryClient.invalidateQueries({ queryKey: ['private-ai-key-spend', variables.keyId] });
177157
toast({
178158
title: 'Success',
179159
description: 'Budget period updated successfully',
@@ -284,8 +264,6 @@ export default function PrivateAIKeysPage() {
284264
isDeleting={deletePrivateAIKeyMutation.isPending}
285265
allowModification={true}
286266
showOwner={true}
287-
spendMap={spendMap}
288-
onLoadSpend={(keyId) => setLoadedSpendKeys(prev => new Set([...prev, keyId]))}
289267
onUpdateBudget={(keyId, budgetDuration) => {
290268
updateBudgetPeriodMutation.mutate({ keyId, budgetDuration });
291269
}}

frontend/src/app/private-ai-keys/page.tsx

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
'use client';
22

3-
import { useState, useCallback } from 'react';
3+
import { useState } from 'react';
44
import { useMutation, useQueryClient, useQuery } from '@tanstack/react-query';
55
import { Card, CardContent } from '@/components/ui/card';
66
import { useToast } from '@/hooks/use-toast';
@@ -32,7 +32,6 @@ export default function DashboardPage() {
3232
const queryClient = useQueryClient();
3333
const { user } = useAuth();
3434
const [isCreateDialogOpen, setIsCreateDialogOpen] = useState(false);
35-
const [loadedSpendKeys, setLoadedSpendKeys] = useState<Set<number>>(new Set());
3635

3736
// Fetch private AI keys using React Query
3837
const { data: privateAIKeys = [] } = useQuery<PrivateAIKey[]>({
@@ -48,13 +47,8 @@ export default function DashboardPage() {
4847
refetchOnWindowFocus: false, // Prevent unnecessary refetches
4948
});
5049

51-
// Use shared hook for data fetching
52-
const { teamDetails, teamMembers, spendMap, regions } = usePrivateAIKeysData(privateAIKeys, loadedSpendKeys);
53-
54-
// Load spend for a key
55-
const loadSpend = useCallback(async (keyId: number) => {
56-
setLoadedSpendKeys(prev => new Set([...prev, keyId]));
57-
}, []);
50+
// Use shared hook for data fetching (only for team details and regions)
51+
const { teamDetails, teamMembers, regions } = usePrivateAIKeysData(privateAIKeys, new Set());
5852

5953
// Update budget period mutation
6054
const updateBudgetMutation = useMutation({
@@ -63,10 +57,8 @@ export default function DashboardPage() {
6357
return response.json();
6458
},
6559
onSuccess: (_, { keyId }) => {
66-
// Refresh spend data
67-
setLoadedSpendKeys(prev => new Set([...prev, keyId]));
68-
// Invalidate spend queries
69-
queryClient.invalidateQueries({ queryKey: ['private-ai-keys-spend'] });
60+
// Invalidate the specific key's spend query to refresh the data
61+
queryClient.invalidateQueries({ queryKey: ['private-ai-key-spend', keyId] });
7062
toast({
7163
title: 'Success',
7264
description: 'Budget period updated successfully',
@@ -174,8 +166,6 @@ export default function DashboardPage() {
174166
isLoading={createKeyMutation.isPending}
175167
showOwner={true}
176168
allowModification={false}
177-
spendMap={spendMap}
178-
onLoadSpend={loadSpend}
179169
onUpdateBudget={(keyId, budgetDuration) => updateBudgetMutation.mutate({ keyId, budgetDuration })}
180170
isDeleting={deleteKeyMutation.isPending}
181171
isUpdatingBudget={updateBudgetMutation.isPending}

0 commit comments

Comments
 (0)