@@ -172,7 +172,7 @@ __global__ void tokens_weighted_zip_kernel(
172
172
// 手动类型提升
173
173
float2 token_vec =
174
174
__bfloat1622float2 (*reinterpret_cast <const __nv_bfloat162 *>(
175
- &unzipped_tokens[fetch_row_index * token_length + x_offset]));
175
+ &unzipped_tokens[( int64_t ) fetch_row_index * ( int64_t ) token_length + x_offset]));
176
176
float prob = fetch_row >= 0
177
177
? __bfloat162float (local_expert_problist[expert])
178
178
: 0 .0f ;
@@ -193,7 +193,7 @@ __global__ void tokens_weighted_zip_kernel(
193
193
int fetch_row = local_row_fetchlist[expert];
194
194
int fetch_row_index = fetch_row >= 0 ? fetch_row : 0 ;
195
195
float token_val = __bfloat162float (
196
- unzipped_tokens[fetch_row_index * token_length + i]);
196
+ unzipped_tokens[( int64_t ) fetch_row_index * ( int64_t ) token_length + i]);
197
197
float prob = fetch_row >= 0
198
198
? __bfloat162float (local_expert_problist[expert])
199
199
: 0 .0f ;
@@ -210,13 +210,13 @@ __global__ void tokens_weighted_zip_kernel(
210
210
x_offset += thread_stride) {
211
211
__nv_bfloat162 sum = {0 , 0 };
212
212
__nv_bfloat162 *out_ptr = reinterpret_cast <__nv_bfloat162 *>(
213
- &weighted_zipped_tokens[this_row * token_length + x_offset]);
213
+ &weighted_zipped_tokens[( int64_t ) this_row * ( int64_t ) token_length + x_offset]);
214
214
#pragma unroll
215
215
for (int expert = 0 ; expert < num_experts; ++expert) {
216
216
const int fetch_row = local_row_fetchlist[expert];
217
217
const int fetch_row_index = fetch_row >= 0 ? fetch_row : 0 ;
218
218
__nv_bfloat162 token_vec = *reinterpret_cast <const __nv_bfloat162 *>(
219
- &unzipped_tokens[fetch_row_index * token_length + x_offset]);
219
+ &unzipped_tokens[( int64_t ) fetch_row_index * ( int64_t ) token_length + x_offset]);
220
220
__nv_bfloat16 prob =
221
221
fetch_row >= 0 ? local_expert_problist[expert] : (__nv_bfloat16)0 ;
222
222
__nv_bfloat162 prob_vec = {prob, prob};
@@ -234,7 +234,7 @@ __global__ void tokens_weighted_zip_kernel(
234
234
int fetch_row = local_row_fetchlist[expert];
235
235
int fetch_row_index = fetch_row >= 0 ? fetch_row : 0 ;
236
236
__nv_bfloat16 token_val =
237
- unzipped_tokens[fetch_row_index * token_length + i];
237
+ unzipped_tokens[( int64_t ) fetch_row_index * ( int64_t ) token_length + i];
238
238
__nv_bfloat16 prob =
239
239
fetch_row >= 0 ? local_expert_problist[expert] : (__nv_bfloat16)0 ;
240
240
sum += prob * token_val;
@@ -876,4 +876,4 @@ PD_BUILD_OP(tokens_weighted_zip)
876
876
877
877
#undef DISPATCH_CASE
878
878
#undef DISPATCH_TOKEN_TYPE
879
- #undef DISPATCH_PROB_TYPE
879
+ #undef DISPATCH_PROB_TYPE
0 commit comments