Skip to content

Commit 0095350

Browse files
committed
Add fused_transpose_split_quant kernel
1 parent 6c206e1 commit 0095350

File tree

3 files changed

+354
-0
lines changed

3 files changed

+354
-0
lines changed
Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
#include "quant_utils.h"
2+
3+
template <typename T, int VecSize>
4+
struct __align__(sizeof(T) * VecSize) VecType {
5+
T val[VecSize];
6+
__host__ __device__ inline T& operator[](size_t i) { return val[i]; }
7+
__host__ __device__ inline const T& operator[](size_t i) const {
8+
return val[i];
9+
}
10+
};
11+
12+
template <int VecSize>
13+
__device__ void BlockLoad(const phi::bfloat16* X,
14+
__nv_bfloat16 input[4][4],
15+
size_t K) {
16+
for (size_t i = 0; i < 4; i++) {
17+
size_t off_m = blockIdx.x * 128 + threadIdx.y + i * 32;
18+
size_t off_k = blockIdx.y * 128 + threadIdx.x * VecSize;
19+
size_t offset = off_m * K + off_k;
20+
21+
for (size_t j = 0; j < 4; j += VecSize) {
22+
if (off_k + j * 32 < K) {
23+
size_t idx = offset + j * 32;
24+
using LoadT = VecType<__nv_bfloat16, VecSize>;
25+
LoadT data = *reinterpret_cast<const LoadT*>(X + idx);
26+
for (int k = 0; k < VecSize; k++) {
27+
input[i][j + k] = data[k];
28+
}
29+
}
30+
}
31+
}
32+
}
33+
34+
__device__ void BlockColumnMax(const __nv_bfloat16 input[4][4],
35+
__nv_bfloat16 amax[4],
36+
__nv_bfloat16* shm) {
37+
// Reduce [(4), 32, 32, 4] => [32, 32, 4]
38+
__nv_bfloat16 warp_max[4];
39+
for (int i = 0; i < 4; i++) {
40+
for (int j = 0; j < 4; j++) {
41+
__nv_bfloat16 t = __habs(input[i][j]);
42+
warp_max[j] = i == 0 ? t : __hmax(warp_max[j], t);
43+
}
44+
}
45+
46+
// Reduce [(32), 32, 4] => [32, 4]
47+
for (int i = 0; i < 4; i++) {
48+
shm[threadIdx.y * 128 + i * 32 + threadIdx.x] = warp_max[i];
49+
}
50+
__syncthreads();
51+
for (int offset = 16; offset > 0; offset /= 2) {
52+
if (threadIdx.y < offset) {
53+
for (int i = 0; i < 4; i++) {
54+
shm[threadIdx.y * 128 + i * 32 + threadIdx.x] =
55+
__hmax(shm[threadIdx.y * 128 + i * 32 + threadIdx.x],
56+
shm[(threadIdx.y + offset) * 128 + i * 32 + threadIdx.x]);
57+
}
58+
}
59+
__syncthreads();
60+
}
61+
62+
for (int i = 0; i < 4; i++) {
63+
amax[i] = shm[i * 32 + threadIdx.x];
64+
}
65+
__syncthreads();
66+
}
67+
68+
template <typename OutT, bool Pow2Scales, int VecSize>
69+
__device__ void BlockStoreScale(float* scale,
70+
size_t off_m,
71+
__nv_bfloat16 amax[4],
72+
float scale_inv[4],
73+
size_t K) {
74+
float scale_out[4];
75+
for (int i = 0; i < 4; i++) {
76+
scale_inv[i] = ComputeScale<__nv_bfloat16, OutT, Pow2Scales>(
77+
static_cast<float>(amax[i]), 0.0f);
78+
scale_out[i] = __frcp_rn(scale_inv[i]);
79+
}
80+
if (threadIdx.y == 0) {
81+
size_t idx_m = blockIdx.x - off_m / 128;
82+
size_t off_k = blockIdx.y * 128 + threadIdx.x * VecSize;
83+
size_t offset = idx_m * K + off_k;
84+
85+
for (size_t j = 0; j < 4; j += VecSize) {
86+
if (off_k + j * 32 < K) {
87+
size_t idx = offset + j * 32;
88+
using StoreT = VecType<float, VecSize>;
89+
StoreT data;
90+
for (int k = 0; k < VecSize; k++) {
91+
data[k] = scale_out[j + k];
92+
}
93+
*reinterpret_cast<StoreT*>(scale + idx) = data;
94+
}
95+
}
96+
}
97+
}
98+
99+
template <typename OutT, int VecSize>
100+
__device__ void BlockStoreOut(OutT* out,
101+
size_t off_m,
102+
size_t cur_tokens,
103+
const OutT shm[128][129],
104+
size_t K) {
105+
for (size_t i = 0; i < 4; i++) {
106+
size_t idx_m = blockIdx.x * 128 + threadIdx.x * 4;
107+
size_t idx_k = blockIdx.y * 128 + threadIdx.y + i * 32;
108+
size_t idx = idx_k * cur_tokens + (idx_m - off_m);
109+
110+
if (idx_k < K) {
111+
using StoreT = VecType<OutT, VecSize>;
112+
StoreT data;
113+
for (int j = 0; j < VecSize; j++) {
114+
data[j] = shm[i * 32 + threadIdx.y][threadIdx.x * 4 + j];
115+
}
116+
*reinterpret_cast<StoreT*>(out + idx) = data;
117+
}
118+
}
119+
}
120+
121+
__device__ std::pair<size_t, size_t> GetExpertIdx(int64_t* tokens_per_expert,
122+
size_t num_experts) {
123+
__shared__ size_t expert_idx_, off_m_;
124+
125+
if (threadIdx.x == 0 && threadIdx.y == 0) {
126+
size_t idx_m = blockIdx.x * 128;
127+
size_t off_m = 0, next_off_m = 0;
128+
size_t expert_idx;
129+
for (expert_idx = 0; expert_idx < num_experts; expert_idx++) {
130+
next_off_m += tokens_per_expert[expert_idx];
131+
if (idx_m >= off_m && idx_m < next_off_m) {
132+
break;
133+
}
134+
off_m = next_off_m;
135+
}
136+
expert_idx_ = expert_idx;
137+
off_m_ = off_m;
138+
}
139+
__syncthreads();
140+
141+
return {expert_idx_, off_m_};
142+
}
143+
144+
template <typename OutT, bool Pow2Scales, int VecSize>
145+
__global__ void __launch_bounds__(1024)
146+
FusedTransposeSplitQuantKernel(const phi::bfloat16* __restrict__ X,
147+
int64_t* __restrict__ meta,
148+
size_t num_experts,
149+
size_t K) {
150+
__shared__ OutT shm[128][129];
151+
int64_t* tokens_per_expert = meta;
152+
OutT** out_ptrs = reinterpret_cast<OutT**>(meta + num_experts);
153+
float** scale_ptrs = reinterpret_cast<float**>(meta + num_experts * 2);
154+
155+
// Get expert_idx and offset at the M dim of the current block
156+
auto expert_info = GetExpertIdx(tokens_per_expert, num_experts);
157+
size_t expert_idx = expert_info.first;
158+
size_t off_m = expert_info.second;
159+
160+
// Load 128x128 elements from X
161+
__nv_bfloat16 input[4][4];
162+
BlockLoad<VecSize>(X, input, K);
163+
164+
// Find the maximum of each 128 elements on the M axis
165+
__nv_bfloat16 amax[4];
166+
BlockColumnMax(input, amax, reinterpret_cast<__nv_bfloat16*>(shm));
167+
168+
// Compute scale and scale_inv, then store scale back
169+
float scale_inv[4];
170+
BlockStoreScale<OutT, Pow2Scales, VecSize>(
171+
scale_ptrs[expert_idx], off_m, amax, scale_inv, K);
172+
173+
// Scale X and save into shared memory with transposed layout
174+
for (int i = 0; i < 4; i++) {
175+
for (int j = 0; j < 4; j += VecSize) {
176+
for (int k = 0; k < VecSize; k++) {
177+
float input_fp32 = static_cast<float>(input[i][j + k]);
178+
float output_scaled = input_fp32 * scale_inv[j + k];
179+
shm[threadIdx.x * VecSize + j * 32 + k][i * 32 + threadIdx.y] =
180+
static_cast<OutT>(output_scaled);
181+
}
182+
}
183+
}
184+
__syncthreads();
185+
186+
// Store 128x128 elements back
187+
// Note: out is always 4x vectorizable.
188+
BlockStoreOut<OutT, 4>(
189+
out_ptrs[expert_idx], off_m, tokens_per_expert[expert_idx], shm, K);
190+
}
191+
192+
/**
193+
* Quantize on dim[0] of X, transpose dim[0] and dim[1] of X, then
194+
* split the result into out and scale.
195+
*
196+
* Inputs:
197+
* X : [SUM(M_1...M_N), K], bfloat16
198+
*
199+
* Outputs:
200+
* out : {[K, M_1], [K, M_2], ..., [K, M_N]}, float8_e4m3fn
201+
* scale : {[M_1/128, K], [M_2/128, K], ..., [M_N/128, K]}, float
202+
*
203+
* Attrs:
204+
* pow_2_scales
205+
* : bool that indicates whether to use power-of-2 scaling
206+
*
207+
* Requirements:
208+
* 1) M_i % 128 == 0 for M_i in [M_1, M_2, ..., M_N]
209+
* 2) K <= 65535 * 128
210+
*/
211+
void fused_transpose_split_quant(const paddle::Tensor& X,
212+
std::vector<paddle::Tensor>& outs,
213+
std::vector<paddle::Tensor>& scales,
214+
bool pow_2_scales) {
215+
// Check X
216+
PD_CHECK(X.dtype() == paddle::DataType::BFLOAT16);
217+
218+
std::vector<int64_t> shape = X.shape();
219+
PD_CHECK(shape.size() == 2);
220+
const int64_t M = shape[0];
221+
const int64_t K = shape[1];
222+
223+
// Check outs and scales
224+
const size_t num_experts = outs.size();
225+
PD_CHECK(scales.size() == num_experts);
226+
227+
std::vector<int64_t> tokens_per_expert;
228+
int64_t sum_tokens = 0;
229+
for (size_t i = 0; i < num_experts; i++) {
230+
PD_CHECK(outs[i].dtype() == paddle::DataType::FLOAT8_E4M3FN);
231+
PD_CHECK(scales[i].dtype() == paddle::DataType::FLOAT32);
232+
233+
std::vector<int64_t> out_shape = outs[i].shape();
234+
PD_CHECK(out_shape.size() == 2);
235+
PD_CHECK(out_shape[0] == K);
236+
PD_CHECK(out_shape[1] % 128 == 0);
237+
tokens_per_expert.push_back(out_shape[1]);
238+
sum_tokens += out_shape[1];
239+
240+
std::vector<int64_t> scale_shape = scales[i].shape();
241+
PD_CHECK(scale_shape.size() == 2);
242+
PD_CHECK(scale_shape[0] == out_shape[1] / 128);
243+
PD_CHECK(scale_shape[1] == K);
244+
}
245+
246+
PD_CHECK(sum_tokens == M,
247+
"sum of out[i].shape[1] must be equal to X.shape[0]");
248+
PD_CHECK(K <= 65535 * 128, "only supports K <= 65535 * 128");
249+
250+
// Skip 0-size
251+
if (M == 0 || K == 0) {
252+
return;
253+
}
254+
255+
// Copy meta (tokens_per_expert, out_ptrs, scale_ptrs) to device
256+
paddle::Tensor meta_cpu = paddle::empty(
257+
{static_cast<int64_t>(num_experts * 3)}, paddle::DataType::INT64);
258+
int64_t* meta_ptr = meta_cpu.data<int64_t>();
259+
for (size_t i = 0; i < num_experts; i++) {
260+
meta_ptr[i] = static_cast<int64_t>(tokens_per_expert[i]);
261+
}
262+
for (size_t i = 0; i < num_experts; i++) {
263+
meta_ptr[num_experts + i] =
264+
reinterpret_cast<int64_t>(outs[i].data<phi::float8_e4m3fn>());
265+
}
266+
for (size_t i = 0; i < num_experts; i++) {
267+
meta_ptr[num_experts * 2 + i] =
268+
reinterpret_cast<int64_t>(scales[i].data<float>());
269+
}
270+
paddle::Tensor meta_gpu = meta_cpu.copy_to(X.place(), /*blocking=*/false);
271+
272+
// Launch kernel
273+
dim3 grid(M / 128, (K + 127) / 128);
274+
dim3 block(32, 32);
275+
276+
#define LAUNCH_KERNEL(POW_2_SCALES, VEC_SIZE) \
277+
FusedTransposeSplitQuantKernel<phi::float8_e4m3fn, POW_2_SCALES, VEC_SIZE> \
278+
<<<grid, block>>>( \
279+
X.data<phi::bfloat16>(), meta_gpu.data<int64_t>(), num_experts, K);
280+
#define LAUNCH_KERNEL_PARTIAL(VEC_SIZE) \
281+
if (pow_2_scales) { \
282+
LAUNCH_KERNEL(true, VEC_SIZE); \
283+
} else { \
284+
LAUNCH_KERNEL(false, VEC_SIZE); \
285+
}
286+
287+
if (K % 4 == 0) {
288+
LAUNCH_KERNEL_PARTIAL(4);
289+
} else if (K % 2 == 0) {
290+
LAUNCH_KERNEL_PARTIAL(2);
291+
} else {
292+
LAUNCH_KERNEL_PARTIAL(1);
293+
}
294+
#undef LAUNCH_KERNEL_PARTIAL
295+
#undef LAUNCH_KERNEL
296+
}
297+
298+
PD_BUILD_OP(fused_transpose_split_quant)
299+
.Inputs({"X", paddle::Vec("outs"), paddle::Vec("scales")})
300+
.Attrs({"pow_2_scales: bool"})
301+
.SetKernelFn(PD_KERNEL(fused_transpose_split_quant));

