Skip to content

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class MoeGemmRunner {
int64_t gemm_k,
int num_experts,
std::string activation_type,
const int32_t weightonly_group_size,
cudaStream_t stream);

void moe_gemm(const T* A,
Expand All @@ -48,10 +49,11 @@ class MoeGemmRunner {
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
int group_size,
cudaStream_t stream);

private:
template <typename EpilogueTag>
template <typename EpilogueTag, bool FineGrained>
void dispatch_to_arch(const T* A,
const WeightType* B,
const T* weight_scales,
Expand All @@ -62,11 +64,12 @@ class MoeGemmRunner {
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
int group_size,
CutlassGemmConfig gemm_config,
cudaStream_t stream,
int* occupancy = nullptr);

template <typename EpilogueTag>
template <typename EpilogueTag, bool FineGrained>
void run_gemm(const T* A,
const WeightType* B,
const T* weight_scales,
Expand All @@ -77,6 +80,7 @@ class MoeGemmRunner {
int64_t gemm_n,
int64_t gemm_k,
int num_experts,
int group_size,
cudaStream_t stream);

private:
Expand Down

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions csrc/gpu/moe/fused_moe/fused_moe.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ void FusedMoeKernel(const paddle::Tensor& input,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const std::string& quant_method,
const int weightonly_group_size,
const int moe_topk,
const bool group_moe,
const bool norm_topk_prob,
Expand Down Expand Up @@ -86,6 +87,7 @@ void FusedMoeKernel(const paddle::Tensor& input,
ffn2_scale ? ffn2_scale.get_ptr() : nullptr,
ffn2_bias ? ffn2_bias.get_ptr() : nullptr,
nullptr,
weightonly_group_size,
moe_topk,
group_moe,
norm_topk_prob,
Expand All @@ -105,6 +107,7 @@ std::vector<paddle::Tensor> FusedExpertMoe(
const paddle::optional<paddle::Tensor>& ffn2_bias,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const std::string& quant_method,
const int weightonly_group_size,
const int moe_topk,
const bool norm_topk_prob,
const bool group_moe) {
Expand All @@ -122,6 +125,7 @@ std::vector<paddle::Tensor> FusedExpertMoe(
ffn2_scale,
ffn2_bias,
quant_method,
weightonly_group_size,
moe_topk,
group_moe,
norm_topk_prob,
Expand All @@ -137,6 +141,7 @@ std::vector<paddle::Tensor> FusedExpertMoe(
ffn2_scale,
ffn2_bias,
quant_method,
weightonly_group_size,
moe_topk,
group_moe,
norm_topk_prob,
Expand Down Expand Up @@ -184,6 +189,7 @@ PD_BUILD_OP(fused_expert_moe)
paddle::Optional("ffn2_scale")})
.Outputs({"output"})
.Attrs({"quant_method:std::string",
"weightonly_group_size:int",
"moe_topk:int",
"norm_topk_prob:bool",
"group_moe:bool"})
Expand Down
7 changes: 7 additions & 0 deletions csrc/gpu/moe/fused_moe/moe/fused_moe_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class MoeHelper {
const paddle::Tensor *ffn2_scale,
const paddle::Tensor *ffn2_bias,
const paddle::Tensor *moe_token_type_ids,
const int weightonly_group_size,
const int moe_topk,
const bool group_moe,
const bool norm_topk_prob,
Expand Down Expand Up @@ -304,6 +305,7 @@ class MoeHelper {
hidden_size,
num_experts,
"none",
weightonly_group_size,
stream);
} else if (gemm_method_ == "weight_only_int4") {
int4_moe_gemm_runner_->moe_gemm_bias_act(
Expand All @@ -319,6 +321,7 @@ class MoeHelper {
hidden_size,
num_experts,
"none",
weightonly_group_size,
stream);
} else {
fp16_moe_gemm_runner_->moe_gemm_bias_act(
Expand All @@ -333,6 +336,7 @@ class MoeHelper {
hidden_size,
num_experts,
"none",
weightonly_group_size,
stream);
}

Expand All @@ -356,6 +360,7 @@ class MoeHelper {
hidden_size,
inter_size / 2,
num_experts,
weightonly_group_size,
stream);
} else if (gemm_method_ == "weight_only_int4") {
int4_moe_gemm_runner_->moe_gemm(
Expand All @@ -369,6 +374,7 @@ class MoeHelper {
hidden_size,
inter_size / 2,
num_experts,
weightonly_group_size,
stream);
} else {
fp16_moe_gemm_runner_->moe_gemm(
Expand All @@ -381,6 +387,7 @@ class MoeHelper {
hidden_size,
inter_size / 2,
num_experts,
weightonly_group_size,
stream);
}

Expand Down
30 changes: 26 additions & 4 deletions csrc/gpu/moe/fused_moe/moe_ffn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const std::string& quant_method,
const int32_t weightonly_group_size,
paddle::Tensor ffn_out) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
Expand Down Expand Up @@ -62,8 +63,16 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
ffn1_bias
? const_cast<paddle::Tensor*>(ffn1_bias.get_ptr())->data<data_t>()
: nullptr;

/*
group size
different quant and no quant, no quant and quant channel wise have same group
size no quant : group_size = -1
quant channel wise : group_size = -1
quant group wise : group_size = 64 || 128
*/
int group_size = weightonly_group_size;
if (quant_method == "weight_only_int8") {
group_size = group_size == -1 ? hidden_size : group_size;
int8_moe_gemm_runner.moe_gemm_bias_act(
reinterpret_cast<const NvType*>(permuted_data),
reinterpret_cast<const uint8_t*>(ffn1_weight.data<int8_t>()),
Expand All @@ -77,8 +86,10 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
hidden_size,
num_experts,
"none",
group_size,
stream);
} else if (quant_method == "weight_only_int4") {
group_size = group_size == -1 ? hidden_size : group_size;
int4_moe_gemm_runner.moe_gemm_bias_act(
reinterpret_cast<const NvType*>(permuted_data),
reinterpret_cast<const cutlass::uint4b_t*>(ffn1_weight.data<int8_t>()),
Expand All @@ -92,6 +103,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
hidden_size,
num_experts,
"none",
group_size,
stream);
} else {
fp16_moe_gemm_runner.moe_gemm_bias_act(
Expand All @@ -106,13 +118,16 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
hidden_size,
num_experts,
"none",
-1,
stream);
}

auto act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr);
auto act_out = act_out_tensor.data<data_t>();

// reset group_size
group_size = weightonly_group_size;
if (quant_method == "weight_only_int8") {
group_size = group_size == -1 ? inter_size / 2 : group_size;
int8_moe_gemm_runner.moe_gemm(
reinterpret_cast<const NvType*>(act_out),
reinterpret_cast<const uint8_t*>(ffn2_weight.data<int8_t>()),
Expand All @@ -124,9 +139,11 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
hidden_size,
inter_size / 2,
num_experts,
group_size,
stream);

} else if (quant_method == "weight_only_int4") {
group_size = group_size == -1 ? inter_size / 2 : group_size;
int4_moe_gemm_runner.moe_gemm(
reinterpret_cast<const NvType*>(act_out),
reinterpret_cast<const cutlass::uint4b_t*>(ffn2_weight.data<int8_t>()),
Expand All @@ -138,6 +155,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
hidden_size,
inter_size / 2,
num_experts,
group_size,
stream);
} else {
fp16_moe_gemm_runner.moe_gemm(
Expand All @@ -150,6 +168,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input,
hidden_size,
inter_size / 2,
num_experts,
-1,
stream);
}
}
Expand All @@ -162,7 +181,8 @@ std::vector<paddle::Tensor> MoeExpertFFN(
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const std::string& quant_method) {
const std::string& quant_method,
const int32_t weightonly_group_size) {
const auto input_type = permute_input.dtype();
auto ffn_out = paddle::empty_like(permute_input);

Expand All @@ -176,6 +196,7 @@ std::vector<paddle::Tensor> MoeExpertFFN(
ffn1_scale,
ffn2_scale,
quant_method,
weightonly_group_size,
ffn_out);
break;
case paddle::DataType::FLOAT16:
Expand All @@ -187,6 +208,7 @@ std::vector<paddle::Tensor> MoeExpertFFN(
ffn1_scale,
ffn2_scale,
quant_method,
weightonly_group_size,
ffn_out);
break;
default:
Expand Down Expand Up @@ -226,7 +248,7 @@ PD_BUILD_OP(moe_expert_ffn)
paddle::Optional("ffn1_scale"),
paddle::Optional("ffn2_scale")})
.Outputs({"output_tensor"})
.Attrs({"quant_method:std::string"})
.Attrs({"quant_method:std::string", "weightonly_group_size: int"})
.SetKernelFn(PD_KERNEL(MoeExpertFFN))
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype));
Loading