@@ -77,7 +77,7 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std:
7777
7878 const int num_tokens = new_topk_idx.size (0 );
7979 const int num_topk = new_topk_idx.size (1 );
80- const int local_ranksize = A2_LOCAL_RANK_SIZE ;
80+ const int local_ranksize = LOCAL_RANK_SIZE ;
8181 auto server_num = num_ranks / local_ranksize;
8282
8383 auto device = new_topk_idx.device ();
@@ -87,15 +87,24 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std:
8787 const int notify_send_data_size =
8888 num_experts * EXPERT_DATA_SIZE + server_num + MAX_BATCH_SIZE * (1 + 2 * server_num + num_topk);
8989 /*
90- The output parameters are ordered as follows:
91- 1. the number of the tokens that every expert received from this NPU. size:[numExpert]
92- 2. The number of tokens received by each server from this NPU (deduplicated). size:[serverNum]
93- 3. The number of tokens sent from this NPU to each server (without deduplication). size:[MAX_BS, serverNum]
94- 4. The number of servers each token is sent to by this NPU. size:[MAX_BS]
95- 5. The order in which each token of this NPU is sent to various servers. size:[MAX_BS, serverNum]
96- 6. The order in which each token is sent to the expert. size:[MAX_BS, numTopk]
97- 7. The server offset of tokens received by each expert from this NPU. size:[numExpert, MAX_BS]
98- 8. The origin offset of the token received by each expert on the original NPU. size:[numExpert, MAX_BS]
90+ The notify send data is constructed by 8 parameters and
91+ the parameters are ordered as follows:
92+ 1. the number of the tokens that every expert received from this NPU.
93+ size:[numExpert]
94+ 2. The number of tokens received by each server from this NPU (deduplicated).
95+ size:[serverNum]
96+ 3. The number of tokens sent from this NPU to each server (without deduplication).
97+ size:[MAX_BS, serverNum]
98+ 4. The number of servers each token is sent to by this NPU.
99+ size:[MAX_BS]
100+ 5. The order in which each token of this NPU is sent to various servers.
101+ size:[MAX_BS, serverNum]
102+ 6. The order in which each token is sent to the expert.
103+ size:[MAX_BS, numTopk]
104+ 7. The server offset of tokens received by each expert from this NPU.
105+ size:[numExpert, MAX_BS]
106+ 8. The origin offset of the token received by each expert on the original NPU.
107+ size:[numExpert, MAX_BS]
99108 */
100109 auto notify_send_data = at::zeros ({notify_send_data_size}, at::dtype (at::kInt ).device (device));
101110 notify_send_data
0 commit comments