Skip to content

Commit 2f78db6

Browse files
authored
Upgrad swandlab code (#11105)
* Upgrad swandlab code * add SWANLAB_EXP_NAME env
1 parent 5dd2855 commit 2f78db6

File tree

1 file changed

+93
-42
lines changed

1 file changed

+93
-42
lines changed

paddlenlp/trainer/integrations.py

Lines changed: 93 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -403,81 +403,132 @@ def on_save(self, args, state, control, **kwargs):
403403

404404
class SwanLabCallback(TrainerCallback):
405405
"""
406-
A [`TrainerCallback`] that logs metrics, media to [Swanlab](https://swanlab.cn/).
406+
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [SwanLab](https://swanlab.cn/).
407407
"""
408408

409409
def __init__(self):
410-
has_swanlab = is_swanlab_available()
411-
if not has_swanlab:
412-
raise RuntimeError("SwanlabCallback requires swanlab to be installed. Run `pip install swanlab`.")
413-
if has_swanlab:
414-
import swanlab
415-
416-
self._swanlab = swanlab
410+
if not is_swanlab_available():
411+
raise RuntimeError("SwanLabCallback requires swanlab to be installed. Run `pip install swanlab`.")
412+
import swanlab
417413

414+
self._swanlab = swanlab
418415
self._initialized = False
416+
self._log_model = os.getenv("SWANLAB_LOG_MODEL", None)
419417

420418
def setup(self, args, state, model, **kwargs):
421419
"""
422-
Setup the optional Swanlab integration.
423-
424-
One can subclass and override this method to customize the setup if needed.
425-
variables:
420+
Setup the optional SwanLab (*swanlab*) integration.
421+
One can subclass and override this method to customize the setup if needed. Find more information
422+
[here](https://docs.swanlab.cn/guide_cloud/integration/integration-huggingface-transformers.html).
423+
You can also override the following environment variables. Find more information about environment
424+
variables [here](https://docs.swanlab.cn/en/api/environment-variable.html#environment-variables)
426425
Environment:
427-
- **SWANLAB_MODE** (`str`, *optional*, defaults to `"cloud"`):
428-
Whether to use swanlab cloud, local or disabled. Set `SWANLAB_MODE="local"` to use local. Set `SWANLAB_MODE="disabled"` to disable.
429-
- **SWANLAB_PROJECT** (`str`, *optional*, defaults to `"PaddleNLP"`):
430-
Set this to a custom string to store results in a different project.
426+
- **SWANLAB_API_KEY** (`str`, *optional*, defaults to `None`):
427+
Cloud API Key. During login, this environment variable is checked first. If it doesn't exist, the system
428+
checks if the user is already logged in. If not, the login process is initiated.
429+
- If a string is passed to the login interface, this environment variable is ignored.
430+
- If the user is already logged in, this environment variable takes precedence over locally stored
431+
login information.
432+
- **SWANLAB_PROJECT** (`str`, *optional*, defaults to `None`):
433+
Set this to a custom string to store results in a different project. If not specified, the name of the current
434+
running directory is used.
435+
- **SWANLAB_LOG_DIR** (`str`, *optional*, defaults to `swanlog`):
436+
This environment variable specifies the storage path for log files when running in local mode.
437+
By default, logs are saved in a folder named swanlog under the working directory.
438+
- **SWANLAB_MODE** (`Literal["local", "cloud", "disabled"]`, *optional*, defaults to `cloud`):
439+
SwanLab's parsing mode, which involves callbacks registered by the operator. Currently, there are three modes:
440+
local, cloud, and disabled. Note: Case-sensitive. Find more information
441+
[here](https://docs.swanlab.cn/en/api/py-init.html#swanlab-init)
442+
- **SWANLAB_LOG_MODEL** (`str`, *optional*, defaults to `None`):
443+
SwanLab does not currently support the save mode functionality.This feature will be available in a future
444+
release
445+
- **SWANLAB_WEB_HOST** (`str`, *optional*, defaults to `None`):
446+
Web address for the SwanLab cloud environment for private version (its free)
447+
- **SWANLAB_API_HOST** (`str`, *optional*, defaults to `None`):
448+
API address for the SwanLab cloud environment for private version (its free)
431449
"""
432-
433-
if self._swanlab is None:
434-
return
435-
436450
self._initialized = True
437451

438452
if state.is_world_process_zero:
439-
logger.info('Automatic Swanlab logging enabled, to disable set os.environ["SWANLAB_MODE"] = "disabled"')
440-
453+
logger.info('Automatic SwanLab logging enabled, to disable set os.environ["SWANLAB_MODE"] = "disabled"')
441454
combined_dict = {**args.to_dict()}
442455

443456
if hasattr(model, "config") and model.config is not None:
444-
model_config = model.config.to_dict()
457+
model_config = model.config if isinstance(model.config, dict) else model.config.to_dict()
445458
combined_dict = {**model_config, **combined_dict}
446-
459+
if hasattr(model, "lora_config") and model.lora_config is not None:
460+
lora_config = model.lora_config if isinstance(model.lora_config, dict) else model.lora_config.to_dict()
461+
combined_dict = {**{"lora_config": lora_config}, **combined_dict}
447462
trial_name = state.trial_name
448463
init_args = {}
449-
if trial_name is not None:
450-
init_args["name"] = trial_name
451-
init_args["group"] = args.run_name
452-
else:
453-
if not (args.run_name is None or args.run_name == args.output_dir):
454-
init_args["name"] = args.run_name
455-
init_args["dir"] = args.logging_dir
464+
if trial_name is not None and args.run_name is not None:
465+
init_args["experiment_name"] = f"{args.run_name}-{trial_name}"
466+
elif args.run_name is not None:
467+
init_args["experiment_name"] = args.run_name
468+
elif trial_name is not None:
469+
init_args["experiment_name"] = trial_name
470+
471+
# new add this for experiment_name
472+
experiment_name = os.getenv("SWANLAB_EXP_NAME", None)
473+
if experiment_name is not None:
474+
init_args["experiment_name"] = experiment_name
475+
476+
init_args["project"] = os.getenv("SWANLAB_PROJECT", None)
477+
if args.logging_dir is not None:
478+
init_args["logdir"] = os.getenv("SWANLAB_LOG_DIR", args.logging_dir)
479+
456480
if self._swanlab.get_run() is None:
457481
self._swanlab.init(
458-
project=os.getenv("SWANLAB_PROJECT", "PaddleNLP"),
459482
**init_args,
460483
)
461-
self._swanlab.config.update(combined_dict, allow_val_change=True)
484+
# show paddlenlp logo!
485+
self._swanlab.config["FRAMEWORK"] = "paddlenlp"
486+
# add config parameters (run may have been created manually)
487+
self._swanlab.config.update(combined_dict)
462488

463489
def on_train_begin(self, args, state, control, model=None, **kwargs):
464-
if self._swanlab is None:
465-
return
466490
if not self._initialized:
467491
self.setup(args, state, model, **kwargs)
468492

469-
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
470-
if self._swanlab is None:
471-
return
493+
def on_train_end(self, args, state, control, model=None, processing_class=None, **kwargs):
494+
if self._log_model is not None and self._initialized and state.is_world_process_zero:
495+
logger.warning(
496+
"SwanLab does not currently support the save mode functionality. "
497+
"This feature will be available in a future release."
498+
)
472499

473500
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
474-
if self._swanlab is None:
475-
return
501+
single_value_scalars = [
502+
"train_runtime",
503+
"train_samples_per_second",
504+
"train_steps_per_second",
505+
"train_loss",
506+
"total_flos",
507+
]
508+
476509
if not self._initialized:
477510
self.setup(args, state, model)
478511
if state.is_world_process_zero:
479-
logs = rewrite_logs(logs)
480-
self._swanlab.log({**logs, "train/global_step": state.global_step}, step=state.global_step)
512+
for k, v in logs.items():
513+
if k in single_value_scalars:
514+
self._swanlab.log({f"single_value/{k}": v}, step=state.global_step)
515+
non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars}
516+
non_scalar_logs = rewrite_logs(non_scalar_logs)
517+
self._swanlab.log({**non_scalar_logs, "train/global_step": state.global_step}, step=state.global_step)
518+
519+
def on_save(self, args, state, control, **kwargs):
520+
if self._log_model is not None and self._initialized and state.is_world_process_zero:
521+
logger.warning(
522+
"SwanLab does not currently support the save mode functionality. "
523+
"This feature will be available in a future release."
524+
)
525+
526+
def on_predict(self, args, state, control, metrics, **kwargs):
527+
if not self._initialized:
528+
self.setup(args, state, **kwargs)
529+
if state.is_world_process_zero:
530+
metrics = rewrite_logs(metrics)
531+
self._swanlab.log(metrics)
481532

482533

483534
class AutoNLPCallback(TrainerCallback):

0 commit comments

Comments
 (0)