@@ -490,7 +490,6 @@ def __init__(self, **kwargs) -> None:
490
490
"""
491
491
self .top_k = kwargs .get ("top_k" , 0 )
492
492
self .num_experts = kwargs .get ("num_experts" , 0 )
493
- self .with_quant = kwargs .get ("with_quant" , False )
494
493
495
494
@property
496
495
def ep_group (self ):
@@ -518,7 +517,8 @@ def token_dispatch(self,
518
517
shared_gate_up : Optional [torch .Tensor ] = None ,
519
518
shared_dequant_scale : Optional [torch .Tensor ] = None ,
520
519
mc2_mask : Optional [torch .Tensor ] = None ,
521
- apply_router_weight_on_input : bool = False ):
520
+ apply_router_weight_on_input : bool = False ,
521
+ with_quant : bool = False ):
522
522
raise NotImplementedError ("Dispatch function not implemented." )
523
523
524
524
@abstractmethod
@@ -555,6 +555,7 @@ def __init__(self, **kwargs):
555
555
self .topk_weights = None
556
556
self .shared_experts = None
557
557
self .mc2_mask = None
558
+ self .with_quant = False
558
559
559
560
def get_dispatch_mc2_kwargs (
560
561
self ,
@@ -615,7 +616,9 @@ def token_dispatch(self,
615
616
shared_gate_up : Optional [torch .Tensor ] = None ,
616
617
shared_dequant_scale : Optional [torch .Tensor ] = None ,
617
618
mc2_mask : Optional [torch .Tensor ] = None ,
618
- apply_router_weight_on_input : bool = False ):
619
+ apply_router_weight_on_input : bool = False ,
620
+ with_quant : bool = False ):
621
+ self .with_quant = with_quant
619
622
self .expert_map = expert_map
620
623
self .topk_ids = topk_ids
621
624
self .topk_weights = topk_weights
@@ -738,6 +741,7 @@ def __init__(self, **kwargs):
738
741
self .expert_map = None
739
742
self .topk_weights = None
740
743
self .topk_ids = None
744
+ self .with_quant = False
741
745
742
746
def token_dispatch (self ,
743
747
hidden_states : torch .Tensor ,
@@ -751,7 +755,9 @@ def token_dispatch(self,
751
755
shared_gate_up : Optional [torch .Tensor ] = None ,
752
756
shared_dequant_scale : Optional [torch .Tensor ] = None ,
753
757
mc2_mask : Optional [torch .Tensor ] = None ,
754
- apply_router_weight_on_input : bool = False ):
758
+ apply_router_weight_on_input : bool = False ,
759
+ with_quant : bool = False ):
760
+ self .with_quant = with_quant
755
761
self .original_shape = hidden_states .shape
756
762
757
763
num_tokens = hidden_states .shape [:- 1 ].numel ()
@@ -922,7 +928,8 @@ def token_dispatch(self,
922
928
shared_gate_up : Optional [torch .Tensor ] = None ,
923
929
shared_dequant_scale : Optional [torch .Tensor ] = None ,
924
930
mc2_mask : Optional [torch .Tensor ] = None ,
925
- apply_router_weight_on_input : bool = False ):
931
+ apply_router_weight_on_input : bool = False ,
932
+ with_quant : bool = False ):
926
933
self .apply_router_weight_on_input = apply_router_weight_on_input
927
934
if self .apply_router_weight_on_input :
928
935
assert (topk_weights .dim () == 2
@@ -980,6 +987,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
980
987
981
988
def __init__ (self , ** kwargs ):
982
989
super ().__init__ (** kwargs )
990
+ self .with_quant = False
983
991
self .num_local_experts = kwargs .get ("num_local_experts" , 0 )
984
992
self .num_global_redundant_experts = kwargs .get (
985
993
"num_global_redundant_experts" , 0 )
@@ -1032,7 +1040,9 @@ def token_dispatch(self,
1032
1040
shared_gate_up : Optional [torch .Tensor ] = None ,
1033
1041
shared_dequant_scale : Optional [torch .Tensor ] = None ,
1034
1042
mc2_mask : Optional [torch .Tensor ] = None ,
1035
- apply_router_weight_on_input : bool = False ):
1043
+ apply_router_weight_on_input : bool = False ,
1044
+ with_quant : bool = False ):
1045
+ self .with_quant = with_quant
1036
1046
self .hidden_shape = hidden_states .shape
1037
1047
self .topk_weights = topk_weights
1038
1048
assert topk_weights .dim () == 2 , "Expected 2D tensor for topk_weights"
0 commit comments