Skip to content

update DeepGEMM #10429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
292 changes: 202 additions & 90 deletions ops/csrc/fp8/deep_gemm/include/deep_gemm/fp8_gemm.cuh

Large diffs are not rendered by default.

801 changes: 50 additions & 751 deletions ops/csrc/fp8/deep_gemm/include/deep_gemm/mma_utils.cuh

Large diffs are not rendered by default.

33 changes: 23 additions & 10 deletions ops/csrc/fp8/deep_gemm/include/deep_gemm/scheduler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ enum class GemmType {
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
template <GemmType kGemmType,
uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N,
uint32_t kNumGroups, uint32_t kNumTMAMulticast,
uint32_t kNumGroups,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N),
uint32_t kNumNBlocksPerGroup = 16>
uint32_t kNum1DBlocksPerGroup = 16>
struct Scheduler {
int current_iter = -1;
uint32_t num_aligned_m_blocks;
Expand Down Expand Up @@ -61,16 +62,27 @@ struct Scheduler {
}

__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");

// Swizzle for better L2 usages
auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup;
auto group_idx = block_idx / num_blocks_per_group;
auto first_n_block_idx = group_idx * kNumNBlocksPerGroup;
auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx);
auto in_group_idx = block_idx % num_blocks_per_group;
m_block_idx = in_group_idx / num_n_blocks_in_group;
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
// TODO: unify these 2 branches
if constexpr (kIsTMAMulticastOnA) {
auto num_blocks_per_group = num_m_blocks * kNum1DBlocksPerGroup;
auto group_idx = block_idx / num_blocks_per_group;
auto first_n_block_idx = group_idx * kNum1DBlocksPerGroup;
auto num_n_blocks_in_group = min(kNum1DBlocksPerGroup, kNumNBlocks - first_n_block_idx);
auto in_group_idx = block_idx % num_blocks_per_group;
m_block_idx = in_group_idx / num_n_blocks_in_group;
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
} else {
auto num_blocks_per_group = kNumNBlocks * kNum1DBlocksPerGroup;
auto group_idx = block_idx / num_blocks_per_group;
auto first_m_block_idx = group_idx * kNum1DBlocksPerGroup;
auto num_m_blocks_in_group = min(kNum1DBlocksPerGroup, num_m_blocks - first_m_block_idx);
auto in_group_idx = block_idx % num_blocks_per_group;
m_block_idx = first_m_block_idx + in_group_idx % num_m_blocks_in_group;
n_block_idx = in_group_idx / num_m_blocks_in_group;
}
}

template <bool kIgnoreGroupedForGroupedContiguous=true>
Expand Down Expand Up @@ -116,6 +128,7 @@ struct Scheduler {
return true;
}
};

#pragma clang diagnostic pop

} // namespace deep_gemm
11 changes: 5 additions & 6 deletions ops/csrc/fp8/deep_gemm/include/deep_gemm/tma_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ constexpr CUtensorMapDataType get_CUtensorMapDataType() {
}
}

PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
inline PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
// Get pointer to `cuTensorMapEncodeTiled`
cudaDriverEntryPointQueryResult driver_status;
void* cuTensorMapEncodeTiled_ptr = nullptr;
Expand All @@ -81,16 +81,15 @@ CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
uint64_t stride_in_bytes, uint32_t smem_dim[2],
CUtensorMapSwizzle swizzle_type,
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
CUtensorMap tensor_map{};
constexpr uint32_t rank = 2;
uint64_t global_stride[rank - 1] = {stride_in_bytes};
uint32_t elem_strides[rank] = {1, 1};
CUtensorMap tensor_map = {};
uint64_t global_stride[1] = {stride_in_bytes};
uint32_t elem_strides[2] = {1, 1};

if (encode_func == nullptr)
encode_func = get_cuTensorMapEncodeTiled();

