Skip to content

Add unit tests for dmoe, context parallelism, and muP #1358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions tests/distributed/test_context_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) 2025, EleutherAI
# Licensed under the Apache 2.0 license.

"""
Unit‑tests for context‑parallelism

We patch `megatron.mpu.get_context_parallel_*` so that we don't have to set up distributed on the gh CI runner
2‑way context‑parallel world is running, then verify that:

1. `zigzag_data` returns the correct slice for each (fake) rank.
2. `RotaryEmbedding` builds `cos_cached` / `sin_cached` using the same
zig‑zag time‑indices.

"""

import torch
import pytest
import megatron.mpu as mpu
from megatron.mpu.data import zigzag_data
from megatron.model.positional_embeddings import RotaryEmbedding


@pytest.mark.parametrize("rank", [0, 1])
def test_zigzag_and_rotary(monkeypatch, rank):
"""
Simulate a 2‑GPU context‑parallel group and check that both the low‑level
zig‑zag utility and the higher‑level rotary‑embedding cache behave as
expected on each rank.
"""
# Patch the MPU helpers to fake a 2‑way group
monkeypatch.setattr(mpu, "get_context_parallel_world_size", lambda: 2)
monkeypatch.setattr(mpu, "get_context_parallel_rank", lambda: rank)

# zigzag_data
seq_dim = 1
x = torch.arange(16).view(2, 8) # shape: (batch=2, seq=8)

# Compute the expected zig‑zag slice manually
chunks = torch.chunk(x, 2 * 2, dim=seq_dim) # 4 chunks of length 2
expected = (
torch.cat((chunks[0], chunks[-1]), dim=seq_dim)
if rank == 0
else torch.cat((chunks[1], chunks[-2]), dim=seq_dim)
)

out = zigzag_data(x, seq_dim=seq_dim)
assert torch.equal(out, expected), "zig‑zag sharding mismatch"

# RotaryEmbedding cache
dim = 8
rope = RotaryEmbedding(
dim=dim,
max_seq_len=8,
base=10_000,
precision=torch.float32,
zigzag=True,
)

# Re‑create the ‘t’ indices that _prepare_cache() should have used
full_t = torch.arange(8)
expected_t = (
torch.cat((full_t[:2], full_t[-2:])) # rank 0
if rank == 0
else torch.cat((full_t[2:4], full_t[-4:-2])) # rank 1
)

inv_freq = 1.0 / (10_000 ** (torch.arange(0, dim, 2).float() / dim))
freqs = torch.einsum("i,j->ij", expected_t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos_ref, sin_ref = emb.cos(), emb.sin()

assert rope.cos_cached.shape == cos_ref.shape
assert rope.sin_cached.shape == sin_ref.shape
assert torch.allclose(rope.cos_cached, cos_ref, atol=1e-6)
assert torch.allclose(rope.sin_cached, sin_ref, atol=1e-6)

107 changes: 107 additions & 0 deletions tests/model/test_dmoe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) 2025, EleutherAI
# Licensed under the Apache 2.0 licence.

"""
▸ Part 1 – expert‑token helper utilities:
* `get_expert_tokens_for_rank`
* `get_expert_token_counts_for_rank`

▸ Part 2 – lightweight router (`TopKTokenChoiceRouter`)
* shape & range of returned weights / indices
* determinism under identical input

"""

import types
import torch
import pytest
import importlib


@pytest.fixture(autouse=True)
def patch_mpu(monkeypatch):
"""
Pretend we have a 2‑way tensor‑parallel group; most MoE helpers only query
`get_model_parallel_world_size` and `get_model_parallel_rank`.
"""
import megatron.mpu as mpu

monkeypatch.setattr(mpu, "get_model_parallel_world_size", lambda: 2, raising=False)
# `rank` will be injected per‑test case
yield


def _set_rank(monkeypatch, rank: int):
import megatron.mpu as mpu
monkeypatch.setattr(mpu, "get_model_parallel_rank", lambda: rank, raising=False)


# Part 1 – expert‑token split / gather helpers
@pytest.mark.parametrize("rank", [0, 1])
def test_expert_token_helpers(monkeypatch, rank):
"""
A tiny batch of 6 routed tokens divided among 4 experts with the pattern
[2,1,0,3]. With world_size==2 each rank owns 2 experts ⇒ verify that
the expected slices/counts are returned.
"""
from megatron.mpu.initialize import (
get_expert_tokens_for_rank,
get_expert_token_counts_for_rank,
)

_set_rank(monkeypatch, rank)

tokens_per_expert = torch.tensor([2, 1, 0, 3]) # len == num_experts
routed = torch.arange(6*3).view(6, 3) # shape (6, 3)

