|
| 1 | +from typing import Dict, Tuple |
| 2 | + |
| 3 | +import pytest |
| 4 | +import torch |
| 5 | +import torch.nn.functional as F |
| 6 | +from einops import rearrange, repeat |
| 7 | + |
| 8 | +from vllm.model_executor.layers.mamba.ops.ssd_combined import ( |
| 9 | + mamba_chunk_scan_combined) |
| 10 | +from vllm.platforms import current_platform |
| 11 | + |
| 12 | +# Added by the IBM Team, 2024 |
| 13 | + |
| 14 | +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py |
| 15 | + |
| 16 | + |
| 17 | +# this is the segsum implementation taken from above |
| 18 | +def segsum(x): |
| 19 | + """Calculates segment sum.""" |
| 20 | + T = x.size(-1) |
| 21 | + x = repeat(x, "... d -> ... d e", e=T) |
| 22 | + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), |
| 23 | + diagonal=-1) |
| 24 | + x = x.masked_fill(~mask, 0) |
| 25 | + x_segsum = torch.cumsum(x, dim=-2) |
| 26 | + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), |
| 27 | + diagonal=0) |
| 28 | + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) |
| 29 | + return x_segsum |
| 30 | + |
| 31 | + |
| 32 | +def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): |
| 33 | + """ |
| 34 | + Arguments: |
| 35 | + X: (batch, length, n_heads, d_head) |
| 36 | + A: (batch, length, n_heads) |
| 37 | + B: (batch, length, n_heads, d_state) |
| 38 | + C: (batch, length, n_heads, d_state) |
| 39 | + Return: |
| 40 | + Y: (batch, length, n_heads, d_head) |
| 41 | + """ |
| 42 | + assert X.dtype == A.dtype == B.dtype == C.dtype |
| 43 | + assert X.shape[1] % block_len == 0 |
| 44 | + |
| 45 | + # Rearrange into blocks/chunks |
| 46 | + X, A, B, C = (rearrange(x, "b (c l) ... -> b c l ...", l=block_len) |
| 47 | + for x in (X, A, B, C)) |
| 48 | + |
| 49 | + A = rearrange(A, "b c l h -> b h c l") |
| 50 | + A_cumsum = torch.cumsum(A, dim=-1) |
| 51 | + |
| 52 | + # 1. Compute the output for each intra-chunk (diagonal blocks) |
| 53 | + L = torch.exp(segsum(A)) |
| 54 | + Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X) |
| 55 | + |
| 56 | + # 2. Compute the state for each intra-chunk |
| 57 | + # (right term of low-rank factorization of off-diagonal blocks; B terms) |
| 58 | + decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum) |
| 59 | + states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X) |
| 60 | + |
| 61 | + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at |
| 62 | + # chunk boundaries |
| 63 | + # (middle term of factorization of off-diag blocks; A terms) |
| 64 | + if initial_states is None: |
| 65 | + initial_states = torch.zeros_like(states[:, :1]) |
| 66 | + states = torch.cat([initial_states, states], dim=1) |
| 67 | + decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) |
| 68 | + new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) |
| 69 | + states, final_state = new_states[:, :-1], new_states[:, -1] |
| 70 | + |
| 71 | + # 4. Compute state -> output conversion per chunk |
| 72 | + # (left term of low-rank factorization of off-diagonal blocks; C terms) |
| 73 | + state_decay_out = torch.exp(A_cumsum) |
| 74 | + Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out) |
| 75 | + |
| 76 | + # Add output of intra-chunk and inter-chunk terms |
| 77 | + # (diagonal and off-diagonal blocks) |
| 78 | + Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p") |
| 79 | + return Y, final_state |
| 80 | + |
| 81 | + |
| 82 | +def generate_random_inputs(batch_size, |
| 83 | + seqlen, |
| 84 | + n_heads, |
| 85 | + d_head, |
| 86 | + itype, |
| 87 | + device='cuda'): |
| 88 | + |
| 89 | + current_platform.seed_everything(0) |
| 90 | + A = (-torch.exp(torch.rand(n_heads, dtype=itype, device=device))) |
| 91 | + dt = F.softplus( |
| 92 | + torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - |
| 93 | + 4) |
| 94 | + X = torch.randn((batch_size, seqlen, n_heads, d_head), |
| 95 | + dtype=itype, |
| 96 | + device=device) |
| 97 | + B = torch.randn((batch_size, seqlen, n_heads, d_head), |
| 98 | + dtype=itype, |
| 99 | + device=device) |
| 100 | + C = torch.randn((batch_size, seqlen, n_heads, d_head), |
| 101 | + dtype=itype, |
| 102 | + device=device) |
| 103 | + |
| 104 | + return A, dt, X, B, C |
| 105 | + |
| 106 | + |
| 107 | +def generate_continous_batched_examples(example_lens_by_batch, |
| 108 | + num_examples, |
| 109 | + full_length, |
| 110 | + last_taken, |
| 111 | + exhausted, |
| 112 | + n_heads, |
| 113 | + d_head, |
| 114 | + itype, |
| 115 | + device='cuda'): |
| 116 | + |
| 117 | + # this function generates a random examples of certain length |
| 118 | + # and then cut according to "example_lens_by_batch" and feed |
| 119 | + # them in continuous batches to the kernels |
| 120 | + |
| 121 | + # generate the full-length example |
| 122 | + A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads, |
| 123 | + d_head, itype) |
| 124 | + |
| 125 | + Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), |
| 126 | + A * dt, |
| 127 | + B, |
| 128 | + C, |
| 129 | + block_len=full_length // 4) |
| 130 | + |
| 131 | + # internal function that outputs a cont batch of examples |
| 132 | + # given a tuple of lengths for each example in the batch |
| 133 | + # e.g., example_lens=(8, 4) means take 8 samples from first eg, |
| 134 | + # 4 examples from second eg, etc |
| 135 | + def get_continuous_batch(example_lens: Tuple[int, ...]): |
| 136 | + |
| 137 | + indices = [] |
| 138 | + for i, x in enumerate(example_lens): |
| 139 | + c = last_taken.get(i, 0) |
| 140 | + indices.append((c, c + x)) |
| 141 | + last_taken[i] = (c + x) % full_length |
| 142 | + exhausted[i] = last_taken[i] == 0 |
| 143 | + |
| 144 | + return (torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices) |
| 145 | + ]).unsqueeze(0) for x in (dt, X, B, C)) |
| 146 | + |
| 147 | + # internal function that maps "n" to the appropriate right boundary |
| 148 | + # value when forming continuous batches from examples of length given |
| 149 | + # by "full_length". |
| 150 | + # - e.g., when n > full_length, returns n % full_length |
| 151 | + # when n == full_length, returns full_length |
| 152 | + def end_boundary(n: int): |
| 153 | + return n - ((n - 1) // full_length) * full_length |
| 154 | + |
| 155 | + IND_E = None |
| 156 | + for spec in example_lens_by_batch: |
| 157 | + |
| 158 | + # get the (maybe partial) example seen in this cont batch |
| 159 | + dt2, X2, B2, C2 = get_continuous_batch(spec) |
| 160 | + |
| 161 | + # get the metadata |
| 162 | + cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0) |
| 163 | + sed_idx = torch.zeros(cu_seqlens[-1], |
| 164 | + dtype=torch.int32, |
| 165 | + device=cu_seqlens.device) |
| 166 | + for i, (srt, end) in enumerate(zip( |
| 167 | + cu_seqlens, |
| 168 | + cu_seqlens[1:], |
| 169 | + )): |
| 170 | + sed_idx[srt:end] = i |
| 171 | + |
| 172 | + # for cont batch |
| 173 | + if IND_E is None: |
| 174 | + IND_S = [0 for _ in range(len(spec))] |
| 175 | + else: |
| 176 | + IND_S = [x % full_length for x in IND_E] |
| 177 | + IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)] |
| 178 | + |
| 179 | + yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)], |
| 180 | + cu_seqlens, sed_idx.unsqueeze(0), (A, dt2, X2, B2, C2)) |
| 181 | + |
| 182 | + |
| 183 | +@pytest.mark.parametrize("itype", |
| 184 | + [torch.float32, torch.float16, torch.bfloat16]) |
| 185 | +@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32]) |
| 186 | +@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128]) |
| 187 | +@pytest.mark.parametrize("seq_len_chunk_size", [(119, 17), (128, 32)]) |
| 188 | +def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, |
| 189 | + itype): |
| 190 | + |
| 191 | + # this tests the kernels on a single example (no batching) |
| 192 | + |
| 193 | + # set seed |
| 194 | + batch_size = 1 # batch_size |
| 195 | + # ssd_minimal_discrete requires chunk_size divide seqlen |
| 196 | + # - this is only required for generating the reference seqs, |
| 197 | + # it is not an operational limitation. |
| 198 | + seqlen, chunk_size = seq_len_chunk_size |
| 199 | + |
| 200 | + A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads, |
| 201 | + d_head, itype) |
| 202 | + |
| 203 | + Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt, |
| 204 | + B, C, chunk_size) |
| 205 | + |
| 206 | + Y, final_state = mamba_chunk_scan_combined(X, |
| 207 | + dt, |
| 208 | + A, |
| 209 | + B, |
| 210 | + C, |
| 211 | + chunk_size, |
| 212 | + D=None, |
| 213 | + return_final_states=True) |
| 214 | + |
| 215 | + # just test the last in sequence |
| 216 | + torch.allclose(Y[:, -1], Y_min[:, -1], atol=1e-3, rtol=1e-3) |
| 217 | + |
| 218 | + # just test the last head |
| 219 | + # NOTE, in the kernel we always cast states to fp32 |
| 220 | + torch.allclose(final_state[:, -1], |
| 221 | + final_state_min[:, -1].to(torch.float32), |
| 222 | + atol=1e-3, |
| 223 | + rtol=1e-3) |
| 224 | + |
| 225 | + |
| 226 | +@pytest.mark.parametrize("itype", [torch.float32, torch.float16]) |
| 227 | +@pytest.mark.parametrize("n_heads", [4, 8, 13]) |
| 228 | +@pytest.mark.parametrize("d_head", [5, 16, 21, 32]) |
| 229 | +@pytest.mark.parametrize( |
| 230 | + "seq_len_chunk_size_cases", |
| 231 | + [ |
| 232 | +
|
| 233 | + # small-ish chunk_size (8) |
| 234 | + (64, 8, 2, [(64, 32), (64, 32)]), |
| 235 | + (64, 8, 2, [(32, 32), (32, 32), (32, 32)]), |
| 236 | + (64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary |
| 237 | + (64, 8, 2, [(4, 4), (4, 4), (4, 4), |
| 238 | + (4, 4)]), # chunk_size larger than cont batches |
| 239 | + (64, 8, 5, [ |
| 240 | + (64, 32, 16, 8, 8), |
| 241 | + (8, 16, 32, 16, 8), |
| 242 | + (8, 8, 16, 32, 16), |
| 243 | + ]), # mode examples with varied lengths |
| 244 | +
|
| 245 | + # odd chunk_size |
| 246 | + (64, 29, 2, [(11, 4), (13, 23), (19, 22), |
| 247 | + (21, 15)]), # irregular sizes |
| 248 | +
|
| 249 | + # large-ish chunk_size (256) |
| 250 | + (64, 256, 1, [(5, ), (1, ), (1, ), |
| 251 | + (1, )]), # irregular sizes with small sequences |
| 252 | + (64, 256, 2, [(5, 30), (1, 2), (1, 2), |
| 253 | + (1, 2)]), # irregular sizes with small sequences |
| 254 | + ]) |
| 255 | +def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, |
| 256 | + itype): |
| 257 | + |
| 258 | + # this test with multiple examples in a continuous batch |
| 259 | + # (i.e. chunked prefill) |
| 260 | + |
| 261 | + seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases |
| 262 | + |
| 263 | + # hold state during the cutting process so we know if an |
| 264 | + # example has been exhausted and needs to cycle |
| 265 | + last_taken: Dict = {} # map: eg -> pointer to last taken sample |
| 266 | + exhausted: Dict = {} # map: eg -> boolean indicating example is exhausted |
| 267 | + |
| 268 | + states = None |
| 269 | + for Y_min, cu_seqlens, sed_idx, (A, dt, X, B, |
| 270 | + C) in generate_continous_batched_examples( |
| 271 | + cases, num_examples, seqlen, |
| 272 | + last_taken, exhausted, n_heads, |
| 273 | + d_head, itype): |
| 274 | + |
| 275 | + Y, new_states = mamba_chunk_scan_combined( |
| 276 | + X, |
| 277 | + dt, |
| 278 | + A, |
| 279 | + B, |
| 280 | + C, |
| 281 | + chunk_size, |
| 282 | + D=None, |
| 283 | + cu_seqlens=cu_seqlens, |
| 284 | + seq_idx=sed_idx, |
| 285 | + return_varlen_states=True, |
| 286 | + initial_states=states, |
| 287 | + ) |
| 288 | + |
| 289 | + # just test the last in sequence |
| 290 | + for i in range(num_examples): |
| 291 | + |
| 292 | + # just test one dim and dstate |
| 293 | + Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0] |
| 294 | + Y_min_eg = Y_min[i][:, 0, 0] |
| 295 | + torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3) |
| 296 | + |
| 297 | + # update states |
| 298 | + states = new_states |
| 299 | + for i, clear in exhausted.items(): |
| 300 | + if clear: |
| 301 | + states[i].fill_(0.) |
| 302 | + exhausted[i] = False |
0 commit comments