Skip to content

Commit 054565c

Browse files
henrylhtsangfacebook-github-bot
authored andcommitted
Restrict to fp8, and patch 4.2.0 release changes
Differential Revision: D82792657
1 parent 53f9e51 commit 054565c

File tree

3 files changed

+6
-11
lines changed

3 files changed

+6
-11
lines changed

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -262,15 +262,10 @@ struct GenRunner {
262262
};
263263

264264
// Dispatch macros for different element types
265+
// TODO(henrylhtsang / ayaoibrahim1123): Add support for other data types.
265266
#define DISPATCH_ELEMENT_TYPE(DTYPE, ELEMENT_TYPE, ...) \
266267
[&] { \
267-
if (DTYPE == at::kHalf) { \
268-
using ELEMENT_TYPE = cutlass::half_t; \
269-
return __VA_ARGS__(); \
270-
} else if (DTYPE == at::kBFloat16) { \
271-
using ELEMENT_TYPE = cutlass::bfloat16_t; \
272-
return __VA_ARGS__(); \
273-
} else if (DTYPE == at::kFloat8_e4m3fn) { \
268+
if (DTYPE == at::kFloat8_e4m3fn) { \
274269
using ELEMENT_TYPE = cutlass::float_e4m3_t; \
275270
return __VA_ARGS__(); \
276271
} else { \

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -535,14 +535,14 @@ struct Sm100FmhaGenMainloopWarpspecialized {
535535
tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1);
536536
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
537537

538-
auto tilePlikeFP32 = get<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
538+
auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
539539
Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
540540
tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));
541541
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
542542

543543
// Each thread owns a single row
544-
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
545-
using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem
544+
using TMEM_LOAD = conditional_t<size<1>(TileShapeQK{}) < _128{}, SM100_TMEM_LOAD_32dp32b8x, SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
545+
using TMEM_STORE = conditional_t<size<1>(TileShapeQK{}) < _128{}, SM100_TMEM_STORE_32dp32b8x, SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
546546
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
547547

548548
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ struct Sm100FmhaGenKernelWarpspecialized {
366366
pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer;
367367
}
368368
pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp;
369-
pipeline_corr_epi_params.consumer_arv_count = NumWarpsEpilogue * cutlass::NumThreadsPerWarp;
369+
pipeline_corr_epi_params.consumer_arv_count = cute::max(1, NumWarpsEpilogue * cutlass::NumThreadsPerWarp);
370370
typename CollectiveMainloop::PipelineE pipeline_corr_epi(
371371
shared_storage.pipelines.corr_epi,
372372
pipeline_corr_epi_params,

0 commit comments

Comments
 (0)