Skip to content

Add fused_transpose_split_quant kernel #10657

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
#include "quant_utils.h"

template <typename T, int VecSize>
struct __align__(sizeof(T) * VecSize) VecType {
T val[VecSize];
__host__ __device__ inline T& operator[](size_t i) { return val[i]; }
__host__ __device__ inline const T& operator[](size_t i) const {
return val[i];
}
};

template <int VecSize>
__device__ void BlockLoad(const phi::bfloat16* X,
__nv_bfloat16 input[4][4],
size_t K) {
for (size_t i = 0; i < 4; i++) {
size_t off_m = blockIdx.x * 128 + threadIdx.y + i * 32;
size_t off_k = blockIdx.y * 128 + threadIdx.x * VecSize;
size_t offset = off_m * K + off_k;

for (size_t j = 0; j < 4; j += VecSize) {
if (off_k + j * 32 < K) {
size_t idx = offset + j * 32;
using LoadT = VecType<__nv_bfloat16, VecSize>;
LoadT data = *reinterpret_cast<const LoadT*>(X + idx);
for (int k = 0; k < VecSize; k++) {
input[i][j + k] = data[k];
}
}
}
}
}

__device__ void BlockColumnMax(const __nv_bfloat16 input[4][4],
__nv_bfloat16 amax[4],
__nv_bfloat16* shm) {
// Reduce [(4), 32, 32, 4] => [32, 32, 4]
__nv_bfloat16 warp_max[4];
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
__nv_bfloat16 t = __habs(input[i][j]);
warp_max[j] = i == 0 ? t : __hmax(warp_max[j], t);
}
}

// Reduce [(32), 32, 4] => [32, 4]
for (int i = 0; i < 4; i++) {
shm[threadIdx.y * 128 + i * 32 + threadIdx.x] = warp_max[i];
}
__syncthreads();
for (int offset = 16; offset > 0; offset /= 2) {
if (threadIdx.y < offset) {
for (int i = 0; i < 4; i++) {
shm[threadIdx.y * 128 + i * 32 + threadIdx.x] =
__hmax(shm[threadIdx.y * 128 + i * 32 + threadIdx.x],
shm[(threadIdx.y + offset) * 128 + i * 32 + threadIdx.x]);
}
}
__syncthreads();
}

for (int i = 0; i < 4; i++) {
amax[i] = shm[i * 32 + threadIdx.x];
}
__syncthreads();
}

template <typename OutT, bool Pow2Scales, int VecSize>
__device__ void BlockStoreScale(float* scale,
size_t off_m,
__nv_bfloat16 amax[4],
float scale_inv[4],
size_t K) {
float scale_out[4];
for (int i = 0; i < 4; i++) {
scale_inv[i] = ComputeScale<__nv_bfloat16, OutT, Pow2Scales>(
static_cast<float>(amax[i]), 0.0f);
scale_out[i] = __frcp_rn(scale_inv[i]);
}
if (threadIdx.y == 0) {
size_t idx_m = blockIdx.x - off_m / 128;
size_t off_k = blockIdx.y * 128 + threadIdx.x * VecSize;
size_t offset = idx_m * K + off_k;

for (size_t j = 0; j < 4; j += VecSize) {
if (off_k + j * 32 < K) {
size_t idx = offset + j * 32;
using StoreT = VecType<float, VecSize>;
StoreT data;
for (int k = 0; k < VecSize; k++) {
data[k] = scale_out[j + k];
}
*reinterpret_cast<StoreT*>(scale + idx) = data;
}
}
}
}

template <typename OutT, int VecSize>
__device__ void BlockStoreOut(OutT* out,
size_t off_m,
size_t cur_tokens,
const OutT shm[128][129],
size_t K) {
for (size_t i = 0; i < 4; i++) {
size_t idx_m = blockIdx.x * 128 + threadIdx.x * 4;
size_t idx_k = blockIdx.y * 128 + threadIdx.y + i * 32;
size_t idx = idx_k * cur_tokens + (idx_m - off_m);

if (idx_k < K) {
using StoreT = VecType<OutT, VecSize>;
StoreT data;
for (int j = 0; j < VecSize; j++) {
data[j] = shm[i * 32 + threadIdx.y][threadIdx.x * 4 + j];
}
*reinterpret_cast<StoreT*>(out + idx) = data;
}
}
}

__device__ std::pair<size_t, size_t> GetExpertIdx(int64_t* tokens_per_expert,
size_t num_experts) {
__shared__ size_t expert_idx_, off_m_;

if (threadIdx.x == 0 && threadIdx.y == 0) {
size_t idx_m = blockIdx.x * 128;
size_t off_m = 0, next_off_m = 0;
size_t expert_idx;
for (expert_idx = 0; expert_idx < num_experts; expert_idx++) {
next_off_m += tokens_per_expert[expert_idx];
if (idx_m >= off_m && idx_m < next_off_m) {
break;
}
off_m = next_off_m;
}
expert_idx_ = expert_idx;
off_m_ = off_m;
}
__syncthreads();

return {expert_idx_, off_m_};
}

