Skip to content

Commit 3e519e3

Browse files
committed
patch LayerNormFn
Signed-off-by: Icey <1790571317@qq.com>
1 parent 48a6d49 commit 3e519e3

File tree

3 files changed

+9
-215
lines changed

3 files changed

+9
-215
lines changed

vllm_ascend/models/qwen3_next.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
get_tensor_model_parallel_rank,
2020
get_tensor_model_parallel_world_size)
2121
from vllm.forward_context import ForwardContext, get_forward_context
22+
from vllm.model_executor.layers.fla.ops import RMSNormGated
2223
from vllm.model_executor.layers.fla.ops.fused_recurrent import \
2324
fused_recurrent_gated_delta_rule
2425
from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -63,8 +64,6 @@
6364
from vllm.utils import direct_register_custom_op
6465
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
6566

66-
from vllm_ascend.ops.fla import RMSNormGated
67-
6867

6968
def torch_chunk_gated_delta_rule(
7069
query,
@@ -278,6 +277,8 @@ def __init__(
278277
self.norm = RMSNormGated(
279278
self.head_v_dim,
280279
eps=self.layer_norm_epsilon,
280+
norm_before_gate=True,
281+
device="npu",
281282
)
282283

283284
self.out_proj = RowParallelLinear(self.value_dim,

vllm_ascend/ops/fla.py

Lines changed: 3 additions & 213 deletions
Original file line numberDiff line numberDiff line change
@@ -7,111 +7,9 @@
77
# mypy: ignore-errors
88

99
import torch
10-
import torch.nn.functional as F
1110
import triton
12-
import triton.language as tl
13-
from einops import rearrange
14-
15-
16-
def rms_norm_ref(
17-
x,
18-
weight,
19-
bias,
20-
z=None,
21-
eps=1e-6,
22-
group_size=None,
23-
norm_before_gate=True,
24-
upcast=True,
25-
):
26-
dtype = x.dtype
27-
#N = x.shape[-1]
28-
weight = weight.float()
29-
bias = bias.float() if bias is not None else None
30-
if upcast:
31-
x = x.float()
32-
z = z.float() if z is not None else z
33-
if z is not None and not norm_before_gate:
34-
x = x * F.silu(z)
35-
if group_size is None:
36-
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
37-
out = (x * rstd * weight) + bias if bias is not None else (x * rstd *
38-
weight)
39-
else:
40-
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
41-
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) +
42-
eps)
43-
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
44-
if bias is not None:
45-
out = out + bias
46-
if z is not None and norm_before_gate:
47-
out *= F.silu(z)
48-
return out.to(dtype)
49-
50-
51-
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
52-
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
53-
@triton.jit
54-
def _layer_norm_fwd_1pass_kernel(
55-
X, # pointer to the input
56-
Y, # pointer to the output
57-
W, # pointer to the weights
58-
B, # pointer to the biases
59-
Z, # pointer to the other branch
60-
Mean, # pointer to the mean
61-
Rstd, # pointer to the 1/std
62-
stride_x_row, # how much to increase the pointer when moving by 1 row
63-
stride_y_row,
64-
stride_z_row,
65-
M, # number of rows in X
66-
N, # number of columns in X
67-
eps, # epsilon to avoid division by zero
68-
BLOCK_N: tl.constexpr,
69-
HAS_BIAS: tl.constexpr,
70-
HAS_Z: tl.constexpr,
71-
NORM_BEFORE_GATE: tl.constexpr,
72-
IS_RMS_NORM: tl.constexpr,
73-
):
74-
# Map the program id to the row of X and Y it should compute.
75-
row = tl.program_id(0)
76-
group = tl.program_id(1)
77-
X += row * stride_x_row + group * N
78-
Y += row * stride_y_row + group * N
79-
if HAS_Z:
80-
Z += row * stride_z_row + group * N
81-
if not IS_RMS_NORM:
82-
Mean += group * M
83-
Rstd += group * M
84-
W += group * N
85-
if HAS_BIAS:
86-
B += group * N
87-
# Compute mean and variance
88-
cols = tl.arange(0, BLOCK_N)
89-
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
90-
if HAS_Z and not NORM_BEFORE_GATE:
91-
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
92-
x *= z * tl.sigmoid(z)
93-
if not IS_RMS_NORM:
94-
mean = tl.sum(x, axis=0) / N
95-
tl.store(Mean + row, mean)
96-
xbar = tl.where(cols < N, x - mean, 0.0)
97-
var = tl.sum(xbar * xbar, axis=0) / N
98-
else:
99-
xbar = tl.where(cols < N, x, 0.0)
100-
var = tl.sum(xbar * xbar, axis=0) / N
101-
rstd = 1 / tl.sqrt(var + eps)
102-
tl.store(Rstd + row, rstd)
103-
# Normalize and apply linear transformation
104-
mask = cols < N
105-
w = tl.load(W + cols, mask=mask).to(tl.float32)
106-
if HAS_BIAS:
107-
b = tl.load(B + cols, mask=mask).to(tl.float32)
108-
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
109-
y = x_hat * w + b if HAS_BIAS else x_hat * w
110-
if HAS_Z and NORM_BEFORE_GATE:
111-
z = tl.load(Z + cols, mask=mask).to(tl.float32)
112-
y *= z * tl.sigmoid(z)
113-
# Write output
114-
tl.store(Y + cols, y, mask=mask)
11+
from vllm.model_executor.layers.fla.ops.layernorm_guard import \
12+
layer_norm_fwd_kernel
11513

11614

11715
def _layer_norm_fwd(
@@ -158,7 +56,7 @@ def _layer_norm_fwd(
15856
num_warps = min(max(BLOCK_N // 256, 1), 8)
15957
grid = (M, ngroups)
16058
with torch.npu.device(x.device.index):
161-
_layer_norm_fwd_1pass_kernel[grid](
59+
layer_norm_fwd_kernel[grid](
16260
x,
16361
out,
16462
weight,
@@ -220,111 +118,3 @@ def forward(
220118
is_rms_norm=is_rms_norm,
221119
)
222120
return y.reshape(x_shape_og)
223-
224-
225-
def layernorm_fn(
226-
x,
227-
weight,
228-
bias,
229-
z=None,
230-
eps=1e-6,
231-
group_size=None,
232-
norm_before_gate=True,
233-
is_rms_norm=False,
234-
):
235-
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
236-
norm_before_gate, is_rms_norm)
237-
238-
239-
def rmsnorm_fn(x,
240-
weight,
241-
bias,
242-
z=None,
243-
eps=1e-6,
244-
group_size=None,
245-
norm_before_gate=True):
246-
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
247-
norm_before_gate, True)
248-
249-
250-
class LayerNorm(torch.nn.Module):
251-
252-
def __init__(
253-
self,
254-
hidden_size,
255-
eps=1e-5,
256-
group_size=None,
257-
norm_before_gate=True,
258-
device=None,
259-
dtype=None,
260-
):
261-
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
262-
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
263-
"""
264-
265-
factory_kwargs = {"device": device, "dtype": dtype}
266-
super().__init__()
267-
self.eps = eps
268-
self.weight = torch.nn.Parameter(
269-
torch.empty(hidden_size, **factory_kwargs))
270-
self.bias = torch.nn.Parameter(
271-
torch.empty(hidden_size, **factory_kwargs))
272-
self.group_size = group_size
273-
self.norm_before_gate = norm_before_gate
274-
self.reset_parameters()
275-
276-
def reset_parameters(self):
277-
torch.nn.init.ones_(self.weight)
278-
torch.nn.init.zeros_(self.bias)
279-
280-
def forward(self, x, z=None):
281-
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
282-
return layernorm_fn(
283-
x,
284-
self.weight,
285-
self.bias,
286-
z=z,
287-
group_size=self.group_size,
288-
eps=self.eps,
289-
norm_before_gate=self.norm_before_gate,
290-
)
291-
292-
293-
class RMSNormGated(torch.nn.Module):
294-
295-
def __init__(
296-
self,
297-
hidden_size,
298-
eps=1e-5,
299-
group_size=None,
300-
norm_before_gate=True,
301-
device=None,
302-
dtype=None,
303-
):
304-
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
305-
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
306-
"""
307-
factory_kwargs = {"device": device, "dtype": dtype}
308-
super().__init__()
309-
self.eps = eps
310-
self.weight = torch.nn.Parameter(
311-
torch.empty(hidden_size, **factory_kwargs))
312-
self.register_parameter("bias", None)
313-
self.group_size = group_size
314-
self.norm_before_gate = norm_before_gate
315-
self.reset_parameters()
316-
317-
def reset_parameters(self):
318-
torch.nn.init.ones_(self.weight)
319-
320-
def forward(self, x, z=None):
321-
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
322-
return rmsnorm_fn(
323-
x,
324-
self.weight,
325-
self.bias,
326-
z=z,
327-
eps=self.eps,
328-
group_size=self.group_size,
329-
norm_before_gate=self.norm_before_gate,
330-
)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import vllm.model_executor.layers.fla.ops.fused_recurrent
2+
import vllm.model_executor.layers.fla.ops.layernorm_guard
23
import vllm.model_executor.layers.mamba.ops.causal_conv1d
34

45
from vllm_ascend.ops.casual_conv1d import (causal_conv1d_fn,
56
causal_conv1d_update_npu)
7+
from vllm_ascend.ops.fla import LayerNormFn
68
from vllm_ascend.ops.sigmoid_gating import \
79
fused_recurrent_gated_delta_rule_fwd_kernel
810

911
vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_update = causal_conv1d_update_npu
1012
vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_fn = causal_conv1d_fn
1113
vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel
14+
vllm.model_executor.layers.fla.ops.layernorm_guard.LayerNormFn = LayerNormFn

0 commit comments

Comments
 (0)