Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions run_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ set -ex
# use envs as local overwrites for convenience
# e.g.
# LOG_RANK=0,1 NGPU=4 ./run_train.sh
NGPU=${NGPU:-"8"}
export LOG_RANK=${LOG_RANK:-0}
# NGPU=${NGPU:-"8"}
NGPU=${NGPU:-"4"}
# export LOG_RANK=${LOG_RANK:-0,1,2,3,4,5,6,7}
# export LOG_RANK=${LOG_RANK:-0,1,2,3}
export LOG_RANK=${LOG_RANK:-0,1,2,3}
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"}
TRAIN_FILE=${TRAIN_FILE:-"torchtitan.train"}

Expand Down
2 changes: 1 addition & 1 deletion torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ class Comm:
init_timeout_seconds: int = 300
"""Timeout for communication operations, during initialization and first train step."""

train_timeout_seconds: int = 100
train_timeout_seconds: int = 30
"""
Timeout for communication operations after the first train step --
usually a tighter bound than during initialization.
Expand Down
114 changes: 110 additions & 4 deletions torchtitan/distributed/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
# LICENSE file in the root directory of this source tree.


from typing import Callable, Literal
from typing import Callable, Literal, Optional
import threading

import torch
import torch.nn as nn
Expand All @@ -22,6 +23,86 @@
)
from torch.distributed.tensor.parallel import ParallelStyle

class HookCoordinator:
def __init__(self):
# Barrier for 2 threads (forward and backward) to synchronize
# This ensures that we always alternate at executing one compute and one comm op together
self._execution_barrier = threading.Barrier(2)

self._coordination_enabled = False
self._cycle_count = 0
self._num_layers = None

def barrier(self):
"""Barrier for 2 threads to synchronize"""
if not self.is_coordination_enabled():
return

try:
self._execution_barrier.wait()
print(f"Both threads ready, proceeding")
except threading.BrokenBarrierError:
print(f"Barrier broken - one thread has finished!")

def enable_coordination(self, num_layers: Optional[int] = None):
if num_layers is not None and num_layers > 0:
self._coordination_enabled = True
self._cycle_count = 0

# Reset barrier
self._execution_barrier = threading.Barrier(2)

self._num_layers = num_layers
print(f"Compute/Comm hook coordination ENABLED with {num_layers} MoE layers")

def disable_coordination(self):
self._coordination_enabled = False
self._cycle_count = 0
self._execution_barrier.abort() # Break barrier to unblock threads
print("[COORDINATION] Compute/Comm hook coordination DISABLED")

def check_should_continue_coordination(self):
if self._num_layers is not None and self._cycle_count >= self._num_layers:
print("[COORDINATION] Reached target number of cycles, disabling coordination")
return False
return True

def is_coordination_enabled(self):
return self._coordination_enabled

# Global coordinator
_hook_coordinator = HookCoordinator()

class SyncHook(torch.autograd.Function):
@staticmethod
def forward(ctx, x, hook_name=""):
ctx.hook_name = hook_name
# handle edge case for transformer level boundary
if _hook_coordinator._coordination_enabled and hook_name == "D":
_hook_coordinator._cycle_count += 1
print(f"[FORWARD] cycle count: {_hook_coordinator._cycle_count}", "=" * 40)
if not _hook_coordinator.check_should_continue_coordination():
_hook_coordinator.disable_coordination()
return x

_hook_coordinator.barrier()

if _hook_coordinator.is_coordination_enabled():
print(f"[FORWARD] finished {hook_name}_fwd")
return x

@staticmethod
def backward(ctx, grad_output):
hook_name = ctx.hook_name

# Edge case, skip initial barrier, all subsequent backward hooks will acquire
if hook_name == "D" and _hook_coordinator._cycle_count == 0:
return grad_output, None

_hook_coordinator.barrier()
if _hook_coordinator.is_coordination_enabled():
print(f"[BACKWARD] finished {hook_name}_bwd")
return grad_output, None

TOKEN_GROUP_ALIGN_SIZE_M = 8
ValidTokenGroupAlignmentSize = Literal[8, 16, 32]
Expand Down Expand Up @@ -89,7 +170,6 @@ def _token_dispatch(self, mod, inputs, device_mesh):
# annotate module input placements/sharding with input_layouts
routed_input, num_tokens_per_expert = inputs
ep_size = device_mesh.shape[0]

# generate the input splits and output splits for all-to-all
with torch.no_grad():
num_tokens_per_expert_group = all_to_all_single(
Expand Down Expand Up @@ -155,13 +235,39 @@ def _token_combine(self, mod, routed_output, device_mesh):
return routed_output

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
module,
"""
hooks are called in the order they are registered:
A, dispatch, B (pre hooks)
C, combine, D (post hooks)
"""
inner_wrapped_module = self._wrap_with_inner_hooks(module)
distributed_module = distribute_module(
inner_wrapped_module,
device_mesh,
partition_fn=ExpertParallel._partition_fn,
input_fn=self._token_dispatch,
output_fn=self._token_combine,
)
final_module = self._wrap_with_outer_hooks(distributed_module)
return final_module

def _wrap_with_inner_hooks(self, module):
def inner_pre_hook(module, input):
return (SyncHook.apply(input[0], "A"),) + input[1:]
def inner_post_hook(module, input, output):
return SyncHook.apply(output, "C")
module.register_forward_pre_hook(inner_pre_hook)
module.register_forward_hook(inner_post_hook)
return module

def _wrap_with_outer_hooks(self, module):
def outer_pre_hook(module, input):
return (SyncHook.apply(input[0], "B"),) + input[1:]
def outer_post_hook(module, input, output):
return SyncHook.apply(output, "D")
module.register_forward_pre_hook(outer_pre_hook)
module.register_forward_hook(outer_post_hook)
return module


# This class is for dp2ep with TP (without TP we can just use ExpertParallel)
Expand Down
5 changes: 3 additions & 2 deletions torchtitan/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.models.llama3.infra.pipeline import pipeline_llama
from torchtitan.models.llama3.infra.pipeline import pipeline_llama, pipeline_llama_tracer
from torchtitan.models.moe import MoEArgs

from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
Expand All @@ -32,7 +32,8 @@
deepseekv3_configs = {
"debugmodel": DeepSeekV3ModelArgs(
vocab_size=2000,
dim=256,
# needs at least dim 8?
dim=16,
inter_dim=1024,
moe_inter_dim=256,
n_layers=6,
Expand Down
21 changes: 11 additions & 10 deletions torchtitan/models/deepseek_v3/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ description = "DeepSeek-V3 debug training"
print_args = false

[profiling]
enable_profiling = false
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 10
profile_freq = 1
enable_memory_snapshot = false
save_memory_snapshot_folder = "memory_snapshot"

Expand All @@ -30,28 +30,29 @@ lr = 8e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
warmup_steps = 0 # lr scheduler warm up, normally 20% of the train steps
decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
decay_type = "linear"
min_lr_factor = 0.0

[training]
local_batch_size = 8
seq_len = 2048
local_batch_size = 4
seq_len = 4
max_norm = 1.0 # grad norm clipping
steps = 10
steps = 6
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
# dataset = "c4"

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
fsdp_reshard_after_forward = "default" # default / never / always
tensor_parallel_degree = 1
enable_async_tensor_parallel = false
pipeline_parallel_degree = 1
pipeline_parallel_schedule = "1F1B"
pipeline_parallel_degree = 2
expert_parallel_degree = 2
context_parallel_degree = 1
expert_parallel_degree = 1
pipeline_parallel_schedule = "DualPipeV"
expert_tensor_parallel_degree = 1

[checkpoint]
Expand All @@ -63,7 +64,7 @@ export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = "selective" # ["none", "selective", "full"]
mode = "none" # ["none", "selective", "full"]
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy

[compile]
Expand Down
75 changes: 75 additions & 0 deletions torchtitan/models/llama3/infra/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
pipeline_module_split,
)

from torch.distributed.pipelining import SplitPoint, pipeline
from torch.distributed.pipelining.stage import _PipelineStage

from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction
from torchtitan.tools.logging import logger

Expand Down Expand Up @@ -148,3 +151,75 @@ def pipeline_llama(
has_last_stage = True

return pp_schedule, model_parts, has_first_stage, has_last_stage


def pipeline_llama_tracer(
model: nn.Module,
parallel_dims: ParallelDims,
job_config: JobConfig,
device: torch.device,
model_args: BaseModelArgs,
parallelize_fn: ParallelizeFunction,
loss_fn: LossFunction,
):
assert (
parallel_dims.pp_enabled
), "can't apply pipeline parallelism if it is not enabled"

# if job_config.model.norm_type == "fused_rmsnorm":
# # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
# # coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm
# raise NotImplementedError(
# "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm."
# )
pp_mesh = parallel_dims.world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
stage_idx = pp_mesh.get_local_rank()
layers_per_rank = model_args.n_layers // parallel_dims.pp
split_spec = {
f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING
for i in range(1, parallel_dims.pp)
}
# Get example input
input_shape = (job_config.training.local_batch_size, job_config.training.seq_len)
assert hasattr(model_args, "vocab_size")
input_ids = torch.randint(
model_args.vocab_size, input_shape, dtype=torch.int64, device="meta"
)

# Create a pipeline representation from the model
pipe = pipeline(
model, mb_args=(input_ids,), split_spec=split_spec
)
model = pipe.get_stage_module(stage_idx)
stage = _PipelineStage(
stage_module=model,
stage_index=pp_rank,
pipe_info=pipe.pipe_info,
device=device,
group=pp_mesh.get_group(),
)

# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
# optimizer, and checkpointing
for i, m in enumerate(model_parts):
# apply SPMD-style PT-D techniques
m = parallelize_fn(m, parallel_dims, job_config)
model_parts[i] = m
# NOTE: this is to update the model in the stage
# in case the model is modified e.g. by torch.compile
stages[i].submod = m

pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)

# This is used in the train loop to determine whether to pass in the input_ids and labels
has_first_stage = False
has_last_stage = False
for stage in stages:
if stage.is_first:
has_first_stage = True
if stage.is_last:
has_last_stage = True

return pp_schedule, model_parts, has_first_stage, has_last_stage
Loading