|
| 1 | +# Copyright (c) 2024, Albert Gu and Tri Dao. |
| 2 | +"""Minimal implementation of SSD. |
| 3 | +
|
| 4 | +This is the same as Listing 1 from the paper. |
| 5 | +""" |
| 6 | + |
| 7 | +import torch |
| 8 | +import torch.nn.functional as F |
| 9 | +from einops import rearrange, repeat |
| 10 | + |
| 11 | +from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined |
| 12 | + |
| 13 | + |
| 14 | +def segsum_unstable(x): |
| 15 | + """Naive segment sum calculation.""" |
| 16 | + T = x.size(-1) |
| 17 | + x_cumsum = torch.cumsum(x, dim=-1) |
| 18 | + x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :] |
| 19 | + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) |
| 20 | + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) |
| 21 | + return x_segsum |
| 22 | + |
| 23 | +def segsum(x): |
| 24 | + """More stable segment sum calculation.""" |
| 25 | + T = x.size(-1) |
| 26 | + x = repeat(x, "... d -> ... d e", e=T) |
| 27 | + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) |
| 28 | + x = x.masked_fill(~mask, 0) |
| 29 | + x_segsum = torch.cumsum(x, dim=-2) |
| 30 | + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) |
| 31 | + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) |
| 32 | + return x_segsum |
| 33 | + |
| 34 | +def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): |
| 35 | + """ |
| 36 | + Arguments: |
| 37 | + X: (batch, length, n_heads, d_head) |
| 38 | + A: (batch, length, n_heads) |
| 39 | + B: (batch, length, n_heads, d_state) |
| 40 | + C: (batch, length, n_heads, d_state) |
| 41 | + Return: |
| 42 | + Y: (batch, length, n_heads, d_head) |
| 43 | + """ |
| 44 | + assert X.dtype == A.dtype == B.dtype == C.dtype |
| 45 | + assert X.shape[1] % block_len == 0 |
| 46 | + |
| 47 | + # Rearrange into blocks/chunks |
| 48 | + X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)] |
| 49 | + |
| 50 | + A = rearrange(A, "b c l h -> b h c l") |
| 51 | + A_cumsum = torch.cumsum(A, dim=-1) |
| 52 | + |
| 53 | + # 1. Compute the output for each intra-chunk (diagonal blocks) |
| 54 | + L = torch.exp(segsum(A)) |
| 55 | + Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X) |
| 56 | + |
| 57 | + # 2. Compute the state for each intra-chunk |
| 58 | + # (right term of low-rank factorization of off-diagonal blocks; B terms) |
| 59 | + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) |
| 60 | + states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X) |
| 61 | + |
| 62 | + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries |
| 63 | + # (middle term of factorization of off-diag blocks; A terms) |
| 64 | + if initial_states is None: |
| 65 | + initial_states = torch.zeros_like(states[:, :1]) |
| 66 | + states = torch.cat([initial_states, states], dim=1) |
| 67 | + decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) |
| 68 | + new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) |
| 69 | + states, final_state = new_states[:, :-1], new_states[:, -1] |
| 70 | + |
| 71 | + # 4. Compute state -> output conversion per chunk |
| 72 | + # (left term of low-rank factorization of off-diagonal blocks; C terms) |
| 73 | + state_decay_out = torch.exp(A_cumsum) |
| 74 | + Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out) |
| 75 | + |
| 76 | + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) |
| 77 | + Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p") |
| 78 | + return Y, final_state |
| 79 | + |
| 80 | + |
| 81 | +# Simple test |
| 82 | +def test_correctness(): |
| 83 | + torch.manual_seed(42) |
| 84 | + |
| 85 | + ## Dimensions |
| 86 | + # Denoted (B, T, Q, D, P) in the paper |
| 87 | + batch, seqlen, chunk_size, dim, headdim = 1, 2048, 64, 2048, 64 |
| 88 | + nheads = dim // headdim # (H) in the paper |
| 89 | + ngroups = 1 # (G) in the paper |
| 90 | + dstate = 64 # (N) in the paper |
| 91 | + dtype = torch.float32 |
| 92 | + device = "cuda" |
| 93 | + |
| 94 | + x = torch.randn(batch, seqlen, nheads, headdim, dtype=dtype, device=device) |
| 95 | + dt = F.softplus(torch.randn(batch, seqlen, nheads, dtype=torch.float32, device=device) - 4).requires_grad_() |
| 96 | + A = (-torch.exp(torch.rand(nheads, dtype=torch.float32, device=device))).requires_grad_() |
| 97 | + B = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device) |
| 98 | + C = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device) |
| 99 | + D = torch.randn(nheads, dtype=dtype, device=device) |
| 100 | + |
| 101 | + # Comparing fused version and minimal version |
| 102 | + y = mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None) |
| 103 | + y_min, _ = ssd_minimal_discrete(x*dt.unsqueeze(-1), A*dt, B, C, chunk_size) |
0 commit comments