Skip to content

Commit 09ab2f3

Browse files
committed
fix
1 parent 1aae565 commit 09ab2f3

File tree

1 file changed

+61
-67
lines changed

1 file changed

+61
-67
lines changed

csrc/gpu/moe/fused_moe/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h

Lines changed: 61 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -664,8 +664,8 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
664664
cudaStream_t stream) {
665665
static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
666666
static constexpr bool only_simt_configs = std::is_same<T, float>::value;
667-
std::vector<CutlassGemmConfig> candidate_configs =
668-
get_candidate_configs(sm_, group_size, is_weight_only, only_simt_configs, true);
667+
std::vector<CutlassGemmConfig> candidate_configs = get_candidate_configs(
668+
sm_, group_size, is_weight_only, only_simt_configs, true);
669669
static constexpr int warm_time = 5;
670670
static constexpr int test_time = 10;
671671
auto& gemmConfigManager = GemmConfigManager::Instance();
@@ -684,69 +684,66 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
684684
int profile_total_rows =
685685
std::min(gemmConfigManager.nextPowerOfTwo(total_rows),
686686
gemmConfigManager.getMaxProfileM());
687-
chosen_config = candidate_configs[0];
688-
// bool find_one = false;
689-
// for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
690-
// try {
691-
// for (int i = 0; i < warm_time; i++) {
692-
// dispatch_to_arch<EpilogueTag>(A,
693-
// B,
694-
// weight_scales,
695-
// biases,
696-
// C,
697-
// total_rows_before_expert,
698-
// total_rows,
699-
// gemm_n,
700-
// gemm_k,
701-
// num_experts,
702-
// candidate_configs[ii],
703-
// stream);
704-
// }
705-
// cudaEvent_t start;
706-
// cudaEvent_t stop;
707-
// check_cuda_error(cudaEventCreate(&start));
708-
// check_cuda_error(cudaEventCreate(&stop));
709-
// check_cuda_error(cudaStreamSynchronize(stream));
710-
// check_cuda_error(cudaEventRecord(start, stream));
711-
// for (int i = 0; i < test_time; i++) {
712-
// dispatch_to_arch<EpilogueTag>(A,
713-
// B,
714-
// weight_scales,
715-
// biases,
716-
// C,
717-
// total_rows_before_expert,
718-
// total_rows,
719-
// gemm_n,
720-
// gemm_k,
721-
// num_experts,
722-
// candidate_configs[ii],
723-
// stream);
724-
// }
725-
// check_cuda_error(cudaEventRecord(stop, stream));
726-
// check_cuda_error(cudaEventSynchronize(stop));
727-
// float elapsed;
728-
// check_cuda_error(cudaEventElapsedTime(&elapsed, start, stop));
729-
// check_cuda_error(cudaEventDestroy(start));
730-
// check_cuda_error(cudaEventDestroy(stop));
731-
// if (elapsed < best_time) {
732-
// best_time = elapsed;
733-
// best_config = candidate_configs[ii];
734-
// }
735-
// find_one = true;
736-
// } catch (const std::exception& e) {
737-
// std::cerr << "MOE config[" << ii << "] Caught exception: " <<
738-
// e.what()
739-
// << std::endl;
740-
// }
741-
// }
742-
// if (find_one) {
743-
// gemmConfigManager.addBestConfig(gemmId, profile_total_rows,
744-
// best_config); chosen_config = best_config;
745-
// } else {
746-
// PADDLE_FATAL("[MoE Configure Search] find no one avaliable config.");
747-
// }
687+
bool find_one = false;
688+
for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
689+
try {
690+
for (int i = 0; i < warm_time; i++) {
691+
dispatch_to_arch<EpilogueTag>(A,
692+
B,
693+
weight_scales,
694+
biases,
695+
C,
696+
total_rows_before_expert,
697+
total_rows,
698+
gemm_n,
699+
gemm_k,
700+
num_experts,
701+
candidate_configs[ii],
702+
stream);
703+
}
704+
cudaEvent_t start;
705+
cudaEvent_t stop;
706+
check_cuda_error(cudaEventCreate(&start));
707+
check_cuda_error(cudaEventCreate(&stop));
708+
check_cuda_error(cudaStreamSynchronize(stream));
709+
check_cuda_error(cudaEventRecord(start, stream));
710+
for (int i = 0; i < test_time; i++) {
711+
dispatch_to_arch<EpilogueTag>(A,
712+
B,
713+
weight_scales,
714+
biases,
715+
C,
716+
total_rows_before_expert,
717+
total_rows,
718+
gemm_n,
719+
gemm_k,
720+
num_experts,
721+
candidate_configs[ii],
722+
stream);
723+
}
724+
check_cuda_error(cudaEventRecord(stop, stream));
725+
check_cuda_error(cudaEventSynchronize(stop));
726+
float elapsed;
727+
check_cuda_error(cudaEventElapsedTime(&elapsed, start, stop));
728+
check_cuda_error(cudaEventDestroy(start));
729+
check_cuda_error(cudaEventDestroy(stop));
730+
if (elapsed < best_time) {
731+
best_time = elapsed;
732+
best_config = candidate_configs[ii];
733+
}
734+
find_one = true;
735+
} catch (const std::exception& e) {
736+
std::cerr << "MOE config[" << ii << "] Caught exception: " << e.what()
737+
<< std::endl;
738+
}
739+
}
740+
if (find_one) {
741+
gemmConfigManager.addBestConfig(gemmId, profile_total_rows, best_config);
742+
chosen_config = best_config;
743+
} else {
744+
PADDLE_FATAL("[MoE Configure Search] find no one avaliable config.");
745+
}
748746
}
749-
try {
750747
dispatch_to_arch<EpilogueTag>(A,
751748
B,
752749
weight_scales,
@@ -760,9 +757,6 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
760757
group_size,
761758
chosen_config,
762759
stream);
763-
} catch (const std::exception& e) {
764-
std::cerr << "MOE best config Caught exception: " << e.what() << std::endl;
765-
}
766760
}
767761

768762
template <typename T, typename WeightType>

0 commit comments

Comments
 (0)