DeepGEMMPerTensor is a library designed for clean and efficient FP8 General Matrix Multiplications (GEMMs) per tensor without scales. It supports both normal and Mix-of-Experts (MoE) grouped GEMMs. Written in CUDA, the library has no compilation need during installation, by compiling all kernels at runtime using a lightweight Just-In-Time (JIT) module.
Currently, DeepGEMMPerTensor exclusively supports NVIDIA Hopper tensor cores. To address the imprecise FP8 tensor core accumulation, it employs CUDA-core two-level accumulation (promotion). While it leverages some concepts from CUTLASS and CuTe, it avoids heavy reliance on their templates or algebras. Instead, the library is designed for simplicity, with only one core kernel function. This makes it a clean and accessible resource for learning Hopper FP8 matrix multiplication and optimization techniques.
Despite its lightweight design, DeepGEMMPerTensor's performance matches or exceeds expert-tuned libraries across various matrix shapes.
- Hopper architecture GPUs,
sm_90amust be supported - Python 3.8 or above
- CUDA 12.3 or above
- But we highly recommend 12.8 or above for the best performance
- PyTorch 2.1 or above
- CUTLASS 3.6 or above (could be cloned by Git submodule)
# Submodule must be cloned
git clone https://github.com/Bruce-Lee-LY/DeepGEMMPerTensor.git
# Make symbolic links for third-party (CUTLASS and CuTe) include directories
python setup.py develop
# Test JIT compilation
python tests/test_jit.py
# Test all GEMM implements (normal, contiguous-grouped and masked-grouped)
python tests/test_core.pypython setup.py installThen, import deep_gemm_per_tensor in your Python project, and enjoy!
This library exclusively contains GEMM kernels. For transposition or other FP8 casting operations, please implement or fuse them into prior kernels independently. While the library provides some simple PyTorch utility functions, these may result in slower performance, but our primary focus is on optimizing the GEMM kernels themselves.
To perform a basic non-grouped FP8 GEMM, call the deep_gemm_per_tensor.gemm_per_tensor_fp8_fp8_bf16_nt function. For more details, please refer to the function documentation.
Unlike traditional grouped GEMMs in CUTLASS, DeepGEMMPerTensor groups only the M-axis, while N and K must remain fixed. This design is tailored for scenarios where experts in an MoE model share the same shape.
For training forward passes or inference prefilling, where each expert may process a varying number of tokens, we concatenate these tokens into a single tensor, referred to as the "contiguous" layout. Note that each expert segment must be aligned to the GEMM M block size (get_m_alignment_for_contiguous_layout()).
For more information, please refer to the m_grouped_gemm_per_tensor_fp8_fp8_bf16_nt_contiguous function documentation.
During the inference decoding phase, when CUDA graph is enabled and the CPU is unaware of the number of tokens each expert receives, we support masked grouped GEMMs. By providing a mask tensor, the kernel computes only the valid portions.
Use m_grouped_gemm_per_tensor_fp8_fp8_bf16_nt_masked for this purpose and consult the relevant documentation.
The library provides some utility functions besides the above kernels:
deep_gemm_per_tensor.set_num_sms: set the maximum SM count to usedeep_gemm_per_tensor.get_num_sms: get the current SM maximum countdeep_gemm_per_tensor.get_m_alignment_for_contiguous_layout: get the group-level alignment requirement for grouped contiguous layout
The library also provides some environment variables, which may be useful:
- General
DG_JIT_DEBUG:0or1, print more JIT debugging information,0by default
- JIT cache related
DG_JIT_CACHE_DIR: string, the cache directory to store compiled kernels,$HOME/.deep_gemm_per_tensorby defaultDG_JIT_DISABLE_CACHE:0or1, disable the use of cache directory,0by default
- NVCC/NVRTC selections
DG_JIT_USE_NVRTC:0or1, use NVRTC instead of NVCC, faster compilation but maybe have lower performance for some cases,0by defaultDG_JIT_NVCC_COMPILER: string, specified NVCC compiler path; will find intorch.utils.cpp_extension.CUDA_HOMEby default
- Compiler options
DG_JIT_OVERRIDE_CPP_STANDARD: integer (e.g.,20), support for some old version GCC compiler,20by defaultDG_JIT_PTXAS_VERBOSE:0or1, show detailed PTXAS compiler output,0by defaultDG_JIT_PRINT_REG_REUSE:0or1, print FFMA-interleaving details,0by defaultDG_JIT_PRINT_COMPILER_COMMAND:0or1, print NVCC compilation command,0by default
- Post optimization
DG_JIT_DISABLE_FFMA_INTERLEAVE:0or1, disable FFMA-interleaving optimization,0by default
- Heuristic selection
DG_PRINT_CONFIGS:0or1, print selected configs for each shape,0by default
- Testing
DG_NSYS_PROFILING:0or1, Nsight-system compatible testing,0by default
For additional examples and details, please refer to the test code or review the corresponding Python documentation.
We indicate the techniques excluded from CUTLASS with π³.
Following the CUTLASS design, the kernels in DeepGEMMPerTensor are warp-specialized, enabling overlapping data movement, tensor-core MMA instructions, and CUDA-core promotion. A simplified figure illustrating this process is shown below:
The Tensor Memory Accelerator (TMA) is a new hardware feature introduced by the Hopper architecture, designed for faster and asynchronous data movement. Specifically, we utilize TMA for:
- TMA load for LHS, and RHS matrices
- TMA store for the output matrix
- TMA multicast (automatically decide LHS or RHS to broadcast)
- TMA descriptor prefetching
- Utilization of the
stmatrixPTX instruction - Register count control tailored for different warpgroups
- Less bank conflicts via 3D TMA or swizzling
- Larger block sizes (up to 256x128 π³)
- One scheduler for all non-grouped and grouped kernels
- Rasterization to enhance L2 cache reuse
DeepGEMMPerTensor employs a fully Just-In-Time (JIT) design, with no compilation required at installation. All kernels are compiled at runtime using a lightweight JIT implementation. This approach offers several advantages:
- GEMM shapes, block sizes, and the number of pipeline stages are treated as compile-time constants
- Saving registers
- Compilers may do more optimizations
- Automatic selection of block sizes, number of warpgroups, optimal pipeline stages, and TMA cluster size
- But without auto-tuning, the optimal one is deterministically selected
- Full unrolling of the MMA pipelines, providing compilers with more optimization opportunities
- Very important for small shapes
- Refer to
launch_k_iterationsin the kernel file for details
Overall, JIT significantly improves performance for small shapes, similar to the approach of the Triton compiler.
For certain shapes, block sizes aligned to powers of 2 can lead to underutilized SMs. For instance, with M=256, N=7168, a typical block size assignment of BLOCK_M=128, BLOCK_N=128 results in only (256 / 128) * (7168 / 128) = 112 out of 132 SMs being utilized. To address this, we support unaligned block sizes like 112, enabling (256 / 128) * (7168 / 112) = 128 SMs to work in such scenarios.
We observe a performance improvement in the CUTLASS FP8 kernel between NVCC 12.2 and 12.3. By comparing the compiled SASS, we discover that one bit in a series of FADD instructions is flipped in an interleaving pattern.
After referencing some open-source CUDA assembler implementations, we identified that this bit controls yield, which may enhance warp-level parallelism (just a guess, yielding the current warp and let other warps work).
To leverage this, we develop a similar script to modify the FFMA instructions in the compiled binary. Besides simply modifying the yield bit, we also flip the reuse bit (registers cannot be reused if the warp is yielded). This adjustment improves performance (10%+ in some cases) for fine-grained FP8 GEMMs by creating more opportunities to overlap MMA instructions with promotion FFMA instructions.
DeepGEMMPerTensor is inspired by the DeepGEMM and the CUTLASS project. Thanks and respect to the developers!
This code repository is released under the MIT License.
@misc{deepgemm2025,
title={DeepGEMMPerTensor: clean and efficient FP8 GEMM per tensor kernels without scales},
author={Bruce-Lee-LY},
year={2025},
publisher = {GitHub},
howpublished = {\url{https://github.com/Bruce-Lee-LY/DeepGEMMPerTensor}},
}- Continuous Optimization
