Skip to content

Commit 13d4924

Browse files
committed
Add pow_2_scales param
1 parent 452bb11 commit 13d4924

File tree

2 files changed

+67
-44
lines changed

2 files changed

+67
-44
lines changed

slm/model_zoo/gpt-3/external_ops/fused_quanted_ops/fused_transpose_split_quant.cu

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,15 @@ __device__ void BlockColumnMax(const __nv_bfloat16 input[4][4],
6464
}
6565
}
6666

67-
template <typename OutT, int VecSize>
67+
template <typename OutT, bool Pow2Scales, int VecSize>
6868
__device__ void BlockStoreScale(float* scale,
6969
__nv_bfloat16 amax[4],
7070
float scale_inv[4],
7171
size_t K) {
7272
float scale_out[4];
7373
for (int i = 0; i < 4; i++) {
74-
float amax_fp32 = static_cast<float>(amax[i]);
75-
scale_inv[i] = ComputeScale<__nv_bfloat16, OutT, true>(amax_fp32, 0.0f);
74+
scale_inv[i] = ComputeScale<__nv_bfloat16, OutT, Pow2Scales>(
75+
static_cast<float>(amax[i]), 0.0f);
7676
scale_out[i] = __frcp_rn(scale_inv[i]);
7777
}
7878
if (threadIdx.y == 0) {
@@ -129,7 +129,7 @@ __device__ void BlockStoreOut(OutT* out,
129129
}
130130
}
131131

132-
template <typename OutT, int VecSize>
132+
template <typename OutT, bool Pow2Scales, int VecSize>
133133
__global__ void __launch_bounds__(1024, 2) FusedTransposeSplitQuantKernel(
134134
const phi::bfloat16* __restrict__ X,
135135
OutT* __restrict__ out,
@@ -149,7 +149,7 @@ __global__ void __launch_bounds__(1024, 2) FusedTransposeSplitQuantKernel(
149149

150150
// Compute scale and scale_inv, then store scale back
151151
float scale_inv[4];
152-
BlockStoreScale<OutT, VecSize>(scale, amax, scale_inv, K);
152+
BlockStoreScale<OutT, Pow2Scales, VecSize>(scale, amax, scale_inv, K);
153153

154154
// Scale X and save into shared memory with transposed layout
155155
for (int i = 0; i < 4; i++) {
@@ -187,7 +187,9 @@ __global__ void __launch_bounds__(1024, 2) FusedTransposeSplitQuantKernel(
187187
* 2) K <= 65535 * 128
188188
*/
189189
std::vector<paddle::Tensor> fused_transpose_split_quant(
190-
const paddle::Tensor& X, const std::vector<int64_t>& tokens_per_expert) {
190+
const paddle::Tensor& X,
191+
const std::vector<int64_t>& tokens_per_expert,
192+
bool pow_2_scales) {
191193
PD_CHECK(X.dtype() == paddle::DataType::BFLOAT16);
192194

193195
std::vector<int64_t> shape = X.shape();
@@ -242,21 +244,29 @@ std::vector<paddle::Tensor> fused_transpose_split_quant(
242244
dim3 grid(M / 128, (K + 127) / 128);
243245
dim3 block(32, 32);
244246

245-
#define LAUNCH_KERNEL(VEC_SIZE) \
246-
FusedTransposeSplitQuantKernel<phi::float8_e4m3fn, VEC_SIZE> \
247-
<<<grid, block>>>(X.data<phi::bfloat16>(), \
248-
out.data<phi::float8_e4m3fn>(), \
249-
scale.data<float>(), \
250-
tokens_per_expert_gpu.data<int64_t>(), \
251-
tokens_per_expert.size(), \
247+
#define LAUNCH_KERNEL(POW_2_SCALES, VEC_SIZE) \
248+
FusedTransposeSplitQuantKernel<phi::float8_e4m3fn, POW_2_SCALES, VEC_SIZE> \
249+
<<<grid, block>>>(X.data<phi::bfloat16>(), \
250+
out.data<phi::float8_e4m3fn>(), \
251+
scale.data<float>(), \
252+
tokens_per_expert_gpu.data<int64_t>(), \
253+
tokens_per_expert.size(), \
252254
K);
255+
#define LAUNCH_KERNEL_POW_2_SCALES(VEC_SIZE) \
256+
if (pow_2_scales) { \
257+
LAUNCH_KERNEL(true, VEC_SIZE); \
258+
} else { \
259+
LAUNCH_KERNEL(false, VEC_SIZE); \
260+
}
261+
253262
if (K % 4 == 0) {
254-
LAUNCH_KERNEL(4);
263+
LAUNCH_KERNEL_POW_2_SCALES(4);
255264
} else if (K % 2 == 0) {
256-
LAUNCH_KERNEL(2);
265+
LAUNCH_KERNEL_POW_2_SCALES(2);
257266
} else {
258-
LAUNCH_KERNEL(1);
267+
LAUNCH_KERNEL_POW_2_SCALES(1);
259268
}
269+
#undef LAUNCH_KERNEL_POW_2_SCALES
260270
#undef LAUNCH_KERNEL
261271

262272
return {out, scale};
@@ -265,5 +275,5 @@ std::vector<paddle::Tensor> fused_transpose_split_quant(
265275
PD_BUILD_OP(fused_transpose_split_quant)
266276
.Inputs({"X"})
267277
.Outputs({"output", "scale"})
268-
.Attrs({"tokens_per_expert: std::vector<int64_t>"})
278+
.Attrs({"tokens_per_expert: std::vector<int64_t>", "pow_2_scales: bool"})
269279
.SetKernelFn(PD_KERNEL(fused_transpose_split_quant));

tests/ops/test_fused_transpose_split_quant.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,37 +12,50 @@ def restore_transpose_split_quant(out, scale):
1212
return out * scale
1313

1414

15-
def run():
16-
tokens_per_expert = [24*128, 50*128, 1*128, 128*128, 13*128]
15+
def test_fused_transpose_split_quant(tokens_per_expert, seq_len, pow_2_scales):
16+
print(tokens_per_expert, seq_len)
17+
18+
x = paddle.randn([sum(tokens_per_expert), seq_len], dtype='bfloat16')
19+
x = paddle.clip(x, min=-50, max=50)
20+
21+
out_raw, scale_raw = FQO.fused_transpose_split_quant(
22+
x, tokens_per_expert, pow_2_scales
23+
)
24+
25+
out, scale = [], []
26+
token_offset = 0
27+
for tokens in tokens_per_expert:
28+
out_offset = seq_len * token_offset
29+
out_size = seq_len * tokens
30+
out.append(
31+
out_raw[out_offset : out_offset + out_size]
32+
.reshape([seq_len, tokens])
33+
)
34+
scale.append(
35+
scale_raw[token_offset // 128 : (token_offset + tokens) // 128]
36+
)
37+
token_offset += tokens
1738

18-
for seq_len in [1, 127, 2562, 4001, 7168]:
19-
print(tokens_per_expert, seq_len)
39+
x_restore = restore_transpose_split_quant(out, scale)
40+
x_cast = x.astype('float32')
2041

21-
x = paddle.randn([sum(tokens_per_expert), seq_len], dtype='bfloat16')
22-
x = paddle.clip(x, min=-50, max=50)
42+
np.testing.assert_allclose(x_cast, x_restore, rtol=0.01, atol=0.3)
2343

24-
out_raw, scale_raw = FQO.fused_transpose_split_quant(
25-
x, tokens_per_expert
26-
)
2744

28-
out, scale = [], []
29-
token_offset = 0
30-
for tokens in tokens_per_expert:
31-
out_offset = seq_len * token_offset
32-
out_size = seq_len * tokens
33-
out.append(
34-
out_raw[out_offset : out_offset + out_size]
35-
.reshape([seq_len, tokens])
36-
)
37-
scale.append(
38-
scale_raw[token_offset // 128 : (token_offset + tokens) // 128]
39-
)
40-
token_offset += tokens
41-
42-
x_restore = restore_transpose_split_quant(out, scale)
43-
x_cast = x.astype('float32')
44-
45-
np.testing.assert_allclose(x_cast, x_restore, rtol=0.01, atol=0.3)
45+
def run():
46+
test_fused_transpose_split_quant([128], 1, False)
47+
test_fused_transpose_split_quant([3*128, 4*128, 5*128], 127, True)
48+
test_fused_transpose_split_quant(
49+
[24*128, 128, 50*128, 16*128], 2162, False
50+
)
51+
test_fused_transpose_split_quant(
52+
[7*128, 29*128, 3*128, 128*128, 13*128], 4000, True
53+
)
54+
test_fused_transpose_split_quant(
55+
[18*128, 5*128, 24*128, 1*128, 6*128, 14*128, 27*128, 7*128],
56+
7168,
57+
False
58+
)
4659

4760

4861
if __name__ == '__main__':

0 commit comments

Comments
 (0)