Skip to content

Commit d28e1b0

Browse files
committed
add cu_seqlens support and ensure numerical equality
1 parent ce59dae commit d28e1b0

File tree

6 files changed

+193
-28
lines changed

6 files changed

+193
-28
lines changed

csrc/selective_scan/selective_scan.cpp

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ void set_ssm_params_fwd(SSMParamsBase &params,
7979
void* delta_bias_ptr,
8080
void* x_ptr,
8181
bool has_z,
82-
bool delta_softplus) {
82+
bool delta_softplus,
83+
void* cu_seqlens_ptr,
84+
const int cu_seqlens_size) {
8385

8486
// Reset the parameters
8587
memset(&params, 0, sizeof(params));
@@ -109,6 +111,10 @@ void set_ssm_params_fwd(SSMParamsBase &params,
109111
params.x_ptr = x_ptr;
110112
params.z_ptr = has_z ? z.data_ptr() : nullptr;
111113
params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
114+
115+
params.cu_seqlens_ptr = cu_seqlens_ptr;
116+
params.cu_seqlens_size = cu_seqlens_size;
117+
112118
// All stride are in elements, not bytes.
113119
params.A_d_stride = A.stride(0);
114120
params.A_dstate_stride = A.stride(1);
@@ -173,15 +179,17 @@ void set_ssm_params_bwd(SSMParamsBwd &params,
173179
void* ddelta_bias_ptr,
174180
bool has_z,
175181
bool delta_softplus,
176-
bool recompute_out_z) {
182+
bool recompute_out_z,
183+
void* cu_seqlens_ptr,
184+
const int cu_seqlens_size) {
177185
// Pass in "dout" instead of "out", we're not gonna use "out" unless we have z
178186
set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
179187
u, delta, A, B, C, has_z ? out : dout,
180188
has_z ? z : dout,
181189
// If not recompute_out_z, pass dout instead of out_z.
182190
// This won't be used by the bwd kernel
183191
recompute_out_z ? out_z : dout,
184-
D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus);
192+
D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus, cu_seqlens_ptr, cu_seqlens_size);
185193
if (!recompute_out_z) { params.out_z_ptr = nullptr; }
186194

187195
// Set the pointers and strides.
@@ -229,7 +237,8 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
229237
const c10::optional<at::Tensor> &D_,
230238
const c10::optional<at::Tensor> &z_,
231239
const c10::optional<at::Tensor> &delta_bias_,
232-
bool delta_softplus) {
240+
bool delta_softplus,
241+
const c10::optional<at::Tensor> &cu_seqlens_) {
233242
auto input_type = u.scalar_type();
234243
auto weight_type = A.scalar_type();
235244
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@@ -319,7 +328,9 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
319328
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
320329
x.data_ptr(),
321330
has_z,
322-
delta_softplus);
331+
delta_softplus,
332+
cu_seqlens_.has_value() ? cu_seqlens_.value().data_ptr() : nullptr,
333+
cu_seqlens_.has_value() ? cu_seqlens_.value().size(0) : 0);
323334

324335
// Otherwise the kernel will be launched from cuda:0 device
325336
// Cast to char to avoid compiler warning about narrowing
@@ -346,7 +357,8 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
346357
const c10::optional<at::Tensor> &out_,
347358
c10::optional<at::Tensor> &dz_,
348359
bool delta_softplus,
349-
bool recompute_out_z) {
360+
bool recompute_out_z,
361+
const c10::optional<at::Tensor> &cu_seqlens_) {
350362
auto input_type = u.scalar_type();
351363
auto weight_type = A.scalar_type();
352364
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@@ -474,7 +486,9 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
474486
dout, du, ddelta, dA, dB, dC, dz,
475487
D_.has_value() ? dD.data_ptr() : nullptr,
476488
delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr,
477-
has_z, delta_softplus, recompute_out_z);
489+
has_z, delta_softplus, recompute_out_z,
490+
cu_seqlens_.has_value() ? cu_seqlens_.value().data_ptr() : nullptr,
491+
cu_seqlens_.has_value() ? cu_seqlens_.value().size(0) : 0);
478492

479493
// Otherwise the kernel will be launched from cuda:0 device
480494
// Cast to char to avoid compiler warning about narrowing

