Skip to content

Commit 9d34781

Browse files
author
Sigrid Jin (Sionic AI)
committed
Merge remote-tracking branch 'origin/main' into jina-support
2 parents 5c45015 + 1bf6513 commit 9d34781

File tree

168 files changed

+7163
-6333
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

168 files changed

+7163
-6333
lines changed

.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ export VLLM_XLA_CACHE_PATH=
7070
echo "Using VLLM V1"
7171
7272
echo "--- Hardware Information ---"
73-
tpu-info
73+
# tpu-info
7474
echo "--- Starting Tests ---"
7575
set +e
7676
overall_script_exit_code=0

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1
4545
# requirements.txt files and should be kept consistent. The ROCm torch
4646
# versions are derived from docker/Dockerfile.rocm
4747
#
48-
set(TORCH_SUPPORTED_VERSION_CUDA "2.7.0")
48+
set(TORCH_SUPPORTED_VERSION_CUDA "2.7.1")
4949
set(TORCH_SUPPORTED_VERSION_ROCM "2.7.0")
5050

5151
#

benchmarks/kernels/benchmark_grouped_gemm_cutlass.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ def bench_run(
8080
a, score, topk, renormalize=False
8181
)
8282

83+
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
84+
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
85+
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
86+
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
87+
8388
def run_triton_moe(
8489
a: torch.Tensor,
8590
w1: torch.Tensor,
@@ -111,6 +116,10 @@ def run_cutlass_moe(
111116
w2: torch.Tensor,
112117
w1_scale: torch.Tensor,
113118
w2_scale: torch.Tensor,
119+
ab_strides1: torch.Tensor,
120+
ab_strides2: torch.Tensor,
121+
c_strides1: torch.Tensor,
122+
c_strides2: torch.Tensor,
114123
topk_weights: torch.Tensor,
115124
topk_ids: torch.Tensor,
116125
per_act_token: bool,
@@ -125,6 +134,10 @@ def run_cutlass_moe(
125134
topk_ids,
126135
w1_scale,
127136
w2_scale,
137+
ab_strides1,
138+
ab_strides2,
139+
c_strides1,
140+
c_strides2,
128141
per_act_token,
129142
a1_scale=None,
130143
)
@@ -136,6 +149,10 @@ def run_cutlass_from_graph(
136149
w2_q: torch.Tensor,
137150
w1_scale: torch.Tensor,
138151
w2_scale: torch.Tensor,
152+
ab_strides1: torch.Tensor,
153+
ab_strides2: torch.Tensor,
154+
c_strides1: torch.Tensor,
155+
c_strides2: torch.Tensor,
139156
topk_weights: torch.Tensor,
140157
topk_ids: torch.Tensor,
141158
):
@@ -150,6 +167,10 @@ def run_cutlass_from_graph(
150167
topk_ids,
151168
w1_scale,
152169
w2_scale,
170+
ab_strides1,
171+
ab_strides2,
172+
c_strides1,
173+
c_strides2,
153174
per_act_token,
154175
a1_scale=None,
155176
)
@@ -194,6 +215,10 @@ def replay_graph(graph, num_repeats):
194215
w2_q,
195216
w1_scale,
196217
w2_scale,
218+
ab_strides1,
219+
ab_strides2,
220+
c_strides1,
221+
c_strides2,
197222
topk_weights,
198223
topk_ids,
199224
)
@@ -231,6 +256,10 @@ def replay_graph(graph, num_repeats):
231256
"w1_scale": w1_scale,
232257
"w2_scale": w2_scale,
233258
"per_act_token": per_act_token,
259+
"ab_strides1": ab_strides1,
260+
"ab_strides2": ab_strides2,
261+
"c_strides1": c_strides1,
262+
"c_strides2": c_strides2,
234263
# cuda graph params
235264
"cutlass_graph": cutlass_graph,
236265
"triton_graph": triton_graph,
@@ -289,6 +318,10 @@ def replay_graph(graph, num_repeats):
289318
w2_q,
290319
w1_scale,
291320
w2_scale,
321+
ab_strides1,
322+
ab_strides2,
323+
c_strides1,
324+
c_strides2,
292325
topk_weights,
293326
topk_ids,
294327
per_act_token,
@@ -297,7 +330,7 @@ def replay_graph(graph, num_repeats):
297330

298331
results.append(
299332
benchmark.Timer(
300-
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
333+
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
301334
globals=globals,
302335
label=label,
303336
sub_label=sub_label,

benchmarks/kernels/benchmark_moe.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,11 @@ def main(args: argparse.Namespace):
586586
topk = config.num_experts_per_tok
587587
intermediate_size = config.moe_intermediate_size
588588
shard_intermediate_size = 2 * intermediate_size // args.tp_size
589+
elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
590+
E = config.num_experts
591+
topk = config.moe_topk[0]
592+
intermediate_size = config.moe_intermediate_size[0]
593+
shard_intermediate_size = 2 * intermediate_size // args.tp_size
589594
else:
590595
# Support for llama4
591596
config = config.get_text_config()

csrc/moe/moe_permute_unpermute_op.cu

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,30 @@ __global__ void shuffleInputRowsKernel(const T* input,
160160
}
161161
}
162162

163+
template <typename T>
164+
__global__ void shuffleInputRowsKernelSlow(const T* input,
165+
const int32_t* dst2src_map,
166+
T* output, int64_t num_src_rows,
167+
int64_t num_dst_rows,
168+
int64_t num_cols) {
169+
int64_t dest_row_idx = blockIdx.x;
170+
int64_t const source_row_idx = dst2src_map[dest_row_idx];
171+
172+
if (blockIdx.x < num_dst_rows) {
173+
// Duplicate and permute rows
174+
auto const* source_row_ptr = input + source_row_idx * num_cols;
175+
auto* dest_row_ptr = output + dest_row_idx * num_cols;
176+
177+
int64_t const start_offset = threadIdx.x;
178+
int64_t const stride = blockDim.x;
179+
180+
for (int elem_index = start_offset; elem_index < num_cols;
181+
elem_index += stride) {
182+
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
183+
}
184+
}
185+
}
186+
163187
void shuffle_rows(const torch::Tensor& input_tensor,
164188
const torch::Tensor& dst2src_map,
165189
torch::Tensor& output_tensor) {
@@ -173,17 +197,24 @@ void shuffle_rows(const torch::Tensor& input_tensor,
173197
int64_t const num_src_rows = input_tensor.size(0);
174198
int64_t const num_cols = input_tensor.size(1);
175199

176-
TORCH_CHECK(!(num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)),
177-
"num_cols must be divisible by 128 / "
178-
"sizeof(input_tensor.scalar_type()) / 8");
179-
180-
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
181-
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
182-
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
183-
dst2src_map.data_ptr<int32_t>(),
184-
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
185-
num_dest_rows, num_cols);
186-
});
200+
if (num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)) {
201+
// use slow kernel if num_cols can't be aligned to 128 bits
202+
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
203+
shuffleInputRowsKernelSlow<scalar_t><<<blocks, threads, 0, stream>>>(
204+
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
205+
dst2src_map.data_ptr<int32_t>(),
206+
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
207+
num_dest_rows, num_cols);
208+
});
209+
} else {
210+
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
211+
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
212+
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
213+
dst2src_map.data_ptr<int32_t>(),
214+
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
215+
num_dest_rows, num_cols);
216+
});
217+
}
187218
}
188219

