Skip to content

Commit 34264a1

Browse files
committed
1
1 parent bbf2634 commit 34264a1

File tree

2 files changed

+56
-2
lines changed

2 files changed

+56
-2
lines changed

vllm_ascend/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,12 @@
133133
# value to False to disable the optimized model.
134134
"USE_OPTIMIZED_MODEL":
135135
lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))),
136+
# VLLM_ASCEND_ENABLE_NPU_MROPE:
137+
# 0: using npu_rotary_embedding.
138+
# 1: using npu_mrope.
139+
# Just a temporary plan,will be removed after npu_mrope supports aclgraph mode.
140+
"VLLM_ASCEND_ENABLE_NPU_MROPE":
141+
lambda: bool(int(os.getenv('VLLM_ASCEND_ENABLE_NPU_MROPE', '0'))),
136142
}
137143

138144
# end-env-vars-definition

vllm_ascend/ops/rotary_embedding.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,52 @@ def rope_forward_oot(
3636
key: torch.Tensor,
3737
offsets: Optional[torch.Tensor] = None,
3838
is_neox_style_override: Optional[bool] = None
39+
) -> Tuple[torch.Tensor, torch.Tensor]:
40+
import torch_npu
41+
query_shape, key_shape = query.shape, key.shape
42+
if self.cos_sin_cache.device != query.device:
43+
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
44+
if self.cos_sin_cache.dtype != query.dtype:
45+
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
46+
neox_style = self.is_neox_style
47+
if is_neox_style_override is not None:
48+
neox_style = is_neox_style_override
49+
# adopt custom kernel path for rotary_embedding
50+
if custom_rotary_embedding_enabled(query, neox_style, self.head_size):
51+
query, key = torch.ops._C.rotary_embedding(
52+
positions,
53+
query,
54+
key,
55+
self.head_size,
56+
self.cos_sin_cache,
57+
neox_style,
58+
)
59+
return query.view(query_shape), key.view(key_shape)
60+
if offsets is not None:
61+
raise NotImplementedError(
62+
"Batched rotary embedding is currently not supported on NPU.")
63+
else:
64+
# TODO: Remove the contiguous in the future.
65+
query = query.contiguous().view(query.shape[0], -1)
66+
key = key.contiguous().view(key.shape[0], -1)
67+
torch_npu._npu_rotary_embedding(
68+
positions,
69+
query,
70+
key,
71+
self.head_size,
72+
self.cos_sin_cache,
73+
neox_style,
74+
)
75+
return query.view(query_shape), key.view(key_shape)
76+
77+
78+
def rope_forward_oot_npu_mrope(
79+
self,
80+
positions: torch.Tensor,
81+
query: torch.Tensor,
82+
key: torch.Tensor,
83+
offsets: Optional[torch.Tensor] = None,
84+
is_neox_style_override: Optional[bool] = None
3985
) -> Tuple[torch.Tensor, torch.Tensor]:
4086
import torch_npu
4187
query_shape, key_shape = query.shape, key.shape
@@ -269,8 +315,10 @@ def deepseek_rope_init_func(
269315
dtype=dtype,
270316
device="npu")
271317

272-
273-
RotaryEmbedding.forward_oot = rope_forward_oot
318+
if vllm_ascend.envs.VLLM_ASCEND_ENABLE_NPU_MROPE:
319+
RotaryEmbedding.forward_oot = rope_forward_oot_npu_mrope
320+
else:
321+
RotaryEmbedding.forward_oot = rope_forward_oot
274322

275323
# Note: we adopt the native huggingface deepseek rope initialization code from
276324
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for

0 commit comments

Comments
 (0)