File tree Expand file tree Collapse file tree 2 files changed +8
-5
lines changed Expand file tree Collapse file tree 2 files changed +8
-5
lines changed Original file line number Diff line number Diff line change @@ -77,7 +77,11 @@ def dispatch_rotary_emb_function(
7777 if current_platform .is_cuda ():
7878 return apply_rotary_emb
7979
80- if current_platform .is_rocm ():
80+ # if torch compile is not enabled
81+ # use rotary embedding function from flash_attn package
82+ # otherwise use the naive pytorch embedding implementation
83+ # is faster when torch compile is enabled.
84+ if current_platform .is_rocm () and not torch .compiler .is_compiling ():
8185 if find_spec ("flash_attn" ) is not None :
8286 from flash_attn .ops .triton .rotary import apply_rotary
8387
@@ -87,11 +91,10 @@ def dispatch_rotary_emb_function(
8791 "flash_attn is not installed. Falling back to PyTorch "
8892 "implementation for rotary embeddings."
8993 )
90-
9194 if default is not None :
9295 return default
93- else :
94- return apply_rotary_emb_torch
96+
97+ return apply_rotary_emb_torch
9598
9699
97100# yarn functions
Original file line number Diff line number Diff line change @@ -370,7 +370,7 @@ def forward(
370370 cu_seqlens_k = cu_seqlens ,
371371 max_seqlen_q = max_seqlen ,
372372 max_seqlen_k = max_seqlen ,
373- dropout_p = 0 ,
373+ dropout_p = 0.0 ,
374374 causal = False ,
375375 )
376376
You can’t perform that action at this time.
0 commit comments