Skip to content

Commit 685bc20

Browse files
committed
Improved LocalLab ShutDown
1 parent e990a5a commit 685bc20

File tree

4 files changed

+338
-243
lines changed

4 files changed

+338
-243
lines changed

CHANGELOG.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,19 @@
22

33
All notable changes to LocalLab will be documented in this file.
44

5+
## [0.4.49] - 2024-04-21
6+
7+
### Fixed
8+
9+
- Fixed server shutdown issues when pressing Ctrl+C
10+
- Improved error handling during server shutdown process
11+
- Enhanced handling of asyncio.CancelledError during shutdown
12+
- Added proper handling for asyncio.Server objects during shutdown
13+
- Reduced duplicate log messages during shutdown
14+
- Added clean shutdown banner for better user experience
15+
- Improved task cancellation with proper timeout handling
16+
- Enhanced force exit mechanism to ensure clean termination
17+
518
## [0.4.48] - 2024-03-15
619

720
### Client Library Changes (v0.2.1)

locallab/core/app.py

Lines changed: 69 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -89,23 +89,23 @@ def init(backend, **kwargs):
8989
async def startup_event():
9090
"""Event that is triggered when the application starts up"""
9191
global startup_event_triggered
92-
92+
9393
# Only log once
9494
if startup_event_triggered:
9595
return
96-
96+
9797
logger.info("FastAPI application startup event triggered")
9898
startup_event_triggered = True
99-
99+
100100
# Wait a short time to ensure logs are processed
101101
await asyncio.sleep(0.5)
102-
102+
103103
# Log a special message that our callback handler will detect
104104
root_logger = logging.getLogger()
105105
root_logger.info("Application startup complete - banner display trigger")
106-
106+
107107
logger.info(f"{Fore.CYAN}Starting LocalLab server...{Style.RESET_ALL}")
108-
108+
109109
# Get HuggingFace token and set it in environment if available
110110
from ..config import get_hf_token
111111
hf_token = get_hf_token(interactive=False)
@@ -114,43 +114,43 @@ async def startup_event():
114114
logger.info(f"{Fore.GREEN}HuggingFace token loaded from configuration{Style.RESET_ALL}")
115115
else:
116116
logger.warning(f"{Fore.YELLOW}No HuggingFace token found. Some models may not be accessible.{Style.RESET_ALL}")
117-
117+
118118
# Check if ngrok should be enabled
119119
from ..cli.config import get_config_value
120120
use_ngrok = get_config_value("use_ngrok", False)
121121
if use_ngrok:
122122
from ..utils.networking import setup_ngrok
123123
port = int(os.environ.get("LOCALLAB_PORT", SERVER_PORT)) # Use SERVER_PORT as fallback
124-
124+
125125
# Handle ngrok setup synchronously since it's not async
126126
ngrok_url = setup_ngrok(port)
127127
if ngrok_url:
128128
logger.info(f"{Fore.GREEN}Ngrok tunnel established successfully{Style.RESET_ALL}")
129129
else:
130130
logger.warning("Failed to establish ngrok tunnel. Server will run locally only.")
131-
131+
132132
# Initialize cache if available
133133
if FASTAPI_CACHE_AVAILABLE:
134134
FastAPICache.init(InMemoryBackend(), prefix="locallab-cache")
135135
logger.info("FastAPICache initialized")
136136
else:
137137
logger.warning("FastAPICache not available, caching disabled")
138-
138+
139139
# Check for model specified in environment variables or CLI config
140140
model_to_load = (
141-
os.environ.get("HUGGINGFACE_MODEL") or
142-
get_config_value("model_id") or
141+
os.environ.get("HUGGINGFACE_MODEL") or
142+
get_config_value("model_id") or
143143
DEFAULT_MODEL
144144
)
145-
145+
146146
# Log model configuration
147147
logger.info(f"{Fore.CYAN}Model configuration:{Style.RESET_ALL}")
148148
logger.info(f" - Model to load: {model_to_load}")
149149
logger.info(f" - Quantization: {'Enabled - ' + os.environ.get('LOCALLAB_QUANTIZATION_TYPE', QUANTIZATION_TYPE) if os.environ.get('LOCALLAB_ENABLE_QUANTIZATION', '').lower() == 'true' else 'Disabled'}")
150150
logger.info(f" - Attention slicing: {'Enabled' if os.environ.get('LOCALLAB_ENABLE_ATTENTION_SLICING', '').lower() == 'true' else 'Disabled'}")
151151
logger.info(f" - Flash attention: {'Enabled' if os.environ.get('LOCALLAB_ENABLE_FLASH_ATTENTION', '').lower() == 'true' else 'Disabled'}")
152152
logger.info(f" - Better transformer: {'Enabled' if os.environ.get('LOCALLAB_ENABLE_BETTERTRANSFORMER', '').lower() == 'true' else 'Disabled'}")
153-
153+
154154
# Start loading the model in background if specified
155155
if model_to_load:
156156
try:
@@ -166,66 +166,89 @@ async def startup_event():
166166
async def shutdown_event():
167167
"""Cleanup tasks when the server shuts down"""
168168
logger.info(f"{Fore.YELLOW}Shutting down server...{Style.RESET_ALL}")
169-
169+
170170
# Unload model to free GPU memory
171171
try:
172172
# Get current model ID before unloading
173173
current_model = model_manager.current_model
174-
174+
175175
# Unload the model
176176
if hasattr(model_manager, 'unload_model'):
177177
model_manager.unload_model()
178178
else:
179179
# Fallback if unload_model method doesn't exist
180180
model_manager.model = None
181181
model_manager.current_model = None
182-
182+
183183
# Clean up memory
184184
if torch.cuda.is_available():
185185
torch.cuda.empty_cache()
186186
gc.collect()
187-
187+
188188
# Log model unloading
189189
if current_model:
190190
log_model_unloaded(current_model)
191-
191+
192192
logger.info("Model unloaded and memory freed")
193193
except Exception as e:
194194
logger.error(f"Error during shutdown cleanup: {str(e)}")
195-
195+
196196
# Clean up any pending tasks
197197
try:
198-
tasks = [t for t in asyncio.all_tasks()
199-
if t is not asyncio.current_task() and not t.done()]
198+
# Get all tasks except the current one
199+
current_task = asyncio.current_task()
200+
tasks = [t for t in asyncio.all_tasks()
201+
if t is not current_task and not t.done()]
202+
200203
if tasks:
201204
logger.debug(f"Cancelling {len(tasks)} remaining tasks")
205+
206+
# Cancel all tasks
202207
for task in tasks:
203208
task.cancel()
204-
await asyncio.gather(*tasks, return_exceptions=True)
209+
210+
# Wait for tasks to complete with a timeout
211+
try:
212+
# Use wait_for with a timeout to avoid hanging
213+
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=3.0)
214+
logger.debug("All tasks cancelled successfully")
215+
except asyncio.TimeoutError:
216+
logger.warning("Timeout waiting for tasks to cancel")
217+
except asyncio.CancelledError:
218+
# This is expected during shutdown
219+
logger.debug("Task cancellation was itself cancelled - this is normal during shutdown")
220+
except Exception as e:
221+
logger.warning(f"Error during task cancellation: {str(e)}")
205222
except Exception as e:
206223
logger.warning(f"Error cleaning up tasks: {str(e)}")
207-
224+
208225
# Set server status to stopped
209226
set_server_status("stopped")
210-
227+
211228
logger.info(f"{Fore.GREEN}Server shutdown complete{Style.RESET_ALL}")
212-
229+
213230
# Only force exit if this is a true shutdown initiated by SIGINT/SIGTERM
214231
# Check if this was triggered by an actual signal
215232
if hasattr(shutdown_event, 'force_exit_required') and shutdown_event.force_exit_required:
216233
import threading
217234
def force_exit():
218235
import time
219236
import os
220-
import signal
221237
time.sleep(3) # Give a little time for clean shutdown
222-
logger.info("Forcing exit after shutdown to ensure clean termination")
223-
try:
224-
os._exit(0) # Direct exit instead of sending another signal
225-
except:
226-
pass
227-
228-
threading.Thread(target=force_exit, daemon=True).start()
238+
239+
# Check if we need to force exit
240+
if hasattr(shutdown_event, 'force_exit_required') and shutdown_event.force_exit_required:
241+
logger.info("Forcing exit after shutdown to ensure clean termination")
242+
try:
243+
# Reset the flag to avoid multiple exit attempts
244+
shutdown_event.force_exit_required = False
245+
os._exit(0) # Direct exit instead of sending another signal
246+
except:
247+
pass
248+
249+
# Start a daemon thread that will force exit if needed
250+
exit_thread = threading.Thread(target=force_exit, daemon=True)
251+
exit_thread.start()
229252

