Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -262,15 +262,10 @@ struct GenRunner {
};

// Dispatch macros for different element types
// TODO(henrylhtsang / ayaoibrahim1123): Add support for other data types.
#define DISPATCH_ELEMENT_TYPE(DTYPE, ELEMENT_TYPE, ...) \
[&] { \
if (DTYPE == at::kHalf) { \
using ELEMENT_TYPE = cutlass::half_t; \
return __VA_ARGS__(); \
} else if (DTYPE == at::kBFloat16) { \
using ELEMENT_TYPE = cutlass::bfloat16_t; \
return __VA_ARGS__(); \
} else if (DTYPE == at::kFloat8_e4m3fn) { \
if (DTYPE == at::kFloat8_e4m3fn) { \
using ELEMENT_TYPE = cutlass::float_e4m3_t; \
return __VA_ARGS__(); \
} else { \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -535,14 +535,14 @@ struct Sm100FmhaGenMainloopWarpspecialized {
tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1);
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));

auto tilePlikeFP32 = get<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));

// Each thread owns a single row
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem
using TMEM_LOAD = conditional_t<size<1>(TileShapeQK{}) < _128{}, SM100_TMEM_LOAD_32dp32b8x, SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
using TMEM_STORE = conditional_t<size<1>(TileShapeQK{}) < _128{}, SM100_TMEM_STORE_32dp32b8x, SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem

int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ struct Sm100FmhaGenKernelWarpspecialized {
pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer;
}
pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
pipeline_corr_epi_params.consumer_arv_count = NumWarpsEpilogue * cutlass::NumThreadsPerWarp;
pipeline_corr_epi_params.consumer_arv_count = cute::max(1, NumWarpsEpilogue * cutlass::NumThreadsPerWarp);
typename CollectiveMainloop::PipelineE pipeline_corr_epi(
shared_storage.pipelines.corr_epi,
pipeline_corr_epi_params,
Expand Down
Loading