@@ -77,7 +77,7 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std:
7777 auto device = new_topk_idx.device ();
7878 auto num_tokens_per_expert = at::zeros ({num_experts}, at::dtype (at::kInt ).device (device));
7979 auto num_tokens_per_rank = at::zeros ({num_ranks}, at::dtype (at::kInt ).device (device));
80- auto is_token_in_rank = at::zeros ({num_tokens, num_ranks}, at::dtype (at::kInt ).device (device));
80+ auto is_token_in_rank = torch::empty ({num_tokens, num_ranks}, at::dtype (at::kInt ).device (device));
8181
8282 EXEC_NPU_CMD (aclnnDispatchLayout, new_topk_idx, num_tokens, num_ranks, num_experts, num_topk, num_tokens_per_rank,
8383 num_tokens_per_expert, is_token_in_rank);
@@ -183,11 +183,11 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional<at::Tensor>
183183
184184 int send_per_group = 3 ; // (send_to_expert_num, send_to_expert_offset, send_rank_tokens)
185185
186- auto send_data = at::zeros ({num_experts * send_per_group}, at::dtype (at::kInt ).device (x.device ()));
186+ auto send_data = torch::empty ({num_experts * send_per_group}, at::dtype (at::kInt ).device (x.device ()));
187187 int64_t send_count = send_per_group * num_local_experts * num_ranks;
188188
189- auto send_data_offset = at::zeros ({num_experts}, at::dtype (at::kInt ).device (x.device ()));
190- at::Tensor recv_data = at::zeros ({num_experts * send_per_group}, at::dtype (at::kInt ).device (x.device ()));
189+ auto send_data_offset = torch::empty ({num_experts}, at::dtype (at::kInt ).device (x.device ()));
190+ at::Tensor recv_data = torch::empty ({num_experts * send_per_group}, at::dtype (at::kInt ).device (x.device ()));
191191
192192 // get ep name
193193 char hcom_ep_name[HCOMM_NAME_LEN];
@@ -209,7 +209,7 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional<at::Tensor>
209209
210210 auto options_cpu = torch::TensorOptions ().dtype (torch::kInt32 ).device (torch::kCPU );
211211 std::vector<int32_t > local_expert_acc (num_experts, 0 );
212- auto send_token_idx_cpu = at::zeros ({num_tokens, num_topk}, options_cpu);
212+ auto send_token_idx_cpu = torch::empty ({num_tokens, num_topk}, options_cpu);
213213 auto send_token_idx_ptr = send_token_idx_cpu.data_ptr <int >();
214214
215215 auto topk_idx_cpu = new_topk_idx.to (at::kCPU );
@@ -227,8 +227,8 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional<at::Tensor>
227227
228228 EP_HOST_ASSERT (recv_data.dim () == 1 and recv_data.is_contiguous ());
229229 EP_HOST_ASSERT (recv_data.size (0 ) % num_experts == 0 );
230- at::Tensor recv_offset_cpu = at::zeros ({num_experts}, options_cpu);
231- at::Tensor recv_count_cpu = at::zeros ({num_experts}, options_cpu);
230+ at::Tensor recv_offset_cpu = torch::empty ({num_experts}, options_cpu);
231+ at::Tensor recv_count_cpu = torch::empty ({num_experts}, options_cpu);
232232 auto recv_data_cpu = recv_data.to (at::kCPU );
233233 auto recv_data_ptr = recv_data_cpu.data_ptr <int >();
234234 auto recv_count_ptr = recv_count_cpu.data_ptr <int >();
@@ -269,10 +269,10 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional<at::Tensor>
269269 auto recv_count = recv_count_cpu.to (x.device ());
270270
271271 int num_recv_tokens = (total_recv_tokens == 0 ) ? 1 : total_recv_tokens;
272- auto expandx_out = use_quant ? at::zeros ({num_recv_tokens, hidden}, at::dtype (at::kChar ).device (x.device ()))
273- : at::zeros ({num_recv_tokens, hidden}, x.options ());
274- auto dynamic_scales_out = at::zeros ({num_recv_tokens}, at::dtype (at::kFloat ).device (x.device ()));
275- auto expand_idx_out = at::zeros ({num_recv_tokens * 3 }, at::dtype (at::kInt ).device (x.device ()));
272+ auto expandx_out = use_quant ? torch::empty ({num_recv_tokens, hidden}, at::dtype (at::kChar ).device (x.device ()))
273+ : torch::empty ({num_recv_tokens, hidden}, x.options ());
274+ auto dynamic_scales_out = torch::empty ({num_recv_tokens}, at::dtype (at::kFloat ).device (x.device ()));
275+ auto expand_idx_out = torch::empty ({num_recv_tokens * 3 }, at::dtype (at::kInt ).device (x.device ()));
276276
277277 EXEC_NPU_CMD (aclnnCamMoeDispatchNormal, new_x, expert_ids, send_data_offset, send_token_idx, recv_offset,
278278 recv_count, hcom_ep_name,
0 commit comments