|
| 1 | +// swiglu_probs_grad_op.cu |
| 2 | +#include <cuda_bf16.h> |
| 3 | +#include <cuda_runtime.h> |
| 4 | + |
| 5 | +#include <vector> |
| 6 | + |
| 7 | +#include "paddle/extension.h" |
| 8 | + |
| 9 | +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 |
| 10 | +#include <cuda_bf16.h> |
| 11 | +using BFloat16 = __nv_bfloat16; |
| 12 | +#else |
| 13 | +struct BFloat16 { |
| 14 | + uint16_t x; |
| 15 | + |
| 16 | + __host__ __device__ BFloat16() : x(0) {} |
| 17 | + |
| 18 | + __host__ __device__ BFloat16(float val) { |
| 19 | + uint32_t* val_bits = reinterpret_cast<uint32_t*>(&val); |
| 20 | + x = static_cast<uint16_t>(*val_bits >> 16); |
| 21 | + } |
| 22 | + |
| 23 | + __host__ __device__ operator float() const { |
| 24 | + uint32_t bits = static_cast<uint32_t>(x) << 16; |
| 25 | + return *reinterpret_cast<float*>(&bits); |
| 26 | + } |
| 27 | +}; |
| 28 | +#endif |
| 29 | + |
| 30 | +template <int thread_per_block> |
| 31 | +__global__ void SwigluProbsGradKernel( |
| 32 | + const BFloat16* o1, // [seq_len*topk, moe_intermediate_size*2] |
| 33 | + const BFloat16* do2_s, // [seq_len*topk, moe_intermediate_size] |
| 34 | + const float* unzipped_probs, // [seq_len*topk, 1] |
| 35 | + BFloat16* do1, // [seq_len*topk, moe_intermediate_size*2] |
| 36 | + float* probs_grad, // [seq_len*topk, 1] |
| 37 | + BFloat16* o2_s, // [seq_len*topk, moe_intermediate_size] |
| 38 | + int moe_intermediate_size) { |
| 39 | + const int row_idx = blockIdx.x; |
| 40 | + const int tid = threadIdx.x; |
| 41 | + |
| 42 | + const BFloat16* o1_row = o1 + row_idx * moe_intermediate_size * 2; |
| 43 | + const BFloat16* do2_s_row = do2_s + row_idx * moe_intermediate_size; |
| 44 | + BFloat16* do1_row = do1 + row_idx * moe_intermediate_size * 2; |
| 45 | + BFloat16* o2s_row = o2_s + row_idx * moe_intermediate_size; |
| 46 | + |
| 47 | + float prob = unzipped_probs[row_idx]; |
| 48 | + |
| 49 | + __shared__ float sum_buffer[thread_per_block]; |
| 50 | + |
| 51 | + float local_probs_grad = 0.0f; |
| 52 | + |
| 53 | + for (int i = tid; i < moe_intermediate_size; i += blockDim.x) { |
| 54 | + float lhs = static_cast<float>(o1_row[i]); |
| 55 | + float rhs = static_cast<float>(o1_row[i + moe_intermediate_size]); |
| 56 | + |
| 57 | + float sig = 1.0f / (1.0f + expf(-lhs)); |
| 58 | + float tmp = sig * lhs; |
| 59 | + float o2_val = tmp * rhs; |
| 60 | + |
| 61 | + float do2_s_val = static_cast<float>(do2_s_row[i]); |
| 62 | + float do2_val = do2_s_val * prob; |
| 63 | + |
| 64 | + float x0_grad = do2_val * rhs * sig * (1.0f + lhs - tmp); |
| 65 | + float x1_grad = do2_val * tmp; |
| 66 | + |
| 67 | + do1_row[i] = BFloat16(x0_grad); |
| 68 | + do1_row[i + moe_intermediate_size] = BFloat16(x1_grad); |
| 69 | + o2s_row[i] = BFloat16(o2_val * prob); |
| 70 | + |
| 71 | + local_probs_grad += do2_s_val * o2_val; |
| 72 | + } |
| 73 | + |
| 74 | + sum_buffer[tid] = local_probs_grad; |
| 75 | + __syncthreads(); |
| 76 | + |
| 77 | + for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) { |
| 78 | + if (tid < stride) { |
| 79 | + sum_buffer[tid] += sum_buffer[tid + stride]; |
| 80 | + } |
| 81 | + __syncthreads(); |
| 82 | + } |
| 83 | + |
| 84 | + if (tid == 0) { |
| 85 | + probs_grad[row_idx] = sum_buffer[0]; |
| 86 | + } |
| 87 | +} |
| 88 | + |
| 89 | +typedef struct __align__(8) { |
| 90 | + __nv_bfloat16 x; |
| 91 | + __nv_bfloat16 y; |
| 92 | + __nv_bfloat16 z; |
| 93 | + __nv_bfloat16 w; |
| 94 | +} |
| 95 | +bfloat16x4_t; |
| 96 | + |
| 97 | +__device__ __forceinline__ float4 fast_swiglu_vec4(const bfloat16x4_t& lhs, |
| 98 | + const bfloat16x4_t& rhs) { |
| 99 | + const float x_f_x = __bfloat162float(lhs.x); |
| 100 | + const float x_f_y = __bfloat162float(lhs.y); |
| 101 | + const float x_f_z = __bfloat162float(lhs.z); |
| 102 | + const float x_f_w = __bfloat162float(lhs.w); |
| 103 | + |
| 104 | + const float y_f_x = __bfloat162float(rhs.x); |
| 105 | + const float y_f_y = __bfloat162float(rhs.y); |
| 106 | + const float y_f_z = __bfloat162float(rhs.z); |
| 107 | + const float y_f_w = __bfloat162float(rhs.w); |
| 108 | + |
| 109 | + const float silu_x = x_f_x * __frcp_rn(1.0f + __expf(-x_f_x)); |
| 110 | + const float silu_y = x_f_y * __frcp_rn(1.0f + __expf(-x_f_y)); |
| 111 | + const float silu_z = x_f_z * __frcp_rn(1.0f + __expf(-x_f_z)); |
| 112 | + const float silu_w = x_f_w * __frcp_rn(1.0f + __expf(-x_f_w)); |
| 113 | + |
| 114 | + return {silu_x * y_f_x, silu_y * y_f_y, silu_z * y_f_z, silu_w * y_f_w}; |
| 115 | +} |
| 116 | + |
| 117 | +__device__ __forceinline__ float4 f4_prod(const float4& x_f, |
| 118 | + const float4& y_f) { |
| 119 | + return {x_f.x * y_f.x, x_f.y * y_f.y, x_f.z * y_f.z, x_f.w * y_f.w}; |
| 120 | +} |
| 121 | +__device__ __forceinline__ float4 f4_prod(const float4& x_f, const float& y_f) { |
| 122 | + return {x_f.x * y_f, x_f.y * y_f, x_f.z * y_f, x_f.w * y_f}; |
| 123 | +} |
| 124 | +__device__ __forceinline__ float4 f4_add(const float4& x_f, const float& y_f) { |
| 125 | + return {x_f.x + y_f, x_f.y + y_f, x_f.z + y_f, x_f.w + y_f}; |
| 126 | +} |
| 127 | +__device__ __forceinline__ float4 f4_add(const float4& x_f, const float4& y_f) { |
| 128 | + return {x_f.x + y_f.x, x_f.y + y_f.y, x_f.z + y_f.z, x_f.w + y_f.w}; |
| 129 | +} |
| 130 | +__device__ __forceinline__ float4 f4_sub(const float4& x_f, const float4& y_f) { |
| 131 | + return {x_f.x - y_f.x, x_f.y - y_f.y, x_f.z - y_f.z, x_f.w - y_f.w}; |
| 132 | +} |
| 133 | +__device__ __forceinline__ float4 fast_sig_vec4(const float4& x_vec4) { |
| 134 | + const float sig_x = __frcp_rn(1.0f + __expf(-x_vec4.x)); |
| 135 | + const float sig_y = __frcp_rn(1.0f + __expf(-x_vec4.y)); |
| 136 | + const float sig_z = __frcp_rn(1.0f + __expf(-x_vec4.z)); |
| 137 | + const float sig_w = __frcp_rn(1.0f + __expf(-x_vec4.w)); |
| 138 | + return {sig_x, sig_y, sig_z, sig_w}; |
| 139 | +} |
| 140 | +__device__ __forceinline__ float4 |
| 141 | +load_and_cast_float4(const bfloat16x4_t* x_vec4_ptr) { |
| 142 | + bfloat16x4_t x_vec4 = *x_vec4_ptr; |
| 143 | + return { |
| 144 | + static_cast<float>(x_vec4.x), |
| 145 | + static_cast<float>(x_vec4.y), |
| 146 | + static_cast<float>(x_vec4.z), |
| 147 | + static_cast<float>(x_vec4.w), |
| 148 | + }; |
| 149 | +} |
| 150 | +__device__ __forceinline__ void cast_and_store_bf16x4(bfloat16x4_t* dst_ptr, |
| 151 | + const float4& x_vec4) { |
| 152 | + *dst_ptr = {static_cast<__nv_bfloat16>(x_vec4.x), |
| 153 | + static_cast<__nv_bfloat16>(x_vec4.y), |
| 154 | + static_cast<__nv_bfloat16>(x_vec4.z), |
| 155 | + static_cast<__nv_bfloat16>(x_vec4.w)}; |
| 156 | +} |
| 157 | +__device__ __forceinline__ float mreduce_f4(const float4& x_f4, |
| 158 | + const float4& y_f4) { |
| 159 | + float x_m = x_f4.x * y_f4.x; |
| 160 | + float y_m = x_f4.y * y_f4.y; |
| 161 | + float z_m = x_f4.z * y_f4.z; |
| 162 | + float w_m = x_f4.w * y_f4.w; |
| 163 | + return {x_m + y_m + z_m + w_m}; |
| 164 | +} |
| 165 | + |
| 166 | +template <int thread_per_block> |
| 167 | +__global__ void SwigluProbsGradKernelVec4( |
| 168 | + const BFloat16* o1, // [seq_len*topk, moe_intermediate_size*2] |
| 169 | + const BFloat16* do2_s, // [seq_len*topk, moe_intermediate_size] |
| 170 | + const float* unzipped_probs, // [seq_len*topk, 1] |
| 171 | + BFloat16* do1, // [seq_len*topk, moe_intermediate_size*2] |
| 172 | + float* probs_grad, // [seq_len*topk, 1] |
| 173 | + BFloat16* o2_s, // [seq_len*topk, moe_intermediate_size] |
| 174 | + int moe_intermediate_size) { |
| 175 | + constexpr int numel_per_thread = 4; |
| 176 | + constexpr int k_warp_size = 32; |
| 177 | + const int row_idx = blockIdx.x; |
| 178 | + const int tid = threadIdx.x; |
| 179 | + |
| 180 | + const BFloat16* o1_row = o1 + row_idx * moe_intermediate_size * 2; |
| 181 | + const BFloat16* do2_s_row = do2_s + row_idx * moe_intermediate_size; |
| 182 | + const bfloat16x4_t* o1_row_left_half_vec4 = |
| 183 | + reinterpret_cast<const bfloat16x4_t*>(o1_row); |
| 184 | + const bfloat16x4_t* do2_s_row_vec4 = |
| 185 | + reinterpret_cast<const bfloat16x4_t*>(do2_s_row); |
| 186 | + const bfloat16x4_t* o1_row_right_half_vec4 = |
| 187 | + reinterpret_cast<const bfloat16x4_t*>(o1_row + moe_intermediate_size); |
| 188 | + BFloat16* do1_row = do1 + row_idx * moe_intermediate_size * 2; |
| 189 | + BFloat16* o2s_row = o2_s + row_idx * moe_intermediate_size; |
| 190 | + bfloat16x4_t* do1_row_vec4 = reinterpret_cast<bfloat16x4_t*>(do1_row); |
| 191 | + bfloat16x4_t* o2s_row_vec4 = reinterpret_cast<bfloat16x4_t*>(o2s_row); |
| 192 | + |
| 193 | + float prob = unzipped_probs[row_idx]; |
| 194 | + __shared__ float sum_buffer[thread_per_block]; |
| 195 | + |
| 196 | + float local_probs_grad = 0.0f; |
| 197 | + |
| 198 | + const int vec_numel = moe_intermediate_size / numel_per_thread; |
| 199 | + for (int i = tid; i < vec_numel; i += blockDim.x) { |
| 200 | + float4 lhs_vec4 = load_and_cast_float4(o1_row_left_half_vec4 + i); |
| 201 | + float4 rhs_vec4 = load_and_cast_float4(o1_row_right_half_vec4 + i); |
| 202 | + float4 do2_s_val_vec4 = load_and_cast_float4(do2_s_row_vec4 + i); |
| 203 | + float4 sig_vec4 = fast_sig_vec4(lhs_vec4); |
| 204 | + float4 tmp_vec4 = f4_prod(sig_vec4, lhs_vec4); |
| 205 | + float4 o2_val_vec4 = f4_prod(tmp_vec4, rhs_vec4); |
| 206 | + float4 o2s_val_vec4 = f4_prod(o2_val_vec4, prob); |
| 207 | + float4 do2_val_vec4 = f4_prod(do2_s_val_vec4, prob); |
| 208 | + float4 x0_grad_vec4 = f4_prod( |
| 209 | + do2_val_vec4, |
| 210 | + f4_prod(rhs_vec4, |
| 211 | + f4_prod(sig_vec4, (f4_sub(f4_add(lhs_vec4, 1.0f), tmp_vec4))))); |
| 212 | + float4 x1_grad_vec4 = f4_prod(do2_val_vec4, tmp_vec4); |
| 213 | + cast_and_store_bf16x4(do1_row_vec4 + i, x0_grad_vec4); |
| 214 | + cast_and_store_bf16x4(do1_row_vec4 + i + vec_numel, x1_grad_vec4); |
| 215 | + cast_and_store_bf16x4(o2s_row_vec4 + i, o2s_val_vec4); |
| 216 | + local_probs_grad += mreduce_f4(do2_s_val_vec4, o2_val_vec4); |
| 217 | + } |
| 218 | + |
| 219 | + sum_buffer[tid] = local_probs_grad; |
| 220 | + __syncthreads(); |
| 221 | + |
| 222 | +#pragma unroll |
| 223 | + for (int stride = blockDim.x / 2; stride >= k_warp_size; stride >>= 1) { |
| 224 | + if (tid < stride) { |
| 225 | + sum_buffer[tid] += sum_buffer[tid + stride]; |
| 226 | + } |
| 227 | + __syncthreads(); |
| 228 | + } |
| 229 | + |
| 230 | + if (tid < k_warp_size) { |
| 231 | + local_probs_grad = sum_buffer[tid]; |
| 232 | +#pragma unroll |
| 233 | + for (int offset = k_warp_size / 2; offset > 0; offset >>= 1) { |
| 234 | + local_probs_grad += |
| 235 | + __shfl_down_sync(0xFFFFFFFF, local_probs_grad, offset); |
| 236 | + } |
| 237 | + } |
| 238 | + |
| 239 | + if (tid == 0) { |
| 240 | + probs_grad[row_idx] = local_probs_grad; |
| 241 | + } |
| 242 | +} |
| 243 | + |
| 244 | +std::vector<paddle::Tensor> SwigluProbsGradCUDABackward( |
| 245 | + const paddle::Tensor& o1, |
| 246 | + const paddle::Tensor& do2_s, |
| 247 | + const paddle::Tensor& unzipped_probs) { |
| 248 | + auto o1_dims = o1.dims(); |
| 249 | + int o1_outer_dim = 1; |
| 250 | + for(int i = 0; i < o1_dims.size() - 1; i++){ |
| 251 | + o1_outer_dim *= o1_dims[i]; |
| 252 | + } |
| 253 | + |
| 254 | + const int moe_intermediate_size_2 = o1_dims[o1_dims.size() - 1]; |
| 255 | + const int moe_intermediate_size = moe_intermediate_size_2 / 2; |
| 256 | + |
| 257 | + auto do1 = paddle::empty_like(o1); |
| 258 | + auto probs_grad = paddle::empty( |
| 259 | + {o1_outer_dim}, paddle::DataType::FLOAT32, o1.place()); |
| 260 | + auto o2_s = paddle::empty_like(do2_s); |
| 261 | + |
| 262 | + const BFloat16* o1_ptr = |
| 263 | + reinterpret_cast<const BFloat16*>(o1.data<phi::bfloat16>()); |
| 264 | + const BFloat16* do2_s_ptr = |
| 265 | + reinterpret_cast<const BFloat16*>(do2_s.data<phi::bfloat16>()); |
| 266 | + const float* unzipped_probs_ptr = unzipped_probs.data<float>(); |
| 267 | + BFloat16* do1_ptr = reinterpret_cast<BFloat16*>(do1.data<phi::bfloat16>()); |
| 268 | + float* probs_grad_ptr = probs_grad.data<float>(); |
| 269 | + BFloat16* o2_s_ptr = reinterpret_cast<BFloat16*>(o2_s.data<phi::bfloat16>()); |
| 270 | + |
| 271 | + constexpr int block_size = 256; |
| 272 | + if (moe_intermediate_size % 4 != 0) { |
| 273 | + SwigluProbsGradKernel<block_size> |
| 274 | + <<<o1_outer_dim, block_size, 0, o1.stream()>>>(o1_ptr, |
| 275 | + do2_s_ptr, |
| 276 | + unzipped_probs_ptr, |
| 277 | + do1_ptr, |
| 278 | + probs_grad_ptr, |
| 279 | + o2_s_ptr, |
| 280 | + moe_intermediate_size); |
| 281 | + } else { |
| 282 | + SwigluProbsGradKernelVec4<block_size> |
| 283 | + <<<o1_outer_dim, block_size, 0, o1.stream()>>>(o1_ptr, |
| 284 | + do2_s_ptr, |
| 285 | + unzipped_probs_ptr, |
| 286 | + do1_ptr, |
| 287 | + probs_grad_ptr, |
| 288 | + o2_s_ptr, |
| 289 | + moe_intermediate_size); |
| 290 | + } |
| 291 | + |
| 292 | + |
| 293 | + return {do1, probs_grad, o2_s}; |
| 294 | +} |
| 295 | + |
| 296 | +PD_BUILD_OP(fused_swiglu_probs_bwd) |
| 297 | + .Inputs({"o1", "do2_s", "unzipped_probs"}) |
| 298 | + .Outputs({"do1", "probs_grad", "o2_s"}) |
| 299 | + .SetKernelFn(PD_KERNEL(SwigluProbsGradCUDABackward)); |
0 commit comments