Skip to content

[FIX] fix the bug that some gemm config did not be handled #10627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 22, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ void generic_moe_gemm_kernelLauncher(const T* A,
cudaStream_t stream,
int* kernel_occupancy = nullptr) {
if (gemm_config.split_k_style != SplitKStyle::NO_SPLIT_K) {
PADDLE_FATAL("[MoeGemm] Grouped gemm does not support split-k");
PADDLE_THROW(
phi::errors::Fatal("[MoeGemm] Grouped gemm does not support split-k"));
}

#ifdef PADDLE_CUDA_BF16
Expand Down Expand Up @@ -169,9 +170,9 @@ void generic_moe_gemm_kernelLauncher(const T* A,
int occupancy = std::min(2, GemmGrouped::maximum_active_blocks());

if (occupancy == 0) {
PADDLE_FATAL(
PADDLE_THROW(phi::errors::Fatal(
"[MoE Runner] GPU lacks the shared memory resources to run "
"GroupedGEMM kernel");
"GroupedGEMM kernel"));
}
const int threadblock_count = multi_processor_count * occupancy;

Expand All @@ -197,7 +198,7 @@ void generic_moe_gemm_kernelLauncher(const T* A,
if (can_implement != cutlass::Status::kSuccess) {
std::string err_msg = "MoEFC kernel will fail for params. Error: " +
std::string(cutlassGetStatusString(can_implement));
PADDLE_FATAL("[MoE Runner] " + err_msg);
PADDLE_THROW(phi::errors::Fatal("[MoE Runner] " + err_msg));
}

auto init_status = gemm.initialize(args);
Expand Down Expand Up @@ -243,7 +244,7 @@ struct dispatch_stages {
std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " +
std::to_string(arch::kMinComputeCapability) +
" with stages set to " + std::to_string(Stages);
PADDLE_FATAL("[dispatch_stages::dispatch] " + err_msg);
PADDLE_THROW(phi::errors::Fatal("[dispatch_stages::dispatch] " + err_msg));
}
};

Expand Down Expand Up @@ -394,7 +395,8 @@ void dispatch_gemm_config(const T* A,
default:
std::string err_msg = "dispatch_gemm_config does not support stages " +
std::to_string(gemm_config.stages);
PADDLE_FATAL("[MoE][dispatch_gemm_config] " + err_msg);
PADDLE_THROW(
phi::errors::Fatal("[MoE][dispatch_gemm_config] " + err_msg));
break;
}
}
Expand Down Expand Up @@ -452,17 +454,18 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
dispatch_gemm_config_macro(64, 128, 64, 32, 64, 64);
dispatch_gemm_config_macro(128, 128, 64, 64, 32, 64);
case CutlassTileConfig::Undefined:
PADDLE_FATAL("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
PADDLE_THROW(common::errors::InvalidArgument(
"[dispatch_moe_gemm_to_cutlass] gemm config undefined."));
break;
case CutlassTileConfig::ChooseWithHeuristic:
PADDLE_FATAL(
PADDLE_THROW(common::errors::InvalidArgument(
"[dispatch_moe_gemm_to_cutlass] gemm config should have "
"already been set by heuristic.");
"already been set by heuristic."));
break;
default:
PADDLE_FATAL(
PADDLE_THROW(common::errors::InvalidArgument(
"[dispatch_moe_gemm_to_cutlass] Config is invalid for same "
"type MoE tensorop GEMM.");
"type MoE tensorop GEMM."));
break;
}
}
Expand Down Expand Up @@ -497,40 +500,44 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64);
dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64);
case CutlassTileConfig::Undefined:
PADDLE_FATAL("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
PADDLE_THROW(common::errors::InvalidArgument(
"[dispatch_moe_gemm_to_cutlass] gemm config undefined."));
break;
case CutlassTileConfig::ChooseWithHeuristic:
PADDLE_FATAL(
PADDLE_THROW(common::errors::InvalidArgument(
"[dispatch_moe_gemm_to_cutlass] gemm config should have "
"already been set by heuristic.");
"already been set by heuristic."));
break;
default:
PADDLE_FATAL(
PADDLE_THROW(common::errors::InvalidArgument(
"[dispatch_moe_gemm_to_cutlass] Config is invalid for "
"mixed type tensorop GEMM.");
"mixed type tensorop GEMM."));
break;
}
} else {
switch (gemm_config.tile_config) {
dispatch_gemm_config_macro(16, 128, 64, 16, 32, 64);
dispatch_gemm_config_macro(16, 256, 64, 16, 64, 64);
dispatch_gemm_config_macro(64, 64, 64, 32, 32, 64);
dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64);
dispatch_gemm_config_macro(128, 64, 64, 64, 32, 64);
dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64);
dispatch_gemm_config_macro(128, 128, 64, 64, 64, 64);
dispatch_gemm_config_macro(128, 128, 64, 128, 32, 64);
dispatch_gemm_config_macro(128, 256, 64, 64, 64, 64);
dispatch_gemm_config_macro(64, 128, 64, 64, 32, 64);
case CutlassTileConfig::Undefined:
PADDLE_FATAL("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
PADDLE_THROW(common::errors::InvalidArgument(
"[dispatch_moe_gemm_to_cutlass] gemm config undefined."));
break;
case CutlassTileConfig::ChooseWithHeuristic:
PADDLE_FATAL(
PADDLE_THROW(common::errors::InvalidArgument(
"[dispatch_moe_gemm_to_cutlass] gemm config should have "
"already been set by heuristic.");
"already been set by heuristic."));
break;
default:
PADDLE_FATAL(
"[dispatch_moe_gemm_to_cutlass] Config is invalid for "
"mixed type tensorop GEMM.");
PADDLE_THROW(common::errors::InvalidArgument(
"[dispatch_moe_gemm_to_cutlass] gemm config undefined."));
break;
}
}
Expand Down Expand Up @@ -561,19 +568,19 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
switch (gemm_config.tile_config) {
dispatch_gemm_config_macro(128, 128, 8, 64, 64, 8);
case CutlassTileConfig::Undefined:
PADDLE_FATAL(
PADDLE_THROW(common::errors::InvalidArgument(
"[dispatch_moe_gemm_to_cutlass][SIMT] gemm config "
"undefined.");
"undefined."));
break;
case CutlassTileConfig::ChooseWithHeuristic:
PADDLE_FATAL(
PADDLE_THROW(common::errors::InvalidArgument(
"[dispatch_moe_gemm_to_cutlass][SIMT] gemm config should "
"have already been set by heuristic.");
"have already been set by heuristic."));
break;
default:
PADDLE_FATAL(
PADDLE_THROW(common::errors::InvalidArgument(
"[dispatch_moe_gemm_to_cutlass][SIMT] Unsupported config "
"for float MoE gemm.");
"for float MoE gemm."));
break;
}
}
Expand Down