Skip to content

Commit f705b6b

Browse files
author
Zhaowu Pan
authored
Add int64_t index type for possible overflow position. (#10663)
1 parent 4bfc44d commit f705b6b

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

slm/model_zoo/gpt-3/external_ops/token_dispatcher_utils/tokens_stable_unzip.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,12 @@ __global__ void tokens_unzip_stable_kernel(
138138
shared_expert_probmap[internal_row][expert];
139139
}
140140
if constexpr (has_scale) {
141-
vectorized_memcpy(&XScale[row * scale_length],
142-
&XScale_unzipped[unzipped_row_idx * scale_length],
141+
vectorized_memcpy(&XScale[(int64_t)row * (int64_t)scale_length],
142+
&XScale_unzipped[(int64_t)unzipped_row_idx * (int64_t)scale_length],
143143
scale_length);
144144
}
145-
vectorized_memcpy(&X[row * token_length],
146-
&X_unzipped[unzipped_row_idx * token_length],
145+
vectorized_memcpy(&X[(int64_t)row * (int64_t)token_length],
146+
&X_unzipped[(int64_t)unzipped_row_idx * (int64_t)token_length],
147147
token_length);
148148
}
149149
}
@@ -367,4 +367,4 @@ PD_BUILD_OP(tokens_unzip_stable)
367367
.SetKernelFn(PD_KERNEL(tokens_unzip_stable));
368368

369369

370-
#undef CUMSUM_BLOCK_SIZE
370+
#undef CUMSUM_BLOCK_SIZE

slm/model_zoo/gpt-3/external_ops/token_dispatcher_utils/tokens_unzip_and_zip.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ __global__ void tokens_weighted_zip_kernel(
172172
// 手动类型提升
173173
float2 token_vec =
174174
__bfloat1622float2(*reinterpret_cast<const __nv_bfloat162 *>(
175-
&unzipped_tokens[fetch_row_index * token_length + x_offset]));
175+
&unzipped_tokens[(int64_t)fetch_row_index * (int64_t)token_length + x_offset]));
176176
float prob = fetch_row >= 0
177177
? __bfloat162float(local_expert_problist[expert])
178178
: 0.0f;
@@ -193,7 +193,7 @@ __global__ void tokens_weighted_zip_kernel(
193193
int fetch_row = local_row_fetchlist[expert];
194194
int fetch_row_index = fetch_row >= 0 ? fetch_row : 0;
195195
float token_val = __bfloat162float(
196-
unzipped_tokens[fetch_row_index * token_length + i]);
196+
unzipped_tokens[(int64_t)fetch_row_index * (int64_t)token_length + i]);
197197
float prob = fetch_row >= 0
198198
? __bfloat162float(local_expert_problist[expert])
199199
: 0.0f;
@@ -210,13 +210,13 @@ __global__ void tokens_weighted_zip_kernel(
210210
x_offset += thread_stride) {
211211
__nv_bfloat162 sum = {0, 0};
212212
__nv_bfloat162 *out_ptr = reinterpret_cast<__nv_bfloat162 *>(
213-
&weighted_zipped_tokens[this_row * token_length + x_offset]);
213+
&weighted_zipped_tokens[(int64_t)this_row * (int64_t)token_length + x_offset]);
214214
#pragma unroll
215215
for (int expert = 0; expert < num_experts; ++expert) {
216216
const int fetch_row = local_row_fetchlist[expert];
217217
const int fetch_row_index = fetch_row >= 0 ? fetch_row : 0;
218218
__nv_bfloat162 token_vec = *reinterpret_cast<const __nv_bfloat162 *>(
219-
&unzipped_tokens[fetch_row_index * token_length + x_offset]);
219+
&unzipped_tokens[(int64_t)fetch_row_index * (int64_t)token_length + x_offset]);
220220
__nv_bfloat16 prob =
221221
fetch_row >= 0 ? local_expert_problist[expert] : (__nv_bfloat16)0;
222222
__nv_bfloat162 prob_vec = {prob, prob};
@@ -234,7 +234,7 @@ __global__ void tokens_weighted_zip_kernel(
234234
int fetch_row = local_row_fetchlist[expert];
235235
int fetch_row_index = fetch_row >= 0 ? fetch_row : 0;
236236
__nv_bfloat16 token_val =
237-
unzipped_tokens[fetch_row_index * token_length + i];
237+
unzipped_tokens[(int64_t)fetch_row_index * (int64_t)token_length + i];
238238
__nv_bfloat16 prob =
239239
fetch_row >= 0 ? local_expert_problist[expert] : (__nv_bfloat16)0;
240240
sum += prob * token_val;
@@ -876,4 +876,4 @@ PD_BUILD_OP(tokens_weighted_zip)
876876

877877
#undef DISPATCH_CASE
878878
#undef DISPATCH_TOKEN_TYPE
879-
#undef DISPATCH_PROB_TYPE
879+
#undef DISPATCH_PROB_TYPE

0 commit comments

Comments
 (0)