|
| 1 | +""" |
| 2 | +E2E tests for activation checkpointing |
| 3 | +""" |
| 4 | + |
| 5 | +import pytest |
| 6 | +import transformers |
| 7 | +from torch.utils.checkpoint import checkpoint |
| 8 | + |
| 9 | +from axolotl.cli.args import TrainerCliArgs |
| 10 | +from axolotl.common.datasets import load_datasets |
| 11 | +from axolotl.train import train |
| 12 | +from axolotl.utils.config import normalize_config, validate_config |
| 13 | +from axolotl.utils.dict import DictDefault |
| 14 | + |
| 15 | +from ..utils import check_model_output_exists |
| 16 | + |
| 17 | + |
| 18 | +@pytest.fixture() |
| 19 | +def fix_checkpoint_after_test(): |
| 20 | + yield |
| 21 | + transformers.modeling_utils.checkpoint = checkpoint |
| 22 | + |
| 23 | + |
| 24 | +class TestActivationCheckpointing: |
| 25 | + """ |
| 26 | + E2E tests for activation checkpointing |
| 27 | + """ |
| 28 | + |
| 29 | + def test_activation_checkpointing_offload( |
| 30 | + self, |
| 31 | + temp_dir, |
| 32 | + fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name |
| 33 | + ): |
| 34 | + # pylint: disable=duplicate-code |
| 35 | + cfg = DictDefault( |
| 36 | + { |
| 37 | + "base_model": "HuggingFaceTB/SmolLM2-135M", |
| 38 | + "sequence_len": 1024, |
| 39 | + "val_set_size": 0.0, |
| 40 | + "special_tokens": { |
| 41 | + "pad_token": "<|endoftext|>", |
| 42 | + "eos_token": "<|im_end|>", |
| 43 | + }, |
| 44 | + "datasets": [ |
| 45 | + { |
| 46 | + "chat_template": "chatml", |
| 47 | + "path": "mlabonne/FineTome-100k", |
| 48 | + "type": "chat_template", |
| 49 | + "split": "train[:10%]", |
| 50 | + "field_messages": "conversations", |
| 51 | + "message_field_role": "from", |
| 52 | + "message_field_content": "value", |
| 53 | + }, |
| 54 | + ], |
| 55 | + "num_epochs": 1, |
| 56 | + "max_steps": 5, |
| 57 | + "micro_batch_size": 1, |
| 58 | + "gradient_accumulation_steps": 1, |
| 59 | + "output_dir": temp_dir, |
| 60 | + "learning_rate": 0.00001, |
| 61 | + "optimizer": "adamw_8bit", |
| 62 | + "lr_scheduler": "cosine", |
| 63 | + "flash_attention": True, |
| 64 | + "sample_packing": True, |
| 65 | + "bf16": True, |
| 66 | + "save_safetensors": True, |
| 67 | + "gradient_checkpointing": "offload", |
| 68 | + } |
| 69 | + ) |
| 70 | + |
| 71 | + cfg = validate_config(cfg) |
| 72 | + normalize_config(cfg) |
| 73 | + cli_args = TrainerCliArgs() |
| 74 | + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) |
| 75 | + |
| 76 | + train(cfg=cfg, dataset_meta=dataset_meta) |
| 77 | + check_model_output_exists(temp_dir, cfg) |
0 commit comments