Skip to content

Commit 5406b5e

Browse files
authored
update DeepGEMM (#10429)
* merge * Update m_grouped_gemm.py
1 parent 670cbd9 commit 5406b5e

File tree

12 files changed

+449
-964
lines changed

12 files changed

+449
-964
lines changed

ops/csrc/fp8/deep_gemm/include/deep_gemm/fp8_gemm.cuh

Lines changed: 202 additions & 90 deletions
Large diffs are not rendered by default.

ops/csrc/fp8/deep_gemm/include/deep_gemm/mma_utils.cuh

Lines changed: 50 additions & 751 deletions
Large diffs are not rendered by default.

ops/csrc/fp8/deep_gemm/include/deep_gemm/scheduler.cuh

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@ enum class GemmType {
3030
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
3131
template <GemmType kGemmType,
3232
uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N,
33-
uint32_t kNumGroups, uint32_t kNumTMAMulticast,
33+
uint32_t kNumGroups,
34+
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
3435
uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N),
35-
uint32_t kNumNBlocksPerGroup = 16>
36+
uint32_t kNum1DBlocksPerGroup = 16>
3637
struct Scheduler {
3738
int current_iter = -1;
3839
uint32_t num_aligned_m_blocks;
@@ -61,16 +62,27 @@ struct Scheduler {
6162
}
6263

6364
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
64-
DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
65+
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
6566

6667
// Swizzle for better L2 usages
67-
auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup;
68-
auto group_idx = block_idx / num_blocks_per_group;
69-
auto first_n_block_idx = group_idx * kNumNBlocksPerGroup;
70-
auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx);
71-
auto in_group_idx = block_idx % num_blocks_per_group;
72-
m_block_idx = in_group_idx / num_n_blocks_in_group;
73-
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
68+
// TODO: unify these 2 branches
69+
if constexpr (kIsTMAMulticastOnA) {
70+
auto num_blocks_per_group = num_m_blocks * kNum1DBlocksPerGroup;
71+
auto group_idx = block_idx / num_blocks_per_group;
72+
auto first_n_block_idx = group_idx * kNum1DBlocksPerGroup;
73+
auto num_n_blocks_in_group = min(kNum1DBlocksPerGroup, kNumNBlocks - first_n_block_idx);
74+
auto in_group_idx = block_idx % num_blocks_per_group;
75+
m_block_idx = in_group_idx / num_n_blocks_in_group;
76+
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
77+
} else {
78+
auto num_blocks_per_group = kNumNBlocks * kNum1DBlocksPerGroup;
79+
auto group_idx = block_idx / num_blocks_per_group;
80+
auto first_m_block_idx = group_idx * kNum1DBlocksPerGroup;
81+
auto num_m_blocks_in_group = min(kNum1DBlocksPerGroup, num_m_blocks - first_m_block_idx);
82+
auto in_group_idx = block_idx % num_blocks_per_group;
83+
m_block_idx = first_m_block_idx + in_group_idx % num_m_blocks_in_group;
84+
n_block_idx = in_group_idx / num_m_blocks_in_group;
85+
}
7486
}
7587

7688
template <bool kIgnoreGroupedForGroupedContiguous=true>
@@ -116,6 +128,7 @@ struct Scheduler {
116128
return true;
117129
}
118130
};
131+
119132
#pragma clang diagnostic pop
120133

121134
} // namespace deep_gemm

ops/csrc/fp8/deep_gemm/include/deep_gemm/tma_utils.cuh

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ constexpr CUtensorMapDataType get_CUtensorMapDataType() {
5858
}
5959
}
6060

