Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,22 @@ __global__ void FusedActDequant(const phi::float8_e4m3fn *__restrict__ Xin,
const float *__restrict__ Xscale,
phi::bfloat16 *__restrict__ out,
const int64_t rows,
const int64_t cols) {
const int cols) {
const int64_t this_row_idx = blockIdx.x;
if (this_row_idx >= rows) return;

const int64_t Xscale_stride = (cols + 127) / 128; // 计算缩放因子的步长
const int Xscale_stride = (cols + 127) / 128; // 计算缩放因子的步长

const int vector_size = 16; // 向量的元素数量,处理16个元素

// 每行的向量数量
const int64_t num_vectors = cols / vector_size;
const int num_vectors = cols / vector_size;
const int remaining_elements = cols % vector_size;

const int64_t tid = threadIdx.x;
const int tid = threadIdx.x;

for (int64_t vec_idx = tid; vec_idx < num_vectors; vec_idx += blockDim.x) {
int64_t x_offset = vec_idx * vector_size;
for (int vec_idx = tid; vec_idx < num_vectors; vec_idx += blockDim.x) {
int x_offset = vec_idx * vector_size;
int64_t X_idx = (int64_t)this_row_idx * (int64_t)cols + (int64_t)x_offset;

// 加载16个 __nv_fp8_e4m3 元素到向量中
Expand Down Expand Up @@ -76,7 +76,7 @@ __global__ void FusedActDequant(const phi::float8_e4m3fn *__restrict__ Xin,

// 处理剩余不能被向量化的元素
if (remaining_elements > 0) {
int64_t x_offset = num_vectors * vector_size;
int x_offset = num_vectors * vector_size;
int64_t X_idx = (int64_t)this_row_idx * (int64_t)cols + (int64_t)x_offset;
int64_t idx = X_idx + tid;
if (tid < remaining_elements) {
Expand Down Expand Up @@ -110,6 +110,16 @@ std::vector<paddle::Tensor> fused_act_dequant(const paddle::Tensor &X,
int64_t rows, cols;
rows = X.shape()[0];
cols = X.shape()[1];
PADDLE_ENFORCE_LE(
rows,
std::numeric_limits<int32_t>::max(),
common::errors::InvalidArgument(
"rows should be less than INT_MAX, received rows: (%ld)", rows));
PADDLE_ENFORCE_LE(
cols,
std::numeric_limits<int32_t>::max(),
common::errors::InvalidArgument(
"cols should be less than INT_MAX, received cols: (%ld)", cols));
paddle::Tensor out;

out = paddle::empty({rows, cols}, paddle::DataType::BFLOAT16, X.place());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ __global__ void FusedSPAQKernel(const phi::bfloat16 *__restrict__ Xin,

// Phase 3: Compute scales and quantize the outputs
const float block_max_float = (float)quant_block_amax[quant_block_idx];
const int64_t scale_stride = (cols / 2 + 127) / 128;
const int scale_stride = (cols / 2 + 127) / 128;

float scale = ComputeScale<float, __nv_fp8_e4m3, using_pow2_scaling>(
block_max_float, 0.0f);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ __global__ void __launch_bounds__(1024)
for (int j = 0; j < 4; j++) {
float input_fp32 = static_cast<float>(input[i][j]);
float output_scaled = input_fp32 * scale_inv;
shm[static_cast<size_t>(threadIdx.x) * 4 + j][i * 32 + threadIdx.y] =
shm[threadIdx.x * 4 + j][i * 32 + threadIdx.y] =
static_cast<OutT>(output_scaled);
}
}
Expand All @@ -207,14 +207,13 @@ __global__ void __launch_bounds__(1024)
for (size_t i = 0; i < 4; i++) {
size_t idx_n = blockIdx.z;
size_t idx_k = block_x * 128 + threadIdx.y + i * 32;
size_t idx_m = block_y * 128 + static_cast<size_t>(threadIdx.x) * 4;
size_t idx_m = block_y * 128 + threadIdx.x * 4;
size_t idx = (idx_n * K + idx_k) * M + idx_m;

using StoreT = VecType<OutT, 4>;
StoreT data;
for (int j = 0; j < 4; j++) {
data[j] =
shm[i * 32 + threadIdx.y][static_cast<size_t>(threadIdx.x) * 4 + j];
data[j] = shm[i * 32 + threadIdx.y][threadIdx.x * 4 + j];
}
*reinterpret_cast<StoreT*>(out + idx) = data;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ __global__ void SwigluProbsGradKernelVec4(
BFloat16* do1, // [seq_len*topk, moe_intermediate_size*2]
float* probs_grad, // [seq_len*topk, 1]
BFloat16* o2_s, // [seq_len*topk, moe_intermediate_size]
int64_t moe_intermediate_size) {
int moe_intermediate_size) {
constexpr int numel_per_thread = 4;
constexpr int k_warp_size = 32;
const int64_t row_idx = blockIdx.x;
Expand All @@ -210,7 +210,7 @@ __global__ void SwigluProbsGradKernelVec4(

float local_probs_grad = 0.0f;

const int64_t vec_numel = (int64_t)moe_intermediate_size / numel_per_thread;
const int vec_numel = (int64_t)moe_intermediate_size / numel_per_thread;
for (int64_t i = tid; i < vec_numel; i += blockDim.x) {
float4 lhs_vec4 = load_and_cast_float4(o1_row_left_half_vec4 + i);
float4 rhs_vec4 = load_and_cast_float4(o1_row_right_half_vec4 + i);
Expand Down Expand Up @@ -269,6 +269,12 @@ std::vector<paddle::Tensor> SwigluProbsGradCUDABackward(

const int64_t moe_intermediate_size_2 = o1_dims[o1_dims.size() - 1];
const int64_t moe_intermediate_size = moe_intermediate_size_2 / 2;
PADDLE_ENFORCE_LE(moe_intermediate_size,
std::numeric_limits<int32_t>::max(),
common::errors::InvalidArgument(
"moe_intermediate_size should be less than INT_MAX, "
"received moe_intermediate_size: (%ld)",
moe_intermediate_size));

auto do1 = inplace ? o1 : paddle::empty_like(o1);
auto probs_grad =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,15 @@ __device__ void BlockColumnMax(const __nv_bfloat16 input[4][4],

// Reduce [(32), 32, 4] => [32, 4]
for (int i = 0; i < 4; i++) {
shm[static_cast<size_t>(threadIdx.y) * 128 + i * 32 + threadIdx.x] =
warp_max[i];
shm[threadIdx.y * 128 + i * 32 + threadIdx.x] = warp_max[i];
}
__syncthreads();
for (int offset = 16; offset > 0; offset /= 2) {
if (threadIdx.y < offset) {
for (int i = 0; i < 4; i++) {
shm[static_cast<size_t>(threadIdx.y) * 128 + i * 32 + threadIdx.x] =
__hmax(shm[static_cast<size_t>(threadIdx.y) * 128 + i * 32 +
threadIdx.x],
shm[(static_cast<size_t>(threadIdx.y) + offset) * 128 +
i * 32 + threadIdx.x]);
shm[threadIdx.y * 128 + i * 32 + threadIdx.x] =
__hmax(shm[threadIdx.y * 128 + i * 32 + threadIdx.x],
shm[(threadIdx.y + offset) * 128 + i * 32 + threadIdx.x]);
}
}
__syncthreads();
Expand Down Expand Up @@ -130,8 +127,7 @@ __device__ void BlockStoreOut(OutT* out,
using StoreT = VecType<OutT, VecSize>;
StoreT data;
for (int j = 0; j < VecSize; j++) {
data[j] =
shm[i * 32 + threadIdx.y][static_cast<size_t>(threadIdx.x) * 4 + j];
data[j] = shm[i * 32 + threadIdx.y][threadIdx.x * 4 + j];
}
*reinterpret_cast<StoreT*>(out + idx) = data;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,8 @@ __global__ void tokens_unzip_gather_kernel(
int64_t *__restrict__ index_unzipped,
int64_t unzipped_rows,
int64_t zipped_rows,
int64_t token_length,
int64_t scale_length,
int token_length,
int scale_length,
int num_experts,
int expert_id,
int64_t offset) {
Expand Down Expand Up @@ -511,6 +511,12 @@ std::vector<paddle::Tensor> tokens_unzip_gather(
PD_CHECK(x_shape.size() == 2);
int64_t zipped_rows = x_shape[0];
int64_t hidden_size = x_shape[1];
PADDLE_ENFORCE_LE(
hidden_size,
std::numeric_limits<int32_t>::max(),
common::errors::InvalidArgument("hidden_size should be less than "
"INT_MAX, received hidden_size: (%ld)",
hidden_size));

std::vector<int64_t> x_scale_shape;
int64_t quanted_hidden_size = 0;
Expand All @@ -521,6 +527,12 @@ std::vector<paddle::Tensor> tokens_unzip_gather(
PD_CHECK(x_scale_shape[0] == x_shape[0]);
quanted_hidden_size = x_scale_shape[1];
}
PADDLE_ENFORCE_LE(quanted_hidden_size,
std::numeric_limits<int32_t>::max(),
common::errors::InvalidArgument(
"quanted_hidden_size should be less than "
"INT_MAX, received quanted_hidden_size: (%ld)",
quanted_hidden_size));

auto x_unzipped =
paddle::zeros({padded_num_tokens, hidden_size}, dtype, place);
Expand Down
Loading