Skip to content

Change interface to selective_state_update for continuous batching #521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions mamba_ssm/ops/triton/selective_state_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
@triton.heuristics({"HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"] is not None})
@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
@triton.jit
def _selective_scan_update_kernel(
# Pointers to matrices
state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr, state_batch_indices_ptr,
# Matrix dimensions
batch, nheads, dim, dstate, nheads_ngroups_ratio,
# Strides
Expand All @@ -43,12 +44,20 @@ def _selective_scan_update_kernel(
HAS_DT_BIAS: tl.constexpr,
HAS_D: tl.constexpr,
HAS_Z: tl.constexpr,
HAS_STATE_BATCH_INDICES: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_b = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head

if HAS_STATE_BATCH_INDICES:
state_batch_indices_ptr += pid_b
state_batch_idx = tl.load(state_batch_indices_ptr)
state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
else:
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head

x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
if HAS_DT_BIAS:
Expand Down Expand Up @@ -118,7 +127,8 @@ def _selective_scan_update_kernel(
tl.store(out_ptrs, out, mask=offs_m < dim)


def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False,
state_batch_indices=None):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
Expand Down Expand Up @@ -152,7 +162,10 @@ def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None,
z = z.unsqueeze(1)
if dt_bias is not None and dt_bias.dim() == 1:
dt_bias = dt_bias.unsqueeze(0)
batch, nheads, dim, dstate = state.shape
_, nheads, dim, dstate = state.shape
batch = x.shape[0]
if x.shape != (batch, nheads, dim):
print(f"{state.shape} {x.shape} {batch} {nheads} {dim}")
assert x.shape == (batch, nheads, dim)
assert dt.shape == x.shape
assert A.shape == (nheads, dim, dstate)
Expand All @@ -166,6 +179,8 @@ def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None,
assert z.shape == x.shape
if dt_bias is not None:
assert dt_bias.shape == (nheads, dim)
if state_batch_indices is not None:
assert state_batch_indices.shape == (batch,)
out = torch.empty_like(x)
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0))
Expand All @@ -179,7 +194,7 @@ def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None,
tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0
with torch.cuda.device(x.device.index):
_selective_scan_update_kernel[grid](
state, x, dt, dt_bias, A, B, C, D, z, out,
state, x, dt, dt_bias, A, B, C, D, z, out, state_batch_indices,
batch, nheads, dim, dstate, nheads // ngroups,
state.stride(0), state.stride(1), state.stride(2), state.stride(3),
x.stride(0), x.stride(1), x.stride(2),
Expand Down
99 changes: 99 additions & 0 deletions tests/ops/triton/test_selective_state_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,102 @@ def test_selective_state_update_with_heads(dim, dstate, ngroups, has_z, tie_hdim
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)

@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
# @pytest.mark.parametrize('itype', [torch.float16])
@pytest.mark.parametrize("has_z", [False, True])
# @pytest.mark.parametrize('has_z', [True])
@pytest.mark.parametrize("dstate", [16, 32, 64])
# @pytest.mark.parametrize("dstate", [16])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
# @pytest.mark.parametrize("dim", [2048])
def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
rtol, atol = 6e-2, 6e-2
if torch.version.hip:
atol *= 2
# set seed
torch.random.manual_seed(0)
batch_size = 16

total_entries = 10 * batch_size
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
state_indices = torch.randperm(total_entries)[:batch_size].to(dtype=torch.int32, device=device)

x = torch.randn(batch_size, dim, device=device, dtype=itype)
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(batch_size, dstate, device=device)
C = torch.randn(batch_size, dstate, device=device)
D = torch.randn(dim, device=device)
if has_z:
z = torch.randn_like(x)
else:
z = None
state_ref = state[state_indices,:].detach().clone()
out = selective_state_update(state, x, dt, A, B, C, D=D, z=z,
dt_bias=dt_bias, dt_softplus=True, state_batch_indices=state_indices)
out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)

print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
assert torch.allclose(state[state_indices,:], state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)


@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
#@pytest.mark.parametrize('itype', [torch.float32])
@pytest.mark.parametrize("has_z", [False, True])
# @pytest.mark.parametrize('has_z', [True])
@pytest.mark.parametrize("tie_hdim", [False, True])
# @pytest.mark.parametrize('tie_hdim', [True])
@pytest.mark.parametrize("ngroups", [1, 2, 4])
# @pytest.mark.parametrize("ngroups", [2])
@pytest.mark.parametrize("dstate", [16, 32, 64])
# @pytest.mark.parametrize("dstate", [16])
@pytest.mark.parametrize("dim", [2048, 4096])
# @pytest.mark.parametrize("dim", [2048])
def test_selective_state_update_with_heads_with_batch_indices(dim, dstate, ngroups, has_z, tie_hdim, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
if itype == torch.bfloat16:
rtol, atol = 1e-1, 1e-1
# set seed
torch.random.manual_seed(0)
batch_size = 16
headdim = 64
nheads = dim // headdim

total_entries = 10 * batch_size
state = torch.randn(total_entries, nheads, headdim, dstate, dtype=itype, device=device)
state_indices = torch.randperm(total_entries)[:batch_size].to(dtype=torch.int32, device=device)

x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
if not tie_hdim:
dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
D = torch.randn(nheads, headdim, device=device)
else:
dt = repeat(torch.randn(batch_size, nheads, device=device, dtype=itype), "b h -> b h p", p=headdim)
dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim)
A = repeat(-torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate)
D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
B = torch.randn(batch_size, ngroups, dstate, device=device)
C = torch.randn(batch_size, ngroups, dstate, device=device)
if has_z:
z = torch.randn_like(x)
else:
z = None
state_ref = state[state_indices,:].detach().clone()
state_og = state[state_indices,:].detach().clone()
out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True, state_batch_indices=state_indices)
out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)

print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
assert torch.allclose(state[state_indices,:], state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)