Skip to content

Commit 646c1db

Browse files
shaopeng-666shaopeng666
andauthored
Add mrope op fusion (#3509)
### What this PR does / why we need it? Add mrope fusion op for qwen2.5-vl. This mrope operator dosen't support Qwen3-VL currently. Thus could only take affect in qwen2.5-vl - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: shaopeng666 <shaopeng666@noreply.gitcode.com> Co-authored-by: shaopeng666 <shaopeng666@noreply.gitcode.com>
1 parent 0777e2f commit 646c1db

File tree

3 files changed

+123
-4
lines changed

3 files changed

+123
-4
lines changed

tests/ut/ops/test_rotary_embedding.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
from transformers.configuration_utils import PretrainedConfig
77
from vllm.config import ModelConfig, VllmConfig
88
from vllm.model_executor.layers.rotary_embedding import (
9-
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
9+
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding)
1010

1111
from tests.ut.base import TestBase
1212
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
1313
from vllm_ascend.ops.rotary_embedding import _custom_rotary_embedding_enabled
1414

1515
MODEL = "Qwen3-0.6B"
16+
MODEL_VL = "Qwen/Qwen2.5-VL-3B-Instruct"
1617
MAX_NUM_BATCHED_TOKEND = 10000
1718

1819

@@ -376,3 +377,86 @@ def test_yarn_get_mscale(self, mock_npuplatform):
376377
expected,
377378
places=6,
378379
msg=f"Failed for scale={scale}, mscale={mscale}")
380+
381+
382+
class TestAscendMRotaryEmbedding(unittest.TestCase):
383+
384+
def setUp(self):
385+
# Common setup for tests
386+
self.number_tokens = 3
387+
self.num_head = 8
388+
self.num_kvhead = 8
389+
self.head_size = 128
390+
self.max_position_embeddings = 128000
391+
self.is_neox_style = True
392+
self.rope_theta = 1000000.0
393+
self.positions_1d = torch.tensor([1, 2, 3])
394+
self.positions_2d = torch.randint(1, 10, (3, self.number_tokens))
395+
396+
self.query = torch.randn(
397+
(self.number_tokens, self.num_head * self.head_size),
398+
dtype=torch.bfloat16)
399+
self.key = torch.randn(
400+
(self.number_tokens, self.num_kvhead * self.head_size),
401+
dtype=torch.bfloat16)
402+
403+
# Qwen2.5-VL mrope section case
404+
self.mrope_section = [16, 24, 24]
405+
406+
self.layer = MRotaryEmbedding(self.head_size,
407+
self.head_size,
408+
self.max_position_embeddings,
409+
base=self.rope_theta,
410+
is_neox_style=self.is_neox_style,
411+
dtype=torch.bfloat16,
412+
mrope_section=self.mrope_section)
413+
414+
self.mock_config = MagicMock()
415+
self.mock_config.torchair_graph_config.enabled = False
416+
417+
def _create_vllm_config(self):
418+
vllm_config = VllmConfig()
419+
model_config = ModelConfig(MODEL_VL,
420+
tokenizer=MODEL_VL,
421+
max_model_len=MAX_NUM_BATCHED_TOKEND)
422+
model_config.hf_config = PretrainedConfig()
423+
vllm_config.model_config = model_config
424+
return vllm_config
425+
426+
@patch('torch_npu.npu_mrope')
427+
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
428+
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
429+
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
430+
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
431+
def test_forward_oot_1d_positions(self, mock_npu_mrope):
432+
mock_npu_mrope.return_value = (torch.zeros_like(self.query),
433+
torch.zeros_like(self.key))
434+
435+
vllm_config = self._create_vllm_config()
436+
with set_ascend_forward_context(None, vllm_config):
437+
result_q, result_k = self.layer.forward_oot(
438+
self.positions_1d, self.query, self.key)
439+
440+
mock_npu_mrope.assert_called_once()
441+
self.assertFalse(torch.isnan(result_q).any().item())
442+
self.assertFalse(torch.isnan(result_k).any().item())
443+
self.assertEqual(result_q.shape, self.query.shape)
444+
445+
@patch('torch_npu.npu_mrope')
446+
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
447+
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
448+
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
449+
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
450+
def test_forward_oot_2d_positions(self, mock_npu_mrope):
451+
mock_npu_mrope.return_value = (torch.zeros_like(self.query),
452+
torch.zeros_like(self.key))
453+
454+
vllm_config = self._create_vllm_config()
455+
with set_ascend_forward_context(None, vllm_config):
456+
result_q, result_k = self.layer.forward_oot(
457+
self.positions_2d, self.query, self.key)
458+
459+
mock_npu_mrope.assert_called_once()
460+
self.assertFalse(torch.isnan(result_q).any().item())
461+
self.assertFalse(torch.isnan(result_k).any().item())
462+
self.assertEqual(result_q.shape, self.query.shape)

