Skip to content

Commit 25152b4

Browse files
committed
Optimize reduce performance
1 parent 795a3b7 commit 25152b4

File tree

2 files changed

+17
-22
lines changed

2 files changed

+17
-22
lines changed

slm/model_zoo/gpt-3/external_ops/fused_quanted_ops/fused_swiglu_probs_bwd.cu

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,10 @@ __device__ __forceinline__ float4 f4_sub(const float4& x_f, const float4& y_f) {
131131
return {x_f.x - y_f.x, x_f.y - y_f.y, x_f.z - y_f.z, x_f.w - y_f.w};
132132
}
133133
__device__ __forceinline__ float4 fast_sig_vec4(const float4& x_vec4) {
134-
const float sig_x = __frcp_rn(1.0f + __expf(-x_vec4.x));
135-
const float sig_y = __frcp_rn(1.0f + __expf(-x_vec4.y));
136-
const float sig_z = __frcp_rn(1.0f + __expf(-x_vec4.z));
137-
const float sig_w = __frcp_rn(1.0f + __expf(-x_vec4.w));
134+
const float sig_x = __frcp_rn(1.0f + __expf(-x_vec4.x));
135+
const float sig_y = __frcp_rn(1.0f + __expf(-x_vec4.y));
136+
const float sig_z = __frcp_rn(1.0f + __expf(-x_vec4.z));
137+
const float sig_w = __frcp_rn(1.0f + __expf(-x_vec4.w));
138138
return {sig_x, sig_y, sig_z, sig_w};
139139
}
140140
__device__ __forceinline__ float4
@@ -173,6 +173,7 @@ __global__ void SwigluProbsGradKernelVec4(
173173
BFloat16* o2_s, // [seq_len*topk, moe_intermediate_size]
174174
int moe_intermediate_size) {
175175
constexpr int numel_per_thread = 4;
176+
constexpr int k_warp_size = 32;
176177
const int row_idx = blockIdx.x;
177178
const int tid = threadIdx.x;
178179

@@ -199,26 +200,11 @@ __global__ void SwigluProbsGradKernelVec4(
199200
float4 lhs_vec4 = load_and_cast_float4(o1_row_left_half_vec4 + i);
200201
float4 rhs_vec4 = load_and_cast_float4(o1_row_right_half_vec4 + i);
201202
float4 do2_s_val_vec4 = load_and_cast_float4(do2_s_row_vec4 + i);
202-
// ------------ developing ----------------
203-
/*
204-
float sig = 1.0f / (1.0f + expf(-lhs));
205-
float tmp = sig * lhs;
206-
float o2_val = tmp * rhs;
207-
float do2_val = do2_s_val * prob;
208-
*/
209203
float4 sig_vec4 = fast_sig_vec4(lhs_vec4);
210204
float4 tmp_vec4 = f4_prod(sig_vec4, lhs_vec4);
211205
float4 o2_val_vec4 = f4_prod(tmp_vec4, rhs_vec4);
212206
float4 o2s_val_vec4 = f4_prod(o2_val_vec4, prob);
213207
float4 do2_val_vec4 = f4_prod(do2_s_val_vec4, prob);
214-
/*
215-
float x0_grad = do2_val * rhs * sig * (1.0f + lhs - tmp);
216-
float x1_grad = do2_val * tmp;
217-
do1_row[i] = BFloat16(x0_grad);
218-
do1_row[i + moe_intermediate_size] = BFloat16(x1_grad);
219-
o2s_row[i] = BFloat16(o2_val * prob);
220-
local_probs_grad += do2_s_val * o2_val;
221-
*/
222208
float4 x0_grad_vec4 = f4_prod(
223209
do2_val_vec4,
224210
f4_prod(rhs_vec4,
@@ -233,17 +219,26 @@ __global__ void SwigluProbsGradKernelVec4(
233219
sum_buffer[tid] = local_probs_grad;
234220
__syncthreads();
235221

236-
for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
222+
#pragma unroll
223+
for (int stride = blockDim.x / 2; stride >= k_warp_size; stride >>= 1) {
237224
if (tid < stride) {
238225
sum_buffer[tid] += sum_buffer[tid + stride];
239226
}
240227
__syncthreads();
241228
}
242229

230+
if (tid < k_warp_size) {
231+
local_probs_grad = sum_buffer[tid];
232+
#pragma unroll
233+
for (int offset = k_warp_size / 2; offset > 0; offset >>= 1) {
234+
local_probs_grad +=
235+
__shfl_down_sync(0xFFFFFFFF, local_probs_grad, offset);
236+
}
237+
}
238+
243239
if (tid == 0) {
244-
probs_grad[row_idx] = sum_buffer[0];
240+
probs_grad[row_idx] = local_probs_grad;
245241
}
246-
// ------------ developing ----------------
247242
}
248243

249244
std::vector<paddle::Tensor> SwigluProbsGradCUDABackward(

tests/ops/grad.nsys-rep

-9.93 KB
Binary file not shown.

0 commit comments

Comments
 (0)