auto result = encode_func(
&tensor_map, get_CUtensorMapDataType<typename std::remove_cv<T>::type>(), rank,
&tensor_map, get_CUtensorMapDataType<std::remove_cv_t<T>>(), 2,
global_address, gmem_dim, global_stride, smem_dim, elem_strides,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
Expand Down
5 changes: 5 additions & 0 deletions ops/csrc/fp8/deep_gemm/include/deep_gemm/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,9 @@ do {
template <typename T>
__device__ __host__ constexpr T ceil_div(T a, T b) {
return (a + b - 1) / b;
}

template <typename T>
__device__ __host__ constexpr T constexpr_gcd(T a, T b) {
return b == 0 ? a : constexpr_gcd(b, a % b);
}
4 changes: 2 additions & 2 deletions ops/csrc/fp8/deep_gemm/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@
# Copyright (c) 2025 DeepSeek
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE

from .compiler import build, get_nvcc_compiler
from .runtime import Runtime
from .compiler import get_nvcc_compiler, build
from .template import cpp_format, generate
from .runtime import Runtime
31 changes: 14 additions & 17 deletions ops/csrc/fp8/deep_gemm/jit/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
# Copyright (c) 2025 DeepSeek
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE

import functools
import hashlib
import functools
import os
import re
import subprocess
Expand Down Expand Up @@ -75,9 +75,7 @@ def get_nvcc_compiler() -> Tuple[str, str]:
match = version_pattern.search(os.popen(f"{path} --version").read())
version = match.group(1)
assert match, f"Cannot get the version of NVCC compiler {path}"
assert (
version >= least_version_required
), f"NVCC {path} version {version} is lower than {least_version_required}"
assert version >= least_version_required, f"NVCC {path} version {version} is lower than {least_version_required}"
return path, version
raise RuntimeError("Cannot find any available NVCC compiler")

Expand Down Expand Up @@ -117,18 +115,13 @@ def put(path, data, is_binary=False):

def build(name: str, arg_defs: tuple, code: str) -> Runtime:
# Compiler flags
nvcc_flags = [
"-std=c++17",
"-shared",
"-O3",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"-gencode=arch=compute_90a,code=sm_90a",
"--ptxas-options=--register-usage-level=10" + (",--verbose" if "DG_PTXAS_VERBOSE" in os.environ else ""),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
"--diag-suppress=177,174,940",
]
cxx_flags = ["-fPIC", "-O3", "-Wno-deprecated-declarations", "-Wno-abi"]
cpp_standard = int(os.getenv("DG_NVCC_OVERRIDE_CPP_STANDARD", 20))
nvcc_flags = [f"-std=c++{cpp_standard}", "-shared", "-O3", "--expt-relaxed-constexpr", "--expt-extended-lambda",
"-gencode=arch=compute_90a,code=sm_90a",
"--ptxas-options=--register-usage-level=10" + (",--verbose" if "DG_PTXAS_VERBOSE" in os.environ else ""),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
"--diag-suppress=39,174,177,940"]
cxx_flags = ["-fPIC", "-O3", "-Wno-deprecated-declarations", "-Wno-abi", "-fconcepts"]
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
include_dirs = [get_jit_include_dir()]

Expand All @@ -155,8 +148,12 @@ def build(name: str, arg_defs: tuple, code: str) -> Runtime:
# Compile into a temporary SO file
so_path = f"{path}/kernel.so"
tmp_so_path = f"{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so"

# Compile
command = [get_nvcc_compiler()[0], src_path, "-o", tmp_so_path, *flags, *[f"-I{d}" for d in include_dirs]]
command = [get_nvcc_compiler()[0],
src_path, "-o", tmp_so_path,
*flags,
*[f"-I{d}" for d in include_dirs]]
if os.getenv("DG_JIT_DEBUG", None) or os.getenv("DG_JIT_PRINT_NVCC_COMMAND", False):
print(f"Compiling JIT runtime {name} with command {command}")
return_code = subprocess.check_call(command)
Expand Down
2 changes: 1 addition & 1 deletion ops/csrc/fp8/deep_gemm/jit/interleave_ffma.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def parse_registers(line):


def modify_segment(m, name, ffma_lines):
num_lines = len(ffma_lines)
num_lines = (len(ffma_lines) * 9 // 16) // 2 * 2
assert num_lines % 2 == 0

le_bytes, new_le_bytes = [], []
Expand Down
Loading
Loading