@@ -89,23 +89,23 @@ def init(backend, **kwargs):
89
89
async def startup_event ():
90
90
"""Event that is triggered when the application starts up"""
91
91
global startup_event_triggered
92
-
92
+
93
93
# Only log once
94
94
if startup_event_triggered :
95
95
return
96
-
96
+
97
97
logger .info ("FastAPI application startup event triggered" )
98
98
startup_event_triggered = True
99
-
99
+
100
100
# Wait a short time to ensure logs are processed
101
101
await asyncio .sleep (0.5 )
102
-
102
+
103
103
# Log a special message that our callback handler will detect
104
104
root_logger = logging .getLogger ()
105
105
root_logger .info ("Application startup complete - banner display trigger" )
106
-
106
+
107
107
logger .info (f"{ Fore .CYAN } Starting LocalLab server...{ Style .RESET_ALL } " )
108
-
108
+
109
109
# Get HuggingFace token and set it in environment if available
110
110
from ..config import get_hf_token
111
111
hf_token = get_hf_token (interactive = False )
@@ -114,43 +114,43 @@ async def startup_event():
114
114
logger .info (f"{ Fore .GREEN } HuggingFace token loaded from configuration{ Style .RESET_ALL } " )
115
115
else :
116
116
logger .warning (f"{ Fore .YELLOW } No HuggingFace token found. Some models may not be accessible.{ Style .RESET_ALL } " )
117
-
117
+
118
118
# Check if ngrok should be enabled
119
119
from ..cli .config import get_config_value
120
120
use_ngrok = get_config_value ("use_ngrok" , False )
121
121
if use_ngrok :
122
122
from ..utils .networking import setup_ngrok
123
123
port = int (os .environ .get ("LOCALLAB_PORT" , SERVER_PORT )) # Use SERVER_PORT as fallback
124
-
124
+
125
125
# Handle ngrok setup synchronously since it's not async
126
126
ngrok_url = setup_ngrok (port )
127
127
if ngrok_url :
128
128
logger .info (f"{ Fore .GREEN } Ngrok tunnel established successfully{ Style .RESET_ALL } " )
129
129
else :
130
130
logger .warning ("Failed to establish ngrok tunnel. Server will run locally only." )
131
-
131
+
132
132
# Initialize cache if available
133
133
if FASTAPI_CACHE_AVAILABLE :
134
134
FastAPICache .init (InMemoryBackend (), prefix = "locallab-cache" )
135
135
logger .info ("FastAPICache initialized" )
136
136
else :
137
137
logger .warning ("FastAPICache not available, caching disabled" )
138
-
138
+
139
139
# Check for model specified in environment variables or CLI config
140
140
model_to_load = (
141
- os .environ .get ("HUGGINGFACE_MODEL" ) or
142
- get_config_value ("model_id" ) or
141
+ os .environ .get ("HUGGINGFACE_MODEL" ) or
142
+ get_config_value ("model_id" ) or
143
143
DEFAULT_MODEL
144
144
)
145
-
145
+
146
146
# Log model configuration
147
147
logger .info (f"{ Fore .CYAN } Model configuration:{ Style .RESET_ALL } " )
148
148
logger .info (f" - Model to load: { model_to_load } " )
149
149
logger .info (f" - Quantization: { 'Enabled - ' + os .environ .get ('LOCALLAB_QUANTIZATION_TYPE' , QUANTIZATION_TYPE ) if os .environ .get ('LOCALLAB_ENABLE_QUANTIZATION' , '' ).lower () == 'true' else 'Disabled' } " )
150
150
logger .info (f" - Attention slicing: { 'Enabled' if os .environ .get ('LOCALLAB_ENABLE_ATTENTION_SLICING' , '' ).lower () == 'true' else 'Disabled' } " )
151
151
logger .info (f" - Flash attention: { 'Enabled' if os .environ .get ('LOCALLAB_ENABLE_FLASH_ATTENTION' , '' ).lower () == 'true' else 'Disabled' } " )
152
152
logger .info (f" - Better transformer: { 'Enabled' if os .environ .get ('LOCALLAB_ENABLE_BETTERTRANSFORMER' , '' ).lower () == 'true' else 'Disabled' } " )
153
-
153
+
154
154
# Start loading the model in background if specified
155
155
if model_to_load :
156
156
try :
@@ -166,66 +166,89 @@ async def startup_event():
166
166
async def shutdown_event ():
167
167
"""Cleanup tasks when the server shuts down"""
168
168
logger .info (f"{ Fore .YELLOW } Shutting down server...{ Style .RESET_ALL } " )
169
-
169
+
170
170
# Unload model to free GPU memory
171
171
try :
172
172
# Get current model ID before unloading
173
173
current_model = model_manager .current_model
174
-
174
+
175
175
# Unload the model
176
176
if hasattr (model_manager , 'unload_model' ):
177
177
model_manager .unload_model ()
178
178
else :
179
179
# Fallback if unload_model method doesn't exist
180
180
model_manager .model = None
181
181
model_manager .current_model = None
182
-
182
+
183
183
# Clean up memory
184
184
if torch .cuda .is_available ():
185
185
torch .cuda .empty_cache ()
186
186
gc .collect ()
187
-
187
+
188
188
# Log model unloading
189
189
if current_model :
190
190
log_model_unloaded (current_model )
191
-
191
+
192
192
logger .info ("Model unloaded and memory freed" )
193
193
except Exception as e :
194
194
logger .error (f"Error during shutdown cleanup: { str (e )} " )
195
-
195
+
196
196
# Clean up any pending tasks
197
197
try :
198
- tasks = [t for t in asyncio .all_tasks ()
199
- if t is not asyncio .current_task () and not t .done ()]
198
+ # Get all tasks except the current one
199
+ current_task = asyncio .current_task ()
200
+ tasks = [t for t in asyncio .all_tasks ()
201
+ if t is not current_task and not t .done ()]
202
+
200
203
if tasks :
201
204
logger .debug (f"Cancelling { len (tasks )} remaining tasks" )
205
+
206
+ # Cancel all tasks
202
207
for task in tasks :
203
208
task .cancel ()
204
- await asyncio .gather (* tasks , return_exceptions = True )
209
+
210
+ # Wait for tasks to complete with a timeout
211
+ try :
212
+ # Use wait_for with a timeout to avoid hanging
213
+ await asyncio .wait_for (asyncio .gather (* tasks , return_exceptions = True ), timeout = 3.0 )
214
+ logger .debug ("All tasks cancelled successfully" )
215
+ except asyncio .TimeoutError :
216
+ logger .warning ("Timeout waiting for tasks to cancel" )
217
+ except asyncio .CancelledError :
218
+ # This is expected during shutdown
219
+ logger .debug ("Task cancellation was itself cancelled - this is normal during shutdown" )
220
+ except Exception as e :
221
+ logger .warning (f"Error during task cancellation: { str (e )} " )
205
222
except Exception as e :
206
223
logger .warning (f"Error cleaning up tasks: { str (e )} " )
207
-
224
+
208
225
# Set server status to stopped
209
226
set_server_status ("stopped" )
210
-
227
+
211
228
logger .info (f"{ Fore .GREEN } Server shutdown complete{ Style .RESET_ALL } " )
212
-
229
+
213
230
# Only force exit if this is a true shutdown initiated by SIGINT/SIGTERM
214
231
# Check if this was triggered by an actual signal
215
232
if hasattr (shutdown_event , 'force_exit_required' ) and shutdown_event .force_exit_required :
216
233
import threading
217
234
def force_exit ():
218
235
import time
219
236
import os
220
- import signal
221
237
time .sleep (3 ) # Give a little time for clean shutdown
222
- logger .info ("Forcing exit after shutdown to ensure clean termination" )
223
- try :
224
- os ._exit (0 ) # Direct exit instead of sending another signal
225
- except :
226
- pass
227
-
228
- threading .Thread (target = force_exit , daemon = True ).start ()
238
+
239
+ # Check if we need to force exit
240
+ if hasattr (shutdown_event , 'force_exit_required' ) and shutdown_event .force_exit_required :
241
+ logger .info ("Forcing exit after shutdown to ensure clean termination" )
242
+ try :
243
+ # Reset the flag to avoid multiple exit attempts
244
+ shutdown_event .force_exit_required = False
245
+ os ._exit (0 ) # Direct exit instead of sending another signal
246
+ except :
247
+ pass
248
+
249
+ # Start a daemon thread that will force exit if needed
250
+ exit_thread = threading .Thread (target = force_exit , daemon = True )
251
+ exit_thread .start ()
229
252
230
253
# Initialize the flag (default to not forcing exit)
231
254
shutdown_event .force_exit_required = False
@@ -234,40 +257,40 @@ def force_exit():
234
257
async def add_process_time_header (request : Request , call_next ):
235
258
"""Middleware to track request processing time"""
236
259
start_time = time .time ()
237
-
260
+
238
261
# Extract path and some basic params for logging
239
262
path = request .url .path
240
263
method = request .method
241
264
client = request .client .host if request .client else "unknown"
242
-
265
+
243
266
# Skip detailed logging for health check endpoints to reduce noise
244
267
is_health_check = path .endswith ("/health" ) or path .endswith ("/startup-status" )
245
-
268
+
246
269
if not is_health_check :
247
270
log_request (f"{ method } { path } " , {"client" : client })
248
-
271
+
249
272
# Process the request
250
273
response = await call_next (request )
251
-
274
+
252
275
# Calculate processing time
253
276
process_time = time .time () - start_time
254
277
response .headers ["X-Process-Time" ] = f"{ process_time :.4f} "
255
-
278
+
256
279
# Add request stats to response headers
257
280
response .headers ["X-Request-Count" ] = str (get_request_count ())
258
-
281
+
259
282
# Log slow requests for performance monitoring (if not a health check)
260
283
if process_time > 1.0 and not is_health_check :
261
284
logger .warning (f"Slow request: { method } { path } took { process_time :.2f} s" )
262
-
285
+
263
286
return response
264
287
265
288
266
289
async def load_model_in_background (model_id : str ):
267
290
"""Load the model asynchronously in the background"""
268
291
logger .info (f"Loading model { model_id } in background..." )
269
292
start_time = time .time ()
270
-
293
+
271
294
try :
272
295
# Ensure HF token is set before loading model
273
296
from ..config import get_hf_token
@@ -277,13 +300,13 @@ async def load_model_in_background(model_id: str):
277
300
logger .debug ("Using HuggingFace token from configuration" )
278
301
else :
279
302
logger .warning ("No HuggingFace token found. Some models may not be accessible." )
280
-
303
+
281
304
# Wait for the model to load
282
305
await model_manager .load_model (model_id )
283
-
306
+
284
307
# Calculate load time
285
308
load_time = time .time () - start_time
286
-
309
+
287
310
# We don't need to call log_model_loaded here since it's already done in the model_manager
288
311
logger .info (f"{ Fore .GREEN } Model { model_id } loaded successfully in { load_time :.2f} seconds!{ Style .RESET_ALL } " )
289
312
except Exception as e :
0 commit comments