Skip to content

Commit c956a30

Browse files
Mamba2 changes from #10909
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
1 parent 92e793d commit c956a30

File tree

8 files changed

+2827
-1
lines changed

8 files changed

+2827
-1
lines changed

tests/kernels/test_mamba_ssm_ssd.py

Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
from typing import Dict, Tuple
2+
3+
import pytest
4+
import torch
5+
import torch.nn.functional as F
6+
from einops import rearrange, repeat
7+
8+
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
9+
mamba_chunk_scan_combined)
10+
from vllm.platforms import current_platform
11+
12+
# Added by the IBM Team, 2024
13+
14+
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py
15+
16+
17+
# this is the segsum implementation taken from above
18+
def segsum(x):
19+
"""Calculates segment sum."""
20+
T = x.size(-1)
21+
x = repeat(x, "... d -> ... d e", e=T)
22+
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool),
23+
diagonal=-1)
24+
x = x.masked_fill(~mask, 0)
25+
x_segsum = torch.cumsum(x, dim=-2)
26+
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool),
27+
diagonal=0)
28+
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
29+
return x_segsum
30+
31+
32+
def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
33+
"""
34+
Arguments:
35+
X: (batch, length, n_heads, d_head)
36+
A: (batch, length, n_heads)
37+
B: (batch, length, n_heads, d_state)
38+
C: (batch, length, n_heads, d_state)
39+
Return:
40+
Y: (batch, length, n_heads, d_head)
41+
"""
42+
assert X.dtype == A.dtype == B.dtype == C.dtype
43+
assert X.shape[1] % block_len == 0
44+
45+
# Rearrange into blocks/chunks
46+
X, A, B, C = (rearrange(x, "b (c l) ... -> b c l ...", l=block_len)
47+
for x in (X, A, B, C))
48+
49+
A = rearrange(A, "b c l h -> b h c l")
50+
A_cumsum = torch.cumsum(A, dim=-1)
51+
52+
# 1. Compute the output for each intra-chunk (diagonal blocks)
53+
L = torch.exp(segsum(A))
54+
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
55+
56+
# 2. Compute the state for each intra-chunk
57+
# (right term of low-rank factorization of off-diagonal blocks; B terms)
58+
decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
59+
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
60+
61+
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at
62+
# chunk boundaries
63+
# (middle term of factorization of off-diag blocks; A terms)
64+
if initial_states is None:
65+
initial_states = torch.zeros_like(states[:, :1])
66+
states = torch.cat([initial_states, states], dim=1)
67+
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
68+
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
69+
states, final_state = new_states[:, :-1], new_states[:, -1]
70+
71+
# 4. Compute state -> output conversion per chunk
72+
# (left term of low-rank factorization of off-diagonal blocks; C terms)
73+
state_decay_out = torch.exp(A_cumsum)
74+
Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
75+
76+
# Add output of intra-chunk and inter-chunk terms
77+
# (diagonal and off-diagonal blocks)
78+
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
79+
return Y, final_state
80+
81+
82+
def generate_random_inputs(batch_size,
83+
seqlen,
84+
n_heads,
85+
d_head,
86+
itype,
87+
device='cuda'):
88+
89+
current_platform.seed_everything(0)
90+
A = (-torch.exp(torch.rand(n_heads, dtype=itype, device=device)))
91+
dt = F.softplus(
92+
torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) -
93+
4)
94+
X = torch.randn((batch_size, seqlen, n_heads, d_head),
95+
dtype=itype,
96+
device=device)
97+
B = torch.randn((batch_size, seqlen, n_heads, d_head),
98+
dtype=itype,
99+
device=device)
100+
C = torch.randn((batch_size, seqlen, n_heads, d_head),
101+
dtype=itype,
102+
device=device)
103+
104+
return A, dt, X, B, C
105+
106+
107+
def generate_continous_batched_examples(example_lens_by_batch,
108+
num_examples,
109+
full_length,
110+
last_taken,
111+
exhausted,
112+
n_heads,
113+
d_head,
114+
itype,
115+
device='cuda'):
116+
117+
# this function generates a random examples of certain length
118+
# and then cut according to "example_lens_by_batch" and feed
119+
# them in continuous batches to the kernels
120+
121+
# generate the full-length example
122+
A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads,
123+
d_head, itype)
124+
125+
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1),
126+
A * dt,
127+
B,
128+
C,
129+
block_len=full_length // 4)
130+
131+
# internal function that outputs a cont batch of examples
132+
# given a tuple of lengths for each example in the batch
133+
# e.g., example_lens=(8, 4) means take 8 samples from first eg,
134+
# 4 examples from second eg, etc
135+
def get_continuous_batch(example_lens: Tuple[int, ...]):
136+
137+
indices = []
138+
for i, x in enumerate(example_lens):
139+
c = last_taken.get(i, 0)
140+
indices.append((c, c + x))
141+
last_taken[i] = (c + x) % full_length
142+
exhausted[i] = last_taken[i] == 0
143+
144+
return (torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices)
145+
]).unsqueeze(0) for x in (dt, X, B, C))
146+
147+
# internal function that maps "n" to the appropriate right boundary
148+
# value when forming continuous batches from examples of length given
149+
# by "full_length".
150+
# - e.g., when n > full_length, returns n % full_length
151+
# when n == full_length, returns full_length
152+
def end_boundary(n: int):
153+
return n - ((n - 1) // full_length) * full_length
154+
155+
IND_E = None
156+
for spec in example_lens_by_batch:
157+
158+
# get the (maybe partial) example seen in this cont batch
159+
dt2, X2, B2, C2 = get_continuous_batch(spec)
160+
161+
# get the metadata
162+
cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0)
163+
sed_idx = torch.zeros(cu_seqlens[-1],
164+
dtype=torch.int32,
165+
device=cu_seqlens.device)
166+
for i, (srt, end) in enumerate(zip(
167+
cu_seqlens,
168+
cu_seqlens[1:],
169+
)):
170+
sed_idx[srt:end] = i
171+
172+
# for cont batch
173+
if IND_E is None:
174+
IND_S = [0 for _ in range(len(spec))]
175+
else:
176+
IND_S = [x % full_length for x in IND_E]
177+
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
178+
179+
yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)],
180+
cu_seqlens, sed_idx.unsqueeze(0), (A, dt2, X2, B2, C2))
181+
182+
183+
@pytest.mark.parametrize("itype",
184+
[torch.float32, torch.float16, torch.bfloat16])
185+
@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32])
186+
@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128])
187+
@pytest.mark.parametrize("seq_len_chunk_size", [(119, 17), (128, 32)])
188+
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
189+
itype):
190+
191+
# this tests the kernels on a single example (no batching)
192+
193+
# set seed
194+
batch_size = 1 # batch_size
195+
# ssd_minimal_discrete requires chunk_size divide seqlen
196+
# - this is only required for generating the reference seqs,
197+
# it is not an operational limitation.
198+
seqlen, chunk_size = seq_len_chunk_size
199+
200+
A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads,
201+
d_head, itype)
202+
203+
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
204+
B, C, chunk_size)
205+
206+
Y, final_state = mamba_chunk_scan_combined(X,
207+
dt,
208+
A,
209+
B,
210+
C,
211+
chunk_size,
212+
D=None,
213+
return_final_states=True)
214+
215+
# just test the last in sequence
216+
torch.allclose(Y[:, -1], Y_min[:, -1], atol=1e-3, rtol=1e-3)
217+
218+
# just test the last head
219+
# NOTE, in the kernel we always cast states to fp32
220+
torch.allclose(final_state[:, -1],
221+
final_state_min[:, -1].to(torch.float32),
222+
atol=1e-3,
223+
rtol=1e-3)
224+
225+
226+
@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
227+
@pytest.mark.parametrize("n_heads", [4, 8, 13])
228+
@pytest.mark.parametrize("d_head", [5, 16, 21, 32])
229+
@pytest.mark.parametrize(
230+
"seq_len_chunk_size_cases",
231+
[
232+
233+
# small-ish chunk_size (8)
234+
(64, 8, 2, [(64, 32), (64, 32)]),
235+
(64, 8, 2, [(32, 32), (32, 32), (32, 32)]),
236+
(64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary
237+
(64, 8, 2, [(4, 4), (4, 4), (4, 4),
238+
(4, 4)]), # chunk_size larger than cont batches
239+
(64, 8, 5, [
240+
(64, 32, 16, 8, 8),
241+
(8, 16, 32, 16, 8),
242+
(8, 8, 16, 32, 16),
243+
]), # mode examples with varied lengths
244+
245+
# odd chunk_size
246+
(64, 29, 2, [(11, 4), (13, 23), (19, 22),
247+
(21, 15)]), # irregular sizes
248+
249+
# large-ish chunk_size (256)
250+
(64, 256, 1, [(5, ), (1, ), (1, ),
251+
(1, )]), # irregular sizes with small sequences
252+
(64, 256, 2, [(5, 30), (1, 2), (1, 2),
253+
(1, 2)]), # irregular sizes with small sequences
254+
])
255+
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
256+
itype):
257+
258+
# this test with multiple examples in a continuous batch
259+
# (i.e. chunked prefill)
260+
261+
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
262+
263+
# hold state during the cutting process so we know if an
264+
# example has been exhausted and needs to cycle
265+
last_taken: Dict = {} # map: eg -> pointer to last taken sample
266+
exhausted: Dict = {} # map: eg -> boolean indicating example is exhausted
267+
268+
states = None
269+
for Y_min, cu_seqlens, sed_idx, (A, dt, X, B,
270+
C) in generate_continous_batched_examples(
271+
cases, num_examples, seqlen,
272+
last_taken, exhausted, n_heads,
273+
d_head, itype):
274+
275+
Y, new_states = mamba_chunk_scan_combined(
276+
X,
277+
dt,
278+
A,
279+
B,
280+
C,
281+
chunk_size,
282+
D=None,
283+
cu_seqlens=cu_seqlens,
284+
seq_idx=sed_idx,
285+
return_varlen_states=True,
286+
initial_states=states,
287+
)
288+
289+
# just test the last in sequence
290+
for i in range(num_examples):
291+
292+
# just test one dim and dstate
293+
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
294+
Y_min_eg = Y_min[i][:, 0, 0]
295+
torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3)
296+
297+
# update states
298+
states = new_states
299+
for i, clear in exhausted.items():
300+
if clear:
301+
states[i].fill_(0.)
302+
exhausted[i] = False

0 commit comments

Comments
 (0)