vllm_ascend/ops/rotary_embedding.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import torch_npu
2323
from vllm.forward_context import get_forward_context
2424
from vllm.model_executor.layers.rotary_embedding import (
25-
DeepseekScalingRotaryEmbedding, RotaryEmbedding,
25+
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
2626
YaRNScalingRotaryEmbedding)
2727

2828
from vllm_ascend.platform import NPUPlatform
@@ -395,3 +395,37 @@ def forward(self,
395395
q_pe, k_pe = _rope_forward_oot(self, positions, query, key,
396396
is_neox_style, offsets)
397397
return q_pe, k_pe
398+
399+
400+
class AscendMRotaryEmbedding(MRotaryEmbedding):
401+
402+
def forward_oot(
403+
self,
404+
positions: torch.Tensor,
405+
query: torch.Tensor,
406+
key: torch.Tensor,
407+
):
408+
if self.mrope_section != [16, 24, 24]:
409+
return super().forward_oot(positions, query, key)
410+
411+
import torch_npu
412+
mrope_section = [0, 0, 0
413+
] if positions.ndim == 1 else self.mrope_section
414+
415+
if self.cos_sin_cache.device != query.device: # type: ignore
416+
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
417+
query.device) # type: ignore
418+
419+
if self.cos_sin_cache.dtype != query.dtype: # type: ignore
420+
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
421+
query.dtype) # type: ignore
422+
423+
query, key = torch_npu.npu_mrope(positions,
424+
query.contiguous(),
425+
key.contiguous(),
426+
self.cos_sin_cache.contiguous(),
427+
self.head_size,
428+
mrope_section=mrope_section,
429+
rotary_mode='half')
430+
431+
return query, key

vllm_ascend/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,8 +517,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
517517
AscendReplicatedLinear,
518518
AscendRowParallelLinear)
519519
from vllm_ascend.ops.rotary_embedding import (
520-
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding,
521-
AscendYaRNRotaryEmbedding)
520+
AscendDeepseekScalingRotaryEmbedding, AscendMRotaryEmbedding,
521+
AscendRotaryEmbedding, AscendYaRNRotaryEmbedding)
522522
from vllm_ascend.ops.vocab_parallel_embedding import (
523523
AscendLogitsProcessor, AscendParallelLMHead,
524524
AscendVocabParallelEmbedding)
@@ -528,6 +528,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
528528
"QuickGELU": AscendQuickGELU,
529529
"SiluAndMul": AscendSiluAndMul,
530530
"RotaryEmbedding": AscendRotaryEmbedding,
531+
"MRotaryEmbedding": AscendMRotaryEmbedding,
531532
"ColumnParallelLinear": AscendColumnParallelLinear,
532533
"RowParallelLinear": AscendRowParallelLinear,
533534
"YaRNScalingRotaryEmbedding": AscendYaRNRotaryEmbedding,

0 commit comments

Comments
 (0)