@@ -54,7 +54,7 @@ def permute(
54
54
topk_weights : torch .Tensor ,
55
55
expert_map : torch .Tensor ,
56
56
num_experts : int ,
57
- use_a8 : bool ,
57
+ apply_a8_quantization : bool ,
58
58
) -> tuple [torch .Tensor , torch .Tensor , Optional [torch .Tensor ], int ]:
59
59
"""Pre-process before MLP.
60
60
@@ -65,6 +65,7 @@ def permute(
65
65
expert_map (torch.Tensor): Tensor of shape (global_num_experts, )
66
66
Mapping from global expert IDs to local expert IDs.
67
67
num_experts (int): Number of local experts (experts on this device).
68
+ apply_a8_quantization (bool): Whether to apply A8 quantization (W4A8 and W8A8).
68
69
69
70
Returns:
70
71
tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing:
@@ -73,6 +74,8 @@ def permute(
73
74
hidden_states based on topk_ids.
74
75
- expert_tokens (torch.Tensor): Tensor of shape (num_experts, )
75
76
Number of tokens assigned to each expert.
77
+ - dynamic_scale (torch.Tensor, optional): Tensor of shape (num_experts, )
78
+ Dynamic scale for each expert, used for quantization.
76
79
- group_list_type (int): Type of group list, 0 for `cumsum`
77
80
and 1 for `count`. This is mainly for `npu_grouped_matmul`
78
81
to determine how to handle the output.
@@ -160,7 +163,7 @@ def permute(
160
163
topk_weights : torch .Tensor ,
161
164
expert_map : torch .Tensor , # noqa: F841
162
165
num_experts : int ,
163
- use_a8 : bool ,
166
+ apply_a8_quantization : bool ,
164
167
) -> tuple [torch .Tensor , torch .Tensor , Optional [torch .Tensor ], int ]:
165
168
num_tokens = hidden_states .shape [0 ]
166
169
@@ -221,7 +224,7 @@ def permute(
221
224
topk_weights : torch .Tensor ,
222
225
expert_map : torch .Tensor ,
223
226
num_experts : int ,
224
- use_a8 : bool ,
227
+ apply_a8_quantization : bool ,
225
228
) -> tuple [torch .Tensor , torch .Tensor , Optional [torch .Tensor ], int ]:
226
229
num_tokens = hidden_states .shape [0 ]
227
230
@@ -378,7 +381,7 @@ def permute(
378
381
topk_weights : torch .Tensor ,
379
382
expert_map : torch .Tensor ,
380
383
num_experts : int ,
381
- use_a8 : bool ,
384
+ apply_a8_quantization : bool ,
382
385
) -> tuple [torch .Tensor , torch .Tensor , Optional [torch .Tensor ], int ]:
383
386
# Store tensors needed for post_process
384
387
self .topk_ids = topk_ids
@@ -392,7 +395,7 @@ def permute(
392
395
"moe_expert_num" : self .moe_config .num_experts ,
393
396
"global_bs" : 0 ,
394
397
"scales" : None ,
395
- "quant_mode" : 2 if use_a8 else 0 ,
398
+ "quant_mode" : 2 if apply_a8_quantization else 0 ,
396
399
"group_ep" : self .mc2_comm_name ,
397
400
"ep_world_size" : self .moe_config .ep_size ,
398
401
"ep_rank_id" : self .moe_config .ep_rank ,
@@ -536,13 +539,15 @@ def permute(
536
539
topk_weights : torch .Tensor ,
537
540
expert_map : torch .Tensor ,
538
541
num_experts : int ,
539
- use_a8 : bool ,
542
+ apply_a8_quantization : bool ,
540
543
) -> tuple [torch .Tensor , torch .Tensor , Optional [torch .Tensor ], int ]:
541
- results = self .token_dispatcher .token_dispatch (hidden_states ,
542
- topk_weights ,
543
- topk_ids ,
544
- None ,
545
- log2phy = None )
544
+ results = self .token_dispatcher .token_dispatch (
545
+ hidden_states ,
546
+ topk_weights ,
547
+ topk_ids ,
548
+ None ,
549
+ log2phy = None ,
550
+ with_quant = apply_a8_quantization )
546
551
return results ["hidden_states" ], results ["group_list" ], results [
547
552
"dynamic_scale" ], results ["group_list_type" ]
548
553
0 commit comments