template <typename OutT, bool Pow2Scales, int VecSize>
__global__ void __launch_bounds__(1024)
FusedTransposeSplitQuantKernel(const phi::bfloat16* __restrict__ X,
int64_t* __restrict__ meta,
size_t num_experts,
size_t K) {
__shared__ OutT shm[128][129];
int64_t* tokens_per_expert = meta;
OutT** out_ptrs = reinterpret_cast<OutT**>(meta + num_experts);
float** scale_ptrs = reinterpret_cast<float**>(meta + num_experts * 2);

// Get expert_idx and offset at the M dim of the current block
auto expert_info = GetExpertIdx(tokens_per_expert, num_experts);
size_t expert_idx = expert_info.first;
size_t off_m = expert_info.second;

// Load 128x128 elements from X
__nv_bfloat16 input[4][4];
BlockLoad<VecSize>(X, input, K);

// Find the maximum of each 128 elements on the M axis
__nv_bfloat16 amax[4];
BlockColumnMax(input, amax, reinterpret_cast<__nv_bfloat16*>(shm));

// Compute scale and scale_inv, then store scale back
float scale_inv[4];
BlockStoreScale<OutT, Pow2Scales, VecSize>(
scale_ptrs[expert_idx], off_m, amax, scale_inv, K);

// Scale X and save into shared memory with transposed layout
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j += VecSize) {
for (int k = 0; k < VecSize; k++) {
float input_fp32 = static_cast<float>(input[i][j + k]);
float output_scaled = input_fp32 * scale_inv[j + k];
shm[threadIdx.x * VecSize + j * 32 + k][i * 32 + threadIdx.y] =
static_cast<OutT>(output_scaled);
}
}
}
__syncthreads();

// Store 128x128 elements back
// Note: out is always 4x vectorizable.
BlockStoreOut<OutT, 4>(
out_ptrs[expert_idx], off_m, tokens_per_expert[expert_idx], shm, K);
}

/**
* Quantize on dim[0] of X, transpose dim[0] and dim[1] of X, then
* split the result into out and scale.
*
* Inputs:
* X : [SUM(M_1...M_N), K], bfloat16
*
* Outputs:
* out : {[K, M_1], [K, M_2], ..., [K, M_N]}, float8_e4m3fn
* scale : {[M_1/128, K], [M_2/128, K], ..., [M_N/128, K]}, float
*
* Attrs:
* pow_2_scales
* : bool that indicates whether to use power-of-2 scaling
*
* Requirements:
* 1) M_i % 128 == 0 for M_i in [M_1, M_2, ..., M_N]
* 2) K <= 65535 * 128
*/
void fused_transpose_split_quant(const paddle::Tensor& X,
std::vector<paddle::Tensor>& outs,
std::vector<paddle::Tensor>& scales,
bool pow_2_scales) {
// Check X
PD_CHECK(X.dtype() == paddle::DataType::BFLOAT16);

std::vector<int64_t> shape = X.shape();
PD_CHECK(shape.size() == 2);
const int64_t M = shape[0];
const int64_t K = shape[1];

// Check outs and scales
const size_t num_experts = outs.size();
PD_CHECK(scales.size() == num_experts);

std::vector<int64_t> tokens_per_expert;
int64_t sum_tokens = 0;
for (size_t i = 0; i < num_experts; i++) {
PD_CHECK(outs[i].dtype() == paddle::DataType::FLOAT8_E4M3FN);
PD_CHECK(scales[i].dtype() == paddle::DataType::FLOAT32);

std::vector<int64_t> out_shape = outs[i].shape();
PD_CHECK(out_shape.size() == 2);
PD_CHECK(out_shape[0] == K);
PD_CHECK(out_shape[1] % 128 == 0);
tokens_per_expert.push_back(out_shape[1]);
sum_tokens += out_shape[1];

std::vector<int64_t> scale_shape = scales[i].shape();
PD_CHECK(scale_shape.size() == 2);
PD_CHECK(scale_shape[0] == out_shape[1] / 128);
PD_CHECK(scale_shape[1] == K);
}

PD_CHECK(sum_tokens == M,
"sum of out[i].shape[1] must be equal to X.shape[0]");
PD_CHECK(K <= 65535 * 128, "only supports K <= 65535 * 128");

// Skip 0-size
if (M == 0 || K == 0) {
return;
}

// Copy meta (tokens_per_expert, out_ptrs, scale_ptrs) to device
paddle::Tensor meta_cpu = paddle::empty(
{static_cast<int64_t>(num_experts * 3)}, paddle::DataType::INT64);
int64_t* meta_ptr = meta_cpu.data<int64_t>();
for (size_t i = 0; i < num_experts; i++) {
meta_ptr[i] = static_cast<int64_t>(tokens_per_expert[i]);
}
for (size_t i = 0; i < num_experts; i++) {
meta_ptr[num_experts + i] =
reinterpret_cast<int64_t>(outs[i].data<phi::float8_e4m3fn>());
}
for (size_t i = 0; i < num_experts; i++) {
meta_ptr[num_experts * 2 + i] =
reinterpret_cast<int64_t>(scales[i].data<float>());
}
paddle::Tensor meta_gpu = meta_cpu.copy_to(X.place(), /*blocking=*/false);

// Launch kernel
dim3 grid(M / 128, (K + 127) / 128);
dim3 block(32, 32);

#define LAUNCH_KERNEL(POW_2_SCALES, VEC_SIZE) \
FusedTransposeSplitQuantKernel<phi::float8_e4m3fn, POW_2_SCALES, VEC_SIZE> \
<<<grid, block>>>( \
X.data<phi::bfloat16>(), meta_gpu.data<int64_t>(), num_experts, K);
#define LAUNCH_KERNEL_PARTIAL(VEC_SIZE) \
if (pow_2_scales) { \
LAUNCH_KERNEL(true, VEC_SIZE); \
} else { \
LAUNCH_KERNEL(false, VEC_SIZE); \
}

if (K % 4 == 0) {
LAUNCH_KERNEL_PARTIAL(4);
} else if (K % 2 == 0) {
LAUNCH_KERNEL_PARTIAL(2);
} else {
LAUNCH_KERNEL_PARTIAL(1);
}
#undef LAUNCH_KERNEL_PARTIAL
#undef LAUNCH_KERNEL
}

