2424
2525# Ops used in the pattern (assume kernels are built and available)
2626RMS_OP = torch .ops ._C .rms_norm .default
27- CONTIGUOUS_OP = torch .ops .aten .contiguous .default
28- # Some graphs canonicalize `.contiguous()` into a clone with memory_format
29- CLONE_OP = torch .ops .aten .clone .default
3027ROPE_OPS : list [torch ._ops .OpOverload ] = [
3128 torch .ops ._C .rotary_embedding .default ,
32- torch .ops .vllm .flashinfer_rotary_embedding .default ,
29+ # torch.ops.vllm.flashinfer_rotary_embedding.default,
3330]
3431FUSED_QK_ROPE_OP = torch .ops ._C .fused_qk_norm_rope .default
3532SPLIT_SIZES_OP = torch .ops .aten .split_with_sizes .default
3633RESHAPE_OP = torch .ops .aten .reshape .default
3734EMPTY_LIKE_OP = torch .ops .aten .empty_like .default
35+ VIEW_OP = torch .ops .aten .view .default
36+ CONTIGUOUS_OP = torch .ops .aten .contiguous .default
3837
3938
4039class QkNormRopePattern :
@@ -85,18 +84,22 @@ def pattern(
8584 cos_sin_cache : torch .Tensor ,
8685 ):
8786 # split qkv -> q,k,v
88- # split_tuple = SPLIT_SIZES_OP(
89- # qkv, [self.q_size, self.kv_size, self.kv_size], -1
90- # )
91- # q = operator.getitem(split_tuple, 0 )
92- # k = operator.getitem(split_tuple, 1 )
93- # v = operator.getitem(split_tuple, 2 )
94- q , k , v = qkv . split ([ self . q_size , self . kv_size , self . kv_size ], dim = - 1 )
87+ # q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
88+ split_tuple = SPLIT_SIZES_OP (
89+ qkv , [ self . q_size , self . kv_size , self . kv_size ], - 1
90+ )
91+ q = operator .getitem (split_tuple , 0 )
92+ k = operator .getitem (split_tuple , 1 )
93+ v = operator . getitem ( split_tuple , 2 )
9594
9695 # Q path: view -> (optional contiguous) -> RMS -> view back to q.shape
97- q_by_head = q .view (* q .shape [:- 1 ], q .shape [- 1 ] // self .head_dim , self .head_dim )
98- q_out = torch .empty_like (q_by_head )
99- q_by_head_contiguous = q_by_head .contiguous ()
96+ # q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
97+ # q_out = torch.empty_like(q_by_head)
98+ # q_by_head_contiguous = q_by_head.contiguous()
99+ q_by_head = VIEW_OP (q , (* q .shape [:- 1 ], q .shape [- 1 ] // self .head_dim , self .head_dim ))
100+ q_out = EMPTY_LIKE_OP (q_by_head )
101+ q_by_head_contiguous = CONTIGUOUS_OP (q_by_head )
102+
100103 qn = auto_functionalized (
101104 RMS_OP ,
102105 result = q_out ,
@@ -105,14 +108,17 @@ def pattern(
105108 epsilon = self .eps ,
106109 )
107110 q_normed_by_head = qn [1 ]
108- # RMS_OP(result=q_out, input=q_by_head_contiguous, weight=q_weight, epsilon=self.eps)
109- # q_normed_by_head = q_out
110- q_flat = q_normed_by_head . view ( q .shape )
111+
112+ # q_flat = q_normed_by_head.view(q.shape)
113+ q_flat = VIEW_OP ( q_normed_by_head , q .shape )
111114
112115 # K path: view -> (optional contiguous) -> RMS -> view back to k.shape
113- k_by_head = k .view (* k .shape [:- 1 ], k .shape [- 1 ] // self .head_dim , self .head_dim )
114- k_out = torch .empty_like (k_by_head )
115- k_by_head_contiguous = k_by_head .contiguous ()
116+ # k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
117+ # k_out = torch.empty_like(k_by_head)
118+ # k_by_head_contiguous = k_by_head.contiguous()
119+ k_by_head = VIEW_OP (k , (* k .shape [:- 1 ], k .shape [- 1 ] // self .head_dim , self .head_dim ))
120+ k_out = EMPTY_LIKE_OP (k_by_head )
121+ k_by_head_contiguous = CONTIGUOUS_OP (k_by_head )
116122 kn = auto_functionalized (
117123 RMS_OP ,
118124 result = k_out ,
@@ -121,9 +127,9 @@ def pattern(
121127 epsilon = self .eps ,
122128 )
123129 k_normed_by_head = kn [1 ]
124- # RMS_OP(result=k_out, input=k_by_head_contiguous, weight=k_weight, epsilon=self.eps)
125- # k_normed_by_head = k_out
126- k_flat = k_normed_by_head . view ( k .shape )
130+
131+ # k_flat = k_normed_by_head.view(k.shape)
132+ k_flat = VIEW_OP ( k_normed_by_head , k .shape )
127133
128134 # RoPE: apply to flattened q/k
129135 rope = auto_functionalized (
@@ -136,15 +142,6 @@ def pattern(
136142 is_neox = self .is_neox ,
137143 )
138144 return rope [1 ], rope [2 ], v
139- # self.rope_op(
140- # positions=positions,
141- # query=q_flat,
142- # key=k_flat,
143- # head_size=self.head_dim,
144- # cos_sin_cache=cos_sin_cache,
145- # is_neox=self.is_neox
146- # )
147- # return q_flat, k_flat, v
148145
149146
150147 def replacement (
@@ -244,36 +241,25 @@ def __init__(self, config: VllmConfig):
244241 "QK Norm+RoPE fusion enabled, but no Attention layers were discovered."
245242 )
246243 return
247-
248244 layer_name , layer = next (iter (attn_layers .items ()))
249245
250- # Derive parameters from the layer to avoid combinatorial loops
251- eps = getattr (getattr (layer , "q_norm" , None ), "variance_epsilon" , None )
252- if not isinstance (eps , float ):
253- eps = 1e-6 # fallback default
254-
255- rope_mod = getattr (layer , "rotary_emb" , None )
256- use_flashinfer = getattr (rope_mod , "use_flashinfer" , False )
257- rope_op = (
258- torch .ops .vllm .flashinfer_rotary_embedding .default
259- if use_flashinfer
260- else torch .ops ._C .rotary_embedding .default
261- )
262- is_neox = getattr (rope_mod , "is_neox_style" , True )
263-
264- try :
265- QkNormRopePattern (
266- layer ,
267- eps = eps ,
268- rope_op = rope_op ,
269- is_neox = is_neox ,
270- ).register (self .patterns )
271- except Exception as e :
272- logger .debug (
273- "Skipping pattern registration for layer %s: %s" ,
274- layer_name ,
275- e ,
276- )
246+ for epsilon in [1e-5 , 1e-6 ]:
247+ for neox in [True , False ]:
248+ for rope_op in ROPE_OPS :
249+ try :
250+ QkNormRopePattern (
251+ layer = layer ,
252+ eps = epsilon ,
253+ rope_op = rope_op ,
254+ is_neox = neox ,
255+ ).register (self .patterns )
256+ except Exception as e :
257+ logger .debug (
258+ "Skipping QkNormRopePattern registration with eps=%s is_neox=%s: %s" ,
259+ epsilon ,
260+ neox ,
261+ e ,
262+ )
277263
278264 # Dump patterns for debugging if enabled
279265 self .dump_patterns (config , self .patterns )
0 commit comments