-
Notifications
You must be signed in to change notification settings - Fork 462
[0.10.2][bugfix] fix torchair & mtp problems introduced by hybrid kv_cache #2935
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: Angazenn <supperccell@163.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: Angazenn <supperccell@163.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: Angazenn <supperccell@163.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: Angazenn <supperccell@163.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
…al_blocks Signed-off-by: MengqingCao <cmq0113@163.com>
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces fixes for torchhair
and mtp
related to hybrid kv_cache, and adds support for the Qwen3Next
model. While the refactoring for hybrid kv_cache appears solid, the implementation for the new Qwen3Next
model has critical flaws. It incorrectly uses Triton kernels, which are specific to NVIDIA GPUs and will not run on Ascend NPUs, leading to runtime failures. There are also significant performance bottlenecks in the model's Python-based operations and a bug in eagle_proposer.py
from an outdated function call.
attn_metadata_i = self.runner.attn_metadata_builder.build( | ||
common_attn_metadata, self.runner.get_model()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The call to self.runner.attn_metadata_builder.build
is missing the common_prefix_len
argument. The signature for this method has been updated in this PR to build(self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, model: nn.Module)
. This will cause a TypeError
at runtime. A value for common_prefix_len
should be provided, likely 0
as seen in other parts of the codebase.
attn_metadata_i = self.runner.attn_metadata_builder.build( | |
common_attn_metadata, self.runner.get_model()) | |
attn_metadata_i = self.runner.attn_metadata_builder.build( | |
0, common_attn_metadata, self.runner.get_model()) |
attn_metadata = self.runner.attn_metadata_builder.build( | ||
common_attn_metadata, self.runner.model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This call to self.runner.attn_metadata_builder.build
is also missing the common_prefix_len
argument, which will cause a TypeError
. Please add the missing argument, likely with a value of 0
.
attn_metadata = self.runner.attn_metadata_builder.build( | |
common_attn_metadata, self.runner.model) | |
attn_metadata = self.runner.attn_metadata_builder.build( | |
0, common_attn_metadata, self.runner.model) |
@triton.jit() | ||
def _causal_conv1d_update_kernel( | ||
# Pointers to matrices | ||
x_ptr, # (batch, dim, seqlen) | ||
w_ptr, # (dim, width) | ||
bias_ptr, | ||
conv_state_ptr, | ||
cache_seqlens_ptr, # circular buffer | ||
conv_state_indices_ptr, | ||
num_accepted_tokens_ptr, | ||
intermediate_conv_window_ptr, | ||
o_ptr, # (batch, dim, seqlen) | ||
# Matrix dimensions | ||
batch: int, | ||
dim: tl.constexpr, | ||
seqlen: tl.constexpr, | ||
state_len: tl.constexpr, | ||
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines | ||
# Strides | ||
stride_x_seq: tl.constexpr, | ||
stride_x_dim: tl.constexpr, | ||
stride_x_token: tl.constexpr, | ||
stride_w_dim: tl.constexpr, | ||
stride_w_width: tl.constexpr, | ||
stride_conv_state_seq: tl.constexpr, | ||
stride_conv_state_dim: tl.constexpr, | ||
stride_conv_state_tok: tl.constexpr, | ||
stride_state_indices: tl.constexpr, | ||
stride_inter_seq: tl.constexpr, | ||
stride_inter_step: tl.constexpr, | ||
stride_inter_dim: tl.constexpr, | ||
stride_inter_win: tl.constexpr, | ||
stride_o_seq: tl.constexpr, | ||
stride_o_dim: tl.constexpr, | ||
stride_o_token: tl.constexpr, | ||
# others | ||
pad_slot_id: tl.constexpr, | ||
# Meta-parameters | ||
HAS_BIAS: tl.constexpr, | ||
KERNEL_WIDTH: tl.constexpr, | ||
SILU_ACTIVATION: tl.constexpr, | ||
IS_CONTINUOUS_BATCHING: tl.constexpr, | ||
IS_SPEC_DECODING: tl.constexpr, | ||
NP2_STATELEN: tl.constexpr, | ||
USE_PAD_SLOT: tl.constexpr, | ||
BLOCK_N: tl.constexpr, | ||
SAVE_INTERMEDIATE: tl.constexpr, | ||
): | ||
# ruff: noqa: E501 | ||
idx_seq = tl.program_id(0) | ||
if idx_seq >= batch: | ||
return | ||
|
||
# [BLOCK_N,] elements along the feature-dimension (channel) | ||
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) | ||
|
||
if IS_CONTINUOUS_BATCHING: | ||
# mask = idx_seq < batch | ||
conv_state_batch_coord = tl.load(conv_state_indices_ptr + | ||
idx_seq * stride_state_indices).to( | ||
tl.int64) | ||
else: | ||
conv_state_batch_coord = idx_seq | ||
if USE_PAD_SLOT: # noqa | ||
if conv_state_batch_coord == pad_slot_id: | ||
# not processing as this is not the actual sequence | ||
return | ||
|
||
if IS_SPEC_DECODING: | ||
# The rolling of conv state: | ||
# | ||
# Before forward, the conv_state is: | ||
# [history1, history2, ..., historyM]. | ||
# | ||
# After forward, the conv_state becomes: | ||
# [history2, ..., historyM, draft1, draft2, ..., draftN]. | ||
# | ||
# After acceptance, it becomes: | ||
# | ||
# - accept 1 tokens: [history2, ..., historyM, draft1] | ||
# - accept 2 tokens: [history3, ..., historyM, draft1, draft2] | ||
# - and so on. | ||
conv_state_token_offset = tl.load(num_accepted_tokens_ptr + | ||
idx_seq) - 1 | ||
else: | ||
conv_state_token_offset = 0 | ||
|
||
# STEP 1: READ init_state data | ||
conv_states_base = (conv_state_ptr + | ||
(conv_state_batch_coord * stride_conv_state_seq) + | ||
(idx_feats * stride_conv_state_dim)) | ||
mask_w = idx_feats < dim | ||
|
||
prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok | ||
if KERNEL_WIDTH >= 2: | ||
conv_states_ptrs = prior_tokens # [BLOCK_N] | ||
col0 = tl.load(conv_states_ptrs, mask_w, 0.0) | ||
if KERNEL_WIDTH >= 3: | ||
conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] | ||
col1 = tl.load(conv_states_ptrs, mask_w, 0.0) | ||
if KERNEL_WIDTH >= 4: | ||
conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] | ||
col2 = tl.load(conv_states_ptrs, mask_w, 0.0) | ||
if KERNEL_WIDTH == 5: | ||
conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] | ||
#col3 = tl.load(conv_states_ptrs, mask_w, 0.0) | ||
|
||
# STEP 2: assume state_len > seqlen | ||
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] | ||
|
||
# The conv_state updates works in a sliding window manner, | ||
# at each forward pass, the tokens are shift by 1, so we | ||
# load since idx_tokens + 1. | ||
conv_state_ptrs_source = ( | ||
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + | ||
conv_state_token_offset * stride_conv_state_tok + | ||
(idx_feats * stride_conv_state_dim)[None, :] + | ||
((idx_tokens + 1) * stride_conv_state_tok)[:, None] | ||
) # [BLOCK_M, BLOCK_N] | ||
mask = ((conv_state_batch_coord < num_cache_lines) | ||
& ((idx_tokens + seqlen) < state_len)[:, None] | ||
& (idx_feats < dim)[None, :]) | ||
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) | ||
|
||
VAL = state_len - seqlen | ||
x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim | ||
) # [BLOCK_N] | ||
|
||
x_ptrs = (x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] | ||
) # [BLOCK_M, BLOCK_N] | ||
|
||
mask_x = ((idx_tokens - VAL >= 0)[:, None] | ||
& (idx_tokens - VAL < seqlen)[:, None] | ||
& (idx_feats < dim)[None, :] | ||
) # token-index # token-index # feature-index | ||
loaded_x = tl.load(x_ptrs, mask_x, 0.0) | ||
tl.debug_barrier() | ||
|
||
new_conv_state = tl.where(mask, conv_state, loaded_x) | ||
|
||
conv_state_base = (conv_state_ptr + | ||
(conv_state_batch_coord * stride_conv_state_seq) + | ||
(idx_feats * stride_conv_state_dim)) # [BLOCK_N,] | ||
conv_state_ptrs_target = (conv_state_base + | ||
(idx_tokens * stride_conv_state_tok)[:, None] | ||
) # [BLOCK_M, BLOCK_N] | ||
mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] | ||
tl.store(conv_state_ptrs_target, new_conv_state, mask) | ||
|
||
# STEP 3: init accumulator | ||
if HAS_BIAS: | ||
bias = bias_ptr + idx_feats | ||
mask_bias = idx_feats < dim | ||
acc_preload = tl.load(bias, mask=mask_bias, | ||
other=0.0).to(tl.float32) # [BLOCK_N] | ||
else: | ||
acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) | ||
|
||
# STEP 4: | ||
# PRE-LOAD WEIGHTS | ||
# first kernel column, configured for weights to handle BLOCK_N features in range | ||
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] | ||
mask_w = idx_feats < dim | ||
if KERNEL_WIDTH >= 2: | ||
w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor | ||
w_col0 = tl.load(w_ptrs, mask_w, other=0.0) | ||
w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor | ||
w_col1 = tl.load(w_ptrs, mask_w, other=0.0) | ||
if KERNEL_WIDTH >= 3: | ||
w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor | ||
w_col2 = tl.load(w_ptrs, mask_w, other=0.0) | ||
if KERNEL_WIDTH >= 4: | ||
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor | ||
w_col3 = tl.load(w_ptrs, mask_w, other=0.0) | ||
|
||
x_base_1d = x_base # starting of chunk [BLOCK_N] | ||
mask_x_1d = idx_feats < dim | ||
|
||
# STEP 5: compute each token | ||
for idx_token in tl.static_range(seqlen): | ||
acc = acc_preload | ||
|
||
matrix_w = w_col0 | ||
matrix_x = col0 | ||
for j in tl.static_range(KERNEL_WIDTH): | ||
if KERNEL_WIDTH == 2: | ||
if j == 1: # KERNEL_WIDTH-1: | ||
matrix_w = w_col1 | ||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] | ||
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) | ||
elif KERNEL_WIDTH == 3: | ||
if j == 1: | ||
matrix_w = w_col1 | ||
matrix_x = col1 | ||
elif j == 2: | ||
matrix_w = w_col2 | ||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] | ||
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) | ||
elif KERNEL_WIDTH == 4: | ||
if j == 1: | ||
matrix_w = w_col1 | ||
matrix_x = col1 | ||
elif j == 2: | ||
matrix_w = w_col2 | ||
matrix_x = col2 | ||
elif j == 3: | ||
matrix_w = w_col3 | ||
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] | ||
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) | ||
|
||
acc += matrix_x * matrix_w # [BLOCK_N] | ||
|
||
if KERNEL_WIDTH == 2: | ||
col0 = matrix_x | ||
elif KERNEL_WIDTH == 3: | ||
col0 = col1 | ||
col1 = matrix_x | ||
elif KERNEL_WIDTH == 4: | ||
col0 = col1 | ||
col1 = col2 | ||
col2 = matrix_x | ||
|
||
if SILU_ACTIVATION: | ||
acc = acc / (1 + tl.exp(-acc)) | ||
# mask_1d = (idx_token < seqlen) & ( | ||
# idx_feats < dim | ||
# ) # token-index # feature-index | ||
maskL = idx_feats < dim | ||
maskR = tl.full(maskL.shape, False, tl.int1) | ||
mask_1d = tl.where(idx_token < seqlen, maskL, maskR) | ||
|
||
o_ptrs = (o_ptr + (idx_seq) * stride_o_seq + | ||
idx_token * stride_o_token + (idx_feats * stride_o_dim)) | ||
|
||
tl.store(o_ptrs, acc, mask=mask_1d) | ||
|
||
if SAVE_INTERMEDIATE: | ||
# Save the window state after consuming this token | ||
# Layout: [seq(cache line), step, dim, win(K-1)] | ||
base_ptr = (intermediate_conv_window_ptr + | ||
conv_state_batch_coord * stride_inter_seq + | ||
idx_token * stride_inter_step + | ||
idx_feats * stride_inter_dim) | ||
if KERNEL_WIDTH >= 2: | ||
tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w) | ||
if KERNEL_WIDTH >= 3: | ||
tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w) | ||
if KERNEL_WIDTH >= 4: | ||
tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file defines a Triton kernel _causal_conv1d_update_kernel
, which is specific to NVIDIA GPUs and is not compatible with Ascend NPUs. The function causal_conv1d_update_npu
calls this kernel. Since this op is used by the new Qwen3Next
model, it will cause runtime failures on Ascend hardware. The implementation needs to be adapted to use Ascend-compatible kernels (e.g., using CANN/TBE) instead of Triton.
@triton.jit | ||
def _layer_norm_fwd_1pass_kernel( | ||
X, # pointer to the input | ||
Y, # pointer to the output | ||
W, # pointer to the weights | ||
B, # pointer to the biases | ||
Z, # pointer to the other branch | ||
Mean, # pointer to the mean | ||
Rstd, # pointer to the 1/std | ||
stride_x_row, # how much to increase the pointer when moving by 1 row | ||
stride_y_row, | ||
stride_z_row, | ||
M, # number of rows in X | ||
N, # number of columns in X | ||
eps, # epsilon to avoid division by zero | ||
BLOCK_N: tl.constexpr, | ||
HAS_BIAS: tl.constexpr, | ||
HAS_Z: tl.constexpr, | ||
NORM_BEFORE_GATE: tl.constexpr, | ||
IS_RMS_NORM: tl.constexpr, | ||
): | ||
# Map the program id to the row of X and Y it should compute. | ||
row = tl.program_id(0) | ||
group = tl.program_id(1) | ||
X += row * stride_x_row + group * N | ||
Y += row * stride_y_row + group * N | ||
if HAS_Z: | ||
Z += row * stride_z_row + group * N | ||
if not IS_RMS_NORM: | ||
Mean += group * M | ||
Rstd += group * M | ||
W += group * N | ||
if HAS_BIAS: | ||
B += group * N | ||
# Compute mean and variance | ||
cols = tl.arange(0, BLOCK_N) | ||
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) | ||
if HAS_Z and not NORM_BEFORE_GATE: | ||
z = tl.load(Z + cols, mask=cols < N).to(tl.float32) | ||
x *= z * tl.sigmoid(z) | ||
if not IS_RMS_NORM: | ||
mean = tl.sum(x, axis=0) / N | ||
tl.store(Mean + row, mean) | ||
xbar = tl.where(cols < N, x - mean, 0.0) | ||
var = tl.sum(xbar * xbar, axis=0) / N | ||
else: | ||
xbar = tl.where(cols < N, x, 0.0) | ||
var = tl.sum(xbar * xbar, axis=0) / N | ||
rstd = 1 / tl.sqrt(var + eps) | ||
tl.store(Rstd + row, rstd) | ||
# Normalize and apply linear transformation | ||
mask = cols < N | ||
w = tl.load(W + cols, mask=mask).to(tl.float32) | ||
if HAS_BIAS: | ||
b = tl.load(B + cols, mask=mask).to(tl.float32) | ||
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd | ||
y = x_hat * w + b if HAS_BIAS else x_hat * w | ||
if HAS_Z and NORM_BEFORE_GATE: | ||
z = tl.load(Z + cols, mask=mask).to(tl.float32) | ||
y *= z * tl.sigmoid(z) | ||
# Write output | ||
tl.store(Y + cols, y, mask=mask) | ||
|
||
|
||
def _layer_norm_fwd( | ||
x, | ||
weight, | ||
bias, | ||
eps, | ||
z=None, | ||
out=None, | ||
group_size=None, | ||
norm_before_gate=True, | ||
is_rms_norm=False, | ||
): | ||
M, N = x.shape | ||
if group_size is None: | ||
group_size = N | ||
assert N % group_size == 0 | ||
ngroups = N // group_size | ||
assert x.stride(-1) == 1 | ||
if z is not None: | ||
assert z.stride(-1) == 1 | ||
assert z.shape == (M, N) | ||
assert weight.shape == (N, ) | ||
assert weight.stride(-1) == 1 | ||
if bias is not None: | ||
assert bias.stride(-1) == 1 | ||
assert bias.shape == (N, ) | ||
# allocate output | ||
if out is not None: | ||
assert out.shape == x.shape | ||
else: | ||
out = torch.empty_like(x) | ||
assert out.stride(-1) == 1 | ||
mean = (torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) | ||
if not is_rms_norm else None) | ||
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) | ||
# Less than 64KB per feature: enqueue fused kernel | ||
MAX_FUSED_SIZE = 65536 // x.element_size() | ||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) | ||
if group_size > BLOCK_N: | ||
raise RuntimeError( | ||
"This layer norm doesn't support feature dim >= 64KB.") | ||
# heuristics for number of warps | ||
num_warps = min(max(BLOCK_N // 256, 1), 8) | ||
grid = (M, ngroups) | ||
with torch.npu.device(x.device.index): | ||
_layer_norm_fwd_1pass_kernel[grid]( | ||
x, | ||
out, | ||
weight, | ||
bias, | ||
z, | ||
mean, | ||
rstd, | ||
x.stride(0), | ||
out.stride(0), | ||
z.stride(0) if z is not None else 0, | ||
M, | ||
group_size, | ||
eps, | ||
BLOCK_N=BLOCK_N, | ||
NORM_BEFORE_GATE=norm_before_gate, | ||
IS_RMS_NORM=is_rms_norm, | ||
num_warps=num_warps, | ||
) | ||
return out, mean, rstd | ||
|
||
|
||
class LayerNormFn(torch.autograd.Function): | ||
|
||
@staticmethod | ||
def forward( | ||
ctx, | ||
x, | ||
weight, | ||
bias, | ||
z=None, | ||
eps=1e-6, | ||
group_size=None, | ||
norm_before_gate=True, | ||
is_rms_norm=False, | ||
): | ||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" | ||
|
||
x_shape_og = x.shape | ||
# reshape input data into 2D tensor | ||
x = x.reshape(-1, x.shape[-1]) | ||
if x.stride(-1) != 1: | ||
x = x.contiguous() | ||
if z is not None: | ||
assert z.shape == x_shape_og | ||
z = z.reshape(-1, z.shape[-1]) | ||
if z.stride(-1) != 1: | ||
z = z.contiguous() | ||
weight = weight.contiguous() | ||
if bias is not None: | ||
bias = bias.contiguous() | ||
y, mean, rstd = _layer_norm_fwd( | ||
x, | ||
weight, | ||
bias, | ||
eps, | ||
z=z, | ||
group_size=group_size, | ||
norm_before_gate=norm_before_gate, | ||
is_rms_norm=is_rms_norm, | ||
) | ||
return y.reshape(x_shape_og) | ||
|
||
|
||
def layernorm_fn( | ||
x, | ||
weight, | ||
bias, | ||
z=None, | ||
eps=1e-6, | ||
group_size=None, | ||
norm_before_gate=True, | ||
is_rms_norm=False, | ||
): | ||
return LayerNormFn.apply(x, weight, bias, z, eps, group_size, | ||
norm_before_gate, is_rms_norm) | ||
|
||
|
||
def rmsnorm_fn(x, | ||
weight, | ||
bias, | ||
z=None, | ||
eps=1e-6, | ||
group_size=None, | ||
norm_before_gate=True): | ||
return LayerNormFn.apply(x, weight, bias, z, eps, group_size, | ||
norm_before_gate, True) | ||
|
||
|
||
class LayerNorm(torch.nn.Module): | ||
|
||
def __init__( | ||
self, | ||
hidden_size, | ||
eps=1e-5, | ||
group_size=None, | ||
norm_before_gate=True, | ||
device=None, | ||
dtype=None, | ||
): | ||
"""If group_size is not None, we do GroupNorm with each group having group_size elements. | ||
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). | ||
""" | ||
|
||
factory_kwargs = {"device": device, "dtype": dtype} | ||
super().__init__() | ||
self.eps = eps | ||
self.weight = torch.nn.Parameter( | ||
torch.empty(hidden_size, **factory_kwargs)) | ||
self.bias = torch.nn.Parameter( | ||
torch.empty(hidden_size, **factory_kwargs)) | ||
self.group_size = group_size | ||
self.norm_before_gate = norm_before_gate | ||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
torch.nn.init.ones_(self.weight) | ||
torch.nn.init.zeros_(self.bias) | ||
|
||
def forward(self, x, z=None): | ||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" | ||
return layernorm_fn( | ||
x, | ||
self.weight, | ||
self.bias, | ||
z=z, | ||
group_size=self.group_size, | ||
eps=self.eps, | ||
norm_before_gate=self.norm_before_gate, | ||
) | ||
|
||
|
||
class RMSNormGated(torch.nn.Module): | ||
|
||
def __init__( | ||
self, | ||
hidden_size, | ||
eps=1e-5, | ||
group_size=None, | ||
norm_before_gate=True, | ||
device=None, | ||
dtype=None, | ||
): | ||
"""If group_size is not None, we do GroupNorm with each group having group_size elements. | ||
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). | ||
""" | ||
factory_kwargs = {"device": device, "dtype": dtype} | ||
super().__init__() | ||
self.eps = eps | ||
self.weight = torch.nn.Parameter( | ||
torch.empty(hidden_size, **factory_kwargs)) | ||
self.register_parameter("bias", None) | ||
self.group_size = group_size | ||
self.norm_before_gate = norm_before_gate | ||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
torch.nn.init.ones_(self.weight) | ||
|
||
def forward(self, x, z=None): | ||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" | ||
return rmsnorm_fn( | ||
x, | ||
self.weight, | ||
self.bias, | ||
z=z, | ||
eps=self.eps, | ||
group_size=self.group_size, | ||
norm_before_gate=self.norm_before_gate, | ||
) | ||
|
||
|
||
@triton.jit | ||
def fused_gdn_gating_kernel( | ||
g, | ||
A_log, | ||
a, | ||
dt_bias, | ||
seq_len, | ||
NUM_HEADS: tl.constexpr, | ||
beta: tl.constexpr, | ||
threshold: tl.constexpr, | ||
BLK_HEADS: tl.constexpr, | ||
): | ||
i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) | ||
head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) | ||
off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off | ||
mask = head_off < NUM_HEADS | ||
blk_A_log = tl.load(A_log + head_off, mask=mask) | ||
blk_a = tl.load(a + off, mask=mask) | ||
blk_bias = tl.load(dt_bias + head_off, mask=mask) | ||
# If the model is loaded in fp16, without the .float() here, A might be -inf | ||
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) | ||
softplus_x = tl.where(beta * x <= threshold, | ||
(1 / beta) * tl.log(1 + tl.exp(beta * x)), x) | ||
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x | ||
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file defines Triton kernels (_layer_norm_fwd_1pass_kernel
and fused_gdn_gating_kernel
) which are designed for NVIDIA GPUs and are incompatible with Ascend NPUs. These kernels are used in RMSNormGated
and fused_gdn_gating
respectively, which are then used by the Qwen3Next
model. This will cause runtime failures on Ascend hardware. These kernels must be replaced with Ascend-compatible implementations.
@triton.jit(do_not_specialize=['N', 'T']) | ||
def fused_recurrent_gated_delta_rule_fwd_kernel( | ||
q, | ||
k, | ||
v, | ||
g, | ||
beta, | ||
o, | ||
h0, | ||
ht, | ||
cu_seqlens, | ||
ssm_state_indices, | ||
num_accepted_tokens, | ||
scale, | ||
N: tl.constexpr, # num of sequences | ||
T: tl.constexpr, # num of tokens | ||
B: tl.constexpr, | ||
H: tl.constexpr, | ||
HV: tl.constexpr, | ||
K: tl.constexpr, | ||
V: tl.constexpr, | ||
BK: tl.constexpr, | ||
BV: tl.constexpr, | ||
stride_init_state_token: tl.constexpr, | ||
stride_final_state_token: tl.constexpr, | ||
stride_indices_seq: tl.constexpr, | ||
stride_indices_tok: tl.constexpr, | ||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state | ||
INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace | ||
IS_BETA_HEADWISE: tl. | ||
constexpr, # whether beta is headwise vector or scalar, | ||
USE_QK_L2NORM_IN_KERNEL: tl.constexpr, | ||
IS_VARLEN: tl.constexpr, | ||
IS_CONTINUOUS_BATCHING: tl.constexpr, | ||
IS_SPEC_DECODING: tl.constexpr, | ||
): | ||
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) | ||
i_n, i_hv = i_nh // HV, i_nh % HV | ||
i_h = i_hv // (HV // H) | ||
if IS_VARLEN: | ||
bos, eos = tl.load(cu_seqlens + i_n).to( | ||
tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) | ||
all = T | ||
T = eos - bos | ||
else: | ||
bos, eos = i_n * T, i_n * T + T | ||
all = B * T | ||
|
||
if T == 0: | ||
# no tokens to process for this sequence | ||
return | ||
|
||
o_k = i_k * BK + tl.arange(0, BK) | ||
o_v = i_v * BV + tl.arange(0, BV) | ||
|
||
# p_q = q + (bos * H + i_h) * K + o_k | ||
# p_k = k + (bos * H + i_h) * K + o_k | ||
# p_v = v + (bos * HV + i_hv) * V + o_v | ||
# if IS_BETA_HEADWISE: | ||
# p_beta = beta + (bos * HV + i_hv) * V + o_v | ||
# else: | ||
# p_beta = beta + bos * HV + i_hv | ||
# p_g = g + bos * HV + i_hv | ||
# p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v | ||
|
||
mask_k = o_k < K | ||
mask_v = o_v < V | ||
mask_h = mask_k[:, None] & mask_v[None, :] | ||
|
||
b_h = tl.zeros([BK, BV], dtype=tl.float32) | ||
if USE_INITIAL_STATE: | ||
if IS_CONTINUOUS_BATCHING: | ||
if IS_SPEC_DECODING: | ||
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 | ||
else: | ||
i_t = 0 | ||
p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + | ||
i_t).to(tl.int64) * stride_init_state_token | ||
else: | ||
p_h0 = h0 + bos * HV * K * V | ||
p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] | ||
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) | ||
|
||
for i_t in range(0, T): | ||
p_q = q + (bos * H + i_h) * K + o_k + H * K * i_t | ||
p_k = k + (bos * H + i_h) * K + o_k + H * K * i_t | ||
p_v = v + (bos * HV + i_hv) * V + o_v + HV * V * i_t | ||
if IS_BETA_HEADWISE: | ||
p_beta = beta + (bos * HV + i_hv) * V + o_v + HV * V * i_t | ||
else: | ||
p_beta = beta + bos * HV + i_hv + HV * i_t | ||
p_g = g + bos * HV + i_hv + HV * i_t | ||
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + HV * V * i_t | ||
|
||
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) | ||
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) | ||
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) | ||
b_g = tl.load(p_g).to(tl.float32) | ||
|
||
if USE_QK_L2NORM_IN_KERNEL: | ||
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) | ||
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) | ||
b_q = b_q * scale | ||
# [BK, BV] | ||
# b_h *= tl.exp(b_g) | ||
b_h *= exp(b_g) | ||
# [BV] | ||
b_v -= tl.sum(b_h * b_k[:, None], 0) | ||
if IS_BETA_HEADWISE: | ||
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) | ||
else: | ||
b_beta = tl.load(p_beta).to(tl.float32) | ||
b_v *= b_beta | ||
# [BK, BV] | ||
b_h += b_k[:, None] * b_v[None, :] | ||
# [BV] | ||
b_o = tl.sum(b_h * b_q[:, None], 0) | ||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) | ||
|
||
# keep the states for multi-query tokens | ||
if INPLACE_FINAL_STATE: | ||
p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq + | ||
i_t).to(tl.int64) * stride_final_state_token | ||
else: | ||
p_ht = ht + (bos + i_t) * stride_final_state_token | ||
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] | ||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) | ||
|
||
# p_q += H * K | ||
# p_k += H * K | ||
# p_o += HV * V | ||
# p_v += HV * V | ||
# p_g += HV | ||
# p_beta += HV * (V if IS_BETA_HEADWISE else 1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for b_idx in range(batch_size): | ||
start, end = non_spec_query_start_loc[ | ||
b_idx], non_spec_query_start_loc[b_idx + 1] | ||
cur_q = query_non_spec[:, start:end, ...] | ||
cur_k = key_non_spec[:, start:end, ...] | ||
cur_v = value_non_spec[:, start:end, ...] | ||
cur_g = g_non_spec[:, start:end, ...] | ||
cur_b = beta_non_spec[:, start:end, ...] | ||
cur_state = initial_state[b_idx].unsqueeze(0) | ||
|
||
( | ||
cur_core_attn_out_non_spec, | ||
cur_last_recurrent_state, | ||
) = torch_chunk_gated_delta_rule( | ||
query=cur_q, | ||
key=cur_k, | ||
value=cur_v, | ||
g=cur_g, | ||
beta=cur_b, | ||
initial_state=cur_state, | ||
output_final_state=True, | ||
use_qk_l2norm_in_kernel=True, | ||
) | ||
|
||
core_attn_out.append(cur_core_attn_out_non_spec) | ||
last_recurrent_state.append(cur_last_recurrent_state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The loop for b_idx in range(batch_size):
inside the torch_chunk_gated_delta_rule
function iterates over the batch dimension. This is a significant performance bottleneck, especially for prefill where batch sizes can be large, and will lead to very slow performance. This logic should be vectorized to process the entire batch at once.
for i in range(len(seqlens)): | ||
x_s = splits[i] | ||
if cache_indices[i] == PAD_SLOT_ID: | ||
continue | ||
out_ref_b.append( | ||
causal_conv1d_ref( | ||
x_s, | ||
weight, | ||
bias, | ||
activation=activation, | ||
return_final_states=True, | ||
final_states_out=conv_states[cache_indices[i]].unsqueeze(0), | ||
initial_states=conv_states[cache_indices[i]] | ||
if has_initial_state[i] else None)) | ||
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function causal_conv1d_fn
contains a Python loop for i in range(len(seqlens)):
which iterates over sequences in a batch. This is a reference implementation (causal_conv1d_ref
is called inside) and will be very slow. Using this in a performance-critical path like the model forward pass will cause a major performance bottleneck. This should be replaced with a batched, optimized kernel implementation for Ascend.
bf8f5bb
to
83e290a
Compare
Co-authored-by: hust17yixuan <303660421@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com>
83e290a
to
12a0c58
Compare
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?