@@ -36,6 +36,52 @@ def rope_forward_oot(
36
36
key : torch .Tensor ,
37
37
offsets : Optional [torch .Tensor ] = None ,
38
38
is_neox_style_override : Optional [bool ] = None
39
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
40
+ import torch_npu
41
+ query_shape , key_shape = query .shape , key .shape
42
+ if self .cos_sin_cache .device != query .device :
43
+ self .cos_sin_cache = self .cos_sin_cache .to (query .device )
44
+ if self .cos_sin_cache .dtype != query .dtype :
45
+ self .cos_sin_cache = self .cos_sin_cache .to (query .dtype )
46
+ neox_style = self .is_neox_style
47
+ if is_neox_style_override is not None :
48
+ neox_style = is_neox_style_override
49
+ # adopt custom kernel path for rotary_embedding
50
+ if custom_rotary_embedding_enabled (query , neox_style , self .head_size ):
51
+ query , key = torch .ops ._C .rotary_embedding (
52
+ positions ,
53
+ query ,
54
+ key ,
55
+ self .head_size ,
56
+ self .cos_sin_cache ,
57
+ neox_style ,
58
+ )
59
+ return query .view (query_shape ), key .view (key_shape )
60
+ if offsets is not None :
61
+ raise NotImplementedError (
62
+ "Batched rotary embedding is currently not supported on NPU." )
63
+ else :
64
+ # TODO: Remove the contiguous in the future.
65
+ query = query .contiguous ().view (query .shape [0 ], - 1 )
66
+ key = key .contiguous ().view (key .shape [0 ], - 1 )
67
+ torch_npu ._npu_rotary_embedding (
68
+ positions ,
69
+ query ,
70
+ key ,
71
+ self .head_size ,
72
+ self .cos_sin_cache ,
73
+ neox_style ,
74
+ )
75
+ return query .view (query_shape ), key .view (key_shape )
76
+
77
+
78
+ def rope_forward_oot_npu_mrope (
79
+ self ,
80
+ positions : torch .Tensor ,
81
+ query : torch .Tensor ,
82
+ key : torch .Tensor ,
83
+ offsets : Optional [torch .Tensor ] = None ,
84
+ is_neox_style_override : Optional [bool ] = None
39
85
) -> Tuple [torch .Tensor , torch .Tensor ]:
40
86
import torch_npu
41
87
query_shape , key_shape = query .shape , key .shape
@@ -269,8 +315,10 @@ def deepseek_rope_init_func(
269
315
dtype = dtype ,
270
316
device = "npu" )
271
317
272
-
273
- RotaryEmbedding .forward_oot = rope_forward_oot
318
+ if vllm_ascend .envs .VLLM_ASCEND_ENABLE_NPU_MROPE :
319
+ RotaryEmbedding .forward_oot = rope_forward_oot_npu_mrope
320
+ else :
321
+ RotaryEmbedding .forward_oot = rope_forward_oot
274
322
275
323
# Note: we adopt the native huggingface deepseek rope initialization code from
276
324
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for
0 commit comments