Skip to content

Commit 7388b60

Browse files
committed
feat: add swanlabcallback
1 parent 6de50e6 commit 7388b60

File tree

3 files changed

+125
-1
lines changed

3 files changed

+125
-1
lines changed

paddlenlp/trainer/integrations.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def is_wandb_available():
4242
return False
4343
return importlib.util.find_spec("wandb") is not None
4444

45+
def is_swanlab_available():
46+
return importlib.util.find_spec("swanlab") is not None
4547

4648
def is_ray_available():
4749
return importlib.util.find_spec("ray.air") is not None
@@ -55,6 +57,8 @@ def get_available_reporting_integrations():
5557
integrations.append("wandb")
5658
if is_tensorboardX_available():
5759
integrations.append("tensorboard")
60+
if is_swanlab_available():
61+
integrations.append("swanlab")
5862

5963
return integrations
6064

@@ -395,6 +399,90 @@ def on_save(self, args, state, control, **kwargs):
395399
self._wandb.log_artifact(artifact, aliases=[f"checkpoint-{state.global_step}"])
396400

397401

402+
class SwanLabCallback(TrainerCallback):
403+
"""
404+
A [`TrainerCallback`] that logs metrics, media to [Swanlab](https://swanlab.com/).
405+
"""
406+
407+
def __init__(self):
408+
has_swanlab = is_swanlab_available()
409+
if not has_swanlab:
410+
raise RuntimeError("SwanlabCallback requires swanlab to be installed. Run `pip install swanlab`.")
411+
if has_swanlab:
412+
import swanlab
413+
414+
self._swanlab = swanlab
415+
416+
self._initialized = False
417+
418+
def setup(self, args, state, model, **kwargs):
419+
"""
420+
Setup the optional Swanlab integration.
421+
422+
One can subclass and override this method to customize the setup if needed.
423+
variables:
424+
Environment:
425+
- **SWANLAB_MODE** (`str`, *optional*, defaults to `"cloud"`):
426+
Whether to use swanlab cloud, local or disabled. Set `SWANLAB_MODE="local"` to use local. Set `SWANLAB_MODE="disabled"` to disable.
427+
- **SWANLAB_PROJECT** (`str`, *optional*, defaults to `"PaddleNLP"`):
428+
Set this to a custom string to store results in a different project.
429+
"""
430+
431+
if self._swanlab is None:
432+
return
433+
434+
if args.swanlab_api_key:
435+
self._swanlab.login(api_key=args.swanlab_api_key)
436+
437+
self._initialized = True
438+
439+
if state.is_world_process_zero:
440+
logger.info(
441+
'Automatic Swanlab logging enabled, to disable set os.environ["SWANLAB_MODE"] = "disabled"'
442+
)
443+
444+
combined_dict = {**args.to_dict()}
445+
446+
if hasattr(model, "config") and model.config is not None:
447+
model_config = model.config.to_dict()
448+
combined_dict = {**model_config, **combined_dict}
449+
450+
trial_name = state.trial_name
451+
init_args = {}
452+
if trial_name is not None:
453+
init_args["name"] = trial_name
454+
init_args["group"] = args.run_name
455+
else:
456+
if not (args.run_name is None or args.run_name == args.output_dir):
457+
init_args["name"] = args.run_name
458+
init_args["dir"] = args.logging_dir
459+
if self._swanlab.run is None:
460+
self._swanlab.init(
461+
project=os.getenv("SWANLAB_PROJECT", "PaddleNLP"),
462+
**init_args,
463+
)
464+
self._swanlab.config.update(combined_dict, allow_val_change=True)
465+
466+
def on_train_begin(self, args, state, control, model=None, **kwargs):
467+
if self._swanlab is None:
468+
return
469+
if not self._initialized:
470+
self.setup(args, state, model, **kwargs)
471+
472+
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
473+
if self._swanlab is None:
474+
return
475+
476+
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
477+
if self._swanlab is None:
478+
return
479+
if not self._initialized:
480+
self.setup(args, state, model)
481+
if state.is_world_process_zero:
482+
logs = rewrite_logs(logs)
483+
self._swanlab.log({**logs, "train/global_step": state.global_step}, step=state.global_step)
484+
485+
398486
class AutoNLPCallback(TrainerCallback):
399487
"""
400488
A [`TrainerCallback`] that sends the logs to [`Ray Tune`] for [`AutoNLP`]
@@ -423,6 +511,7 @@ def on_evaluate(self, args, state, control, **kwargs):
423511
"autonlp": AutoNLPCallback,
424512
"wandb": WandbCallback,
425513
"tensorboard": TensorBoardCallback,
514+
"swanlab": SwanLabCallback,
426515
}
427516

428517

paddlenlp/trainer/training_args.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ class TrainingArguments:
376376
instance of `Dataset`.
377377
report_to (`str` or `List[str]`, *optional*, defaults to `"visualdl"`):
378378
The list of integrations to report the results and logs to.
379-
Supported platforms are `"visualdl"`/`"wandb"`/`"tensorboard"`.
379+
Supported platforms are `"visualdl"`/`"wandb"`/`"tensorboard"`/`"swanlab"`.
380380
`"none"` for no integrations.
381381
ddp_find_unused_parameters (`bool`, *optional*):
382382
When using distributed training, the value of the flag `find_unused_parameters` passed to
@@ -385,6 +385,8 @@ class TrainingArguments:
385385
Weights & Biases (WandB) API key(s) for authentication with the WandB service.
386386
wandb_http_proxy (`str`, *optional*):
387387
Weights & Biases (WandB) http proxy for connecting with the WandB service.
388+
swanlab_api_key (`str`, *optional*):
389+
Swanlab API key for authentication with the Swanlab service.
388390
resume_from_checkpoint (`str`, *optional*):
389391
The path to a folder with a valid checkpoint for your model. This argument is not directly used by
390392
[`Trainer`], it's intended to be used by your training/evaluation scripts instead. See the [example
@@ -888,6 +890,10 @@ class TrainingArguments:
888890
default=None,
889891
metadata={"help": "Weights & Biases (WandB) http proxy for connecting with the WandB service."},
890892
)
893+
swanlab_api_key: Optional[str] = field(
894+
default=None,
895+
metadata={"help": "Swanlab API key for authentication with the Swanlab service."},
896+
)
891897
resume_from_checkpoint: Optional[str] = field(
892898
default=None,
893899
metadata={"help": "The path to a folder with a valid checkpoint for your model."},

tests/trainer/test_trainer_visualization.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
TensorBoardCallback,
2626
VisualDLCallback,
2727
WandbCallback,
28+
SwanLabCallback,
2829
)
2930
from tests.trainer.trainer_utils import RegressionModelConfig, RegressionPretrainedModel
3031

@@ -65,6 +66,34 @@ def test_wandbcallback(self):
6566
os.environ.pop("WANDB_MODE", None)
6667
shutil.rmtree(output_dir)
6768

69+
class TestSwanlabCallback(unittest.TestCase):
70+
def test_swanlabcallback(self):
71+
output_dir = tempfile.mkdtemp()
72+
args = TrainingArguments(
73+
output_dir=output_dir,
74+
max_steps=200,
75+
logging_steps=20,
76+
run_name="test_swanlabcallback",
77+
logging_dir=output_dir,
78+
)
79+
state = TrainerState(trial_name="PaddleNLP")
80+
control = TrainerControl()
81+
config = RegressionModelConfig(a=1, b=1)
82+
model = RegressionPretrainedModel(config)
83+
os.environ["SWANLAB_MODE"] = "disabled"
84+
swanlabcallback = SwanLabCallback()
85+
self.assertFalse(swanlabcallback._initialized)
86+
swanlabcallback.on_train_begin(args, state, control)
87+
self.assertTrue(swanlabcallback._initialized)
88+
for global_step in range(args.max_steps):
89+
state.global_step = global_step
90+
if global_step % args.logging_steps == 0:
91+
log = {"loss": 100 - 0.4 * global_step, "learning_rate": 0.1, "global_step": global_step}
92+
swanlabcallback.on_log(args, state, control, logs=log)
93+
swanlabcallback.on_train_end(args, state, control, model=model)
94+
swanlabcallback._swanlab.finish()
95+
os.environ.pop("SWANLAB_MODE", None)
96+
shutil.rmtree(output_dir)
6897

6998
class TestTensorboardCallback(unittest.TestCase):
7099
def test_tbcallback(self):

0 commit comments

Comments
 (0)