PD_BUILD_OP(fused_transpose_split_quant)
.Inputs({"X", paddle::Vec("outs"), paddle::Vec("scales")})
.Attrs({"pow_2_scales: bool"})
.SetKernelFn(PD_KERNEL(fused_transpose_split_quant));
1 change: 1 addition & 0 deletions slm/model_zoo/gpt-3/external_ops/setup_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def setup_fused_quant_ops():
"fused_quanted_ops/fused_act_dequant_transpose_act_quant.cu",
"fused_quanted_ops/fused_spaq.cu",
"fused_quanted_ops/fused_stack_transpose_quant.cu",
"fused_quanted_ops/fused_transpose_split_quant.cu",
],
extra_compile_args={
"cxx": ["-O3", "-w", "-Wno-abi", "-fPIC", "-std=c++17"],
Expand Down
52 changes: 52 additions & 0 deletions tests/ops/test_fused_transpose_split_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import FusedQuantOps as FQO
import numpy as np

import paddle


def restore_transpose_split_quant(out, scale):
out = [t.astype('float32') for t in out]
out = paddle.concat(out, axis=1).transpose([1, 0])
scale = paddle.concat(scale, axis=0)
scale = paddle.repeat_interleave(scale, repeats=128, axis=0)
return out * scale


def test_fused_transpose_split_quant(tokens_per_expert, seq_len, pow_2_scales):
print(tokens_per_expert, seq_len, pow_2_scales)

x = paddle.randn([sum(tokens_per_expert), seq_len], dtype='bfloat16')
x = paddle.clip(x, min=-50, max=50)

out, scale = [], []
for tokens in tokens_per_expert:
out.append(paddle.empty([seq_len, tokens], dtype='float8_e4m3fn'))
scale.append(paddle.empty([tokens//128, seq_len], dtype='float32'))

FQO.fused_transpose_split_quant(x, out, scale, pow_2_scales)

x_restore = restore_transpose_split_quant(out, scale)
x_cast = x.astype('float32')

np.testing.assert_allclose(x_cast, x_restore, rtol=0.01, atol=0.3)


def run():
test_fused_transpose_split_quant([0, 0], 1024, False)
test_fused_transpose_split_quant([128, 2*128], 0, True)
test_fused_transpose_split_quant([128], 1, False)
test_fused_transpose_split_quant([0, 128, 0, 2*128], 127, True)
test_fused_transpose_split_quant([3*128, 4*128, 5*128], 233, False)
test_fused_transpose_split_quant(
[24*128, 128, 50*128, 16*128], 2162, True
)
test_fused_transpose_split_quant(
[7*128, 29*128, 3*128, 128*128, 13*128], 4000, False
)
test_fused_transpose_split_quant(
[18*128, 5*128, 24*128, 128, 6*128, 0, 27*128, 7*128], 7168, True
)


if __name__ == '__main__':
run()
Loading