@@ -205,11 +205,23 @@ def __init__(self, config):
205
205
self ._serve_task = None
206
206
self ._socket = None
207
207
self .app = config .app
208
+ self .callback_triggered = False
208
209
209
210
async def start (self ):
210
211
self .started = True
211
212
logger .info ("Started SimpleTCPServer as fallback" )
212
213
214
+ # Trigger callback if server has one and it hasn't been triggered yet
215
+ if (hasattr (self , 'server' ) and self .server and
216
+ hasattr (self .server , 'on_startup_callback' ) and
217
+ not self .callback_triggered and
218
+ not (hasattr (self .server , 'callback_triggered' ) and self .server .callback_triggered )):
219
+ logger .info ("Executing startup callback from SimpleTCPServer.start" )
220
+ self .server .on_startup_callback ()
221
+ self .callback_triggered = True
222
+ if hasattr (self .server , 'callback_triggered' ):
223
+ self .server .callback_triggered = True
224
+
213
225
if not self ._serve_task :
214
226
self ._serve_task = asyncio .create_task (self ._run_server ())
215
227
@@ -428,6 +440,7 @@ def __init__(self, config):
428
440
super ().__init__ (config )
429
441
self .servers = [] # Initialize servers list
430
442
self .should_exit = False
443
+ self .callback_triggered = False # Flag to track if callback has been triggered
431
444
432
445
def install_signal_handlers (self ):
433
446
def handle_exit (signum , frame ):
@@ -444,21 +457,43 @@ async def startup(self, sockets=None):
444
457
try :
445
458
await super ().startup (sockets = sockets )
446
459
logger .info ("Using uvicorn's built-in Server implementation" )
460
+
461
+ # Execute callback after successful startup
462
+ # This is critical to show the running banner
463
+ if hasattr (self , 'on_startup_callback' ) and not self .callback_triggered :
464
+ logger .info ("Executing server startup callback" )
465
+ self .on_startup_callback ()
466
+ self .callback_triggered = True
467
+
447
468
except Exception as e :
448
469
logger .error (f"Error during server startup: { str (e )} " )
449
470
logger .debug (f"Server startup error details: { traceback .format_exc ()} " )
450
471
self .servers = []
472
+
473
+ # Create SimpleTCPServer as fallback
451
474
if sockets :
452
475
for socket in sockets :
453
476
server = SimpleTCPServer (config = self .config )
454
477
server .server = self
455
478
await server .start ()
456
479
self .servers .append (server )
480
+
481
+ # Make sure callback is executed for the fallback server too
482
+ if hasattr (self , 'on_startup_callback' ) and not self .callback_triggered :
483
+ logger .info ("Executing server startup callback (fallback server)" )
484
+ self .on_startup_callback ()
485
+ self .callback_triggered = True
457
486
else :
458
487
server = SimpleTCPServer (config = self .config )
459
488
server .server = self
460
489
await server .start ()
461
490
self .servers .append (server )
491
+
492
+ # Make sure callback is executed for the fallback server too
493
+ if hasattr (self , 'on_startup_callback' ) and not self .callback_triggered :
494
+ logger .info ("Executing server startup callback (fallback server)" )
495
+ self .on_startup_callback ()
496
+ self .callback_triggered = True
462
497
463
498
async def shutdown (self , sockets = None ):
464
499
logger .debug ("Starting server shutdown process" )
@@ -562,15 +597,17 @@ def start_server(use_ngrok: bool = None, port: int = None, ngrok_auth_token: Opt
562
597
logger .error (f"{ Fore .YELLOW } Please ensure all dependencies are installed: pip install -e .{ Style .RESET_ALL } " )
563
598
raise
564
599
565
- # Create a function to display the Running banner when the server is ready
566
- startup_complete = False # Flag to track if startup has been completed
600
+ # Flag to track if startup has been completed
601
+ startup_complete = [ False ] # Using a list as a mutable reference
567
602
568
603
def on_startup ():
569
- nonlocal startup_complete
570
- if startup_complete :
604
+ # Use the mutable reference to track startup
605
+ if startup_complete [ 0 ] :
571
606
return
572
607
573
608
try :
609
+ logger .info ("Server startup callback triggered" )
610
+
574
611
# Set server status to running
575
612
set_server_status ("running" )
576
613
@@ -614,12 +651,14 @@ def on_startup():
614
651
logger .debug (f"Footer display error details: { traceback .format_exc ()} " )
615
652
616
653
# Set flag to indicate startup is complete
617
- startup_complete = True
654
+ startup_complete [0 ] = True
655
+ logger .info ("Server startup display completed successfully" )
656
+
618
657
except Exception as e :
619
658
logger .error (f"Error during server startup display: { str (e )} " )
620
659
logger .debug (f"Startup display error details: { traceback .format_exc ()} " )
621
660
# Still mark startup as complete to avoid repeated attempts
622
- startup_complete = True
661
+ startup_complete [ 0 ] = True
623
662
# Ensure server status is set to running even if display fails
624
663
set_server_status ("running" )
625
664
@@ -639,21 +678,21 @@ def on_startup():
639
678
640
679
# Define the callback for Colab
641
680
async def on_startup_async ():
642
- # This will only run once due to the flag in on_startup
643
- on_startup ()
681
+ # This is an async callback that uvicorn might call
682
+ if not startup_complete [0 ]:
683
+ on_startup ()
644
684
645
685
config = uvicorn .Config (
646
686
app ,
647
687
host = "0.0.0.0" , # Bind to all interfaces in Colab
648
688
port = port ,
649
689
reload = False ,
650
690
log_level = "info" ,
651
- # Use an async callback function, not a list
652
- callback_notify = on_startup_async
691
+ callback_notify = [on_startup_async ] # Use a list for the callback
653
692
)
654
693
655
694
server = ServerWithCallback (config )
656
- server .on_startup_callback = on_startup # Set the callback
695
+ server .on_startup_callback = on_startup # Also set the direct callback
657
696
658
697
# Use the appropriate event loop method based on Python version
659
698
try :
@@ -664,7 +703,8 @@ async def on_startup_async():
664
703
if "'Server' object has no attribute 'start'" in str (e ):
665
704
# If we get the 'start' attribute error, use our SimpleTCPServer directly
666
705
logger .warning ("Falling back to direct SimpleTCPServer implementation" )
667
- direct_server = SimpleTCPServer (config = self .config )
706
+ direct_server = SimpleTCPServer (config = config ) # Pass the config directly
707
+ direct_server .server = server # Set reference to the server for callbacks
668
708
asyncio .run (direct_server .serve ())
669
709
else :
670
710
raise
@@ -679,7 +719,8 @@ async def on_startup_async():
679
719
if "'Server' object has no attribute 'start'" in str (e ):
680
720
# If we get the 'start' attribute error, use our SimpleTCPServer directly
681
721
logger .warning ("Falling back to direct SimpleTCPServer implementation" )
682
- direct_server = SimpleTCPServer (config = self .config )
722
+ direct_server = SimpleTCPServer (config = config ) # Pass the config directly
723
+ direct_server .server = server # Set reference to the server for callbacks
683
724
loop .run_until_complete (direct_server .serve ())
684
725
else :
685
726
raise
@@ -698,12 +739,11 @@ async def on_startup_async():
698
739
reload = False ,
699
740
workers = 1 ,
700
741
log_level = "info" ,
701
- # This won't be used directly, as we call on_startup in the ServerWithCallback class
702
- callback_notify = None
742
+ callback_notify = [lambda : on_startup ()] # Use a lambda to prevent immediate execution
703
743
)
704
744
705
745
server = ServerWithCallback (config )
706
- server .on_startup_callback = on_startup # Set the callback
746
+ server .on_startup_callback = on_startup # Set the callback directly
707
747
708
748
# Use asyncio.run which is more reliable
709
749
try :
@@ -714,7 +754,8 @@ async def on_startup_async():
714
754
if "'Server' object has no attribute 'start'" in str (e ):
715
755
# If we get the 'start' attribute error, use our SimpleTCPServer directly
716
756
logger .warning ("Falling back to direct SimpleTCPServer implementation" )
717
- direct_server = SimpleTCPServer (config = self .config )
757
+ direct_server = SimpleTCPServer (config = config ) # Pass the config directly
758
+ direct_server .server = server # Set reference to the server for callbacks
718
759
asyncio .run (direct_server .serve ())
719
760
else :
720
761
raise
@@ -729,13 +770,20 @@ async def on_startup_async():
729
770
if "'Server' object has no attribute 'start'" in str (e ):
730
771
# If we get the 'start' attribute error, use our SimpleTCPServer directly
731
772
logger .warning ("Falling back to direct SimpleTCPServer implementation" )
732
- direct_server = SimpleTCPServer (config = self .config )
773
+ direct_server = SimpleTCPServer (config = config ) # Pass the config directly
774
+ direct_server .server = server # Set reference to the server for callbacks
733
775
loop .run_until_complete (direct_server .serve ())
734
776
else :
735
777
raise
736
778
else :
737
779
# Re-raise other errors
738
780
raise
781
+
782
+ # If we reach here and startup hasn't completed yet, call it manually as a fallback
783
+ if not startup_complete [0 ]:
784
+ logger .warning ("Server started but startup callback wasn't triggered. Calling manually..." )
785
+ on_startup ()
786
+
739
787
except Exception as e :
740
788
logger .error (f"Server startup failed: { str (e )} " )
741
789
logger .error (traceback .format_exc ())
0 commit comments