Skip to content

Commit e016840

Browse files
test: add a unit test for additional_callbacks param of train()
Signed-off-by: Vassilis Vassiliadis <vassilis.vassiliadis@ibm.com>
1 parent 110c26e commit e016840

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

tests/test_sft_trainer.py

+19
Original file line numberDiff line numberDiff line change
@@ -616,3 +616,22 @@ def test_bad_torch_dtype():
616616

617617
with pytest.raises(ValueError):
618618
sft_trainer.train(model_args, DATA_ARGS, train_args, PEFT_PT_ARGS)
619+
620+
621+
def test_run_with_additional_callbacks():
622+
"""Ensure that train() can work with additional_callbacks"""
623+
# Third Party
624+
from transformers.trainer_callback import TrainerCallback
625+
626+
with tempfile.TemporaryDirectory() as tempdir:
627+
train_args = copy.deepcopy(TRAIN_ARGS)
628+
train_args.output_dir = tempdir
629+
model_args = copy.deepcopy(MODEL_ARGS)
630+
631+
sft_trainer.train(
632+
model_args,
633+
DATA_ARGS,
634+
train_args,
635+
PEFT_PT_ARGS,
636+
additional_callbacks=[TrainerCallback()],
637+
)

0 commit comments

Comments
 (0)