@@ -403,81 +403,132 @@ def on_save(self, args, state, control, **kwargs):
403403
404404class SwanLabCallback (TrainerCallback ):
405405 """
406- A [`TrainerCallback`] that logs metrics, media to [Swanlab ](https://swanlab.cn/).
406+ A [`TrainerCallback`] that logs metrics, media, model checkpoints to [SwanLab ](https://swanlab.cn/).
407407 """
408408
409409 def __init__ (self ):
410- has_swanlab = is_swanlab_available ()
411- if not has_swanlab :
412- raise RuntimeError ("SwanlabCallback requires swanlab to be installed. Run `pip install swanlab`." )
413- if has_swanlab :
414- import swanlab
415-
416- self ._swanlab = swanlab
410+ if not is_swanlab_available ():
411+ raise RuntimeError ("SwanLabCallback requires swanlab to be installed. Run `pip install swanlab`." )
412+ import swanlab
417413
414+ self ._swanlab = swanlab
418415 self ._initialized = False
416+ self ._log_model = os .getenv ("SWANLAB_LOG_MODEL" , None )
419417
420418 def setup (self , args , state , model , ** kwargs ):
421419 """
422- Setup the optional Swanlab integration.
423-
424- One can subclass and override this method to customize the setup if needed.
425- variables:
420+ Setup the optional SwanLab (*swanlab*) integration.
421+ One can subclass and override this method to customize the setup if needed. Find more information
422+ [here](https://docs.swanlab.cn/guide_cloud/integration/integration-huggingface-transformers.html).
423+ You can also override the following environment variables. Find more information about environment
424+ variables [here](https://docs.swanlab.cn/en/api/environment-variable.html#environment-variables)
426425 Environment:
427- - **SWANLAB_MODE** (`str`, *optional*, defaults to `"cloud"`):
428- Whether to use swanlab cloud, local or disabled. Set `SWANLAB_MODE="local"` to use local. Set `SWANLAB_MODE="disabled"` to disable.
429- - **SWANLAB_PROJECT** (`str`, *optional*, defaults to `"PaddleNLP"`):
430- Set this to a custom string to store results in a different project.
426+ - **SWANLAB_API_KEY** (`str`, *optional*, defaults to `None`):
427+ Cloud API Key. During login, this environment variable is checked first. If it doesn't exist, the system
428+ checks if the user is already logged in. If not, the login process is initiated.
429+ - If a string is passed to the login interface, this environment variable is ignored.
430+ - If the user is already logged in, this environment variable takes precedence over locally stored
431+ login information.
432+ - **SWANLAB_PROJECT** (`str`, *optional*, defaults to `None`):
433+ Set this to a custom string to store results in a different project. If not specified, the name of the current
434+ running directory is used.
435+ - **SWANLAB_LOG_DIR** (`str`, *optional*, defaults to `swanlog`):
436+ This environment variable specifies the storage path for log files when running in local mode.
437+ By default, logs are saved in a folder named swanlog under the working directory.
438+ - **SWANLAB_MODE** (`Literal["local", "cloud", "disabled"]`, *optional*, defaults to `cloud`):
439+ SwanLab's parsing mode, which involves callbacks registered by the operator. Currently, there are three modes:
440+ local, cloud, and disabled. Note: Case-sensitive. Find more information
441+ [here](https://docs.swanlab.cn/en/api/py-init.html#swanlab-init)
442+ - **SWANLAB_LOG_MODEL** (`str`, *optional*, defaults to `None`):
443+ SwanLab does not currently support the save mode functionality.This feature will be available in a future
444+ release
445+ - **SWANLAB_WEB_HOST** (`str`, *optional*, defaults to `None`):
446+ Web address for the SwanLab cloud environment for private version (its free)
447+ - **SWANLAB_API_HOST** (`str`, *optional*, defaults to `None`):
448+ API address for the SwanLab cloud environment for private version (its free)
431449 """
432-
433- if self ._swanlab is None :
434- return
435-
436450 self ._initialized = True
437451
438452 if state .is_world_process_zero :
439- logger .info ('Automatic Swanlab logging enabled, to disable set os.environ["SWANLAB_MODE"] = "disabled"' )
440-
453+ logger .info ('Automatic SwanLab logging enabled, to disable set os.environ["SWANLAB_MODE"] = "disabled"' )
441454 combined_dict = {** args .to_dict ()}
442455
443456 if hasattr (model , "config" ) and model .config is not None :
444- model_config = model .config .to_dict ()
457+ model_config = model .config if isinstance ( model . config , dict ) else model . config .to_dict ()
445458 combined_dict = {** model_config , ** combined_dict }
446-
459+ if hasattr (model , "lora_config" ) and model .lora_config is not None :
460+ lora_config = model .lora_config if isinstance (model .lora_config , dict ) else model .lora_config .to_dict ()
461+ combined_dict = {** {"lora_config" : lora_config }, ** combined_dict }
447462 trial_name = state .trial_name
448463 init_args = {}
449- if trial_name is not None :
450- init_args ["name" ] = trial_name
451- init_args ["group" ] = args .run_name
452- else :
453- if not (args .run_name is None or args .run_name == args .output_dir ):
454- init_args ["name" ] = args .run_name
455- init_args ["dir" ] = args .logging_dir
464+ if trial_name is not None and args .run_name is not None :
465+ init_args ["experiment_name" ] = f"{ args .run_name } -{ trial_name } "
466+ elif args .run_name is not None :
467+ init_args ["experiment_name" ] = args .run_name
468+ elif trial_name is not None :
469+ init_args ["experiment_name" ] = trial_name
470+
471+ # new add this for experiment_name
472+ experiment_name = os .getenv ("SWANLAB_EXP_NAME" , None )
473+ if experiment_name is not None :
474+ init_args ["experiment_name" ] = experiment_name
475+
476+ init_args ["project" ] = os .getenv ("SWANLAB_PROJECT" , None )
477+ if args .logging_dir is not None :
478+ init_args ["logdir" ] = os .getenv ("SWANLAB_LOG_DIR" , args .logging_dir )
479+
456480 if self ._swanlab .get_run () is None :
457481 self ._swanlab .init (
458- project = os .getenv ("SWANLAB_PROJECT" , "PaddleNLP" ),
459482 ** init_args ,
460483 )
461- self ._swanlab .config .update (combined_dict , allow_val_change = True )
484+ # show paddlenlp logo!
485+ self ._swanlab .config ["FRAMEWORK" ] = "paddlenlp"
486+ # add config parameters (run may have been created manually)
487+ self ._swanlab .config .update (combined_dict )
462488
463489 def on_train_begin (self , args , state , control , model = None , ** kwargs ):
464- if self ._swanlab is None :
465- return
466490 if not self ._initialized :
467491 self .setup (args , state , model , ** kwargs )
468492
469- def on_train_end (self , args , state , control , model = None , tokenizer = None , ** kwargs ):
470- if self ._swanlab is None :
471- return
493+ def on_train_end (self , args , state , control , model = None , processing_class = None , ** kwargs ):
494+ if self ._log_model is not None and self ._initialized and state .is_world_process_zero :
495+ logger .warning (
496+ "SwanLab does not currently support the save mode functionality. "
497+ "This feature will be available in a future release."
498+ )
472499
473500 def on_log (self , args , state , control , model = None , logs = None , ** kwargs ):
474- if self ._swanlab is None :
475- return
501+ single_value_scalars = [
502+ "train_runtime" ,
503+ "train_samples_per_second" ,
504+ "train_steps_per_second" ,
505+ "train_loss" ,
506+ "total_flos" ,
507+ ]
508+
476509 if not self ._initialized :
477510 self .setup (args , state , model )
478511 if state .is_world_process_zero :
479- logs = rewrite_logs (logs )
480- self ._swanlab .log ({** logs , "train/global_step" : state .global_step }, step = state .global_step )
512+ for k , v in logs .items ():
513+ if k in single_value_scalars :
514+ self ._swanlab .log ({f"single_value/{ k } " : v }, step = state .global_step )
515+ non_scalar_logs = {k : v for k , v in logs .items () if k not in single_value_scalars }
516+ non_scalar_logs = rewrite_logs (non_scalar_logs )
517+ self ._swanlab .log ({** non_scalar_logs , "train/global_step" : state .global_step }, step = state .global_step )
518+
519+ def on_save (self , args , state , control , ** kwargs ):
520+ if self ._log_model is not None and self ._initialized and state .is_world_process_zero :
521+ logger .warning (
522+ "SwanLab does not currently support the save mode functionality. "
523+ "This feature will be available in a future release."
524+ )
525+
526+ def on_predict (self , args , state , control , metrics , ** kwargs ):
527+ if not self ._initialized :
528+ self .setup (args , state , ** kwargs )
529+ if state .is_world_process_zero :
530+ metrics = rewrite_logs (metrics )
531+ self ._swanlab .log (metrics )
481532
482533
483534class AutoNLPCallback (TrainerCallback ):
0 commit comments