Skip to content

Commit 3462302

Browse files
committed
Add tests for layernorm_gated and ssm_update with heads
1 parent b85b20d commit 3462302

File tree

2 files changed

+154
-2
lines changed

2 files changed

+154
-2
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import math
2+
3+
import torch
4+
import torch.nn.functional as F
5+
6+
import pytest
7+
8+
from einops import rearrange, repeat
9+
10+
from mamba_ssm.ops.triton.layernorm_gated import layernorm_fn, rms_norm_ref
11+
12+
13+
@pytest.mark.parametrize("norm_before_gate", [True, False])
14+
# @pytest.mark.parametrize("norm_before_gate", [False])
15+
@pytest.mark.parametrize("has_group", [False, True])
16+
# @pytest.mark.parametrize("has_group", [False])
17+
@pytest.mark.parametrize("is_rms_norm", [False, True])
18+
# @pytest.mark.parametrize("is_rms_norm", [True])
19+
@pytest.mark.parametrize("has_z", [False, True])
20+
# @pytest.mark.parametrize("has_z", [True])
21+
@pytest.mark.parametrize("has_bias", [False, True])
22+
# @pytest.mark.parametrize("has_bias", [False])
23+
# @pytest.mark.parametrize('dtype', [torch.float32, torch.float16, torch.bfloat16])
24+
@pytest.mark.parametrize('dtype', [torch.float16])
25+
# @pytest.mark.parametrize("wtype", [torch.float32, torch.float16, torch.bfloat16])
26+
@pytest.mark.parametrize("wtype", [torch.float32])
27+
@pytest.mark.parametrize('d', [2048, 4096])
28+
# @pytest.mark.parametrize('d', [4096])
29+
def test_layer_norm_gated(d, dtype, wtype, has_bias, has_z, is_rms_norm, has_group, norm_before_gate):
30+
if not has_z and not norm_before_gate:
31+
pytest.skip()
32+
if not norm_before_gate and not is_rms_norm: # Reference LN isn't implemented for this case yet
33+
pytest.skip()
34+
device = 'cuda'
35+
rtol, atol = (1e-5, 1e-5) if dtype == torch.float32 else (1e-2, 8e-3)
36+
group_size = None if not has_group else 64
37+
# set seed
38+
torch.random.manual_seed(0)
39+
batch = 16
40+
seqlen = 1024
41+
x = torch.randn(batch, seqlen, d, dtype=dtype, device=device, requires_grad=True)
42+
if has_z:
43+
z = torch.randn(batch, seqlen, d, dtype=dtype, device=device, requires_grad=True)
44+
else:
45+
z = None
46+
weight = torch.randn(d, dtype=wtype, device=device, requires_grad=True)
47+
if has_bias:
48+
bias = torch.randn(d, dtype=wtype, device=device, requires_grad=True)
49+
else:
50+
bias = None
51+
x_ref = x.detach().clone().requires_grad_()
52+
x_pt = x.detach().clone().requires_grad_()
53+
z_ref = z.detach().clone().requires_grad_() if z is not None else None
54+
z_pt = z.detach().clone().requires_grad_() if z is not None else None
55+
weight_ref = weight.detach().clone().requires_grad_()
56+
weight_pt = weight.detach().clone().requires_grad_()
57+
bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
58+
bias_pt = bias.detach().clone().requires_grad_() if bias is not None else None
59+
out = layernorm_fn(x, weight, bias, z=z, eps=1e-5, group_size=group_size, norm_before_gate=norm_before_gate,
60+
is_rms_norm=is_rms_norm)
61+
if not is_rms_norm:
62+
if not has_group:
63+
out_ref = F.layer_norm(x_ref.float(), (d,), weight=weight_ref.float(), bias=bias_ref.float() if bias_ref is not None else None, eps=1e-5)
64+
out_pt = F.layer_norm(x_pt.to(wtype), (d,), weight=weight_pt, bias=bias_pt, eps=1e-5)
65+
else:
66+
out_ref = rearrange(F.layer_norm(rearrange(x_ref, "... (g d) -> ... g d", d=group_size).float(), (group_size,), eps=1e-5), "... g d -> ... (g d)") * weight_ref.float()
67+
if has_bias:
68+
out_ref = out_ref + bias_ref.float()
69+
out_pt = rearrange(F.layer_norm(rearrange(x_pt, "... (g d) -> ... g d", d=group_size), (group_size,), eps=1e-5), "... g d -> ... (g d)") * weight_pt
70+
if has_bias:
71+
out_pt = out_pt + bias_pt
72+
if has_z and norm_before_gate:
73+
out_ref = out_ref * F.silu(z_ref.float())
74+
out_pt = out_pt * F.silu(z_pt)
75+
else:
76+
out_ref = rms_norm_ref(x_ref, weight_ref, bias_ref, z=z_ref, eps=1e-5, group_size=group_size,
77+
norm_before_gate=norm_before_gate)
78+
out_pt = rms_norm_ref(x_pt, weight_pt, bias_pt, z=z_pt, eps=1e-5, group_size=group_size,
79+
norm_before_gate=norm_before_gate, upcast=False)
80+
print(f"Max diff = {(out - out_ref).abs().max().item()}")
81+
print(f"Max diff Pytorch = {(out_pt - out_ref).abs().max().item()}")
82+
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + atol
83+
84+
g = torch.randn_like(out)
85+
out.backward(g)
86+
out_ref.backward(g)
87+
out_pt.backward(g)
88+
print(f"Max dx diff = {(x.grad - x_ref.grad).abs().max().item()}")
89+
print(f"Max dx diff Pytorch = {(x_pt.grad - x_ref.grad).abs().max().item()}")
90+
if has_z:
91+
print(f"Max dz diff = {(z.grad - z_ref.grad).abs().max().item()}")
92+
print(f"Max dz diff Pytorch = {(z_pt.grad - z_ref.grad).abs().max().item()}")
93+
print(f"Max dw diff = {(weight.grad - weight_ref.grad).abs().max().item()}")
94+
print(f"Max dw diff Pytorch = {(weight_pt.grad - weight_ref.grad).abs().max().item()}")
95+
if has_bias:
96+
print(f"Max db diff = {(bias.grad - bias_ref.grad).abs().max().item()}")
97+
print(f"Max db diff Pytorch = {(bias_pt.grad - bias_ref.grad).abs().max().item()}")
98+
assert (x.grad - x_ref.grad).abs().max().item() <= 2 * (x_pt.grad - x_ref.grad).abs().max().item() + atol
99+
if has_z:
100+
assert (z.grad - z_ref.grad).abs().max().item() <= 2 * (z_pt.grad - z_ref.grad).abs().max().item() + atol
101+
assert (weight.grad - weight_ref.grad).abs().max().item() <= 2 * (weight_pt.grad - weight_ref.grad).abs().max().item() + atol
102+
if has_bias:
103+
assert (bias.grad - bias_ref.grad).abs().max().item() <= 2 * (bias_pt.grad - bias_ref.grad).abs().max().item() + atol

