Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -454,12 +454,50 @@ struct QuantGemmKernel
bq_group_offset = 0;
bq_k_split_offset = 0;
}

// Compute AQ and BQ offsets for ABQuantGrouped mode
if constexpr(kQuantType == QuantType::ABQuantGrouped && !APreshuffleQuant)
{
using AQuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
using BQuantGroupSize = remove_cvref_t<typename GemmPipeline::BQuantGroupSize>;

const index_t k_offset = amd_wave_read_first_lane(k_id * KRead);
aq_group_offset = amd_wave_read_first_lane(k_offset / AQuantGroupSize::kK);
bq_group_offset = amd_wave_read_first_lane(k_offset / BQuantGroupSize::kK);

if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, AQLayout>)
{
aq_k_split_offset = amd_wave_read_first_lane(aq_group_offset);
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, AQLayout>)
{
aq_k_split_offset = amd_wave_read_first_lane(aq_group_offset * kargs.stride_AQ);
}

if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BQLayout>)
{
const index_t stride_bq =
amd_wave_read_first_lane(integer_divide_ceil(kargs.N, BQuantGroupSize::kN));
bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset * stride_bq);
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BQLayout>)
{
bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset);
}
}
else
{
aq_group_offset = 0;
aq_k_split_offset = 0;
}
}

index_t a_k_split_offset;
index_t b_k_split_offset;
index_t bq_group_offset; // Logical offset in K-groups (K/kK dimension)
index_t bq_k_split_offset; // Memory pointer offset (accounting for layout/stride)
index_t aq_group_offset; // Logical offset in AQ K-groups (K/kK dimension) for ABQuant
index_t aq_k_split_offset; // Memory pointer offset for AQ (accounting for layout/stride)
index_t bq_group_offset; // Logical offset in BQ K-groups (K/kK dimension)
index_t bq_k_split_offset; // Memory pointer offset for BQ (accounting for layout/stride)
index_t splitted_k;
};

Expand Down Expand Up @@ -531,6 +569,7 @@ struct QuantGemmKernel

CK_TILE_DEVICE static auto MakeAQBlockWindow(const AQDataType* aq_ptr,
const QuantGemmKernelArgs& kargs,
const index_t aq_group_offset,
const index_t i_m,
const index_t i_n)
{
Expand Down Expand Up @@ -591,35 +630,19 @@ struct QuantGemmKernel

return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
}
else if constexpr(kQuantType == QuantType::AQuantGrouped && !APreshuffleQuant)
else if constexpr((kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::ABQuantGrouped) &&
!APreshuffleQuant)
{
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.QK_A),
make_tuple(kargs.stride_AQ, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
}
else // Column major AQ
{
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.QK_A, kargs.M),
make_tuple(kargs.stride_AQ, 1),
// For split-K with ABQuant, adjust AQ tensor dimension based on group offset
// This reflects "remaining K-groups from split-K offset position"
const index_t aq_k_groups = kargs.QK_A - aq_group_offset;

number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
}
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped && !APreshuffleQuant)
{
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.QK_A),
make_tuple(kargs.M, aq_k_groups),
make_tuple(kargs.stride_AQ, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
Expand All @@ -628,9 +651,8 @@ struct QuantGemmKernel
{
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.QK_A),
make_tuple(1, kargs.stride_AQ),

make_tuple(aq_k_groups, kargs.M),
make_tuple(kargs.stride_AQ, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
}
Expand Down Expand Up @@ -668,12 +690,19 @@ struct QuantGemmKernel
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
{block_m_idx * tile_window_height, 0});
}
else if constexpr(kQuantType == QuantType::AQuantGrouped && !APreshuffleQuant)
else if constexpr((kQuantType == QuantType::AQuantGrouped ||
kQuantType == QuantType::ABQuantGrouped) &&
!APreshuffleQuant)
{

using AQuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
constexpr auto aqk_per_block = TilePartitioner::KPerBlock / AQuantGroupSize::kK;
constexpr auto block_m = TilePartitioner::MPerBlock;

if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>,
"ABQuantGrouped requires RowMajor AQ layout");
}
if constexpr(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(aq_tensor_view,
Expand All @@ -687,16 +716,6 @@ struct QuantGemmKernel
{0, i_m});
}
}
else if constexpr(kQuantType == QuantType::ABQuantGrouped && !APreshuffleQuant)
{
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
constexpr auto block_m = TilePartitioner::MPerBlock;
constexpr auto block_k = TilePartitioner::KPerBlock;
return make_tile_window(
aq_tensor_view,
make_tuple(number<block_m>{}, number<block_k / QuantGroupSize::kK>{}),
{i_m, 0});
}
else if constexpr(kQuantType == QuantType::RowColQuant)
{
return make_tile_window(aq_tensor_view,
Expand Down Expand Up @@ -1121,17 +1140,21 @@ struct QuantGemmKernel

CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs)
{
// Split-K is supported for BQuantGrouped mode without preshuffle
// Split-K is supported for BQuantGrouped and ABQuantGrouped modes without preshuffle
if(kargs.k_batch != 1)
{
constexpr bool is_bquant_non_preshuffle =
(kQuantType == QuantType::BQuantGrouped) && !BPreshuffleQuant;
if constexpr(!is_bquant_non_preshuffle)
constexpr bool is_abquant_non_preshuffle =
(kQuantType == QuantType::ABQuantGrouped) && !APreshuffleQuant && !BPreshuffleQuant;

if constexpr(!is_bquant_non_preshuffle && !is_abquant_non_preshuffle)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Conditions not met for Kbatch >1 ! "
"Split-K only supported for BQuantGrouped without preshuffle.");
"Split-K only supported for BQuantGrouped or ABQuantGrouped "
"without preshuffle.");
}
return false;
}
Expand Down Expand Up @@ -1175,6 +1198,27 @@ struct QuantGemmKernel
}
return false;
}