csrc/selective_scan/selective_scan.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ struct SSMParamsBase {
3333

3434
bool delta_softplus;
3535

36+
int cu_seqlens_size;
37+
3638
index_t A_d_stride;
3739
index_t A_dstate_stride;
3840
index_t B_batch_stride;
@@ -66,6 +68,8 @@ struct SSMParamsBase {
6668
void *__restrict__ x_ptr;
6769
void *__restrict__ z_ptr;
6870
void *__restrict__ out_z_ptr;
71+
72+
void *__restrict__ cu_seqlens_ptr;
6973
};
7074

7175
struct SSMParamsBwd: public SSMParamsBase {

csrc/selective_scan/selective_scan_bwd_kernel.cuh

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
136136
: reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate;
137137
float dD_val = 0;
138138
float ddelta_bias_val = 0;
139+
long *cu_seqlens = reinterpret_cast<long *>(params.cu_seqlens_ptr) + batch_id * params.u_batch_stride
139140

140141
constexpr int kChunkSize = kNThreads * kNItems;
141142
u += (params.n_chunks - 1) * kChunkSize;
@@ -245,7 +246,22 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
245246
#pragma unroll
246247
for (int i = 0; i < kNItems; ++i) {
247248
const float delta_a_exp = exp2f(delta_vals[i] * A_scaled);
249+
250+
// Reset A bar for cumulative sequences (Real)
251+
int left = 1;
252+
int right = params.cu_seqlens_size - 2;
253+
while (left <= right) {
254+
if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) {
255+
delta_a_exp = 0.f;
256+
} else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) {
257+
left = ((left + right) >> 1) + 1;
258+
} else {
259+
right = ((left + right) >> 1) - 1;
260+
}
261+
}
262+
248263
thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]);
264+
249265
if (i == 0) {
250266
smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp;
251267
} else {
@@ -332,6 +348,21 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
332348
for (int i = 0; i < kNItems; ++i) {
333349
// Pytorch's implementation of complex exp (which calls thrust) is very slow
334350
complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled);
351+
352+
// Reset A bar for cumulative sequences (Complex)
353+
int left = 1;
354+
int right = params.cu_seqlens_size - 2;
355+
while (left <= right) {
356+
if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) {
357+
delta_a_exp.real_ = 0.f;
358+
delta_a_exp.imag_ = 0.f;
359+
} else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) {
360+
left = ((left + right) >> 1) + 1;
361+
} else {
362+
right = ((left + right) >> 1) - 1;
363+
}
364+
}
365+
335366
weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]);
336367
thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
337368
if (i == 0) {

csrc/selective_scan/selective_scan_fwd_kernel.cuh

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
107107
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
108108
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
109109
scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;
110+
long *cu_seqlens = reinterpret_cast<long *>(params.cu_seqlens_ptr) + batch_id * params.u_batch_stride
110111

111112
float D_val[kNRows] = {0};
112113
if (params.D_ptr != nullptr) {
@@ -215,6 +216,20 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
215216
if constexpr (!kIsComplex) {
216217
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
217218
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
219+
220+
// Reset A bar for cumulative sequences (Real)
221+
int left = 1;
222+
int right = params.cu_seqlens_size - 2;
223+
while (left <= right) {
224+
if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) {
225+
thread_data[i].x = 0.f;
226+
} else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) {
227+
left = ((left + right) >> 1) + 1;
228+
} else {
229+
right = ((left + right) >> 1) - 1;
230+
}
231+
}
232+
218233
if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
219234
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
220235
thread_data[i] = make_float2(1.f, 0.f);
@@ -225,6 +240,21 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
225240
complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]);
226241
weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i];
227242
thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
243+
244+
// Reset A bar for cumulative sequences (Complex)
245+
int left = 1;
246+
int right = params.cu_seqlens_size - 2;
247+
while (left <= right) {
248+
if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) {
249+
thread_data[i].x = 0.f;
250+
thread_data[i].y = 0.f;
251+
} else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) {
252+
left = ((left + right) >> 1) + 1;
253+
} else {
254+
right = ((left + right) >> 1) - 1;
255+
}
256+
}
257+
228258
if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
229259
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
230260
thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f);

mamba_ssm/modules/mamba_simple.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from einops import rearrange, repeat
1212

13-
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
13+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, selective_scan_ref
1414

1515
try:
1616
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
@@ -116,7 +116,7 @@ def __init__(
116116

117117
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
118118

119-
def forward(self, hidden_states, inference_params=None):
119+
def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
120120
"""
121121
hidden_states: (B, L, D)
122122
Returns: same shape as hidden_states
@@ -157,9 +157,22 @@ def forward(self, hidden_states, inference_params=None):
157157
self.D.float(),
158158
delta_bias=self.dt_proj.bias.float(),
159159
delta_softplus=True,
160+
cu_seqlens=cu_seqlens[0] if cu_seqlens is not None else None,
160161
)
161162
else:
162163
x, z = xz.chunk(2, dim=1)
164+
165+
# (Optional Step1 for cu_seqlens): Right padding zeros at sequence boundary for con1d ops in cumulative sequences
166+
if cu_seqlens is not None:
167+
padded_x = x
168+
count = 0
169+
for idx in cu_seqlens[0][1:-1].tolist():
170+
padded_idx = idx + count*(self.d_conv - 1)
171+
padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], self.d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2)
172+
count = count + 1
173+
x = padded_x
174+
assert x.shape[2] == (self.d_conv - 1) * len(cu_seqlens[0][1:-1]) + z.shape[2]
175+
163176
# Compute short convolution
164177
if conv_state is not None:
165178
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
@@ -175,6 +188,17 @@ def forward(self, hidden_states, inference_params=None):
175188
bias=self.conv1d.bias,
176189
activation=self.activation,
177190
)
191+
192+
# (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences
193+
if cu_seqlens is not None:
194+
mask = []
195+
for seq_len in (cu_seqlens[0][1:] - cu_seqlens[0][:-1]).tolist():
196+
mask.extend([True] * seq_len)
197+
mask.extend([False] * (self.d_conv - 1))
198+
mask = mask[:-(self.d_conv - 1)]
199+
assert x.shape[2] == len(mask)
200+
x = x[:, :, mask]
201+
assert x.shape[2] == z.shape[2]
178202

179203
# We're careful here about the layout, to avoid extra transposes.
180204
# We want dt to have d as the slowest moving dimension
@@ -185,6 +209,7 @@ def forward(self, hidden_states, inference_params=None):
185209
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
186210
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
187211
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
212+
188213
assert self.activation in ["silu", "swish"]
189214
y = selective_scan_fn(
190215
x,
@@ -197,6 +222,7 @@ def forward(self, hidden_states, inference_params=None):
197222
delta_bias=self.dt_proj.bias.float(),
198223
delta_softplus=True,
199224
return_last_state=ssm_state is not None,
225+
cu_seqlens=cu_seqlens[0] if cu_seqlens is not None else None,
200226
)
201227
if ssm_state is not None:
202228
y, last_state = y

0 commit comments

Comments
 (0)