|
7 | 7 | # mypy: ignore-errors
|
8 | 8 |
|
9 | 9 | import torch
|
10 |
| -import torch.nn.functional as F |
11 | 10 | import triton
|
12 |
| -import triton.language as tl |
13 |
| -from einops import rearrange |
14 |
| - |
15 |
| - |
16 |
| -def rms_norm_ref( |
17 |
| - x, |
18 |
| - weight, |
19 |
| - bias, |
20 |
| - z=None, |
21 |
| - eps=1e-6, |
22 |
| - group_size=None, |
23 |
| - norm_before_gate=True, |
24 |
| - upcast=True, |
25 |
| -): |
26 |
| - dtype = x.dtype |
27 |
| - #N = x.shape[-1] |
28 |
| - weight = weight.float() |
29 |
| - bias = bias.float() if bias is not None else None |
30 |
| - if upcast: |
31 |
| - x = x.float() |
32 |
| - z = z.float() if z is not None else z |
33 |
| - if z is not None and not norm_before_gate: |
34 |
| - x = x * F.silu(z) |
35 |
| - if group_size is None: |
36 |
| - rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) |
37 |
| - out = (x * rstd * weight) + bias if bias is not None else (x * rstd * |
38 |
| - weight) |
39 |
| - else: |
40 |
| - x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) |
41 |
| - rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + |
42 |
| - eps) |
43 |
| - out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight |
44 |
| - if bias is not None: |
45 |
| - out = out + bias |
46 |
| - if z is not None and norm_before_gate: |
47 |
| - out *= F.silu(z) |
48 |
| - return out.to(dtype) |
49 |
| - |
50 |
| - |
51 |
| -@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) |
52 |
| -@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) |
53 |
| -@triton.jit |
54 |
| -def _layer_norm_fwd_1pass_kernel( |
55 |
| - X, # pointer to the input |
56 |
| - Y, # pointer to the output |
57 |
| - W, # pointer to the weights |
58 |
| - B, # pointer to the biases |
59 |
| - Z, # pointer to the other branch |
60 |
| - Mean, # pointer to the mean |
61 |
| - Rstd, # pointer to the 1/std |
62 |
| - stride_x_row, # how much to increase the pointer when moving by 1 row |
63 |
| - stride_y_row, |
64 |
| - stride_z_row, |
65 |
| - M, # number of rows in X |
66 |
| - N, # number of columns in X |
67 |
| - eps, # epsilon to avoid division by zero |
68 |
| - BLOCK_N: tl.constexpr, |
69 |
| - HAS_BIAS: tl.constexpr, |
70 |
| - HAS_Z: tl.constexpr, |
71 |
| - NORM_BEFORE_GATE: tl.constexpr, |
72 |
| - IS_RMS_NORM: tl.constexpr, |
73 |
| -): |
74 |
| - # Map the program id to the row of X and Y it should compute. |
75 |
| - row = tl.program_id(0) |
76 |
| - group = tl.program_id(1) |
77 |
| - X += row * stride_x_row + group * N |
78 |
| - Y += row * stride_y_row + group * N |
79 |
| - if HAS_Z: |
80 |
| - Z += row * stride_z_row + group * N |
81 |
| - if not IS_RMS_NORM: |
82 |
| - Mean += group * M |
83 |
| - Rstd += group * M |
84 |
| - W += group * N |
85 |
| - if HAS_BIAS: |
86 |
| - B += group * N |
87 |
| - # Compute mean and variance |
88 |
| - cols = tl.arange(0, BLOCK_N) |
89 |
| - x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) |
90 |
| - if HAS_Z and not NORM_BEFORE_GATE: |
91 |
| - z = tl.load(Z + cols, mask=cols < N).to(tl.float32) |
92 |
| - x *= z * tl.sigmoid(z) |
93 |
| - if not IS_RMS_NORM: |
94 |
| - mean = tl.sum(x, axis=0) / N |
95 |
| - tl.store(Mean + row, mean) |
96 |
| - xbar = tl.where(cols < N, x - mean, 0.0) |
97 |
| - var = tl.sum(xbar * xbar, axis=0) / N |
98 |
| - else: |
99 |
| - xbar = tl.where(cols < N, x, 0.0) |
100 |
| - var = tl.sum(xbar * xbar, axis=0) / N |
101 |
| - rstd = 1 / tl.sqrt(var + eps) |
102 |
| - tl.store(Rstd + row, rstd) |
103 |
| - # Normalize and apply linear transformation |
104 |
| - mask = cols < N |
105 |
| - w = tl.load(W + cols, mask=mask).to(tl.float32) |
106 |
| - if HAS_BIAS: |
107 |
| - b = tl.load(B + cols, mask=mask).to(tl.float32) |
108 |
| - x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd |
109 |
| - y = x_hat * w + b if HAS_BIAS else x_hat * w |
110 |
| - if HAS_Z and NORM_BEFORE_GATE: |
111 |
| - z = tl.load(Z + cols, mask=mask).to(tl.float32) |
112 |
| - y *= z * tl.sigmoid(z) |
113 |
| - # Write output |
114 |
| - tl.store(Y + cols, y, mask=mask) |
| 11 | +from vllm.model_executor.layers.fla.ops.layernorm_guard import \ |
| 12 | + layer_norm_fwd_kernel |
115 | 13 |
|
116 | 14 |
|
117 | 15 | def _layer_norm_fwd(
|
@@ -158,7 +56,7 @@ def _layer_norm_fwd(
|
158 | 56 | num_warps = min(max(BLOCK_N // 256, 1), 8)
|
159 | 57 | grid = (M, ngroups)
|
160 | 58 | with torch.npu.device(x.device.index):
|
161 |
| - _layer_norm_fwd_1pass_kernel[grid]( |
| 59 | + layer_norm_fwd_kernel[grid]( |
162 | 60 | x,
|
163 | 61 | out,
|
164 | 62 | weight,
|
@@ -220,111 +118,3 @@ def forward(
|
220 | 118 | is_rms_norm=is_rms_norm,
|
221 | 119 | )
|
222 | 120 | return y.reshape(x_shape_og)
|
223 |
| - |
224 |
| - |
225 |
| -def layernorm_fn( |
226 |
| - x, |
227 |
| - weight, |
228 |
| - bias, |
229 |
| - z=None, |
230 |
| - eps=1e-6, |
231 |
| - group_size=None, |
232 |
| - norm_before_gate=True, |
233 |
| - is_rms_norm=False, |
234 |
| -): |
235 |
| - return LayerNormFn.apply(x, weight, bias, z, eps, group_size, |
236 |
| - norm_before_gate, is_rms_norm) |
237 |
| - |
238 |
| - |
239 |
| -def rmsnorm_fn(x, |
240 |
| - weight, |
241 |
| - bias, |
242 |
| - z=None, |
243 |
| - eps=1e-6, |
244 |
| - group_size=None, |
245 |
| - norm_before_gate=True): |
246 |
| - return LayerNormFn.apply(x, weight, bias, z, eps, group_size, |
247 |
| - norm_before_gate, True) |
248 |
| - |
249 |
| - |
250 |
| -class LayerNorm(torch.nn.Module): |
251 |
| - |
252 |
| - def __init__( |
253 |
| - self, |
254 |
| - hidden_size, |
255 |
| - eps=1e-5, |
256 |
| - group_size=None, |
257 |
| - norm_before_gate=True, |
258 |
| - device=None, |
259 |
| - dtype=None, |
260 |
| - ): |
261 |
| - """If group_size is not None, we do GroupNorm with each group having group_size elements. |
262 |
| - group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). |
263 |
| - """ |
264 |
| - |
265 |
| - factory_kwargs = {"device": device, "dtype": dtype} |
266 |
| - super().__init__() |
267 |
| - self.eps = eps |
268 |
| - self.weight = torch.nn.Parameter( |
269 |
| - torch.empty(hidden_size, **factory_kwargs)) |
270 |
| - self.bias = torch.nn.Parameter( |
271 |
| - torch.empty(hidden_size, **factory_kwargs)) |
272 |
| - self.group_size = group_size |
273 |
| - self.norm_before_gate = norm_before_gate |
274 |
| - self.reset_parameters() |
275 |
| - |
276 |
| - def reset_parameters(self): |
277 |
| - torch.nn.init.ones_(self.weight) |
278 |
| - torch.nn.init.zeros_(self.bias) |
279 |
| - |
280 |
| - def forward(self, x, z=None): |
281 |
| - """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" |
282 |
| - return layernorm_fn( |
283 |
| - x, |
284 |
| - self.weight, |
285 |
| - self.bias, |
286 |
| - z=z, |
287 |
| - group_size=self.group_size, |
288 |
| - eps=self.eps, |
289 |
| - norm_before_gate=self.norm_before_gate, |
290 |
| - ) |
291 |
| - |
292 |
| - |
293 |
| -class RMSNormGated(torch.nn.Module): |
294 |
| - |
295 |
| - def __init__( |
296 |
| - self, |
297 |
| - hidden_size, |
298 |
| - eps=1e-5, |
299 |
| - group_size=None, |
300 |
| - norm_before_gate=True, |
301 |
| - device=None, |
302 |
| - dtype=None, |
303 |
| - ): |
304 |
| - """If group_size is not None, we do GroupNorm with each group having group_size elements. |
305 |
| - group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). |
306 |
| - """ |
307 |
| - factory_kwargs = {"device": device, "dtype": dtype} |
308 |
| - super().__init__() |
309 |
| - self.eps = eps |
310 |
| - self.weight = torch.nn.Parameter( |
311 |
| - torch.empty(hidden_size, **factory_kwargs)) |
312 |
| - self.register_parameter("bias", None) |
313 |
| - self.group_size = group_size |
314 |
| - self.norm_before_gate = norm_before_gate |
315 |
| - self.reset_parameters() |
316 |
| - |
317 |
| - def reset_parameters(self): |
318 |
| - torch.nn.init.ones_(self.weight) |
319 |
| - |
320 |
| - def forward(self, x, z=None): |
321 |
| - """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" |
322 |
| - return rmsnorm_fn( |
323 |
| - x, |
324 |
| - self.weight, |
325 |
| - self.bias, |
326 |
| - z=z, |
327 |
| - eps=self.eps, |
328 |
| - group_size=self.group_size, |
329 |
| - norm_before_gate=self.norm_before_gate, |
330 |
| - ) |
0 commit comments