From 9d5bc4f7fd616eebf5ce1b53ac2d12df3cb6fec6 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Tue, 16 Sep 2025 14:10:44 -0700 Subject: [PATCH] [Example] export example --- run_train.sh | 6 +- torchtitan/models/deepseek_v3/__init__.py | 8 +- .../train_configs/debug_model.toml | 19 ++--- torchtitan/models/llama3/infra/pipeline.py | 75 +++++++++++++++++++ torchtitan/train.py | 3 +- 5 files changed, 95 insertions(+), 16 deletions(-) diff --git a/run_train.sh b/run_train.sh index 8aaf55de28..48c2d3770f 100755 --- a/run_train.sh +++ b/run_train.sh @@ -10,8 +10,10 @@ set -ex # use envs as local overwrites for convenience # e.g. # LOG_RANK=0,1 NGPU=4 ./run_train.sh -NGPU=${NGPU:-"8"} -export LOG_RANK=${LOG_RANK:-0} +# NGPU=${NGPU:-"8"} +NGPU=${NGPU:-"4"} +# export LOG_RANK=${LOG_RANK:-0,1,2,3,4,5,6,7} +export LOG_RANK=${LOG_RANK:-0,1,2,3} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"} TRAIN_FILE=${TRAIN_FILE:-"torchtitan.train"} diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index d0478cc961..8e84e9b7af 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -11,7 +11,7 @@ from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.datasets.hf_datasets import build_hf_dataloader -from torchtitan.models.llama3.infra.pipeline import pipeline_llama +from torchtitan.models.llama3.infra.pipeline import pipeline_llama, pipeline_llama_tracer from torchtitan.models.moe import MoEArgs from torchtitan.protocols.train_spec import register_train_spec, TrainSpec @@ -32,10 +32,10 @@ deepseekv3_configs = { "debugmodel": DeepSeekV3ModelArgs( vocab_size=2000, - dim=256, + dim=4, inter_dim=1024, moe_inter_dim=256, - n_layers=6, + n_layers=16, n_dense_layers=1, n_heads=16, moe_args=MoEArgs( @@ -166,7 +166,7 @@ model_cls=DeepSeekV3Model, model_args=deepseekv3_configs, parallelize_fn=parallelize_deepseekv3, - pipelining_fn=pipeline_llama, + pipelining_fn=pipeline_llama_tracer, build_optimizers_fn=build_optimizers_with_moe_load_balancing, build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_hf_dataloader, diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index 4314e9905f..908d064c73 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -4,9 +4,9 @@ description = "DeepSeek-V3 debug training" print_args = false [profiling] -enable_profiling = false +enable_profiling = true save_traces_folder = "profile_trace" -profile_freq = 10 +profile_freq = 5 enable_memory_snapshot = false save_memory_snapshot_folder = "memory_snapshot" @@ -36,11 +36,12 @@ decay_type = "linear" min_lr_factor = 0.0 [training] -local_batch_size = 8 -seq_len = 2048 +local_batch_size = 10 +seq_len = 4 max_norm = 1.0 # grad norm clipping -steps = 10 +steps = 6 dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) +# dataset = "c4" [parallelism] data_parallel_replicate_degree = 1 @@ -48,10 +49,10 @@ data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always tensor_parallel_degree = 1 enable_async_tensor_parallel = false -pipeline_parallel_degree = 1 -pipeline_parallel_schedule = "1F1B" +pipeline_parallel_degree = 2 +expert_parallel_degree = 2 context_parallel_degree = 1 -expert_parallel_degree = 1 +pipeline_parallel_schedule = "DualPipeV" expert_tensor_parallel_degree = 1 [checkpoint] @@ -63,7 +64,7 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] -mode = "selective" # ["none", "selective", "full"] +mode = "none" # ["none", "selective", "full"] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [compile] diff --git a/torchtitan/models/llama3/infra/pipeline.py b/torchtitan/models/llama3/infra/pipeline.py index 8741b2eef4..c18dcda8f3 100644 --- a/torchtitan/models/llama3/infra/pipeline.py +++ b/torchtitan/models/llama3/infra/pipeline.py @@ -25,6 +25,9 @@ pipeline_module_split, ) +from torch.distributed.pipelining import SplitPoint, pipeline +from torch.distributed.pipelining.stage import _PipelineStage + from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction from torchtitan.tools.logging import logger @@ -148,3 +151,75 @@ def pipeline_llama( has_last_stage = True return pp_schedule, model_parts, has_first_stage, has_last_stage + + +def pipeline_llama_tracer( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, + device: torch.device, + model_args: BaseModelArgs, + parallelize_fn: ParallelizeFunction, + loss_fn: LossFunction, +): + assert ( + parallel_dims.pp_enabled + ), "can't apply pipeline parallelism if it is not enabled" + + # if job_config.model.norm_type == "fused_rmsnorm": + # # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode + # # coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm + # raise NotImplementedError( + # "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm." + # ) + pp_mesh = parallel_dims.world_mesh["pp"] + pp_rank = pp_mesh.get_local_rank() + stage_idx = pp_mesh.get_local_rank() + layers_per_rank = model_args.n_layers // parallel_dims.pp + split_spec = { + f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING + for i in range(1, parallel_dims.pp) + } + # Get example input + input_shape = (job_config.training.local_batch_size, job_config.training.seq_len) + assert hasattr(model_args, "vocab_size") + input_ids = torch.randint( + model_args.vocab_size, input_shape, dtype=torch.int64, device="meta" + ) + + # Create a pipeline representation from the model + pipe = pipeline( + model, mb_args=(input_ids,), split_spec=split_spec + ) + model = pipe.get_stage_module(stage_idx) + stage = _PipelineStage( + stage_module=model, + stage_index=pp_rank, + pipe_info=pipe.pipe_info, + device=device, + group=pp_mesh.get_group(), + ) + + # For PP with looped schedules, each item in model_parts is one stage-model-chunk. + # We need to iterate through model_parts to apply SPMD parallelisms, compilation, + # optimizer, and checkpointing + for i, m in enumerate(model_parts): + # apply SPMD-style PT-D techniques + m = parallelize_fn(m, parallel_dims, job_config) + model_parts[i] = m + # NOTE: this is to update the model in the stage + # in case the model is modified e.g. by torch.compile + stages[i].submod = m + + pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) + + # This is used in the train loop to determine whether to pass in the input_ids and labels + has_first_stage = False + has_last_stage = False + for stage in stages: + if stage.is_first: + has_first_stage = True + if stage.is_last: + has_last_stage = True + + return pp_schedule, model_parts, has_first_stage, has_last_stage diff --git a/torchtitan/train.py b/torchtitan/train.py index 7c49000774..c97eef5766 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -636,7 +636,8 @@ def close(self) -> None: if self.metrics_processor: self.metrics_processor.close() - +import fbvscode +fbvscode.attach_debugger() if __name__ == "__main__": init_logger() config_manager = ConfigManager()