189220
#else

csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,36 @@ struct sm90_fp8_config_default {
2929

3030
template <typename InType, typename OutType,
3131
template <typename, typename, typename> typename Epilogue>
32-
struct sm90_fp8_config_M16 {
33-
// M in [1, 16]
32+
struct sm90_fp8_config_M4 {
33+
// M in [1, 4]
3434
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
3535
using KernelSchedule =
3636
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
3737
using EpilogueSchedule =
3838
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
39-
using TileShape = cute::Shape<cute::_64, cute::_64, cute::_128>;
40-
using ClusterShape = cute::Shape<cute::_1, cute::_4, cute::_1>;
39+
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_128>;
40+
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
4141

4242
using Cutlass3xGemm =
4343
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
44-
KernelSchedule, EpilogueSchedule>;
44+
KernelSchedule, EpilogueSchedule, true>;
45+
};
46+
47+
template <typename InType, typename OutType,
48+
template <typename, typename, typename> typename Epilogue>
49+
struct sm90_fp8_config_M64 {
50+
// M in (4, 64]
51+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
52+
using KernelSchedule =
53+
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
54+
using EpilogueSchedule =
55+
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
56+
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_256>;
57+
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
58+
59+
using Cutlass3xGemm =
60+
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
61+
KernelSchedule, EpilogueSchedule, true>;
4562
};
4663

4764
template <typename InType, typename OutType,
@@ -102,7 +119,9 @@ void run_cutlass_moe_mm_sm90(
102119
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
103120
using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192<
104121
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
105-
using Cutlass3xGemmM16 = typename sm90_fp8_config_M16<
122+
using Cutlass3xGemmM4 = typename sm90_fp8_config_M4<
123+
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
124+
using Cutlass3xGemmM64 = typename sm90_fp8_config_M64<
106125
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
107126
using Cutlass3xGemmDefault = typename sm90_fp8_config_default<
108127
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
@@ -111,18 +130,24 @@ void run_cutlass_moe_mm_sm90(
111130
uint32_t const n = out_tensors.size(1);
112131
uint32_t const k = a_tensors.size(1);
113132

114-
if (n >= 8192) {
115-
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
133+
// Use swap_ab for M <= 64 by default to reduce padding
134+
if (m <= 4) {
135+
cutlass_group_gemm_caller<Cutlass3xGemmM4>(
116136
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
117137
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
118138
per_out_ch);
119-
} else if (k >= 8192) {
120-
cutlass_group_gemm_caller<Cutlass3xGemmK8192>(
139+
} else if (m <= 64) {
140+
cutlass_group_gemm_caller<Cutlass3xGemmM64>(
121141
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
122142
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
123143
per_out_ch);
124-
} else if (m <= 16) {
125-
cutlass_group_gemm_caller<Cutlass3xGemmM16>(
144+
} else if (n >= 8192) {
145+
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
146+
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
147+
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
148+
per_out_ch);
149+
} else if (k >= 8192) {
150+
cutlass_group_gemm_caller<Cutlass3xGemmK8192>(
126151
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
127152
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
128153
per_out_ch);

0 commit comments

Comments
 (0)