230253
# Initialize the flag (default to not forcing exit)
231254
shutdown_event.force_exit_required = False
@@ -234,40 +257,40 @@ def force_exit():
234257
async def add_process_time_header(request: Request, call_next):
235258
"""Middleware to track request processing time"""
236259
start_time = time.time()
237-
260+
238261
# Extract path and some basic params for logging
239262
path = request.url.path
240263
method = request.method
241264
client = request.client.host if request.client else "unknown"
242-
265+
243266
# Skip detailed logging for health check endpoints to reduce noise
244267
is_health_check = path.endswith("/health") or path.endswith("/startup-status")
245-
268+
246269
if not is_health_check:
247270
log_request(f"{method} {path}", {"client": client})
248-
271+
249272
# Process the request
250273
response = await call_next(request)
251-
274+
252275
# Calculate processing time
253276
process_time = time.time() - start_time
254277
response.headers["X-Process-Time"] = f"{process_time:.4f}"
255-
278+
256279
# Add request stats to response headers
257280
response.headers["X-Request-Count"] = str(get_request_count())
258-
281+
259282
# Log slow requests for performance monitoring (if not a health check)
260283
if process_time > 1.0 and not is_health_check:
261284
logger.warning(f"Slow request: {method} {path} took {process_time:.2f}s")
262-
285+
263286
return response
264287

265288

266289
async def load_model_in_background(model_id: str):
267290
"""Load the model asynchronously in the background"""
268291
logger.info(f"Loading model {model_id} in background...")
269292
start_time = time.time()
270-
293+
271294
try:
272295
# Ensure HF token is set before loading model
273296
from ..config import get_hf_token
@@ -277,13 +300,13 @@ async def load_model_in_background(model_id: str):
277300
logger.debug("Using HuggingFace token from configuration")
278301
else:
279302
logger.warning("No HuggingFace token found. Some models may not be accessible.")
280-
303+
281304
# Wait for the model to load
282305
await model_manager.load_model(model_id)
283-
306+
284307
# Calculate load time
285308
load_time = time.time() - start_time
286-
309+
287310
# We don't need to call log_model_loaded here since it's already done in the model_manager
288311
logger.info(f"{Fore.GREEN}Model {model_id} loaded successfully in {load_time:.2f} seconds!{Style.RESET_ALL}")
289312
except Exception as e:

0 commit comments

Comments
 (0)