diff --git a/csrc/selective_scan/selective_scan.cpp b/csrc/selective_scan/selective_scan.cpp index cde867cd..3dabb2b7 100644 --- a/csrc/selective_scan/selective_scan.cpp +++ b/csrc/selective_scan/selective_scan.cpp @@ -229,7 +229,7 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, const c10::optional &D_, const c10::optional &z_, const c10::optional &delta_bias_, - bool delta_softplus) { + bool delta_softplus, const c10::optional &x) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -309,15 +309,20 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, // at::Tensor out = torch::empty_like(u); // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout at::Tensor out = torch::empty_like(delta); - at::Tensor x; - x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type)); + if (x.has_value()){ + auto _x = x.value(); + TORCH_CHECK(_x.scalar_type() == weight_type); + TORCH_CHECK(_x.is_cuda()); + TORCH_CHECK(_x.stride(-1) == 1); + CHECK_SHAPE(_x, batch_size, dim, n_chunks, dstate * 2); + } SSMParamsBase params; set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, u, delta, A, B, C, out, z, out_z, D_.has_value() ? D_.value().data_ptr() : nullptr, delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, - x.data_ptr(), + x.value().data_ptr(), has_z, delta_softplus); @@ -330,7 +335,7 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, selective_scan_fwd_cuda(params, stream); }); }); - std::vector result = {out, x}; + std::vector result = {out, x.value()}; if (has_z) { result.push_back(out_z); } return result; } diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index 80e9e37e..9f05340b 100755 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -241,10 +241,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { scan_t running_prefix; if constexpr (!kIsComplex) { // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read - running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f); + running_prefix = chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx] : ( threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f)); // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); } else { - running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f); + running_prefix = chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx] : ( threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f)); // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); } SSMScanPrefixCallbackOp prefix_op(running_prefix); diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index c3596bfe..a93af07f 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -20,7 +20,7 @@ class SelectiveScanFn(torch.autograd.Function): @staticmethod def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): + return_last_state=False, prev_state=None): if u.stride(-1) != 1: u = u.contiguous() if delta.stride(-1) != 1: @@ -39,26 +39,37 @@ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softp if C.dim() == 3: C = rearrange(C, "b dstate l -> b 1 dstate l") ctx.squeeze_C = True - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + n_chunks = int((u.shape[-1] + 2048 - 1) / 2048) + x = torch.zeros( + (u.shape[0], u.shape[1], n_chunks, int(A.shape[1] * 2),), + device=u.device, + dtype=torch.float32, + requires_grad=u.requires_grad + ) + x[:, :, 0, 0::2] = 1 + if prev_state is not None: + x[:, :, 0, 1::2].copy_(prev_state) + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, x) ctx.delta_softplus = delta_softplus ctx.has_z = z is not None last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) if not ctx.has_z: - ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x, prev_state) return out if not return_last_state else (out, last_state) else: - ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) + ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out, prev_state) out_z = rest[0] return out_z if not return_last_state else (out_z, last_state) @staticmethod def backward(ctx, dout, *args): if not ctx.has_z: - u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors + u, delta, A, B, C, D, delta_bias, x, prev_state = ctx.saved_tensors z = None out = None else: - u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors + u, delta, A, B, C, D, z, delta_bias, x, out, prev_state = ctx.saved_tensors + assert prev_state is None, "providing prev_state is not supported in training configuration" if dout.stride(-1) != 1: dout = dout.contiguous() # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the @@ -75,21 +86,20 @@ def backward(ctx, dout, *args): dD if D is not None else None, dz, ddelta_bias if delta_bias is not None else None, - None, - None) + None, None, None) def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): + return_last_state=False, prev_state=None): """if return_last_state is True, returns (out, last_state) - last_state has shape (batch, dim, dstate). Note that the gradient of the last state is + last_state has shape (batch, dim, dstate). Note that the gradient of the last state and prev_state (if provided) is not considered in the backward pass. """ - return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, prev_state) def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): + return_last_state=False, prev_state=None): """ u: r(B D L) delta: r(B D L) @@ -99,6 +109,7 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta D: r(D) z: r(B D L) delta_bias: r(D), fp32 + prev_state: r(B D N), fp32 out: r(B D L) last_state (optional): r(B D dstate) or c(B D dstate) @@ -121,7 +132,7 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta else: B = B.float() C = C.float() - x = A.new_zeros((batch, dim, dstate)) + x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state ys = [] deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) if not is_variable_B: diff --git a/tests/ops/test_selective_scan.py b/tests/ops/test_selective_scan.py index 8a834b3c..10e66a71 100644 --- a/tests/ops/test_selective_scan.py +++ b/tests/ops/test_selective_scan.py @@ -35,8 +35,9 @@ @pytest.mark.parametrize("is_variable_C", [True]) # @pytest.mark.parametrize("is_variable_B", [False, True]) @pytest.mark.parametrize("is_variable_B", [True]) +@pytest.mark.parametrize("scan_chunks", [1,2,3]) def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, - delta_softplus, return_last_state, seqlen, itype, wtype): + delta_softplus, return_last_state, seqlen, itype, wtype, scan_chunks): if varBC_groups > 1 and (not is_variable_B or not is_variable_C): pytest.skip() # This config is not applicable device = 'cuda' @@ -92,13 +93,34 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z u_ref = u.detach().clone().requires_grad_() delta_ref = delta.detach().clone().requires_grad_() delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None - out, *rest = selective_scan_fn( - u, delta, A, B, C, D, z=z, - delta_bias=delta_bias, delta_softplus=delta_softplus, - return_last_state=return_last_state - ) - if return_last_state: - state = rest[0] + state = None + state_ref = None + outs = [] + for c in range(scan_chunks): + chunked_prompt_len = seqlen // scan_chunks + chunk_start = chunked_prompt_len * c + chunk_end = chunked_prompt_len * (c + 1) + if c == scan_chunks - 1: + chunk_end = seqlen + _B = B + if is_variable_B: + _B = B[...,chunk_start:chunk_end] + _C = C + if is_variable_B: + _C = C[...,chunk_start:chunk_end] + _z = z + if has_z: + _z = z[...,chunk_start:chunk_end] + out, *rest = selective_scan_fn( + u[...,chunk_start:chunk_end], delta[...,chunk_start:chunk_end], A, _B, _C, D, z=_z, + delta_bias=delta_bias, delta_softplus=delta_softplus, + return_last_state=return_last_state,prev_state=state if c > 0 else None + ) + outs.append(out) + if return_last_state: + state = rest[0] + if len(outs) > 1: + out = torch.cat(outs,dim=-1) out_ref, *rest = selective_scan_ref( u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref, delta_bias=delta_bias_ref, delta_softplus=delta_softplus, @@ -115,6 +137,9 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z if return_last_state: print(f'State max diff: {(state - state_ref).abs().max().item()}') assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + if scan_chunks > 1: + # skip grad test in case of scan chunks ( not supported atm ) + return g = torch.randn_like(out) out_ref.backward(g)