Skip to content

Commit cf67619

Browse files
committed
lint: fix lint error
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
1 parent 9a88181 commit cf67619

File tree

3 files changed

+24
-25
lines changed

3 files changed

+24
-25
lines changed

tests/compile/test_qk_norm_rope_fusion.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
import torch
66

77
from tests.compile.backend import LazyInitPass, TestBackend
8+
from vllm.attention import Attention, AttentionType
89
from vllm.compilation.noop_elimination import NoOpEliminationPass
910
from vllm.compilation.post_cleanup import PostCleanupPass
1011
from vllm.compilation.qk_norm_rope_fusion import (
1112
FUSED_QK_ROPE_OP,
12-
QKNormRoPEFusionPass,
1313
RMS_OP,
14+
QKNormRoPEFusionPass,
1415
)
1516
from vllm.config import (
1617
CompilationConfig,
@@ -21,9 +22,9 @@
2122
)
2223
from vllm.model_executor.layers.layernorm import RMSNorm
2324
from vllm.model_executor.layers.rotary_embedding import get_rope
24-
from vllm.attention import Attention, AttentionType
2525
from vllm.platforms import current_platform
2626

27+
2728
class QKNormRoPETestModel(torch.nn.Module):
2829
"""A minimal model that exercises the unfused Q/K RMSNorm + RoPE pattern.
2930
@@ -73,15 +74,11 @@ def forward(self, qkv: torch.Tensor, positions: torch.Tensor):
7374
# Unfused baseline: split, per-head RMS, then RoPE
7475
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
7576

76-
q_by_head = q.view(
77-
*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim
78-
)
77+
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
7978
q_by_head = self.q_norm(q_by_head)
8079
q = q_by_head.view(q.shape)
8180

82-
k_by_head = k.view(
83-
*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim
84-
)
81+
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
8582
k_by_head = self.k_norm(k_by_head)
8683
k = k_by_head.view(k.shape)
8784

@@ -93,7 +90,8 @@ def forward(self, qkv: torch.Tensor, positions: torch.Tensor):
9390
@pytest.mark.parametrize("T", [17])
9491
@pytest.mark.parametrize("num_heads, num_kv_heads, head_dim", [(16, 2, 128)])
9592
@pytest.mark.skipif(
96-
not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm",
93+
not current_platform.is_cuda_alike(),
94+
reason="Only test on CUDA and ROCm",
9795
)
9896
def test_qk_norm_rope_fusion(dtype, T, num_heads, num_kv_heads, head_dim):
9997
torch.set_default_device("cuda")

vllm/compilation/qk_norm_rope_fusion.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from vllm.logger import init_logger
1616
from vllm.platforms import current_platform
1717

18-
from .inductor_pass import enable_fake_mode
1918
from .fusion import empty_bf16, empty_i64
19+
from .inductor_pass import enable_fake_mode
2020
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
2121

2222
logger = init_logger(__name__)
@@ -93,10 +93,12 @@ def pattern(
9393
v = operator.getitem(split_tuple, 2)
9494

9595
# Q path: view -> (optional contiguous) -> RMS -> view back to q.shape
96-
# q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
96+
# q_by_head=q.view(*q.shape[:-1],q.shape[-1]//self.head_dim,self.head_dim)
9797
# q_out = torch.empty_like(q_by_head)
9898
# q_by_head_contiguous = q_by_head.contiguous()
99-
q_by_head = VIEW_OP(q, (*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim))
99+
q_by_head = VIEW_OP(
100+
q, (*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
101+
)
100102
q_out = EMPTY_LIKE_OP(q_by_head)
101103
q_by_head_contiguous = CONTIGUOUS_OP(q_by_head)
102104

@@ -113,10 +115,12 @@ def pattern(
113115
q_flat = VIEW_OP(q_normed_by_head, q.shape)
114116

115117
# K path: view -> (optional contiguous) -> RMS -> view back to k.shape
116-
# k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
118+
# k_by_head=k.view(*k.shape[:-1],k.shape[-1]//self.head_dim,self.head_dim)
117119
# k_out = torch.empty_like(k_by_head)
118120
# k_by_head_contiguous = k_by_head.contiguous()
119-
k_by_head = VIEW_OP(k, (*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim))
121+
k_by_head = VIEW_OP(
122+
k, (*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
123+
)
120124
k_out = EMPTY_LIKE_OP(k_by_head)
121125
k_by_head_contiguous = CONTIGUOUS_OP(k_by_head)
122126
kn = auto_functionalized(
@@ -130,7 +134,7 @@ def pattern(
130134

131135
# k_flat = k_normed_by_head.view(k.shape)
132136
k_flat = VIEW_OP(k_normed_by_head, k.shape)
133-
137+
134138
# RoPE: apply to flattened q/k
135139
rope = auto_functionalized(
136140
self.rope_op,
@@ -143,7 +147,6 @@ def pattern(
143147
)
144148
return rope[1], rope[2], v
145149

146-
147150
def replacement(
148151
qkv: torch.Tensor,
149152
positions: torch.Tensor,
@@ -155,7 +158,7 @@ def replacement(
155158
pos_flat = RESHAPE_OP(positions, [-1])
156159

157160
# Run fused op (mutates qkv)
158-
auto_functionalized(
161+
result = auto_functionalized(
159162
FUSED_QK_ROPE_OP,
160163
qkv=qkv,
161164
num_heads_q=self.num_heads,
@@ -169,18 +172,19 @@ def replacement(
169172
is_neox=self.is_neox,
170173
position_ids=pos_flat,
171174
)
175+
result_qkv = result[1]
172176

173177
# Split back to q,k,v and return
174178
split_tuple = SPLIT_SIZES_OP(
175-
qkv, [self.q_size, self.kv_size, self.kv_size], -1
179+
result_qkv, [self.q_size, self.kv_size, self.kv_size], -1
176180
)
177181
return (
178182
operator.getitem(split_tuple, 0),
179183
operator.getitem(split_tuple, 1),
180184
operator.getitem(split_tuple, 2),
181185
)
182186

183-
# Sample inputs to help pattern tracing (sizes don't have to be exact at runtime)
187+
# Sample inputs to help pattern tracing
184188
T = 5
185189
qkv = empty_bf16(T, self.q_size + 2 * self.kv_size)
186190
positions = empty_i64(T)
@@ -229,9 +233,7 @@ def __init__(self, config: VllmConfig):
229233
)
230234

231235
if not current_platform.is_cuda_alike():
232-
logger.debug(
233-
"QK Norm+RoPE fusion not enabled: unsupported platform"
234-
)
236+
logger.debug("QK Norm+RoPE fusion not enabled: unsupported platform")
235237
return
236238

237239
# Register a pattern per attention layer, as sizes differ by shard
@@ -255,7 +257,8 @@ def __init__(self, config: VllmConfig):
255257
).register(self.patterns)
256258
except Exception as e:
257259
logger.debug(
258-
"Skipping QkNormRopePattern registration with eps=%s is_neox=%s: %s",
260+
"Skipping QkNormRopePattern register with eps=%s "
261+
"is_neox=%s: %s",
259262
epsilon,
260263
neox,
261264
e,

vllm/config/compilation.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -631,8 +631,6 @@ def __post_init__(self) -> None:
631631

632632
if self.backend == "":
633633
self.backend = current_platform.simple_compile_backend
634-
635-
636634

637635
def init_backend(self, vllm_config: "VllmConfig") -> str | Callable:
638636
"""

0 commit comments

Comments
 (0)