@@ -64,15 +64,15 @@ __device__ void BlockColumnMax(const __nv_bfloat16 input[4][4],
64
64
}
65
65
}
66
66
67
- template <typename OutT, int VecSize>
67
+ template <typename OutT, bool Pow2Scales, int VecSize>
68
68
__device__ void BlockStoreScale (float * scale,
69
69
__nv_bfloat16 amax[4 ],
70
70
float scale_inv[4 ],
71
71
size_t K) {
72
72
float scale_out[4 ];
73
73
for (int i = 0 ; i < 4 ; i++) {
74
- float amax_fp32 = static_cast < float >(amax[i]);
75
- scale_inv[i] = ComputeScale<__nv_bfloat16, OutT, true >(amax_fp32 , 0 .0f );
74
+ scale_inv[i] = ComputeScale<__nv_bfloat16, OutT, Pow2Scales>(
75
+ static_cast < float >(amax[i]) , 0 .0f );
76
76
scale_out[i] = __frcp_rn (scale_inv[i]);
77
77
}
78
78
if (threadIdx .y == 0 ) {
@@ -129,7 +129,7 @@ __device__ void BlockStoreOut(OutT* out,
129
129
}
130
130
}
131
131
132
- template <typename OutT, int VecSize>
132
+ template <typename OutT, bool Pow2Scales, int VecSize>
133
133
__global__ void __launch_bounds__ (1024 , 2 ) FusedTransposeSplitQuantKernel(
134
134
const phi::bfloat16* __restrict__ X,
135
135
OutT* __restrict__ out,
@@ -149,7 +149,7 @@ __global__ void __launch_bounds__(1024, 2) FusedTransposeSplitQuantKernel(
149
149
150
150
// Compute scale and scale_inv, then store scale back
151
151
float scale_inv[4 ];
152
- BlockStoreScale<OutT, VecSize>(scale, amax, scale_inv, K);
152
+ BlockStoreScale<OutT, Pow2Scales, VecSize>(scale, amax, scale_inv, K);
153
153
154
154
// Scale X and save into shared memory with transposed layout
155
155
for (int i = 0 ; i < 4 ; i++) {
@@ -187,7 +187,9 @@ __global__ void __launch_bounds__(1024, 2) FusedTransposeSplitQuantKernel(
187
187
* 2) K <= 65535 * 128
188
188
*/
189
189
std::vector<paddle::Tensor> fused_transpose_split_quant (
190
- const paddle::Tensor& X, const std::vector<int64_t >& tokens_per_expert) {
190
+ const paddle::Tensor& X,
191
+ const std::vector<int64_t >& tokens_per_expert,
192
+ bool pow_2_scales) {
191
193
PD_CHECK (X.dtype () == paddle::DataType::BFLOAT16);
192
194
193
195
std::vector<int64_t > shape = X.shape ();
@@ -242,21 +244,29 @@ std::vector<paddle::Tensor> fused_transpose_split_quant(
242
244
dim3 grid (M / 128 , (K + 127 ) / 128 );
243
245
dim3 block (32 , 32 );
244
246
245
- #define LAUNCH_KERNEL (VEC_SIZE ) \
246
- FusedTransposeSplitQuantKernel<phi::float8_e4m3fn, VEC_SIZE> \
247
- <<<grid, block>>> (X.data <phi::bfloat16>(), \
248
- out.data <phi::float8_e4m3fn>(), \
249
- scale.data <float >(), \
250
- tokens_per_expert_gpu.data <int64_t >(), \
251
- tokens_per_expert.size (), \
247
+ #define LAUNCH_KERNEL (POW_2_SCALES, VEC_SIZE ) \
248
+ FusedTransposeSplitQuantKernel<phi::float8_e4m3fn, POW_2_SCALES, VEC_SIZE> \
249
+ <<<grid, block>>> (X.data <phi::bfloat16>(), \
250
+ out.data <phi::float8_e4m3fn>(), \
251
+ scale.data <float >(), \
252
+ tokens_per_expert_gpu.data <int64_t >(), \
253
+ tokens_per_expert.size (), \
252
254
K);
255
+ #define LAUNCH_KERNEL_POW_2_SCALES (VEC_SIZE ) \
256
+ if (pow_2_scales) { \
257
+ LAUNCH_KERNEL (true , VEC_SIZE); \
258
+ } else { \
259
+ LAUNCH_KERNEL (false , VEC_SIZE); \
260
+ }
261
+
253
262
if (K % 4 == 0 ) {
254
- LAUNCH_KERNEL (4 );
263
+ LAUNCH_KERNEL_POW_2_SCALES (4 );
255
264
} else if (K % 2 == 0 ) {
256
- LAUNCH_KERNEL (2 );
265
+ LAUNCH_KERNEL_POW_2_SCALES (2 );
257
266
} else {
258
- LAUNCH_KERNEL (1 );
267
+ LAUNCH_KERNEL_POW_2_SCALES (1 );
259
268
}
269
+ #undef LAUNCH_KERNEL_POW_2_SCALES
260
270
#undef LAUNCH_KERNEL
261
271
262
272
return {out, scale};
@@ -265,5 +275,5 @@ std::vector<paddle::Tensor> fused_transpose_split_quant(
265
275
PD_BUILD_OP (fused_transpose_split_quant)
266
276
.Inputs({" X" })
267
277
.Outputs({" output" , " scale" })
268
- .Attrs({" tokens_per_expert: std::vector<int64_t>" })
278
+ .Attrs({" tokens_per_expert: std::vector<int64_t>" , " pow_2_scales: bool " })
269
279
.SetKernelFn(PD_KERNEL(fused_transpose_split_quant));
0 commit comments