|
22 | 22 | from vllm.model_executor.layers.rotary_embedding import (
|
23 | 23 | DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
24 | 24 |
|
| 25 | +import vllm_ascend.envs as envs |
25 | 26 | from vllm_ascend.platform import CUSTOM_OP_ENABLED
|
26 | 27 |
|
27 | 28 |
|
@@ -75,6 +76,52 @@ def rope_forward_oot(
|
75 | 76 | return query.view(query_shape), key.view(key_shape)
|
76 | 77 |
|
77 | 78 |
|
| 79 | +def rope_forward_oot_npu_mrope( |
| 80 | + self, |
| 81 | + positions: torch.Tensor, |
| 82 | + query: torch.Tensor, |
| 83 | + key: torch.Tensor, |
| 84 | + offsets: Optional[torch.Tensor] = None, |
| 85 | + is_neox_style_override: Optional[bool] = None |
| 86 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 87 | + import torch_npu |
| 88 | + query_shape, key_shape = query.shape, key.shape |
| 89 | + if self.cos_sin_cache.device != query.device: |
| 90 | + self.cos_sin_cache = self.cos_sin_cache.to(query.device) |
| 91 | + if self.cos_sin_cache.dtype != query.dtype: |
| 92 | + self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) |
| 93 | + neox_style = self.is_neox_style |
| 94 | + if is_neox_style_override is not None: |
| 95 | + neox_style = is_neox_style_override |
| 96 | + # adopt custom kernel path for rotary_embedding |
| 97 | + if custom_rotary_embedding_enabled(query, neox_style, self.head_size): |
| 98 | + query, key = torch.ops._C.rotary_embedding( |
| 99 | + positions, |
| 100 | + query, |
| 101 | + key, |
| 102 | + self.head_size, |
| 103 | + self.cos_sin_cache, |
| 104 | + neox_style, |
| 105 | + ) |
| 106 | + return query.view(query_shape), key.view(key_shape) |
| 107 | + if offsets is not None: |
| 108 | + raise NotImplementedError( |
| 109 | + "Batched rotary embedding is currently not supported on NPU.") |
| 110 | + else: |
| 111 | + # TODO: Remove the contiguous in the future. |
| 112 | + query = query.contiguous().view(query.shape[0], -1) |
| 113 | + key = key.contiguous().view(key.shape[0], -1) |
| 114 | + query, key = torch_npu.npu_mrope( |
| 115 | + positions, |
| 116 | + query, |
| 117 | + key, |
| 118 | + self.cos_sin_cache, |
| 119 | + self.head_size, |
| 120 | + mrope_section=[0, 0, 0], |
| 121 | + rotary_mode="half" if neox_style else "interleave") |
| 122 | + return query.view(query_shape), key.view(key_shape) |
| 123 | + |
| 124 | + |
78 | 125 | def native_rope_deepseek_forward(self,
|
79 | 126 | positions: torch.Tensor,
|
80 | 127 | query: torch.Tensor,
|
@@ -270,7 +317,10 @@ def deepseek_rope_init_func(
|
270 | 317 | device="npu")
|
271 | 318 |
|
272 | 319 |
|
273 |
| -RotaryEmbedding.forward_oot = rope_forward_oot |
| 320 | +if envs.VLLM_ASCEND_ENABLE_NPU_MROPE: |
| 321 | + RotaryEmbedding.forward_oot = rope_forward_oot_npu_mrope |
| 322 | +else: |
| 323 | + RotaryEmbedding.forward_oot = rope_forward_oot |
274 | 324 |
|
275 | 325 | # Note: we adopt the native huggingface deepseek rope initialization code from
|
276 | 326 | # https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for
|
|
0 commit comments