// Constraint 3 (ABQuant only): KRead must also align with AQ group boundaries
// For ABQuantGrouped mode, we have two quantization tensors (AQ and BQ).
// Both must have their K-dimension aligned with the split-K batch size.
if constexpr(kQuantType == QuantType::ABQuantGrouped)
{
using AQuantGroupSize = remove_cvref_t<typename GemmPipeline::AQuantGroupSize>;
if(KRead % AQuantGroupSize::kK != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Split-K batch size must be aligned with AQ quantization group "
"size! KRead=" +
std::to_string(KRead) +
" is not divisible by AQuantGroupSize::kK=" +
std::to_string(AQuantGroupSize::kK));
}
return false;
}
}
}
}

Expand Down Expand Up @@ -1336,7 +1380,10 @@ struct QuantGemmKernel
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& b_block_window =
MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
const auto& aq_block_window = MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n);
// Note: Pass aq_group_offset so the tensor view dimension reflects
// the remaining K-groups from the split-K offset position for ABQuant.
const auto& aq_block_window = MakeAQBlockWindow(
aq_ptr, kargs, splitk_batch_offset.aq_group_offset, block_idx_m, block_idx_n);
// Note: Pass bq_group_offset so the tensor view dimension reflects
// the remaining K-groups from the split-K offset position.
const auto& bq_block_window = MakeBQBlockWindow(
Expand Down Expand Up @@ -1466,7 +1513,20 @@ struct QuantGemmKernel
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
const BDataType* b_ptr =
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);

// AQ pointer: Apply split-K offset for ABQuant mode
const AQDataType* aq_ptr;
if constexpr(kQuantType == QuantType::ABQuantGrouped && !APreshuffleQuant)
{
aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr) +
splitk_batch_offset.aq_k_split_offset;
}
else
{
aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
}

// BQ pointer: Apply split-K offset for BQuant and ABQuant modes
const BQDataType* bq_ptr =
static_cast<const BQDataType*>(kargs.bq_ptr) + splitk_batch_offset.bq_k_split_offset;
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,8 @@ struct QuantGroupedGemmKernel
Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& b_block_window =
Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
const auto& aq_block_window =
Base::MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n);
const auto& aq_block_window = Base::MakeAQBlockWindow(
aq_ptr, kargs, splitk_batch_offset.aq_group_offset, block_idx_m, block_idx_n);
const auto& bq_block_window = Base::MakeBQBlockWindow(
bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,27 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
target_compile_options(test_tile_gemm_quant_bquant_transpose PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})

# BQuant split-K tests (no preshuffle)
add_gtest_executable(test_tile_gemm_quant_bquant_splitk_decode
add_gtest_executable(test_tile_gemm_quant_bquant_splitk_decode
test_gemm_quant_bquant_splitk_decode.cpp
)
target_compile_options(test_tile_gemm_quant_bquant_splitk_decode PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})

add_gtest_executable(test_tile_gemm_quant_bquant_splitk_prefill
add_gtest_executable(test_tile_gemm_quant_bquant_splitk_prefill
test_gemm_quant_bquant_splitk_prefill.cpp
)
target_compile_options(test_tile_gemm_quant_bquant_splitk_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})

# ABQuant split-K tests (no preshuffle, 2D BQ groups only)
add_gtest_executable(test_tile_gemm_quant_abquant_splitk_decode
test_gemm_quant_abquant_splitk_decode.cpp
)
target_compile_options(test_tile_gemm_quant_abquant_splitk_decode PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})

add_gtest_executable(test_tile_gemm_quant_abquant_splitk_prefill
test_gemm_quant_abquant_splitk_prefill.cpp
)
target_compile_options(test_tile_gemm_quant_abquant_splitk_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})

# BQuant tests (with PreshuffleB) - split into 5 files
add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_decode_1d
test_gemm_quant_bquant_preshuffle_decode_1d.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"

#include <gtest/gtest.h>
#include <memory>

#include "test_gemm_quant_fixtures.hpp"

// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using ABQuantGrouped =
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;

// Group sizes
using GroupSize1D = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;

// Type combinations for ABQuant split-K tests - Decode shape
// Only use 2D BQ groups (1×128×128) to avoid AICK-644 bug
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
// clang-format off
using ABQuantSplitKDecodeTypes = ::testing::Types<
// FP8 with 2D BQ groups
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigDecodeIntrawave, GroupSize1D, GroupSize2D128N, ColumnMajor>,
// BF8 with 2D BQ groups
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigDecodeIntrawave, GroupSize1D, GroupSize2D128N, ColumnMajor>
>;
// clang-format on

// Test suite for ABQuant split-K Decode
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantSplitKDecodeTypes);

// ABQuant split-K tests
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedSplitK2Test)
{
// K=1024 for split_k=2: 1024/2=512=4x128
this->run_test_with_validation(32, 128, 1024, 2);
}

TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedSplitK3Test)
{
// K=3072 for split_k=3: 3072/3=1024=8x128
this->run_test_with_validation(32, 128, 3072, 3);
}

TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedSplitK4Test)
{
// K=2048 for split_k=4: 2048/4=512=4x128
this->run_test_with_validation(32, 128, 2048, 4);
}

TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedSplitK5Test)
{
// K=2560 for split_k=5: 2560/5=512=4x128
this->run_test_with_validation(32, 128, 2560, 5);
}
Loading
Loading