Skip to content

Commit caf5cb6

Browse files
authored
add e2e smoke test for using activation/gradient checkpointing with offload (axolotl-ai-cloud#2565)
* add e2e smoke test for using activation/gradient checkpointing with offload * disable duplicate code check for the test * fix relative import * seq len too small to test this dataset with packing * Fix checkpoint ptaching for tests
1 parent 5dba5c8 commit caf5cb6

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""
2+
E2E tests for activation checkpointing
3+
"""
4+
5+
import pytest
6+
import transformers
7+
from torch.utils.checkpoint import checkpoint
8+
9+
from axolotl.cli.args import TrainerCliArgs
10+
from axolotl.common.datasets import load_datasets
11+
from axolotl.train import train
12+
from axolotl.utils.config import normalize_config, validate_config
13+
from axolotl.utils.dict import DictDefault
14+
15+
from ..utils import check_model_output_exists
16+
17+
18+
@pytest.fixture()
19+
def fix_checkpoint_after_test():
20+
yield
21+
transformers.modeling_utils.checkpoint = checkpoint
22+
23+
24+
class TestActivationCheckpointing:
25+
"""
26+
E2E tests for activation checkpointing
27+
"""
28+
29+
def test_activation_checkpointing_offload(
30+
self,
31+
temp_dir,
32+
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
33+
):
34+
# pylint: disable=duplicate-code
35+
cfg = DictDefault(
36+
{
37+
"base_model": "HuggingFaceTB/SmolLM2-135M",
38+
"sequence_len": 1024,
39+
"val_set_size": 0.0,
40+
"special_tokens": {
41+
"pad_token": "<|endoftext|>",
42+
"eos_token": "<|im_end|>",
43+
},
44+
"datasets": [
45+
{
46+
"chat_template": "chatml",
47+
"path": "mlabonne/FineTome-100k",
48+
"type": "chat_template",
49+
"split": "train[:10%]",
50+
"field_messages": "conversations",
51+
"message_field_role": "from",
52+
"message_field_content": "value",
53+
},
54+
],
55+
"num_epochs": 1,
56+
"max_steps": 5,
57+
"micro_batch_size": 1,
58+
"gradient_accumulation_steps": 1,
59+
"output_dir": temp_dir,
60+
"learning_rate": 0.00001,
61+
"optimizer": "adamw_8bit",
62+
"lr_scheduler": "cosine",
63+
"flash_attention": True,
64+
"sample_packing": True,
65+
"bf16": True,
66+
"save_safetensors": True,
67+
"gradient_checkpointing": "offload",
68+
}
69+
)
70+
71+
cfg = validate_config(cfg)
72+
normalize_config(cfg)
73+
cli_args = TrainerCliArgs()
74+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
75+
76+
train(cfg=cfg, dataset_meta=dataset_meta)
77+
check_model_output_exists(temp_dir, cfg)

0 commit comments

Comments
 (0)