diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 21615f3c7f..abebe0b4c5 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -365,29 +365,6 @@ def fused_experts_with_mc2( return hidden_states, shared_output -def init_routing_quant(hidden_states, top_k, topk_ids, global_num_experts): - num_tokens, _ = hidden_states.shape - row_idx_len = num_tokens * top_k - row_idx = (torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=hidden_states.device).view( - top_k, -1).permute(1, 0).contiguous()) - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) - - expanded_row_idx = (expanded_row_idx.view(top_k, -1).permute( - 1, 0).contiguous().view(-1)) - global_expert_tokens = torch.bincount(expanded_expert_idx, - minlength=global_num_experts) - global_expert_tokens = global_expert_tokens.to(torch.int32) - quantized_tokens, token_scales = torch_npu.npu_dynamic_quant(hidden_states) - return quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales - - # currently expert parallelism implemented with all2all # is under-optimized. def fused_experts_with_all2all( @@ -417,22 +394,19 @@ def fused_experts_with_all2all( if expert_map is not None: global_num_experts = len(expert_map) + global_redundant_expert_num - if hasattr(torch_npu, "npu_moe_init_routing_quant"): - quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant( - hidden_states, - expert_idx=topk_ids.to(torch.int32), - active_num=0, - expert_capacity=0, - expert_num=global_num_experts, - drop_pad_mode=0, - expert_tokens_num_mode=2, - expert_tokens_before_capacity_flag=False, - quant_mode=1, - ) - else: - quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales = init_routing_quant( - hidden_states, top_k, topk_ids, global_num_experts) - + active_num = top_k * num_tokens + active_expert_range = [0, global_num_experts] + quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales = torch_npu.npu_moe_init_routing_v2( + hidden_states, + expert_idx=topk_ids.to(torch.int32), + active_num=active_num, + expert_capacity=0, + expert_num=global_num_experts, + drop_pad_mode=0, + expert_tokens_num_type=1, + expert_tokens_num_flag=True, + quant_mode=1, + active_expert_range=active_expert_range) gather_sizes = global_expert_tokens.new_empty( global_expert_tokens.shape[0]) dist.all_to_all_single(gather_sizes, global_expert_tokens)