diff --git a/run_train.sh b/run_train.sh index 8aaf55de28..a1522719f3 100755 --- a/run_train.sh +++ b/run_train.sh @@ -10,8 +10,11 @@ set -ex # use envs as local overwrites for convenience # e.g. # LOG_RANK=0,1 NGPU=4 ./run_train.sh -NGPU=${NGPU:-"8"} -export LOG_RANK=${LOG_RANK:-0} +# NGPU=${NGPU:-"8"} +NGPU=${NGPU:-"4"} +# export LOG_RANK=${LOG_RANK:-0,1,2,3,4,5,6,7} +# export LOG_RANK=${LOG_RANK:-0,1,2,3} +export LOG_RANK=${LOG_RANK:-0,1,2,3} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"} TRAIN_FILE=${TRAIN_FILE:-"torchtitan.train"} diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index e0189c9bb3..87acd2393f 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -626,7 +626,7 @@ class Comm: init_timeout_seconds: int = 300 """Timeout for communication operations, during initialization and first train step.""" - train_timeout_seconds: int = 100 + train_timeout_seconds: int = 30 """ Timeout for communication operations after the first train step -- usually a tighter bound than during initialization. diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index 567df051cf..08a885c161 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -5,7 +5,8 @@ # LICENSE file in the root directory of this source tree. -from typing import Callable, Literal +from typing import Callable, Literal, Optional +import threading import torch import torch.nn as nn @@ -22,6 +23,86 @@ ) from torch.distributed.tensor.parallel import ParallelStyle +class HookCoordinator: + def __init__(self): + # Barrier for 2 threads (forward and backward) to synchronize + # This ensures that we always alternate at executing one compute and one comm op together + self._execution_barrier = threading.Barrier(2) + + self._coordination_enabled = False + self._cycle_count = 0 + self._num_layers = None + + def barrier(self): + """Barrier for 2 threads to synchronize""" + if not self.is_coordination_enabled(): + return + + try: + self._execution_barrier.wait() + print(f"Both threads ready, proceeding") + except threading.BrokenBarrierError: + print(f"Barrier broken - one thread has finished!") + + def enable_coordination(self, num_layers: Optional[int] = None): + if num_layers is not None and num_layers > 0: + self._coordination_enabled = True + self._cycle_count = 0 + + # Reset barrier + self._execution_barrier = threading.Barrier(2) + + self._num_layers = num_layers + print(f"Compute/Comm hook coordination ENABLED with {num_layers} MoE layers") + + def disable_coordination(self): + self._coordination_enabled = False + self._cycle_count = 0 + self._execution_barrier.abort() # Break barrier to unblock threads + print("[COORDINATION] Compute/Comm hook coordination DISABLED") + + def check_should_continue_coordination(self): + if self._num_layers is not None and self._cycle_count >= self._num_layers: + print("[COORDINATION] Reached target number of cycles, disabling coordination") + return False + return True + + def is_coordination_enabled(self): + return self._coordination_enabled + +# Global coordinator +_hook_coordinator = HookCoordinator() + +class SyncHook(torch.autograd.Function): + @staticmethod + def forward(ctx, x, hook_name=""): + ctx.hook_name = hook_name + # handle edge case for transformer level boundary + if _hook_coordinator._coordination_enabled and hook_name == "D": + _hook_coordinator._cycle_count += 1 + print(f"[FORWARD] cycle count: {_hook_coordinator._cycle_count}", "=" * 40) + if not _hook_coordinator.check_should_continue_coordination(): + _hook_coordinator.disable_coordination() + return x + + _hook_coordinator.barrier() + + if _hook_coordinator.is_coordination_enabled(): + print(f"[FORWARD] finished {hook_name}_fwd") + return x + + @staticmethod + def backward(ctx, grad_output): + hook_name = ctx.hook_name + + # Edge case, skip initial barrier, all subsequent backward hooks will acquire + if hook_name == "D" and _hook_coordinator._cycle_count == 0: + return grad_output, None + + _hook_coordinator.barrier() + if _hook_coordinator.is_coordination_enabled(): + print(f"[BACKWARD] finished {hook_name}_bwd") + return grad_output, None TOKEN_GROUP_ALIGN_SIZE_M = 8 ValidTokenGroupAlignmentSize = Literal[8, 16, 32] @@ -89,7 +170,6 @@ def _token_dispatch(self, mod, inputs, device_mesh): # annotate module input placements/sharding with input_layouts routed_input, num_tokens_per_expert = inputs ep_size = device_mesh.shape[0] - # generate the input splits and output splits for all-to-all with torch.no_grad(): num_tokens_per_expert_group = all_to_all_single( @@ -155,13 +235,39 @@ def _token_combine(self, mod, routed_output, device_mesh): return routed_output def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: - return distribute_module( - module, + """ + hooks are called in the order they are registered: + A, dispatch, B (pre hooks) + C, combine, D (post hooks) + """ + inner_wrapped_module = self._wrap_with_inner_hooks(module) + distributed_module = distribute_module( + inner_wrapped_module, device_mesh, partition_fn=ExpertParallel._partition_fn, input_fn=self._token_dispatch, output_fn=self._token_combine, ) + final_module = self._wrap_with_outer_hooks(distributed_module) + return final_module + + def _wrap_with_inner_hooks(self, module): + def inner_pre_hook(module, input): + return (SyncHook.apply(input[0], "A"),) + input[1:] + def inner_post_hook(module, input, output): + return SyncHook.apply(output, "C") + module.register_forward_pre_hook(inner_pre_hook) + module.register_forward_hook(inner_post_hook) + return module + + def _wrap_with_outer_hooks(self, module): + def outer_pre_hook(module, input): + return (SyncHook.apply(input[0], "B"),) + input[1:] + def outer_post_hook(module, input, output): + return SyncHook.apply(output, "D") + module.register_forward_pre_hook(outer_pre_hook) + module.register_forward_hook(outer_post_hook) + return module # This class is for dp2ep with TP (without TP we can just use ExpertParallel) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index d0478cc961..990e66f61d 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -11,7 +11,7 @@ from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.datasets.hf_datasets import build_hf_dataloader -from torchtitan.models.llama3.infra.pipeline import pipeline_llama +from torchtitan.models.llama3.infra.pipeline import pipeline_llama, pipeline_llama_tracer from torchtitan.models.moe import MoEArgs from torchtitan.protocols.train_spec import register_train_spec, TrainSpec @@ -32,7 +32,8 @@ deepseekv3_configs = { "debugmodel": DeepSeekV3ModelArgs( vocab_size=2000, - dim=256, + # needs at least dim 8? + dim=16, inter_dim=1024, moe_inter_dim=256, n_layers=6, diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index 4314e9905f..9e67848402 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -4,9 +4,9 @@ description = "DeepSeek-V3 debug training" print_args = false [profiling] -enable_profiling = false +enable_profiling = true save_traces_folder = "profile_trace" -profile_freq = 10 +profile_freq = 1 enable_memory_snapshot = false save_memory_snapshot_folder = "memory_snapshot" @@ -30,17 +30,18 @@ lr = 8e-4 eps = 1e-8 [lr_scheduler] -warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +warmup_steps = 0 # lr scheduler warm up, normally 20% of the train steps decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps decay_type = "linear" min_lr_factor = 0.0 [training] -local_batch_size = 8 -seq_len = 2048 +local_batch_size = 4 +seq_len = 4 max_norm = 1.0 # grad norm clipping -steps = 10 +steps = 6 dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) +# dataset = "c4" [parallelism] data_parallel_replicate_degree = 1 @@ -48,10 +49,10 @@ data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always tensor_parallel_degree = 1 enable_async_tensor_parallel = false -pipeline_parallel_degree = 1 -pipeline_parallel_schedule = "1F1B" +pipeline_parallel_degree = 2 +expert_parallel_degree = 2 context_parallel_degree = 1 -expert_parallel_degree = 1 +pipeline_parallel_schedule = "DualPipeV" expert_tensor_parallel_degree = 1 [checkpoint] @@ -63,7 +64,7 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] -mode = "selective" # ["none", "selective", "full"] +mode = "none" # ["none", "selective", "full"] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [compile] diff --git a/torchtitan/models/llama3/infra/pipeline.py b/torchtitan/models/llama3/infra/pipeline.py index 8741b2eef4..c18dcda8f3 100644 --- a/torchtitan/models/llama3/infra/pipeline.py +++ b/torchtitan/models/llama3/infra/pipeline.py @@ -25,6 +25,9 @@ pipeline_module_split, ) +from torch.distributed.pipelining import SplitPoint, pipeline +from torch.distributed.pipelining.stage import _PipelineStage + from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction from torchtitan.tools.logging import logger @@ -148,3 +151,75 @@ def pipeline_llama( has_last_stage = True return pp_schedule, model_parts, has_first_stage, has_last_stage + + +def pipeline_llama_tracer( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, + device: torch.device, + model_args: BaseModelArgs, + parallelize_fn: ParallelizeFunction, + loss_fn: LossFunction, +): + assert ( + parallel_dims.pp_enabled + ), "can't apply pipeline parallelism if it is not enabled" + + # if job_config.model.norm_type == "fused_rmsnorm": + # # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode + # # coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm + # raise NotImplementedError( + # "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm." + # ) + pp_mesh = parallel_dims.world_mesh["pp"] + pp_rank = pp_mesh.get_local_rank() + stage_idx = pp_mesh.get_local_rank() + layers_per_rank = model_args.n_layers // parallel_dims.pp + split_spec = { + f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING + for i in range(1, parallel_dims.pp) + } + # Get example input + input_shape = (job_config.training.local_batch_size, job_config.training.seq_len) + assert hasattr(model_args, "vocab_size") + input_ids = torch.randint( + model_args.vocab_size, input_shape, dtype=torch.int64, device="meta" + ) + + # Create a pipeline representation from the model + pipe = pipeline( + model, mb_args=(input_ids,), split_spec=split_spec + ) + model = pipe.get_stage_module(stage_idx) + stage = _PipelineStage( + stage_module=model, + stage_index=pp_rank, + pipe_info=pipe.pipe_info, + device=device, + group=pp_mesh.get_group(), + ) + + # For PP with looped schedules, each item in model_parts is one stage-model-chunk. + # We need to iterate through model_parts to apply SPMD parallelisms, compilation, + # optimizer, and checkpointing + for i, m in enumerate(model_parts): + # apply SPMD-style PT-D techniques + m = parallelize_fn(m, parallel_dims, job_config) + model_parts[i] = m + # NOTE: this is to update the model in the stage + # in case the model is modified e.g. by torch.compile + stages[i].submod = m + + pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) + + # This is used in the train loop to determine whether to pass in the input_ids and labels + has_first_stage = False + has_last_stage = False + for stage in stages: + if stage.is_first: + has_first_stage = True + if stage.is_last: + has_last_stage = True + + return pp_schedule, model_parts, has_first_stage, has_last_stage diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index 8be14ecbf0..f6d4a888b8 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -65,7 +65,7 @@ def init_weights(self, init_std: float = 0.02): # TODO: keeping this for-loop implementation for comparison # and readability, may remove later -@expert_parallel +@expert_parallel # COMMUNICATION: This decorator handles AllToAll dispatch/combine in distributed settings def _run_experts_for_loop( w1: torch.Tensor, w2: torch.Tensor, @@ -101,7 +101,7 @@ def _run_experts_for_loop( return out -@expert_parallel +@expert_parallel # COMMUNICATION: This decorator handles AllToAll dispatch/combine in distributed settings def _run_experts_grouped_mm( w1: torch.Tensor, w2: torch.Tensor, @@ -370,6 +370,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: bs, slen, dim = x.shape x = x.view(-1, dim) + # ================================ + # COMPUTE PHASE 1: ROUTING + # ================================ + # Determine which tokens go to which experts based on routing scores + # top_scores and selected_experts_indices shape (bs*slen*top_k,) # num_tokens_per_expert shape (num_experts,) ( @@ -400,6 +405,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: num_tokens_per_expert, ) = self.reorderer(top_scores, selected_experts_indices) + # Prepare tokens for expert dispatch # shape (bs*slen*top_k, dim) token_indices_experts_sorted = token_indices_experts_sorted.reshape( -1, 1 @@ -414,21 +420,35 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: * top_scores_experts_sorted.reshape(-1, 1) ).to(x.dtype) + # =================================================================== + # COMMUNICATION PHASE: EXPERT DISPATCH + # =================================================================== + # DISPATCH: Send tokens to experts (potentially across different ranks) + # This call includes AllToAll communication in distributed settings + # The @expert_parallel decorator handles the communication automatically + # shape (bs*slen*top_k, dim) routed_output = self.experts(routed_input, num_tokens_per_expert) + # ================================ + # COMPUTE PHASE 2: COMBINE + # ================================ + # Combine expert outputs back into final result + if not self.score_before_experts: routed_output = ( routed_output.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) ).to(x.dtype) - # shared expert + # shared expert computation (runs locally on original input) if self.shared_experts is not None: out = self.shared_experts(x) else: out = torch.zeros_like(x) + # COMBINE: Aggregate expert outputs back to their original token positions + # This scatter_add operation combines the expert-processed tokens out = out.scatter_add( dim=0, index=token_indices_experts_sorted, src=routed_output ) diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index 0e851d335a..f71de78d66 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -15,7 +15,7 @@ from torchtitan.tools.logging import logger # the number of warmup steps before the active step in each profiling cycle -WARMUP = 3 +WARMUP = 0 # how much memory allocation/free ops to record in memory snapshots MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 diff --git a/torchtitan/train.py b/torchtitan/train.py index 7c49000774..c6d3246cc7 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -12,6 +12,7 @@ import torch from torch.distributed.elastic.multiprocessing.errors import record +from torch.profiler import record_function import torchtitan.protocols.train_spec as train_spec_module from torchtitan.components.checkpoint import CheckpointManager @@ -32,7 +33,17 @@ maybe_enable_memory_snapshot, maybe_enable_profiling, ) - +from torch.distributed.pipelining.schedules import ( + _Action, + _PipelineContext, + _PipelineScheduleRuntime, + _PipelineStageBase, + _wait_batch_p2p, + FORWARD, + OVERLAP_F_B, +) +import concurrent.futures +import threading class Trainer(torch.distributed.checkpoint.stateful.Stateful): # core configs @@ -432,6 +443,11 @@ def forward_backward_step( ) if parallel_dims.pp_enabled: + # register custom functions + assert isinstance(self.pp_schedule, _PipelineScheduleRuntime) + # self.pp_schedule.register_custom_function(FORWARD, forward_callback) + self.pp_schedule.register_custom_function(OVERLAP_F_B, overlap_callback) + # Pipeline Parallel forward / backward inside step() call with self.train_context(optional_context_parallel_ctx): targets, losses = ( @@ -485,15 +501,17 @@ def train_step( loss = self.forward_backward_step(input_dict, labels) accumulated_losses.append(loss.detach()) - grad_norm = dist_utils.clip_grad_norm_( - [p for m in self.model_parts for p in m.parameters()], - self.job_config.training.max_norm, - foreach=True, - pp_mesh=( - parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None - ), - ep_enabled=parallel_dims.ep_enabled, - ) + # TODO: parameters are not DTensors which im not sure why + # grad_norm = dist_utils.clip_grad_norm_( + # [p for m in self.model_parts for p in m.parameters()], + # self.job_config.training.max_norm, + # foreach=True, + # pp_mesh=( + # parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None + # ), + # ep_enabled=parallel_dims.ep_enabled, + # ) + grad_norm = torch.tensor([0.0], device=self.device) self.checkpointer.maybe_wait_for_staging() self.optimizers.step() self.lr_schedulers.step() @@ -636,7 +654,159 @@ def close(self) -> None: if self.metrics_processor: self.metrics_processor.close() +def _count_moe_modules(model): + """Count MoE modules directly""" + from torchtitan.models.moe import MoE + moe_count = 0 + for name, module in model.named_modules(): + if isinstance(module, MoE): + moe_count += 1 + return moe_count + +def overlap_callback(action: _Action, ctx: _PipelineContext): + """Custom callback for OVERLAP_F_B computation that mimics the original implementation.""" + schedule = ctx.schedule_ref + assert isinstance(schedule, _PipelineScheduleRuntime) + stage_index_to_stage: dict[int, _PipelineStageBase] = { + stage.stage_index: stage for stage in schedule._stages + } + assert action.sub_actions is not None + fwd_action = action.sub_actions[0] + bwd_action = action.sub_actions[1] + + # Get stages + forward_stage_index = fwd_action.stage_index + forward_mb_index = fwd_action.microbatch_index + assert forward_mb_index is not None + backward_stage_index = bwd_action.stage_index + backward_stage = stage_index_to_stage[backward_stage_index] + + # Forward setup + arg_mbs = ctx.arg_mbs + kwarg_mbs = ctx.kwarg_mbs + fwd_recv_ops = schedule.fwd_recv_ops + forward_stage = stage_index_to_stage[forward_stage_index] + forward_is_next_stage_on_this_rank = forward_stage_index + 1 in stage_index_to_stage + forward_is_prev_stage_on_this_rank = forward_stage_index - 1 in stage_index_to_stage + + # Backward setup + backward_is_next_stage_on_this_rank = ( + backward_stage.stage_index + 1 in stage_index_to_stage + ) + backward_is_prev_stage_on_this_rank = ( + backward_stage.stage_index - 1 in stage_index_to_stage + ) + backward_mb_index = bwd_action.microbatch_index + assert backward_mb_index is not None + bwd_recv_ops = schedule.bwd_recv_ops + + print(f"overlap_callback begin {forward_stage_index}:{forward_mb_index}, {backward_stage_index}:{backward_mb_index}", "=" * 80, torch.distributed.get_rank()) + # PP communication ======================================================== + + # Fwd receives + if ( + not forward_stage.is_first + # no recv op expected for V-schedule special case (see [Note: V-schedule special case]) + and not forward_is_prev_stage_on_this_rank + ): + assert ( + forward_stage_index, + forward_mb_index, + ) in fwd_recv_ops, f"Computing {action=} before receiving input" + _wait_batch_p2p(fwd_recv_ops.pop((forward_stage_index, forward_mb_index))) + + # Bwd receives + if ( + not backward_stage.is_last + # no recv op expected for V-schedule special case (see [Note: V-schedule special case]) + and not backward_is_next_stage_on_this_rank + ): + assert ( + backward_stage_index, + backward_mb_index, + ) in bwd_recv_ops, f"Attempted to run compute {action=} before receiving input" + _wait_batch_p2p(bwd_recv_ops.pop((backward_stage_index, backward_mb_index))) + + # PP computation ======================================================== + def forward_backward_overlapped(): + from torchtitan.distributed.expert_parallel import _hook_coordinator + # TODO: Num layers is needed in case the stage layers differ, we need to ensure there is no coordination + min_num_layers = min(_count_moe_modules(forward_stage.submod), _count_moe_modules(backward_stage.submod)) + _hook_coordinator.enable_coordination(num_layers=min_num_layers) + if _hook_coordinator.is_coordination_enabled(): + print("Coordination is active") + else: + print("Coordination is disabled") + + main_cuda_stream = torch.cuda.current_stream() + + def run_backward(): + # Set the backward thread to use the same stream as forward + torch.cuda.set_stream(main_cuda_stream) + print(f"BACKWARD {backward_stage_index} {torch.cuda.current_stream()}") + # Backward ======================================================== + with record_function(f"backward_stage_{backward_stage_index}_mb_{backward_mb_index}"): + loss = schedule._maybe_get_loss(backward_stage, backward_mb_index) + schedule.backward_counter[backward_stage_index] += 1 + last_backward = ( + schedule.backward_counter[backward_stage_index] == schedule._n_microbatches + ) + backward_stage.backward_one_chunk( + backward_mb_index, + loss=loss, + full_backward=True, + last_backward=last_backward, + ) + grad_scale_factor = schedule._n_microbatches if schedule.scale_grads else 1 + if last_backward: + backward_stage.scale_grads(grad_scale_factor) + + if backward_is_prev_stage_on_this_rank: + stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input( + backward_stage.get_local_bwd_output(backward_mb_index), + backward_mb_index, + ) + + + # Forward ======================================================== + def run_forward(): + print(f"FORWARD {forward_stage_index} {torch.cuda.current_stream()}") + output = forward_stage.forward_one_chunk( + forward_mb_index, + arg_mbs[forward_mb_index], + kwarg_mbs[forward_mb_index], + ) + schedule._maybe_compute_loss( + forward_stage, output, ctx.target_mbs, forward_mb_index + ) + if forward_is_next_stage_on_this_rank: + stage_index_to_stage[forward_stage_index + 1].set_local_fwd_input( + output, forward_mb_index + ) + # Run forward and backward in parallel + # if _hook_coordinator.is_coordination_enabled(): + thread = threading.Thread(target=run_backward, daemon=True) + thread.start() + run_forward() + thread.join() + # with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + # forward_future = executor.submit(run_forward) + # backward_future = executor.submit(run_backward) + + # # Wait for both to complete simultaneously + # done, not_done = concurrent.futures.wait([forward_future, backward_future]) + # output = forward_future.result() + # else: + # run_forward() + # run_backward() + + _hook_coordinator.disable_coordination() + forward_backward_overlapped() + print(f"overlap_callback end {forward_stage_index}:{forward_mb_index}, {backward_stage_index}:{backward_mb_index}", "=" * 80) + +import fbvscode +fbvscode.attach_debugger() if __name__ == "__main__": init_logger() config_manager = ConfigManager()