Skip to content

Commit 9a88181

Browse files
committed
update qknorm_rope fusion pass and its unit test
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
1 parent d2154f8 commit 9a88181

File tree

2 files changed

+47
-62
lines changed

2 files changed

+47
-62
lines changed

tests/compile/test_qk_norm_rope_fusion.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from vllm.attention import Attention, AttentionType
2525
from vllm.platforms import current_platform
2626

27-
2827
class QKNormRoPETestModel(torch.nn.Module):
2928
"""A minimal model that exercises the unfused Q/K RMSNorm + RoPE pattern.
3029
@@ -105,7 +104,7 @@ def test_qk_norm_rope_fusion(dtype, T, num_heads, num_kv_heads, head_dim):
105104
vllm_config = VllmConfig(
106105
compilation_config=CompilationConfig(
107106
mode=CompilationMode.VLLM_COMPILE,
108-
custom_ops=["+rms_norm", "+rotary_embedding", "+fused_qk_norm_rope"],
107+
custom_ops=["+rms_norm", "+rotary_embedding"],
109108
pass_config=PassConfig(
110109
enable_qk_norm_rope_fusion=True,
111110
enable_noop=True,

vllm/compilation/qk_norm_rope_fusion.py

Lines changed: 46 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,16 @@
2424

2525
# Ops used in the pattern (assume kernels are built and available)
2626
RMS_OP = torch.ops._C.rms_norm.default
27-
CONTIGUOUS_OP = torch.ops.aten.contiguous.default
28-
# Some graphs canonicalize `.contiguous()` into a clone with memory_format
29-
CLONE_OP = torch.ops.aten.clone.default
3027
ROPE_OPS: list[torch._ops.OpOverload] = [
3128
torch.ops._C.rotary_embedding.default,
32-
torch.ops.vllm.flashinfer_rotary_embedding.default,
29+
# torch.ops.vllm.flashinfer_rotary_embedding.default,
3330
]
3431
FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default
3532
SPLIT_SIZES_OP = torch.ops.aten.split_with_sizes.default
3633
RESHAPE_OP = torch.ops.aten.reshape.default
3734
EMPTY_LIKE_OP = torch.ops.aten.empty_like.default
35+
VIEW_OP = torch.ops.aten.view.default
36+
CONTIGUOUS_OP = torch.ops.aten.contiguous.default
3837

3938

4039
class QkNormRopePattern:
@@ -85,18 +84,22 @@ def pattern(
8584
cos_sin_cache: torch.Tensor,
8685
):
8786
# split qkv -> q,k,v
88-
# split_tuple = SPLIT_SIZES_OP(
89-
# qkv, [self.q_size, self.kv_size, self.kv_size], -1
90-
# )
91-
# q = operator.getitem(split_tuple, 0)
92-
# k = operator.getitem(split_tuple, 1)
93-
# v = operator.getitem(split_tuple, 2)
94-
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
87+
# q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
88+
split_tuple = SPLIT_SIZES_OP(
89+
qkv, [self.q_size, self.kv_size, self.kv_size], -1
90+
)
91+
q = operator.getitem(split_tuple, 0)
92+
k = operator.getitem(split_tuple, 1)
93+
v = operator.getitem(split_tuple, 2)
9594

9695
# Q path: view -> (optional contiguous) -> RMS -> view back to q.shape
97-
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
98-
q_out = torch.empty_like(q_by_head)
99-
q_by_head_contiguous = q_by_head.contiguous()
96+
# q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
97+
# q_out = torch.empty_like(q_by_head)
98+
# q_by_head_contiguous = q_by_head.contiguous()
99+
q_by_head = VIEW_OP(q, (*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim))
100+
q_out = EMPTY_LIKE_OP(q_by_head)
101+
q_by_head_contiguous = CONTIGUOUS_OP(q_by_head)
102+
100103
qn = auto_functionalized(
101104
RMS_OP,
102105
result=q_out,
@@ -105,14 +108,17 @@ def pattern(
105108
epsilon=self.eps,
106109
)
107110
q_normed_by_head = qn[1]
108-
# RMS_OP(result=q_out, input=q_by_head_contiguous, weight=q_weight, epsilon=self.eps)
109-
# q_normed_by_head = q_out
110-
q_flat = q_normed_by_head.view(q.shape)
111+
112+
# q_flat = q_normed_by_head.view(q.shape)
113+
q_flat = VIEW_OP(q_normed_by_head, q.shape)
111114

112115
# K path: view -> (optional contiguous) -> RMS -> view back to k.shape
113-
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
114-
k_out = torch.empty_like(k_by_head)
115-
k_by_head_contiguous = k_by_head.contiguous()
116+
# k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
117+
# k_out = torch.empty_like(k_by_head)
118+
# k_by_head_contiguous = k_by_head.contiguous()
119+
k_by_head = VIEW_OP(k, (*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim))
120+
k_out = EMPTY_LIKE_OP(k_by_head)
121+
k_by_head_contiguous = CONTIGUOUS_OP(k_by_head)
116122
kn = auto_functionalized(
117123
RMS_OP,
118124
result=k_out,
@@ -121,9 +127,9 @@ def pattern(
121127
epsilon=self.eps,
122128
)
123129
k_normed_by_head = kn[1]
124-
# RMS_OP(result=k_out, input=k_by_head_contiguous, weight=k_weight, epsilon=self.eps)
125-
# k_normed_by_head = k_out
126-
k_flat = k_normed_by_head.view(k.shape)
130+
131+
# k_flat = k_normed_by_head.view(k.shape)
132+
k_flat = VIEW_OP(k_normed_by_head, k.shape)
127133

128134
# RoPE: apply to flattened q/k
129135
rope = auto_functionalized(
@@ -136,15 +142,6 @@ def pattern(
136142
is_neox=self.is_neox,
137143
)
138144
return rope[1], rope[2], v
139-
# self.rope_op(
140-
# positions=positions,
141-
# query=q_flat,
142-
# key=k_flat,
143-
# head_size=self.head_dim,
144-
# cos_sin_cache=cos_sin_cache,
145-
# is_neox=self.is_neox
146-
# )
147-
# return q_flat, k_flat, v
148145

149146

150147
def replacement(
@@ -244,36 +241,25 @@ def __init__(self, config: VllmConfig):
244241
"QK Norm+RoPE fusion enabled, but no Attention layers were discovered."
245242
)
246243
return
247-
248244
layer_name, layer = next(iter(attn_layers.items()))
249245

250-
# Derive parameters from the layer to avoid combinatorial loops
251-
eps = getattr(getattr(layer, "q_norm", None), "variance_epsilon", None)
252-
if not isinstance(eps, float):
253-
eps = 1e-6 # fallback default
254-
255-
rope_mod = getattr(layer, "rotary_emb", None)
256-
use_flashinfer = getattr(rope_mod, "use_flashinfer", False)
257-
rope_op = (
258-
torch.ops.vllm.flashinfer_rotary_embedding.default
259-
if use_flashinfer
260-
else torch.ops._C.rotary_embedding.default
261-
)
262-
is_neox = getattr(rope_mod, "is_neox_style", True)
263-
264-
try:
265-
QkNormRopePattern(
266-
layer,
267-
eps=eps,
268-
rope_op=rope_op,
269-
is_neox=is_neox,
270-
).register(self.patterns)
271-
except Exception as e:
272-
logger.debug(
273-
"Skipping pattern registration for layer %s: %s",
274-
layer_name,
275-
e,
276-
)
246+
for epsilon in [1e-5, 1e-6]:
247+
for neox in [True, False]:
248+
for rope_op in ROPE_OPS:
249+
try:
250+
QkNormRopePattern(
251+
layer=layer,
252+
eps=epsilon,
253+
rope_op=rope_op,
254+
is_neox=neox,
255+
).register(self.patterns)
256+
except Exception as e:
257+
logger.debug(
258+
"Skipping QkNormRopePattern registration with eps=%s is_neox=%s: %s",
259+
epsilon,
260+
neox,
261+
e,
262+
)
277263

278264
# Dump patterns for debugging if enabled
279265
self.dump_patterns(config, self.patterns)

0 commit comments

Comments
 (0)