@@ -42,6 +42,8 @@ def is_wandb_available():
42
42
return False
43
43
return importlib .util .find_spec ("wandb" ) is not None
44
44
45
+ def is_swanlab_available ():
46
+ return importlib .util .find_spec ("swanlab" ) is not None
45
47
46
48
def is_ray_available ():
47
49
return importlib .util .find_spec ("ray.air" ) is not None
@@ -55,6 +57,8 @@ def get_available_reporting_integrations():
55
57
integrations .append ("wandb" )
56
58
if is_tensorboardX_available ():
57
59
integrations .append ("tensorboard" )
60
+ if is_swanlab_available ():
61
+ integrations .append ("swanlab" )
58
62
59
63
return integrations
60
64
@@ -395,6 +399,90 @@ def on_save(self, args, state, control, **kwargs):
395
399
self ._wandb .log_artifact (artifact , aliases = [f"checkpoint-{ state .global_step } " ])
396
400
397
401
402
+ class SwanLabCallback (TrainerCallback ):
403
+ """
404
+ A [`TrainerCallback`] that logs metrics, media to [Swanlab](https://swanlab.com/).
405
+ """
406
+
407
+ def __init__ (self ):
408
+ has_swanlab = is_swanlab_available ()
409
+ if not has_swanlab :
410
+ raise RuntimeError ("SwanlabCallback requires swanlab to be installed. Run `pip install swanlab`." )
411
+ if has_swanlab :
412
+ import swanlab
413
+
414
+ self ._swanlab = swanlab
415
+
416
+ self ._initialized = False
417
+
418
+ def setup (self , args , state , model , ** kwargs ):
419
+ """
420
+ Setup the optional Swanlab integration.
421
+
422
+ One can subclass and override this method to customize the setup if needed.
423
+ variables:
424
+ Environment:
425
+ - **SWANLAB_MODE** (`str`, *optional*, defaults to `"cloud"`):
426
+ Whether to use swanlab cloud, local or disabled. Set `SWANLAB_MODE="local"` to use local. Set `SWANLAB_MODE="disabled"` to disable.
427
+ - **SWANLAB_PROJECT** (`str`, *optional*, defaults to `"PaddleNLP"`):
428
+ Set this to a custom string to store results in a different project.
429
+ """
430
+
431
+ if self ._swanlab is None :
432
+ return
433
+
434
+ if args .swanlab_api_key :
435
+ self ._swanlab .login (api_key = args .swanlab_api_key )
436
+
437
+ self ._initialized = True
438
+
439
+ if state .is_world_process_zero :
440
+ logger .info (
441
+ 'Automatic Swanlab logging enabled, to disable set os.environ["SWANLAB_MODE"] = "disabled"'
442
+ )
443
+
444
+ combined_dict = {** args .to_dict ()}
445
+
446
+ if hasattr (model , "config" ) and model .config is not None :
447
+ model_config = model .config .to_dict ()
448
+ combined_dict = {** model_config , ** combined_dict }
449
+
450
+ trial_name = state .trial_name
451
+ init_args = {}
452
+ if trial_name is not None :
453
+ init_args ["name" ] = trial_name
454
+ init_args ["group" ] = args .run_name
455
+ else :
456
+ if not (args .run_name is None or args .run_name == args .output_dir ):
457
+ init_args ["name" ] = args .run_name
458
+ init_args ["dir" ] = args .logging_dir
459
+ if self ._swanlab .run is None :
460
+ self ._swanlab .init (
461
+ project = os .getenv ("SWANLAB_PROJECT" , "PaddleNLP" ),
462
+ ** init_args ,
463
+ )
464
+ self ._swanlab .config .update (combined_dict , allow_val_change = True )
465
+
466
+ def on_train_begin (self , args , state , control , model = None , ** kwargs ):
467
+ if self ._swanlab is None :
468
+ return
469
+ if not self ._initialized :
470
+ self .setup (args , state , model , ** kwargs )
471
+
472
+ def on_train_end (self , args , state , control , model = None , tokenizer = None , ** kwargs ):
473
+ if self ._swanlab is None :
474
+ return
475
+
476
+ def on_log (self , args , state , control , model = None , logs = None , ** kwargs ):
477
+ if self ._swanlab is None :
478
+ return
479
+ if not self ._initialized :
480
+ self .setup (args , state , model )
481
+ if state .is_world_process_zero :
482
+ logs = rewrite_logs (logs )
483
+ self ._swanlab .log ({** logs , "train/global_step" : state .global_step }, step = state .global_step )
484
+
485
+
398
486
class AutoNLPCallback (TrainerCallback ):
399
487
"""
400
488
A [`TrainerCallback`] that sends the logs to [`Ray Tune`] for [`AutoNLP`]
@@ -423,6 +511,7 @@ def on_evaluate(self, args, state, control, **kwargs):
423
511
"autonlp" : AutoNLPCallback ,
424
512
"wandb" : WandbCallback ,
425
513
"tensorboard" : TensorBoardCallback ,
514
+ "swanlab" : SwanLabCallback ,
426
515
}
427
516
428
517
0 commit comments