@@ -479,6 +479,37 @@ def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool):
479
479
return False
480
480
481
481
482
+ from vllm_ascend .models .layers .mla import AscendMultiHeadLatentAttention
483
+ from vllm_ascend .ops .activation import AscendQuickGELU , AscendSiluAndMul
484
+ from vllm_ascend .ops .common_fused_moe import AscendFusedMoE
485
+ from vllm_ascend .ops .layernorm import AscendRMSNorm
486
+ from vllm_ascend .ops .linear import (AscendColumnParallelLinear ,
487
+ AscendMergedColumnParallelLinear ,
488
+ AscendQKVParallelLinear ,
489
+ AscendRowParallelLinear )
490
+ from vllm_ascend .ops .rotary_embedding import (
491
+ AscendDeepseekScalingRotaryEmbedding , AscendRotaryEmbedding )
492
+ from vllm_ascend .ops .vocab_parallel_embedding import (
493
+ AscendLogitsProcessor , AscendParallelLMHead , AscendVocabParallelEmbedding )
494
+
495
+ REGISTERED_ASCEND_OPS = {
496
+ "QuickGELU" : AscendQuickGELU ,
497
+ "SiluAndMul" : AscendSiluAndMul ,
498
+ "RotaryEmbedding" : AscendRotaryEmbedding ,
499
+ "ColumnParallelLinear" : AscendColumnParallelLinear ,
500
+ "RowParallelLinear" : AscendRowParallelLinear ,
501
+ "MergedColumnParallelLinear" : AscendMergedColumnParallelLinear ,
502
+ "QKVParallelLinear" : AscendQKVParallelLinear ,
503
+ "DeepseekScalingRotaryEmbedding" : AscendDeepseekScalingRotaryEmbedding ,
504
+ "VocabParallelEmbedding" : AscendVocabParallelEmbedding ,
505
+ "ParallelLMHead" : AscendParallelLMHead ,
506
+ "LogitsProcessor" : AscendLogitsProcessor ,
507
+ "RMSNorm" : AscendRMSNorm ,
508
+ "FusedMoE" : AscendFusedMoE ,
509
+ "MultiHeadLatentAttention" : AscendMultiHeadLatentAttention ,
510
+ }
511
+
512
+
482
513
def register_ascend_customop ():
483
514
"""Register Ascend CustomOP
484
515
@@ -490,48 +521,8 @@ def register_ascend_customop():
490
521
return
491
522
from vllm .model_executor .custom_op import CustomOp
492
523
493
- from vllm_ascend .ops .activation import AscendQuickGELU , AscendSiluAndMul
494
- from vllm_ascend .ops .linear import (AscendColumnParallelLinear ,
495
- AscendMergedColumnParallelLinear ,
496
- AscendQKVParallelLinear ,
497
- AscendRowParallelLinear )
498
- from vllm_ascend .ops .rotary_embedding import (
499
- AscendDeepseekScalingRotaryEmbedding , AscendRotaryEmbedding )
500
- from vllm_ascend .ops .vocab_parallel_embedding import (
501
- AscendLogitsProcessor , AscendParallelLMHead ,
502
- AscendVocabParallelEmbedding )
503
- CustomOp .register_oot (_decorated_op_cls = AscendQuickGELU , name = "QuickGELU" )
504
- CustomOp .register_oot (_decorated_op_cls = AscendSiluAndMul ,
505
- name = "SiluAndMul" )
506
- CustomOp .register_oot (_decorated_op_cls = AscendRotaryEmbedding ,
507
- name = "RotaryEmbedding" )
508
- CustomOp .register_oot (_decorated_op_cls = AscendColumnParallelLinear ,
509
- name = "ColumnParallelLinear" )
510
- CustomOp .register_oot (_decorated_op_cls = AscendRowParallelLinear ,
511
- name = "RowParallelLinear" )
512
- CustomOp .register_oot (_decorated_op_cls = AscendMergedColumnParallelLinear ,
513
- name = "MergedColumnParallelLinear" )
514
- CustomOp .register_oot (_decorated_op_cls = AscendQKVParallelLinear ,
515
- name = "QKVParallelLinear" )
516
- CustomOp .register_oot (
517
- _decorated_op_cls = AscendDeepseekScalingRotaryEmbedding ,
518
- name = "DeepseekScalingRotaryEmbedding" )
519
- CustomOp .register_oot (_decorated_op_cls = AscendVocabParallelEmbedding ,
520
- name = "VocabParallelEmbedding" )
521
- CustomOp .register_oot (_decorated_op_cls = AscendParallelLMHead ,
522
- name = "ParallelLMHead" )
523
- CustomOp .register_oot (_decorated_op_cls = AscendLogitsProcessor ,
524
- name = "LogitsProcessor" )
525
-
526
- from vllm_ascend .ops .layernorm import AscendRMSNorm
527
- CustomOp .register_oot (_decorated_op_cls = AscendRMSNorm , name = "RMSNorm" )
528
-
529
- from vllm_ascend .ops .common_fused_moe import AscendFusedMoE
530
- CustomOp .register_oot (_decorated_op_cls = AscendFusedMoE , name = "FusedMoE" )
531
-
532
- from vllm_ascend .models .layers .mla import AscendMultiHeadLatentAttention
533
- CustomOp .register_oot (_decorated_op_cls = AscendMultiHeadLatentAttention ,
534
- name = "MultiHeadLatentAttention" )
524
+ for name , op_cls in REGISTERED_ASCEND_OPS .items ():
525
+ CustomOp .register_oot (_decorated_op_cls = op_cls , name = name )
535
526
536
527
# NOTE: Keep this at last to ensure all custom actions are registered
537
528
_ASCEND_CUSTOMOP_IS_REIGISTERED = True
0 commit comments