Skip to content

Commit 586abff

Browse files
committed
refacotr util registry
Signed-off-by: Icey <1790571317@qq.com>
1 parent f48b2b5 commit 586abff

File tree

2 files changed

+38
-46
lines changed

2 files changed

+38
-46
lines changed

tests/ut/test_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from tests.ut.base import TestBase
2626
from vllm_ascend import utils
27+
from vllm_ascend.utils import REGISTERED_ASCEND_OPS
2728

2829

2930
class TestUtils(TestBase):
@@ -302,14 +303,14 @@ def test_register_ascend_customop(self, mock_ascend_rmsnorm,
302303

303304
# ascend custom op is not registered
304305
utils.register_ascend_customop()
305-
# should call register_oot three
306-
self.assertEqual(mock_customop.register_oot.call_count, 14)
306+
self.assertEqual(mock_customop.register_oot.call_count,
307+
len(REGISTERED_ASCEND_OPS))
307308
self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED)
308309

309310
# ascend custom op is already registered
310311
utils.register_ascend_customop()
311-
# should not register_oot again, thus only called three in this ut
312-
self.assertEqual(mock_customop.register_oot.call_count, 14)
312+
self.assertEqual(mock_customop.register_oot.call_count,
313+
len(REGISTERED_ASCEND_OPS))
313314

314315

315316
class TestProfileExecuteDuration(TestBase):

vllm_ascend/utils.py

Lines changed: 33 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,37 @@ def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool):
479479
return False
480480

481481

482+
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
483+
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
484+
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
485+
from vllm_ascend.ops.layernorm import AscendRMSNorm
486+
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
487+
AscendMergedColumnParallelLinear,
488+
AscendQKVParallelLinear,
489+
AscendRowParallelLinear)
490+
from vllm_ascend.ops.rotary_embedding import (
491+
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
492+
from vllm_ascend.ops.vocab_parallel_embedding import (
493+
AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding)
494+
495+
REGISTERED_ASCEND_OPS = {
496+
"QuickGELU": AscendQuickGELU,
497+
"SiluAndMul": AscendSiluAndMul,
498+
"RotaryEmbedding": AscendRotaryEmbedding,
499+
"ColumnParallelLinear": AscendColumnParallelLinear,
500+
"RowParallelLinear": AscendRowParallelLinear,
501+
"MergedColumnParallelLinear": AscendMergedColumnParallelLinear,
502+
"QKVParallelLinear": AscendQKVParallelLinear,
503+
"DeepseekScalingRotaryEmbedding": AscendDeepseekScalingRotaryEmbedding,
504+
"VocabParallelEmbedding": AscendVocabParallelEmbedding,
505+
"ParallelLMHead": AscendParallelLMHead,
506+
"LogitsProcessor": AscendLogitsProcessor,
507+
"RMSNorm": AscendRMSNorm,
508+
"FusedMoE": AscendFusedMoE,
509+
"MultiHeadLatentAttention": AscendMultiHeadLatentAttention,
510+
}
511+
512+
482513
def register_ascend_customop():
483514
"""Register Ascend CustomOP
484515
@@ -490,48 +521,8 @@ def register_ascend_customop():
490521
return
491522
from vllm.model_executor.custom_op import CustomOp
492523

493-
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
494-
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
495-
AscendMergedColumnParallelLinear,
496-
AscendQKVParallelLinear,
497-
AscendRowParallelLinear)
498-
from vllm_ascend.ops.rotary_embedding import (
499-
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
500-
from vllm_ascend.ops.vocab_parallel_embedding import (
501-
AscendLogitsProcessor, AscendParallelLMHead,
502-
AscendVocabParallelEmbedding)
503-
CustomOp.register_oot(_decorated_op_cls=AscendQuickGELU, name="QuickGELU")
504-
CustomOp.register_oot(_decorated_op_cls=AscendSiluAndMul,
505-
name="SiluAndMul")
506-
CustomOp.register_oot(_decorated_op_cls=AscendRotaryEmbedding,
507-
name="RotaryEmbedding")
508-
CustomOp.register_oot(_decorated_op_cls=AscendColumnParallelLinear,
509-
name="ColumnParallelLinear")
510-
CustomOp.register_oot(_decorated_op_cls=AscendRowParallelLinear,
511-
name="RowParallelLinear")
512-
CustomOp.register_oot(_decorated_op_cls=AscendMergedColumnParallelLinear,
513-
name="MergedColumnParallelLinear")
514-
CustomOp.register_oot(_decorated_op_cls=AscendQKVParallelLinear,
515-
name="QKVParallelLinear")
516-
CustomOp.register_oot(
517-
_decorated_op_cls=AscendDeepseekScalingRotaryEmbedding,
518-
name="DeepseekScalingRotaryEmbedding")
519-
CustomOp.register_oot(_decorated_op_cls=AscendVocabParallelEmbedding,
520-
name="VocabParallelEmbedding")
521-
CustomOp.register_oot(_decorated_op_cls=AscendParallelLMHead,
522-
name="ParallelLMHead")
523-
CustomOp.register_oot(_decorated_op_cls=AscendLogitsProcessor,
524-
name="LogitsProcessor")
525-
526-
from vllm_ascend.ops.layernorm import AscendRMSNorm
527-
CustomOp.register_oot(_decorated_op_cls=AscendRMSNorm, name="RMSNorm")
528-
529-
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
530-
CustomOp.register_oot(_decorated_op_cls=AscendFusedMoE, name="FusedMoE")
531-
532-
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
533-
CustomOp.register_oot(_decorated_op_cls=AscendMultiHeadLatentAttention,
534-
name="MultiHeadLatentAttention")
524+
for name, op_cls in REGISTERED_ASCEND_OPS.items():
525+
CustomOp.register_oot(_decorated_op_cls=op_cls, name=name)
535526

536527
# NOTE: Keep this at last to ensure all custom actions are registered
537528
_ASCEND_CUSTOMOP_IS_REIGISTERED = True

0 commit comments

Comments
 (0)