slm/model_zoo/gpt-3/external_ops/setup_fp8.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def setup_fused_quant_ops():
4242
"fused_quanted_ops/fused_act_dequant_transpose_act_quant.cu",
4343
"fused_quanted_ops/fused_spaq.cu",
4444
"fused_quanted_ops/fused_stack_transpose_quant.cu",
45+
"fused_quanted_ops/fused_transpose_split_quant.cu",
4546
],
4647
extra_compile_args={
4748
"cxx": ["-O3", "-w", "-Wno-abi", "-fPIC", "-std=c++17"],
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import FusedQuantOps as FQO
2+
import numpy as np
3+
4+
import paddle
5+
6+
7+
def restore_transpose_split_quant(out, scale):
8+
out = [t.astype('float32') for t in out]
9+
out = paddle.concat(out, axis=1).transpose([1, 0])
10+
scale = paddle.concat(scale, axis=0)
11+
scale = paddle.repeat_interleave(scale, repeats=128, axis=0)
12+
return out * scale
13+
14+
15+
def test_fused_transpose_split_quant(tokens_per_expert, seq_len, pow_2_scales):
16+
print(tokens_per_expert, seq_len, pow_2_scales)
17+
18+
x = paddle.randn([sum(tokens_per_expert), seq_len], dtype='bfloat16')
19+
x = paddle.clip(x, min=-50, max=50)
20+
21+
out, scale = [], []
22+
for tokens in tokens_per_expert:
23+
out.append(paddle.empty([seq_len, tokens], dtype='float8_e4m3fn'))
24+
scale.append(paddle.empty([tokens//128, seq_len], dtype='float32'))
25+
26+
FQO.fused_transpose_split_quant(x, out, scale, pow_2_scales)
27+
28+
x_restore = restore_transpose_split_quant(out, scale)
29+
x_cast = x.astype('float32')
30+
31+
np.testing.assert_allclose(x_cast, x_restore, rtol=0.01, atol=0.3)
32+
33+
34+
def run():
35+
test_fused_transpose_split_quant([0, 0], 1024, False)
36+
test_fused_transpose_split_quant([128, 2*128], 0, True)
37+
test_fused_transpose_split_quant([128], 1, False)
38+
test_fused_transpose_split_quant([0, 128, 0, 2*128], 127, True)
39+
test_fused_transpose_split_quant([3*128, 4*128, 5*128], 233, False)
40+
test_fused_transpose_split_quant(
41+
[24*128, 128, 50*128, 16*128], 2162, True
42+
)
43+
test_fused_transpose_split_quant(
44+
[7*128, 29*128, 3*128, 128*128, 13*128], 4000, False
45+
)
46+
test_fused_transpose_split_quant(
47+
[18*128, 5*128, 24*128, 128, 6*128, 0, 27*128, 7*128], 7168, True
48+
)
49+
50+
51+
if __name__ == '__main__':
52+
run()

0 commit comments

Comments
 (0)