From 7053804c69b5c76230ca0827b7658603b1ff004d Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Mon, 9 Feb 2026 14:29:52 -0600 Subject: [PATCH 1/3] Extend SplitK support to ABQuantGrouped mode - Add AQ offset calculation in SplitKBatchOffset for ABQuant mode - Update mode restriction to support ABQuant alongside BQuant - Add AQ alignment validation for split-K batches - Apply AQ pointer offsetting based on tensor layout - Update block window creation to handle AQ group offsets - Requires non-preshuffle mode for both A and B quantization --- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 150 ++++++++++++------ 1 file changed, 105 insertions(+), 45 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 63682040ad3e..adf7e94f0c3f 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -454,12 +454,50 @@ struct QuantGemmKernel bq_group_offset = 0; bq_k_split_offset = 0; } + + // Compute AQ and BQ offsets for ABQuantGrouped mode + if constexpr(kQuantType == QuantType::ABQuantGrouped && !APreshuffleQuant) + { + using AQuantGroupSize = remove_cvref_t; + using BQuantGroupSize = remove_cvref_t; + + const index_t k_offset = amd_wave_read_first_lane(k_id * KRead); + aq_group_offset = amd_wave_read_first_lane(k_offset / AQuantGroupSize::kK); + bq_group_offset = amd_wave_read_first_lane(k_offset / BQuantGroupSize::kK); + + if constexpr(std::is_same_v) + { + aq_k_split_offset = amd_wave_read_first_lane(aq_group_offset); + } + else if constexpr(std::is_same_v) + { + aq_k_split_offset = amd_wave_read_first_lane(aq_group_offset * kargs.stride_AQ); + } + + if constexpr(std::is_same_v) + { + const index_t stride_bq = + amd_wave_read_first_lane(integer_divide_ceil(kargs.N, BQuantGroupSize::kN)); + bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset * stride_bq); + } + else if constexpr(std::is_same_v) + { + bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset); + } + } + else + { + aq_group_offset = 0; + aq_k_split_offset = 0; + } } index_t a_k_split_offset; index_t b_k_split_offset; - index_t bq_group_offset; // Logical offset in K-groups (K/kK dimension) - index_t bq_k_split_offset; // Memory pointer offset (accounting for layout/stride) + index_t aq_group_offset; // Logical offset in AQ K-groups (K/kK dimension) for ABQuant + index_t aq_k_split_offset; // Memory pointer offset for AQ (accounting for layout/stride) + index_t bq_group_offset; // Logical offset in BQ K-groups (K/kK dimension) + index_t bq_k_split_offset; // Memory pointer offset for BQ (accounting for layout/stride) index_t splitted_k; }; @@ -531,6 +569,7 @@ struct QuantGemmKernel CK_TILE_DEVICE static auto MakeAQBlockWindow(const AQDataType* aq_ptr, const QuantGemmKernelArgs& kargs, + const index_t aq_group_offset, const index_t i_m, const index_t i_n) { @@ -591,35 +630,19 @@ struct QuantGemmKernel return make_tensor_view(aq_ptr, aq_merge_pad1_desc); } - else if constexpr(kQuantType == QuantType::AQuantGrouped && !APreshuffleQuant) + else if constexpr((kQuantType == QuantType::AQuantGrouped || + kQuantType == QuantType::ABQuantGrouped) && + !APreshuffleQuant) { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - aq_ptr, - make_tuple(kargs.M, kargs.QK_A), - make_tuple(kargs.stride_AQ, 1), - number{}, - number<1>{}); - } - else // Column major AQ - { - return make_naive_tensor_view( - aq_ptr, - make_tuple(kargs.QK_A, kargs.M), - make_tuple(kargs.stride_AQ, 1), + // For split-K with ABQuant, adjust AQ tensor dimension based on group offset + // This reflects "remaining K-groups from split-K offset position" + const index_t aq_k_groups = kargs.QK_A - aq_group_offset; - number{}, - number<1>{}); - } - } - else if constexpr(kQuantType == QuantType::ABQuantGrouped && !APreshuffleQuant) - { if constexpr(std::is_same_v) { return make_naive_tensor_view( aq_ptr, - make_tuple(kargs.M, kargs.QK_A), + make_tuple(kargs.M, aq_k_groups), make_tuple(kargs.stride_AQ, 1), number{}, number<1>{}); @@ -628,9 +651,8 @@ struct QuantGemmKernel { return make_naive_tensor_view( aq_ptr, - make_tuple(kargs.M, kargs.QK_A), - make_tuple(1, kargs.stride_AQ), - + make_tuple(aq_k_groups, kargs.M), + make_tuple(kargs.stride_AQ, 1), number{}, number<1>{}); } @@ -668,12 +690,19 @@ struct QuantGemmKernel make_tuple(number{}, number{}), {block_m_idx * tile_window_height, 0}); } - else if constexpr(kQuantType == QuantType::AQuantGrouped && !APreshuffleQuant) + else if constexpr((kQuantType == QuantType::AQuantGrouped || + kQuantType == QuantType::ABQuantGrouped) && + !APreshuffleQuant) { + using AQuantGroupSize = remove_cvref_t; constexpr auto aqk_per_block = TilePartitioner::KPerBlock / AQuantGroupSize::kK; constexpr auto block_m = TilePartitioner::MPerBlock; - + if constexpr(kQuantType == QuantType::ABQuantGrouped) + { + static_assert(std::is_same_v, + "ABQuantGrouped requires RowMajor AQ layout"); + } if constexpr(std::is_same_v) { return make_tile_window(aq_tensor_view, @@ -687,16 +716,6 @@ struct QuantGemmKernel {0, i_m}); } } - else if constexpr(kQuantType == QuantType::ABQuantGrouped && !APreshuffleQuant) - { - using QuantGroupSize = remove_cvref_t; - constexpr auto block_m = TilePartitioner::MPerBlock; - constexpr auto block_k = TilePartitioner::KPerBlock; - return make_tile_window( - aq_tensor_view, - make_tuple(number{}, number{}), - {i_m, 0}); - } else if constexpr(kQuantType == QuantType::RowColQuant) { return make_tile_window(aq_tensor_view, @@ -1121,17 +1140,21 @@ struct QuantGemmKernel CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs) { - // Split-K is supported for BQuantGrouped mode without preshuffle + // Split-K is supported for BQuantGrouped and ABQuantGrouped modes without preshuffle if(kargs.k_batch != 1) { constexpr bool is_bquant_non_preshuffle = (kQuantType == QuantType::BQuantGrouped) && !BPreshuffleQuant; - if constexpr(!is_bquant_non_preshuffle) + constexpr bool is_abquant_non_preshuffle = + (kQuantType == QuantType::ABQuantGrouped) && !APreshuffleQuant && !BPreshuffleQuant; + + if constexpr(!is_bquant_non_preshuffle && !is_abquant_non_preshuffle) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { CK_TILE_ERROR("Conditions not met for Kbatch >1 ! " - "Split-K only supported for BQuantGrouped without preshuffle."); + "Split-K only supported for BQuantGrouped or ABQuantGrouped " + "without preshuffle."); } return false; } @@ -1175,6 +1198,27 @@ struct QuantGemmKernel } return false; } + + // Constraint 3 (ABQuant only): KRead must also align with AQ group boundaries + // For ABQuantGrouped mode, we have two quantization tensors (AQ and BQ). + // Both must have their K-dimension aligned with the split-K batch size. + if constexpr(kQuantType == QuantType::ABQuantGrouped) + { + using AQuantGroupSize = remove_cvref_t; + if(KRead % AQuantGroupSize::kK != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Split-K batch size must be aligned with AQ quantization group " + "size! KRead=" + + std::to_string(KRead) + + " is not divisible by AQuantGroupSize::kK=" + + std::to_string(AQuantGroupSize::kK)); + } + return false; + } + } } } @@ -1336,7 +1380,10 @@ struct QuantGemmKernel MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); - const auto& aq_block_window = MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n); + // Note: Pass aq_group_offset so the tensor view dimension reflects + // the remaining K-groups from the split-K offset position for ABQuant. + const auto& aq_block_window = MakeAQBlockWindow( + aq_ptr, kargs, splitk_batch_offset.aq_group_offset, block_idx_m, block_idx_n); // Note: Pass bq_group_offset so the tensor view dimension reflects // the remaining K-groups from the split-K offset position. const auto& bq_block_window = MakeBQBlockWindow( @@ -1467,7 +1514,20 @@ struct QuantGemmKernel static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; const BDataType* b_ptr = static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; - const AQDataType* aq_ptr = static_cast(kargs.aq_ptr); + + // AQ pointer: Apply split-K offset for ABQuant mode + const AQDataType* aq_ptr; + if constexpr(kQuantType == QuantType::ABQuantGrouped && !APreshuffleQuant) + { + aq_ptr = static_cast(kargs.aq_ptr) + + splitk_batch_offset.aq_k_split_offset; + } + else + { + aq_ptr = static_cast(kargs.aq_ptr); + } + + // BQ pointer: Apply split-K offset for BQuant and ABQuant modes const BQDataType* bq_ptr = static_cast(kargs.bq_ptr) + splitk_batch_offset.bq_k_split_offset; CDataType* c_ptr = static_cast(kargs.c_ptr); From 15ad3d263d1dffb621c4d7490a92a38ee4b1832e Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Mon, 9 Feb 2026 15:15:01 -0600 Subject: [PATCH 2/3] Add unit tests for ABQuant SplitK optimization - Create test_gemm_quant_abquant_splitk_decode.cpp for decode workloads - Create test_gemm_quant_abquant_splitk_prefill.cpp for prefill workloads - Add k_batch parameter to ABQuant test fixture run_test_with_validation - Fix missing c_m_n_dev_buf.SetZero() in ABQuant fixture (critical for split-K atomic operations) - Register new test targets in CMakeLists.txt - Tests cover split_k values: 2, 3, 4, 5 - All 16 tests passing (8 FP8, 8 BF8) --- .../ck_tile/gemm_block_scale/CMakeLists.txt | 15 ++++- .../test_gemm_quant_abquant_splitk_decode.cpp | 64 +++++++++++++++++++ ...test_gemm_quant_abquant_splitk_prefill.cpp | 64 +++++++++++++++++++ .../test_gemm_quant_fixtures.hpp | 9 ++- 4 files changed, 148 insertions(+), 4 deletions(-) create mode 100644 projects/composablekernel/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_splitk_decode.cpp create mode 100644 projects/composablekernel/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_splitk_prefill.cpp diff --git a/projects/composablekernel/test/ck_tile/gemm_block_scale/CMakeLists.txt b/projects/composablekernel/test/ck_tile/gemm_block_scale/CMakeLists.txt index 2b19053f4175..5088d6ed2c72 100644 --- a/projects/composablekernel/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/projects/composablekernel/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -129,16 +129,27 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") target_compile_options(test_tile_gemm_quant_bquant_transpose PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) # BQuant split-K tests (no preshuffle) - add_gtest_executable(test_tile_gemm_quant_bquant_splitk_decode + add_gtest_executable(test_tile_gemm_quant_bquant_splitk_decode test_gemm_quant_bquant_splitk_decode.cpp ) target_compile_options(test_tile_gemm_quant_bquant_splitk_decode PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) - add_gtest_executable(test_tile_gemm_quant_bquant_splitk_prefill + add_gtest_executable(test_tile_gemm_quant_bquant_splitk_prefill test_gemm_quant_bquant_splitk_prefill.cpp ) target_compile_options(test_tile_gemm_quant_bquant_splitk_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # ABQuant split-K tests (no preshuffle, 2D BQ groups only) + add_gtest_executable(test_tile_gemm_quant_abquant_splitk_decode + test_gemm_quant_abquant_splitk_decode.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_splitk_decode PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_abquant_splitk_prefill + test_gemm_quant_abquant_splitk_prefill.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_splitk_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # BQuant tests (with PreshuffleB) - split into 5 files add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_decode_1d test_gemm_quant_bquant_preshuffle_decode_1d.cpp diff --git a/projects/composablekernel/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_splitk_decode.cpp b/projects/composablekernel/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_splitk_decode.cpp new file mode 100644 index 000000000000..7798d52b4906 --- /dev/null +++ b/projects/composablekernel/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_splitk_decode.cpp @@ -0,0 +1,64 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using ABQuantGrouped = + std::integral_constant; + +// Group sizes +using GroupSize1D = ck_tile::QuantGroupShape>; +using GroupSize2D128N = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant split-K tests - Decode shape +// Only use 2D BQ groups (1×128×128) to avoid AICK-644 bug +// Tuple format: +// clang-format off +using ABQuantSplitKDecodeTypes = ::testing::Types< + // FP8 with 2D BQ groups + std::tuple, + // BF8 with 2D BQ groups + std::tuple +>; +// clang-format on + +// Test suite for ABQuant split-K Decode +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantSplitKDecodeTypes); + +// ABQuant split-K tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedSplitK2Test) +{ + // K=1024 for split_k=2: 1024/2=512=4x128 + this->run_test_with_validation(32, 128, 1024, 2); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedSplitK3Test) +{ + // K=3072 for split_k=3: 3072/3=1024=8x128 + this->run_test_with_validation(32, 128, 3072, 3); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedSplitK4Test) +{ + // K=2048 for split_k=4: 2048/4=512=4x128 + this->run_test_with_validation(32, 128, 2048, 4); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedSplitK5Test) +{ + // K=2560 for split_k=5: 2560/5=512=4x128 + this->run_test_with_validation(32, 128, 2560, 5); +} diff --git a/projects/composablekernel/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_splitk_prefill.cpp b/projects/composablekernel/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_splitk_prefill.cpp new file mode 100644 index 000000000000..48bf2c32d4c1 --- /dev/null +++ b/projects/composablekernel/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_splitk_prefill.cpp @@ -0,0 +1,64 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using ABQuantGrouped = + std::integral_constant; + +// Group sizes +using GroupSize1D = ck_tile::QuantGroupShape>; +using GroupSize2D128N = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant split-K tests - Prefill shape +// Only use 2D BQ groups (1×128×128) to avoid AICK-644 bug +// Tuple format: +// clang-format off +using ABQuantSplitKPrefillTypes = ::testing::Types< + // FP8 with 2D BQ groups + std::tuple, + // BF8 with 2D BQ groups + std::tuple +>; +// clang-format on + +// Test suite for ABQuant split-K Prefill +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantSplitKPrefillTypes); + +// ABQuant split-K tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedSplitK2Test) +{ + // K=1024 for split_k=2: 1024/2=512=4x128 + this->run_test_with_validation(128, 128, 1024, 2); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedSplitK3Test) +{ + // K=3072 for split_k=3: 3072/3=1024=8x128 + this->run_test_with_validation(128, 128, 3072, 3); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedSplitK4Test) +{ + // K=2048 for split_k=4: 2048/4=512=4x128 + this->run_test_with_validation(128, 128, 2048, 4); +} + +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedSplitK5Test) +{ + // K=2560 for split_k=5: 2560/5=512=4x128 + this->run_test_with_validation(128, 128, 2560, 5); +} diff --git a/projects/composablekernel/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/projects/composablekernel/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index ca21bc69b7c5..aecfb8b8281b 100644 --- a/projects/composablekernel/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/projects/composablekernel/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -948,7 +948,10 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBaseis_row_major(ALayout{})); @@ -1059,6 +1062,8 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase Date: Mon, 9 Feb 2026 16:57:44 -0600 Subject: [PATCH 3/3] Fix ABQuant grouped GEMM compatibility and split-K tolerance - Add missing aq_group_offset param to MakeAQBlockWindow call - Use actual k_batch for tolerance calculation in split-K tests --- .../ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp | 4 ++-- .../ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index 8b77b01e2fe2..5450aea41f56 100644 --- a/projects/composablekernel/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -451,8 +451,8 @@ struct QuantGroupedGemmKernel Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); const auto& b_block_window = Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); - const auto& aq_block_window = - Base::MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n); + const auto& aq_block_window = Base::MakeAQBlockWindow( + aq_ptr, kargs, splitk_batch_offset.aq_group_offset, block_idx_m, block_idx_n); const auto& bq_block_window = Base::MakeBQBlockWindow( bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n); diff --git a/projects/composablekernel/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/projects/composablekernel/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index aecfb8b8281b..f52eacadc3ba 100644 --- a/projects/composablekernel/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/projects/composablekernel/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -1114,7 +1114,7 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBasetemplate calculate_rtol_atol( - K, 1, max_accumulated_value); + K, k_batch, max_accumulated_value); // Validate results bool pass = ck_tile::check_err(c_m_n_dev_result, @@ -1124,7 +1124,7 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase{})); EXPECT_TRUE(pass) << "ABQuantGrouped validation failed with M=" << M << ", N=" << N - << ", K=" << K; + << ", K=" << K << ", k_batch=" << k_batch; if(!pass) {