From c66374112f94be3ee27c9b391961e6dfd21e7d3f Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 8 Aug 2024 21:39:51 +0000 Subject: [PATCH] Improve interface to selective_state_update for continuous batching --- .../ops/triton/selective_state_update.py | 25 ++++- .../ops/triton/test_selective_state_update.py | 99 +++++++++++++++++++ 2 files changed, 119 insertions(+), 5 deletions(-) diff --git a/mamba_ssm/ops/triton/selective_state_update.py b/mamba_ssm/ops/triton/selective_state_update.py index bc78de90..205bfd3c 100644 --- a/mamba_ssm/ops/triton/selective_state_update.py +++ b/mamba_ssm/ops/triton/selective_state_update.py @@ -18,11 +18,12 @@ @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) +@triton.heuristics({"HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"] is not None}) @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) @triton.jit def _selective_scan_update_kernel( # Pointers to matrices - state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr, + state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr, state_batch_indices_ptr, # Matrix dimensions batch, nheads, dim, dstate, nheads_ngroups_ratio, # Strides @@ -43,12 +44,20 @@ def _selective_scan_update_kernel( HAS_DT_BIAS: tl.constexpr, HAS_D: tl.constexpr, HAS_Z: tl.constexpr, + HAS_STATE_BATCH_INDICES: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, ): pid_m = tl.program_id(axis=0) pid_b = tl.program_id(axis=1) pid_h = tl.program_id(axis=2) - state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + + if HAS_STATE_BATCH_INDICES: + state_batch_indices_ptr += pid_b + state_batch_idx = tl.load(state_batch_indices_ptr) + state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head + else: + state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head if HAS_DT_BIAS: @@ -118,7 +127,8 @@ def _selective_scan_update_kernel( tl.store(out_ptrs, out, mask=offs_m < dim) -def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): +def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False, + state_batch_indices=None): """ Argument: state: (batch, dim, dstate) or (batch, nheads, dim, dstate) @@ -152,7 +162,10 @@ def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, z = z.unsqueeze(1) if dt_bias is not None and dt_bias.dim() == 1: dt_bias = dt_bias.unsqueeze(0) - batch, nheads, dim, dstate = state.shape + _, nheads, dim, dstate = state.shape + batch = x.shape[0] + if x.shape != (batch, nheads, dim): + print(f"{state.shape} {x.shape} {batch} {nheads} {dim}") assert x.shape == (batch, nheads, dim) assert dt.shape == x.shape assert A.shape == (nheads, dim, dstate) @@ -166,6 +179,8 @@ def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, assert z.shape == x.shape if dt_bias is not None: assert dt_bias.shape == (nheads, dim) + if state_batch_indices is not None: + assert state_batch_indices.shape == (batch,) out = torch.empty_like(x) grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)) @@ -179,7 +194,7 @@ def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0 with torch.cuda.device(x.device.index): _selective_scan_update_kernel[grid]( - state, x, dt, dt_bias, A, B, C, D, z, out, + state, x, dt, dt_bias, A, B, C, D, z, out, state_batch_indices, batch, nheads, dim, dstate, nheads // ngroups, state.stride(0), state.stride(1), state.stride(2), state.stride(3), x.stride(0), x.stride(1), x.stride(2), diff --git a/tests/ops/triton/test_selective_state_update.py b/tests/ops/triton/test_selective_state_update.py index e81807ae..55408c89 100644 --- a/tests/ops/triton/test_selective_state_update.py +++ b/tests/ops/triton/test_selective_state_update.py @@ -100,3 +100,102 @@ def test_selective_state_update_with_heads(dim, dstate, ngroups, has_z, tie_hdim print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize('itype', [torch.float16]) +@pytest.mark.parametrize("has_z", [False, True]) +# @pytest.mark.parametrize('has_z', [True]) +@pytest.mark.parametrize("dstate", [16, 32, 64]) +# @pytest.mark.parametrize("dstate", [16]) +@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +# @pytest.mark.parametrize("dim", [2048]) +def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) + if itype == torch.bfloat16: + rtol, atol = 6e-2, 6e-2 + if torch.version.hip: + atol *= 2 + # set seed + torch.random.manual_seed(0) + batch_size = 16 + + total_entries = 10 * batch_size + state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device) + state_indices = torch.randperm(total_entries)[:batch_size].to(dtype=torch.int32, device=device) + + x = torch.randn(batch_size, dim, device=device, dtype=itype) + dt = torch.randn(batch_size, dim, device=device, dtype=itype) + dt_bias = torch.rand(dim, device=device) - 4.0 + A = -torch.rand(dim, dstate, device=device) - 1.0 + B = torch.randn(batch_size, dstate, device=device) + C = torch.randn(batch_size, dstate, device=device) + D = torch.randn(dim, device=device) + if has_z: + z = torch.randn_like(x) + else: + z = None + state_ref = state[state_indices,:].detach().clone() + out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, + dt_bias=dt_bias, dt_softplus=True, state_batch_indices=state_indices) + out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + assert torch.allclose(state[state_indices,:], state_ref, rtol=rtol, atol=atol) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +#@pytest.mark.parametrize('itype', [torch.float32]) +@pytest.mark.parametrize("has_z", [False, True]) +# @pytest.mark.parametrize('has_z', [True]) +@pytest.mark.parametrize("tie_hdim", [False, True]) +# @pytest.mark.parametrize('tie_hdim', [True]) +@pytest.mark.parametrize("ngroups", [1, 2, 4]) +# @pytest.mark.parametrize("ngroups", [2]) +@pytest.mark.parametrize("dstate", [16, 32, 64]) +# @pytest.mark.parametrize("dstate", [16]) +@pytest.mark.parametrize("dim", [2048, 4096]) +# @pytest.mark.parametrize("dim", [2048]) +def test_selective_state_update_with_heads_with_batch_indices(dim, dstate, ngroups, has_z, tie_hdim, itype): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2) + if itype == torch.bfloat16: + rtol, atol = 1e-1, 1e-1 + # set seed + torch.random.manual_seed(0) + batch_size = 16 + headdim = 64 + nheads = dim // headdim + + total_entries = 10 * batch_size + state = torch.randn(total_entries, nheads, headdim, dstate, dtype=itype, device=device) + state_indices = torch.randperm(total_entries)[:batch_size].to(dtype=torch.int32, device=device) + + x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype) + if not tie_hdim: + dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype) + dt_bias = torch.rand(nheads, headdim, device=device) - 4.0 + A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0 + D = torch.randn(nheads, headdim, device=device) + else: + dt = repeat(torch.randn(batch_size, nheads, device=device, dtype=itype), "b h -> b h p", p=headdim) + dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim) + A = repeat(-torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate) + D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim) + B = torch.randn(batch_size, ngroups, dstate, device=device) + C = torch.randn(batch_size, ngroups, dstate, device=device) + if has_z: + z = torch.randn_like(x) + else: + z = None + state_ref = state[state_indices,:].detach().clone() + state_og = state[state_indices,:].detach().clone() + out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True, state_batch_indices=state_indices) + out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + assert torch.allclose(state[state_indices,:], state_ref, rtol=rtol, atol=atol) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)