28
28
from vllm .model_executor .layers .fused_moe .layer import (
29
29
FusedMoE , UnquantizedFusedMoEMethod , determine_expert_map )
30
30
31
- from vllm_ascend .utils import vllm_version_is
32
31
33
- if not (vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" )):
34
- from vllm .model_executor .layers .fused_moe .layer import (
35
- FusedMoEParallelConfig , MoEConfig )
36
- else :
37
- MoEConfig = None
32
+
33
+ from vllm .model_executor .layers .fused_moe .layer import (
34
+ FusedMoEParallelConfig , MoEConfig )
35
+
38
36
39
37
from vllm .model_executor .layers .quantization .base_config import (
40
38
QuantizationConfig , QuantizeMethodBase )
@@ -587,10 +585,9 @@ def select_experts(
587
585
class AscendUnquantizedFusedMoEMethod (UnquantizedFusedMoEMethod ):
588
586
589
587
def __init__ (self , moe : MoEConfig = None ):
590
- if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
591
- super ().__init__ ()
592
- else :
593
- super ().__init__ (moe = moe )
588
+
589
+
590
+ super ().__init__ (moe = moe )
594
591
vllm_config = get_current_vllm_config ()
595
592
596
593
ep_group = get_ep_group ()
@@ -731,23 +728,16 @@ def __init__(
731
728
params_dtype = torch .get_default_dtype ()
732
729
733
730
vllm_config = get_current_vllm_config ()
734
- if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
735
- self .ep_size = get_ep_group ().world_size
736
- self .tp_size = get_etp_group ().world_size
737
- self .dp_size = (dp_size if dp_size is not None else
738
- get_dp_group ().world_size )
739
- self .dp_rank = (0 if self .dp_size == 1 else
740
- get_dp_group ().rank_in_group )
741
- else :
742
- self .moe_parallel_config : FusedMoEParallelConfig = (
743
- FusedMoEParallelConfig .make (
744
- tp_size_ = (tp_size if tp_size is not None else
745
- get_tensor_model_parallel_world_size ()),
746
- dp_size_ = (dp_size if dp_size is not None else
747
- get_dp_group ().world_size ),
748
- vllm_parallel_config = vllm_config .parallel_config ))
749
731
750
- self .moe_parallel_config .ep_size = get_ep_group ().world_size
732
+ self .moe_parallel_config : FusedMoEParallelConfig = (
733
+ FusedMoEParallelConfig .make (
734
+ tp_size_ = (tp_size if tp_size is not None else
735
+ get_tensor_model_parallel_world_size ()),
736
+ dp_size_ = (dp_size if dp_size is not None else
737
+ get_dp_group ().world_size ),
738
+ vllm_parallel_config = vllm_config .parallel_config ))
739
+
740
+ self .moe_parallel_config .ep_size = get_ep_group ().world_size
751
741
752
742
self .top_k = top_k
753
743
self .num_experts = num_experts
@@ -772,54 +762,40 @@ def __init__(
772
762
self .local_num_experts , self .expert_map = determine_expert_map (
773
763
self .ep_size ,
774
764
get_ep_group ().rank_in_group , self .global_num_experts )
775
- if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
776
- self .tp_rank = get_etp_group ().rank_in_group
777
- self .ep_rank = get_ep_group ().rank_in_group
778
- else :
779
- self .moe_parallel_config .tp_rank = get_etp_group (
780
- ).rank_in_group
781
- self .moe_parallel_config .ep_rank = get_ep_group ().rank_in_group
765
+
766
+ self .moe_parallel_config .tp_rank = get_etp_group (
767
+ ).rank_in_group
768
+ self .moe_parallel_config .ep_rank = get_ep_group ().rank_in_group
782
769
783
770
else :
784
771
# Adjust TP size for DP attention
785
772
# haven't test its functionality yet, may remove in the future
786
- if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
787
- self .tp_rank = self .tp_size * self .dp_rank
788
- self .ep_rank = 0
789
- self .tp_size = self .tp_size * self .dp_size
790
- self .ep_size = 1
791
- else :
792
- self .moe_parallel_config .tp_rank = self .tp_size * self .dp_rank
793
- self .moe_parallel_config .ep_rank = 0
794
- self .moe_parallel_config .tp_size = self .tp_size * self .dp_size
795
- self .moe_parallel_config .ep_size = 1
773
+
774
+ self .moe_parallel_config .tp_rank = self .tp_size * self .dp_rank
775
+ self .moe_parallel_config .ep_rank = 0
776
+ self .moe_parallel_config .tp_size = self .tp_size * self .dp_size
777
+ self .moe_parallel_config .ep_size = 1
796
778
797
779
self .local_num_experts , self .expert_map = (self .global_num_experts ,
798
780
None )
799
781
if self .scoring_func != "softmax" and not self .use_grouped_topk :
800
782
raise ValueError ("Only softmax scoring function is supported for "
801
783
"non-grouped topk." )
802
- if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
803
- if quant_config is None :
804
- self .quant_method : Optional [QuantizeMethodBase ] = (
805
- AscendUnquantizedFusedMoEMethod ())
806
- else :
807
- self .quant_method = quant_config .get_quant_method (self , prefix )
808
- else :
809
- moe = MoEConfig (
810
- num_experts = self .global_num_experts ,
811
- experts_per_token = top_k ,
812
- hidden_dim = hidden_size ,
813
- num_local_experts = self .local_num_experts ,
814
- moe_parallel_config = self .moe_parallel_config ,
815
- # TODO (bnell): this needs to be fixed for quantized types.
816
- in_dtype = params_dtype ,
817
- )
818
784
819
- if quant_config is None :
820
- self .quant_method = AscendUnquantizedFusedMoEMethod (moe )
821
- else :
822
- self .quant_method = quant_config .get_quant_method (self , prefix )
785
+ moe = MoEConfig (
786
+ num_experts = self .global_num_experts ,
787
+ experts_per_token = top_k ,
788
+ hidden_dim = hidden_size ,
789
+ num_local_experts = self .local_num_experts ,
790
+ moe_parallel_config = self .moe_parallel_config ,
791
+ # TODO (bnell): this needs to be fixed for quantized types.
792
+ in_dtype = params_dtype ,
793
+ )
794
+
795
+ if quant_config is None :
796
+ self .quant_method = AscendUnquantizedFusedMoEMethod (moe )
797
+ else :
798
+ self .quant_method = quant_config .get_quant_method (self , prefix )
823
799
824
800
assert self .quant_method is not None
825
801
0 commit comments