tests/ops/triton/test_selective_state_update.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,8 @@ def test_selective_state_update(dim, dstate, has_z, itype):
2424
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
2525
if itype == torch.bfloat16:
2626
rtol, atol = 1e-2, 5e-2
27-
2827
if torch.version.hip:
2928
atol *= 2
30-
3129
# set seed
3230
torch.random.manual_seed(0)
3331
batch_size = 2
@@ -51,3 +49,54 @@ def test_selective_state_update(dim, dstate, has_z, itype):
5149
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
5250
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
5351
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
52+
53+
54+
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
55+
# @pytest.mark.parametrize('itype', [torch.float16])
56+
@pytest.mark.parametrize("has_z", [False, True])
57+
# @pytest.mark.parametrize('has_z', [True])
58+
@pytest.mark.parametrize("tie_hdim", [False, True])
59+
# @pytest.mark.parametrize('tie_hdim', [True])
60+
@pytest.mark.parametrize("ngroups", [1, 2, 4])
61+
# @pytest.mark.parametrize("ngroups", [2])
62+
@pytest.mark.parametrize("dstate", [16, 32, 64])
63+
# @pytest.mark.parametrize("dstate", [16])
64+
@pytest.mark.parametrize("dim", [2048, 4096])
65+
# @pytest.mark.parametrize("dim", [2048])
66+
def test_selective_state_update_with_heads(dim, dstate, ngroups, has_z, tie_hdim, itype):
67+
device = "cuda"
68+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
69+
if itype == torch.bfloat16:
70+
rtol, atol = 1e-2, 1e-1
71+
# set seed
72+
torch.random.manual_seed(0)
73+
batch_size = 2
74+
headdim = 64
75+
nheads = dim // headdim
76+
state = torch.randn(batch_size, nheads, headdim, dstate, dtype=itype, device=device)
77+
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
78+
if not tie_hdim:
79+
dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
80+
dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
81+
A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
82+
D = torch.randn(nheads, headdim, device=device)
83+
else:
84+
dt = repeat(torch.randn(batch_size, nheads, device=device, dtype=itype), "b h -> b h p", p=headdim)
85+
dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim)
86+
A = repeat(-torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate)
87+
D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
88+
B = torch.randn(batch_size, ngroups, dstate, device=device)
89+
C = torch.randn(batch_size, ngroups, dstate, device=device)
90+
if has_z:
91+
z = torch.randn_like(x)
92+
else:
93+
z = None
94+
state_ref = state.detach().clone()
95+
state_og = state.detach().clone()
96+
out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
97+
out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
98+
99+
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
100+
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
101+
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
102+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)

0 commit comments

Comments
 (0)