61-
PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
61+
inline PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
6262
// Get pointer to `cuTensorMapEncodeTiled`
6363
cudaDriverEntryPointQueryResult driver_status;
6464
void* cuTensorMapEncodeTiled_ptr = nullptr;
@@ -81,16 +81,15 @@ CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
8181
uint64_t stride_in_bytes, uint32_t smem_dim[2],
8282
CUtensorMapSwizzle swizzle_type,
8383
PFN_cuTensorMapEncodeTiled encode_func = nullptr) {
84-
CUtensorMap tensor_map{};
85-
constexpr uint32_t rank = 2;
86-
uint64_t global_stride[rank - 1] = {stride_in_bytes};
87-
uint32_t elem_strides[rank] = {1, 1};
84+
CUtensorMap tensor_map = {};
85+
uint64_t global_stride[1] = {stride_in_bytes};
86+
uint32_t elem_strides[2] = {1, 1};
8887

8988
if (encode_func == nullptr)
9089
encode_func = get_cuTensorMapEncodeTiled();
9190

9291
auto result = encode_func(
93-
&tensor_map, get_CUtensorMapDataType<typename std::remove_cv<T>::type>(), rank,
92+
&tensor_map, get_CUtensorMapDataType<std::remove_cv_t<T>>(), 2,
9493
global_address, gmem_dim, global_stride, smem_dim, elem_strides,
9594
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type,
9695
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,

ops/csrc/fp8/deep_gemm/include/deep_gemm/utils.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,9 @@ do {
6363
template <typename T>
6464
__device__ __host__ constexpr T ceil_div(T a, T b) {
6565
return (a + b - 1) / b;
66+
}
67+
68+
template <typename T>
69+
__device__ __host__ constexpr T constexpr_gcd(T a, T b) {
70+
return b == 0 ? a : constexpr_gcd(b, a % b);
6671
}

ops/csrc/fp8/deep_gemm/jit/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@
1616
# Copyright (c) 2025 DeepSeek
1717
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE
1818

19-
from .compiler import build, get_nvcc_compiler
20-
from .runtime import Runtime
19+
from .compiler import get_nvcc_compiler, build
2120
from .template import cpp_format, generate
21+
from .runtime import Runtime

ops/csrc/fp8/deep_gemm/jit/compiler.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
# Copyright (c) 2025 DeepSeek
1717
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE
1818

19-
import functools
2019
import hashlib
20+
import functools
2121
import os
2222
import re
2323
import subprocess
@@ -75,9 +75,7 @@ def get_nvcc_compiler() -> Tuple[str, str]:
7575
match = version_pattern.search(os.popen(f"{path} --version").read())
7676
version = match.group(1)
7777
assert match, f"Cannot get the version of NVCC compiler {path}"
78-
assert (
79-
version >= least_version_required
80-
), f"NVCC {path} version {version} is lower than {least_version_required}"
78+
assert version >= least_version_required, f"NVCC {path} version {version} is lower than {least_version_required}"
8179
return path, version
8280
raise RuntimeError("Cannot find any available NVCC compiler")
8381

@@ -117,18 +115,13 @@ def put(path, data, is_binary=False):
117115

118116
def build(name: str, arg_defs: tuple, code: str) -> Runtime:
119117
# Compiler flags
120-
nvcc_flags = [
121-
"-std=c++17",
122-
"-shared",
123-
"-O3",
124-
"--expt-relaxed-constexpr",
125-
"--expt-extended-lambda",
126-
"-gencode=arch=compute_90a,code=sm_90a",
127-
"--ptxas-options=--register-usage-level=10" + (",--verbose" if "DG_PTXAS_VERBOSE" in os.environ else ""),
128-
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
129-
"--diag-suppress=177,174,940",
130-
]
131-
cxx_flags = ["-fPIC", "-O3", "-Wno-deprecated-declarations", "-Wno-abi"]
118+
cpp_standard = int(os.getenv("DG_NVCC_OVERRIDE_CPP_STANDARD", 20))
119+
nvcc_flags = [f"-std=c++{cpp_standard}", "-shared", "-O3", "--expt-relaxed-constexpr", "--expt-extended-lambda",
120+
"-gencode=arch=compute_90a,code=sm_90a",
121+
"--ptxas-options=--register-usage-level=10" + (",--verbose" if "DG_PTXAS_VERBOSE" in os.environ else ""),
122+
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
123+
"--diag-suppress=39,174,177,940"]
124+
cxx_flags = ["-fPIC", "-O3", "-Wno-deprecated-declarations", "-Wno-abi", "-fconcepts"]
132125
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
133126
include_dirs = [get_jit_include_dir()]
134127

@@ -155,8 +148,12 @@ def build(name: str, arg_defs: tuple, code: str) -> Runtime:
155148
# Compile into a temporary SO file
156149
so_path = f"{path}/kernel.so"
157150
tmp_so_path = f"{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so"
151+
158152
# Compile
159-
command = [get_nvcc_compiler()[0], src_path, "-o", tmp_so_path, *flags, *[f"-I{d}" for d in include_dirs]]
153+
command = [get_nvcc_compiler()[0],
154+
src_path, "-o", tmp_so_path,
155+
*flags,
156+
*[f"-I{d}" for d in include_dirs]]
160157
if os.getenv("DG_JIT_DEBUG", None) or os.getenv("DG_JIT_PRINT_NVCC_COMMAND", False):
161158
print(f"Compiling JIT runtime {name} with command {command}")
162159
return_code = subprocess.check_call(command)

ops/csrc/fp8/deep_gemm/jit/interleave_ffma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def parse_registers(line):
9393

9494

9595
def modify_segment(m, name, ffma_lines):
96-
num_lines = len(ffma_lines)
96+
num_lines = (len(ffma_lines) * 9 // 16) // 2 * 2
9797
assert num_lines % 2 == 0
9898

9999
le_bytes, new_le_bytes = [], []

0 commit comments

Comments
 (0)