1515from vllm .logger import init_logger
1616from vllm .platforms import current_platform
1717
18- from .inductor_pass import enable_fake_mode
1918from .fusion import empty_bf16 , empty_i64
19+ from .inductor_pass import enable_fake_mode
2020from .vllm_inductor_pass import VllmInductorPass , VllmPatternMatcherPass
2121
2222logger = init_logger (__name__ )
@@ -93,10 +93,12 @@ def pattern(
9393 v = operator .getitem (split_tuple , 2 )
9494
9595 # Q path: view -> (optional contiguous) -> RMS -> view back to q.shape
96- # q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
96+ # q_by_head= q.view(*q.shape[:-1],q.shape[-1]// self.head_dim,self.head_dim)
9797 # q_out = torch.empty_like(q_by_head)
9898 # q_by_head_contiguous = q_by_head.contiguous()
99- q_by_head = VIEW_OP (q , (* q .shape [:- 1 ], q .shape [- 1 ] // self .head_dim , self .head_dim ))
99+ q_by_head = VIEW_OP (
100+ q , (* q .shape [:- 1 ], q .shape [- 1 ] // self .head_dim , self .head_dim )
101+ )
100102 q_out = EMPTY_LIKE_OP (q_by_head )
101103 q_by_head_contiguous = CONTIGUOUS_OP (q_by_head )
102104
@@ -113,10 +115,12 @@ def pattern(
113115 q_flat = VIEW_OP (q_normed_by_head , q .shape )
114116
115117 # K path: view -> (optional contiguous) -> RMS -> view back to k.shape
116- # k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
118+ # k_by_head= k.view(*k.shape[:-1],k.shape[-1]// self.head_dim,self.head_dim)
117119 # k_out = torch.empty_like(k_by_head)
118120 # k_by_head_contiguous = k_by_head.contiguous()
119- k_by_head = VIEW_OP (k , (* k .shape [:- 1 ], k .shape [- 1 ] // self .head_dim , self .head_dim ))
121+ k_by_head = VIEW_OP (
122+ k , (* k .shape [:- 1 ], k .shape [- 1 ] // self .head_dim , self .head_dim )
123+ )
120124 k_out = EMPTY_LIKE_OP (k_by_head )
121125 k_by_head_contiguous = CONTIGUOUS_OP (k_by_head )
122126 kn = auto_functionalized (
@@ -130,7 +134,7 @@ def pattern(
130134
131135 # k_flat = k_normed_by_head.view(k.shape)
132136 k_flat = VIEW_OP (k_normed_by_head , k .shape )
133-
137+
134138 # RoPE: apply to flattened q/k
135139 rope = auto_functionalized (
136140 self .rope_op ,
@@ -143,7 +147,6 @@ def pattern(
143147 )
144148 return rope [1 ], rope [2 ], v
145149
146-
147150 def replacement (
148151 qkv : torch .Tensor ,
149152 positions : torch .Tensor ,
@@ -155,7 +158,7 @@ def replacement(
155158 pos_flat = RESHAPE_OP (positions , [- 1 ])
156159
157160 # Run fused op (mutates qkv)
158- auto_functionalized (
161+ result = auto_functionalized (
159162 FUSED_QK_ROPE_OP ,
160163 qkv = qkv ,
161164 num_heads_q = self .num_heads ,
@@ -169,18 +172,19 @@ def replacement(
169172 is_neox = self .is_neox ,
170173 position_ids = pos_flat ,
171174 )
175+ result_qkv = result [1 ]
172176
173177 # Split back to q,k,v and return
174178 split_tuple = SPLIT_SIZES_OP (
175- qkv , [self .q_size , self .kv_size , self .kv_size ], - 1
179+ result_qkv , [self .q_size , self .kv_size , self .kv_size ], - 1
176180 )
177181 return (
178182 operator .getitem (split_tuple , 0 ),
179183 operator .getitem (split_tuple , 1 ),
180184 operator .getitem (split_tuple , 2 ),
181185 )
182186
183- # Sample inputs to help pattern tracing (sizes don't have to be exact at runtime)
187+ # Sample inputs to help pattern tracing
184188 T = 5
185189 qkv = empty_bf16 (T , self .q_size + 2 * self .kv_size )
186190 positions = empty_i64 (T )
@@ -229,9 +233,7 @@ def __init__(self, config: VllmConfig):
229233 )
230234
231235 if not current_platform .is_cuda_alike ():
232- logger .debug (
233- "QK Norm+RoPE fusion not enabled: unsupported platform"
234- )
236+ logger .debug ("QK Norm+RoPE fusion not enabled: unsupported platform" )
235237 return
236238
237239 # Register a pattern per attention layer, as sizes differ by shard
@@ -255,7 +257,8 @@ def __init__(self, config: VllmConfig):
255257 ).register (self .patterns )
256258 except Exception as e :
257259 logger .debug (
258- "Skipping QkNormRopePattern registration with eps=%s is_neox=%s: %s" ,
260+ "Skipping QkNormRopePattern register with eps=%s "
261+ "is_neox=%s: %s" ,
259262 epsilon ,
260263 neox ,
261264 e ,
0 commit comments