@@ -131,10 +131,10 @@ __device__ __forceinline__ float4 f4_sub(const float4& x_f, const float4& y_f) {
131
131
return {x_f.x - y_f.x , x_f.y - y_f.y , x_f.z - y_f.z , x_f.w - y_f.w };
132
132
}
133
133
__device__ __forceinline__ float4 fast_sig_vec4 (const float4 & x_vec4) {
134
- const float sig_x = __frcp_rn (1 .0f + __expf (-x_vec4.x ));
135
- const float sig_y = __frcp_rn (1 .0f + __expf (-x_vec4.y ));
136
- const float sig_z = __frcp_rn (1 .0f + __expf (-x_vec4.z ));
137
- const float sig_w = __frcp_rn (1 .0f + __expf (-x_vec4.w ));
134
+ const float sig_x = __frcp_rn (1 .0f + __expf (-x_vec4.x ));
135
+ const float sig_y = __frcp_rn (1 .0f + __expf (-x_vec4.y ));
136
+ const float sig_z = __frcp_rn (1 .0f + __expf (-x_vec4.z ));
137
+ const float sig_w = __frcp_rn (1 .0f + __expf (-x_vec4.w ));
138
138
return {sig_x, sig_y, sig_z, sig_w};
139
139
}
140
140
__device__ __forceinline__ float4
@@ -173,6 +173,7 @@ __global__ void SwigluProbsGradKernelVec4(
173
173
BFloat16* o2_s, // [seq_len*topk, moe_intermediate_size]
174
174
int moe_intermediate_size) {
175
175
constexpr int numel_per_thread = 4 ;
176
+ constexpr int k_warp_size = 32 ;
176
177
const int row_idx = blockIdx .x ;
177
178
const int tid = threadIdx .x ;
178
179
@@ -199,26 +200,11 @@ __global__ void SwigluProbsGradKernelVec4(
199
200
float4 lhs_vec4 = load_and_cast_float4 (o1_row_left_half_vec4 + i);
200
201
float4 rhs_vec4 = load_and_cast_float4 (o1_row_right_half_vec4 + i);
201
202
float4 do2_s_val_vec4 = load_and_cast_float4 (do2_s_row_vec4 + i);
202
- // ------------ developing ----------------
203
- /*
204
- float sig = 1.0f / (1.0f + expf(-lhs));
205
- float tmp = sig * lhs;
206
- float o2_val = tmp * rhs;
207
- float do2_val = do2_s_val * prob;
208
- */
209
203
float4 sig_vec4 = fast_sig_vec4 (lhs_vec4);
210
204
float4 tmp_vec4 = f4_prod (sig_vec4, lhs_vec4);
211
205
float4 o2_val_vec4 = f4_prod (tmp_vec4, rhs_vec4);
212
206
float4 o2s_val_vec4 = f4_prod (o2_val_vec4, prob);
213
207
float4 do2_val_vec4 = f4_prod (do2_s_val_vec4, prob);
214
- /*
215
- float x0_grad = do2_val * rhs * sig * (1.0f + lhs - tmp);
216
- float x1_grad = do2_val * tmp;
217
- do1_row[i] = BFloat16(x0_grad);
218
- do1_row[i + moe_intermediate_size] = BFloat16(x1_grad);
219
- o2s_row[i] = BFloat16(o2_val * prob);
220
- local_probs_grad += do2_s_val * o2_val;
221
- */
222
208
float4 x0_grad_vec4 = f4_prod (
223
209
do2_val_vec4,
224
210
f4_prod (rhs_vec4,
@@ -233,17 +219,26 @@ __global__ void SwigluProbsGradKernelVec4(
233
219
sum_buffer[tid] = local_probs_grad;
234
220
__syncthreads ();
235
221
236
- for (int stride = blockDim .x / 2 ; stride > 0 ; stride >>= 1 ) {
222
+ #pragma unroll
223
+ for (int stride = blockDim .x / 2 ; stride >= k_warp_size; stride >>= 1 ) {
237
224
if (tid < stride) {
238
225
sum_buffer[tid] += sum_buffer[tid + stride];
239
226
}
240
227
__syncthreads ();
241
228
}
242
229
230
+ if (tid < k_warp_size) {
231
+ local_probs_grad = sum_buffer[tid];
232
+ #pragma unroll
233
+ for (int offset = k_warp_size / 2 ; offset > 0 ; offset >>= 1 ) {
234
+ local_probs_grad +=
235
+ __shfl_down_sync (0xFFFFFFFF , local_probs_grad, offset);
236
+ }
237
+ }
238
+
243
239
if (tid == 0 ) {
244
- probs_grad[row_idx] = sum_buffer[ 0 ] ;
240
+ probs_grad[row_idx] = local_probs_grad ;
245
241
}
246
- // ------------ developing ----------------
247
242
}
248
243
249
244
std::vector<paddle::Tensor> SwigluProbsGradCUDABackward (
0 commit comments