# ‑‑ expected slice for this fake rank
# cumulative sums → [2,3,3,6]; rank 0 gets experts 0&1, rank 1 gets 2&3
start = 0 if rank == 0 else 3
end = 3 if rank == 0 else 6
want_slice = routed[start:end]

out_tokens = get_expert_tokens_for_rank(routed, tokens_per_expert)
out_counts = get_expert_token_counts_for_rank(tokens_per_expert)

assert torch.equal(out_tokens, want_slice)
assert out_counts.tolist() == ([2, 1] if rank == 0 else [0, 3])


# Part 2 – Top‑K token‑choice router
def _dummy_args(num_experts=8, top_k=2, hidden_size=16):
"""Return a minimal object that TopKTokenChoiceRouter expects."""
return types.SimpleNamespace(
hidden_size = hidden_size,
moe_num_experts = num_experts,
moe_top_k = top_k,
moe_jitter_eps = None,
params_dtype = torch.float32, # keep everything on CPU
)


@pytest.mark.parametrize("top_k", [1, 2])
def test_router_shapes_and_range(top_k):
"""Router must return (batch, top_k) tensors; indices < num_experts."""
mod = importlib.import_module("megatron.model.router")
Router = mod.TopKTokenChoiceRouter

args = _dummy_args(num_experts=5, top_k=top_k, hidden_size=32)
router = Router(args, init_method=torch.nn.init.uniform_)

seq, bs = 4, 3
x = torch.randn(seq, bs, args.hidden_size)

w, idx = router(x)

assert w.shape == (seq * bs, top_k)
assert idx.shape == (seq * bs, top_k)
assert torch.all(idx < args.moe_num_experts)
# Probabilities must be positive and ≤1
assert torch.all(w >= 0) and torch.all(w <= 1)

# Deterministic behaviour for identical input (no jitter, eval mode).
router.eval()
w2, idx2 = router(x)
assert torch.equal(w, w2)
assert torch.equal(idx, idx2)

80 changes: 80 additions & 0 deletions tests/model/test_mup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) 2025, EleutherAI
# Licensed under the Apache 2.0 license.

import types
import torch
import pytest

from megatron.model.utils import get_params_for_weight_decay_optimization
from megatron.learning_rates import AnnealingLR


class TinyNet(torch.nn.Module):
"""Just enough structure to exercise the param‑group builder."""
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(4, 4) # should get weight‑decay
self.norm = torch.nn.LayerNorm(4) # should be no‑decay


@pytest.fixture(scope="module")
def dummy_args():
# Only the attributes that `get_params_for_weight_decay_optimization`
# actually accesses.
return types.SimpleNamespace(weight_decay=0.1)


def _new_scheduler(optimizer, use_mup, width_mult):
"""
Construct an AnnealingLR and monkey‑patch ``get_lr`` so the test is
independent of the exact schedule math.
"""
sched = AnnealingLR(
optimizer,
start_lr=0.0,
max_lr=0.02,
min_lr=0.0,
warmup_iter=0,
total_iters=1,
decay_style="constant",
use_checkpoint_lr_scheduler=False,
override_lr_scheduler=False,
use_mup=use_mup,
mup_width_multiplier=width_mult,
)
# Force the scheduler to think LR should be 0.02 every step
AnnealingLR.get_lr = lambda self: 0.02
return sched


def test_param_groups_have_lr_adjust(dummy_args):
"""Builder should tag both WD and no‑WD groups with ``lr_adjust``."""
net = TinyNet()
groups = get_params_for_weight_decay_optimization(net, dummy_args)

assert len(groups) == 2
assert all(g.get("lr_adjust", False) for g in groups), (
"Every param‑group returned by the builder must carry lr_adjust=True "
"so muP knows to divide its LR."
)


@pytest.mark.parametrize("use_mup,expected_factor", [(True, 4.0), (False, 1.0)])
def test_scheduler_scales_learning_rate(monkeypatch, dummy_args, use_mup, expected_factor):
"""
When `use_mup` is True the LR of *lr_adjust* groups must be divided by
``mup_width_multiplier``; otherwise, it must stay unchanged.
"""
net = TinyNet()
param_groups = get_params_for_weight_decay_optimization(net, dummy_args)

optimizer = torch.optim.SGD(param_groups, lr=0.0) # fine for sanity checking
width_mult = 4.0
sched = _new_scheduler(optimizer, use_mup=use_mup, width_mult=width_mult)

sched.step()

lrs = [g["lr"] for g in optimizer.param_groups]
assert pytest.approx(lrs[0], rel=1e-7) == 0.02 / expected_factor
assert pytest.approx(lrs[1], rel=1e-7) == 0.02 / expected_factor

Loading