Skip to content

Commit 07be865

Browse files
authored
Adding fused_swiglu_probs_bwd op (#10604)
* Add fused swiglu_probs_bwd op * add o2s as output * fix 3d tensor input and add vectorize optimizations. * fix tests of vec4 * Optimize reduce performance * delete timeline * Update setup_fp8.py fix arch * Fix multi-dimension issue.
1 parent 6c206e1 commit 07be865

File tree

3 files changed

+408
-0
lines changed

3 files changed

+408
-0
lines changed
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
// swiglu_probs_grad_op.cu
2+
#include <cuda_bf16.h>
3+
#include <cuda_runtime.h>
4+
5+
#include <vector>
6+
7+
#include "paddle/extension.h"
8+
9+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
10+
#include <cuda_bf16.h>
11+
using BFloat16 = __nv_bfloat16;
12+
#else
13+
struct BFloat16 {
14+
uint16_t x;
15+
16+
__host__ __device__ BFloat16() : x(0) {}
17+
18+
__host__ __device__ BFloat16(float val) {
19+
uint32_t* val_bits = reinterpret_cast<uint32_t*>(&val);
20+
x = static_cast<uint16_t>(*val_bits >> 16);
21+
}
22+
23+
__host__ __device__ operator float() const {
24+
uint32_t bits = static_cast<uint32_t>(x) << 16;
25+
return *reinterpret_cast<float*>(&bits);
26+
}
27+
};
28+
#endif
29+
30+
template <int thread_per_block>
31+
__global__ void SwigluProbsGradKernel(
32+
const BFloat16* o1, // [seq_len*topk, moe_intermediate_size*2]
33+
const BFloat16* do2_s, // [seq_len*topk, moe_intermediate_size]
34+
const float* unzipped_probs, // [seq_len*topk, 1]
35+
BFloat16* do1, // [seq_len*topk, moe_intermediate_size*2]
36+
float* probs_grad, // [seq_len*topk, 1]
37+
BFloat16* o2_s, // [seq_len*topk, moe_intermediate_size]
38+
int moe_intermediate_size) {
39+
const int row_idx = blockIdx.x;
40+
const int tid = threadIdx.x;
41+
42+
const BFloat16* o1_row = o1 + row_idx * moe_intermediate_size * 2;
43+
const BFloat16* do2_s_row = do2_s + row_idx * moe_intermediate_size;
44+
BFloat16* do1_row = do1 + row_idx * moe_intermediate_size * 2;
45+
BFloat16* o2s_row = o2_s + row_idx * moe_intermediate_size;
46+
47+
float prob = unzipped_probs[row_idx];
48+
49+
__shared__ float sum_buffer[thread_per_block];
50+
51+
float local_probs_grad = 0.0f;
52+
53+
for (int i = tid; i < moe_intermediate_size; i += blockDim.x) {
54+
float lhs = static_cast<float>(o1_row[i]);
55+
float rhs = static_cast<float>(o1_row[i + moe_intermediate_size]);
56+
57+
float sig = 1.0f / (1.0f + expf(-lhs));
58+
float tmp = sig * lhs;
59+
float o2_val = tmp * rhs;
60+
61+
float do2_s_val = static_cast<float>(do2_s_row[i]);
62+
float do2_val = do2_s_val * prob;
63+
64+
float x0_grad = do2_val * rhs * sig * (1.0f + lhs - tmp);
65+
float x1_grad = do2_val * tmp;
66+
67+
do1_row[i] = BFloat16(x0_grad);
68+
do1_row[i + moe_intermediate_size] = BFloat16(x1_grad);
69+
o2s_row[i] = BFloat16(o2_val * prob);
70+
71+
local_probs_grad += do2_s_val * o2_val;
72+
}
73+
74+
sum_buffer[tid] = local_probs_grad;
75+
__syncthreads();
76+
77+
for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
78+
if (tid < stride) {
79+
sum_buffer[tid] += sum_buffer[tid + stride];
80+
}
81+
__syncthreads();
82+
}
83+
84+
if (tid == 0) {
85+
probs_grad[row_idx] = sum_buffer[0];
86+
}
87+
}
88+
89+
typedef struct __align__(8) {
90+
__nv_bfloat16 x;
91+
__nv_bfloat16 y;
92+
__nv_bfloat16 z;
93+
__nv_bfloat16 w;
94+
}
95+
bfloat16x4_t;
96+
97+
__device__ __forceinline__ float4 fast_swiglu_vec4(const bfloat16x4_t& lhs,
98+
const bfloat16x4_t& rhs) {
99+
const float x_f_x = __bfloat162float(lhs.x);
100+
const float x_f_y = __bfloat162float(lhs.y);
101+
const float x_f_z = __bfloat162float(lhs.z);
102+
const float x_f_w = __bfloat162float(lhs.w);
103+
104+
const float y_f_x = __bfloat162float(rhs.x);
105+
const float y_f_y = __bfloat162float(rhs.y);
106+
const float y_f_z = __bfloat162float(rhs.z);
107+
const float y_f_w = __bfloat162float(rhs.w);
108+
109+
const float silu_x = x_f_x * __frcp_rn(1.0f + __expf(-x_f_x));
110+
const float silu_y = x_f_y * __frcp_rn(1.0f + __expf(-x_f_y));
111+
const float silu_z = x_f_z * __frcp_rn(1.0f + __expf(-x_f_z));
112+
const float silu_w = x_f_w * __frcp_rn(1.0f + __expf(-x_f_w));
113+
114+
return {silu_x * y_f_x, silu_y * y_f_y, silu_z * y_f_z, silu_w * y_f_w};
115+
}
116+
117+
__device__ __forceinline__ float4 f4_prod(const float4& x_f,
118+
const float4& y_f) {
119+
return {x_f.x * y_f.x, x_f.y * y_f.y, x_f.z * y_f.z, x_f.w * y_f.w};
120+
}
121+
__device__ __forceinline__ float4 f4_prod(const float4& x_f, const float& y_f) {
122+
return {x_f.x * y_f, x_f.y * y_f, x_f.z * y_f, x_f.w * y_f};
123+
}
124+
__device__ __forceinline__ float4 f4_add(const float4& x_f, const float& y_f) {
125+
return {x_f.x + y_f, x_f.y + y_f, x_f.z + y_f, x_f.w + y_f};
126+
}
127+
__device__ __forceinline__ float4 f4_add(const float4& x_f, const float4& y_f) {
128+
return {x_f.x + y_f.x, x_f.y + y_f.y, x_f.z + y_f.z, x_f.w + y_f.w};
129+
}
130+
__device__ __forceinline__ float4 f4_sub(const float4& x_f, const float4& y_f) {
131+
return {x_f.x - y_f.x, x_f.y - y_f.y, x_f.z - y_f.z, x_f.w - y_f.w};
132+
}
133+
__device__ __forceinline__ float4 fast_sig_vec4(const float4& x_vec4) {
134+
const float sig_x = __frcp_rn(1.0f + __expf(-x_vec4.x));
135+
const float sig_y = __frcp_rn(1.0f + __expf(-x_vec4.y));
136+
const float sig_z = __frcp_rn(1.0f + __expf(-x_vec4.z));
137+
const float sig_w = __frcp_rn(1.0f + __expf(-x_vec4.w));
138+
return {sig_x, sig_y, sig_z, sig_w};
139+
}
140+
__device__ __forceinline__ float4
141+
load_and_cast_float4(const bfloat16x4_t* x_vec4_ptr) {
142+
bfloat16x4_t x_vec4 = *x_vec4_ptr;
143+
return {
144+
static_cast<float>(x_vec4.x),
145+
static_cast<float>(x_vec4.y),
146+
static_cast<float>(x_vec4.z),
147+
static_cast<float>(x_vec4.w),
148+
};
149+
}
150+
__device__ __forceinline__ void cast_and_store_bf16x4(bfloat16x4_t* dst_ptr,
151+
const float4& x_vec4) {
152+
*dst_ptr = {static_cast<__nv_bfloat16>(x_vec4.x),
153+
static_cast<__nv_bfloat16>(x_vec4.y),
154+
static_cast<__nv_bfloat16>(x_vec4.z),
155+
static_cast<__nv_bfloat16>(x_vec4.w)};
156+
}
157+
__device__ __forceinline__ float mreduce_f4(const float4& x_f4,
158+
const float4& y_f4) {
159+
float x_m = x_f4.x * y_f4.x;
160+
float y_m = x_f4.y * y_f4.y;
161+
float z_m = x_f4.z * y_f4.z;
162+
float w_m = x_f4.w * y_f4.w;
163+
return {x_m + y_m + z_m + w_m};
164+
}
165+
166+
template <int thread_per_block>
167+
__global__ void SwigluProbsGradKernelVec4(
168+
const BFloat16* o1, // [seq_len*topk, moe_intermediate_size*2]
169+
const BFloat16* do2_s, // [seq_len*topk, moe_intermediate_size]
170+
const float* unzipped_probs, // [seq_len*topk, 1]
171+
BFloat16* do1, // [seq_len*topk, moe_intermediate_size*2]
172+
float* probs_grad, // [seq_len*topk, 1]
173+
BFloat16* o2_s, // [seq_len*topk, moe_intermediate_size]
174+
int moe_intermediate_size) {
175+
constexpr int numel_per_thread = 4;
176+
constexpr int k_warp_size = 32;
177+
const int row_idx = blockIdx.x;
178+
const int tid = threadIdx.x;
179+
180+
const BFloat16* o1_row = o1 + row_idx * moe_intermediate_size * 2;
181+
const BFloat16* do2_s_row = do2_s + row_idx * moe_intermediate_size;
182+
const bfloat16x4_t* o1_row_left_half_vec4 =
183+
reinterpret_cast<const bfloat16x4_t*>(o1_row);
184+
const bfloat16x4_t* do2_s_row_vec4 =
185+
reinterpret_cast<const bfloat16x4_t*>(do2_s_row);
186+
const bfloat16x4_t* o1_row_right_half_vec4 =
187+
reinterpret_cast<const bfloat16x4_t*>(o1_row + moe_intermediate_size);
188+
BFloat16* do1_row = do1 + row_idx * moe_intermediate_size * 2;
189+
BFloat16* o2s_row = o2_s + row_idx * moe_intermediate_size;
190+
bfloat16x4_t* do1_row_vec4 = reinterpret_cast<bfloat16x4_t*>(do1_row);
191+
bfloat16x4_t* o2s_row_vec4 = reinterpret_cast<bfloat16x4_t*>(o2s_row);
192+
193+
float prob = unzipped_probs[row_idx];
194+
__shared__ float sum_buffer[thread_per_block];
195+
196+
float local_probs_grad = 0.0f;
197+
198+
const int vec_numel = moe_intermediate_size / numel_per_thread;
199+
for (int i = tid; i < vec_numel; i += blockDim.x) {
200+
float4 lhs_vec4 = load_and_cast_float4(o1_row_left_half_vec4 + i);
201+
float4 rhs_vec4 = load_and_cast_float4(o1_row_right_half_vec4 + i);
202+
float4 do2_s_val_vec4 = load_and_cast_float4(do2_s_row_vec4 + i);
203+
float4 sig_vec4 = fast_sig_vec4(lhs_vec4);
204+
float4 tmp_vec4 = f4_prod(sig_vec4, lhs_vec4);
205+
float4 o2_val_vec4 = f4_prod(tmp_vec4, rhs_vec4);
206+
float4 o2s_val_vec4 = f4_prod(o2_val_vec4, prob);
207+
float4 do2_val_vec4 = f4_prod(do2_s_val_vec4, prob);
208+
float4 x0_grad_vec4 = f4_prod(
209+
do2_val_vec4,
210+
f4_prod(rhs_vec4,
211+
f4_prod(sig_vec4, (f4_sub(f4_add(lhs_vec4, 1.0f), tmp_vec4)))));
212+
float4 x1_grad_vec4 = f4_prod(do2_val_vec4, tmp_vec4);
213+
cast_and_store_bf16x4(do1_row_vec4 + i, x0_grad_vec4);
214+
cast_and_store_bf16x4(do1_row_vec4 + i + vec_numel, x1_grad_vec4);
215+
cast_and_store_bf16x4(o2s_row_vec4 + i, o2s_val_vec4);
216+
local_probs_grad += mreduce_f4(do2_s_val_vec4, o2_val_vec4);
217+
}
218+
219+
sum_buffer[tid] = local_probs_grad;
220+
__syncthreads();
221+
222+
#pragma unroll
223+
for (int stride = blockDim.x / 2; stride >= k_warp_size; stride >>= 1) {
224+
if (tid < stride) {
225+
sum_buffer[tid] += sum_buffer[tid + stride];
226+
}
227+
__syncthreads();
228+
}
229+
230+
if (tid < k_warp_size) {
231+
local_probs_grad = sum_buffer[tid];
232+
#pragma unroll
233+
for (int offset = k_warp_size / 2; offset > 0; offset >>= 1) {
234+
local_probs_grad +=
235+
__shfl_down_sync(0xFFFFFFFF, local_probs_grad, offset);
236+
}
237+
}
238+
239+
if (tid == 0) {
240+
probs_grad[row_idx] = local_probs_grad;
241+
}
242+
}
243+
244+
std::vector<paddle::Tensor> SwigluProbsGradCUDABackward(
245+
const paddle::Tensor& o1,
246+
const paddle::Tensor& do2_s,
247+
const paddle::Tensor& unzipped_probs) {
248+
auto o1_dims = o1.dims();
249+
int o1_outer_dim = 1;
250+
for(int i = 0; i < o1_dims.size() - 1; i++){
251+
o1_outer_dim *= o1_dims[i];
252+
}
253+
254+
const int moe_intermediate_size_2 = o1_dims[o1_dims.size() - 1];
255+
const int moe_intermediate_size = moe_intermediate_size_2 / 2;
256+
257+
auto do1 = paddle::empty_like(o1);
258+
auto probs_grad = paddle::empty(
259+
{o1_outer_dim}, paddle::DataType::FLOAT32, o1.place());
260+
auto o2_s = paddle::empty_like(do2_s);
261+
262+
const BFloat16* o1_ptr =
263+
reinterpret_cast<const BFloat16*>(o1.data<phi::bfloat16>());
264+
const BFloat16* do2_s_ptr =
265+
reinterpret_cast<const BFloat16*>(do2_s.data<phi::bfloat16>());
266+
const float* unzipped_probs_ptr = unzipped_probs.data<float>();
267+
BFloat16* do1_ptr = reinterpret_cast<BFloat16*>(do1.data<phi::bfloat16>());
268+
float* probs_grad_ptr = probs_grad.data<float>();
269+
BFloat16* o2_s_ptr = reinterpret_cast<BFloat16*>(o2_s.data<phi::bfloat16>());
270+
271+
constexpr int block_size = 256;
272+
if (moe_intermediate_size % 4 != 0) {
273+
SwigluProbsGradKernel<block_size>
274+
<<<o1_outer_dim, block_size, 0, o1.stream()>>>(o1_ptr,
275+
do2_s_ptr,
276+
unzipped_probs_ptr,
277+
do1_ptr,
278+
probs_grad_ptr,
279+
o2_s_ptr,
280+
moe_intermediate_size);
281+
} else {
282+
SwigluProbsGradKernelVec4<block_size>
283+
<<<o1_outer_dim, block_size, 0, o1.stream()>>>(o1_ptr,
284+
do2_s_ptr,
285+
unzipped_probs_ptr,
286+
do1_ptr,
287+
probs_grad_ptr,
288+
o2_s_ptr,
289+
moe_intermediate_size);
290+
}
291+
292+
293+
return {do1, probs_grad, o2_s};
294+
}
295+
296+
PD_BUILD_OP(fused_swiglu_probs_bwd)
297+
.Inputs({"o1", "do2_s", "unzipped_probs"})
298+
.Outputs({"do1", "probs_grad", "o2_s"})
299+
.SetKernelFn(PD_KERNEL(SwigluProbsGradCUDABackward));

slm/model_zoo/gpt-3/external_ops/setup_fp8.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def setup_fused_quant_ops():
4040
"fused_quanted_ops/fused_act_quant.cu",
4141
"fused_quanted_ops/fused_act_dequant.cu",
4242
"fused_quanted_ops/fused_act_dequant_transpose_act_quant.cu",
43+
"fused_quanted_ops/fused_swiglu_probs_bwd.cu",
4344
"fused_quanted_ops/fused_spaq.cu",
4445
"fused_quanted_ops/fused_stack_transpose_quant.cu",
4546
],

0 commit comments

Comments
 (0)