Skip to content

Fixes for dp + ep + tp combinations #78

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 171 commits into
base: modular-fused-experts
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
171 commits
Select commit Hold shift + click to select a range
9a6ee6b
moe refactoring
bnellnm Apr 1, 2025
82188dd
module deepgemm moe working
bnellnm Apr 1, 2025
fea4fbf
working deep gemm, wip cutlass
bnellnm Apr 2, 2025
d78f1c8
working cutlass
bnellnm Apr 2, 2025
8e3e6a9
deepgemm working again
bnellnm Apr 2, 2025
e41e4bf
cutlass working again
bnellnm Apr 2, 2025
0b877ba
cutlass working again
bnellnm Apr 2, 2025
5d76ee9
fix inplace, format and name cleanups
bnellnm Apr 2, 2025
c2ce01a
fix inplace, format + name cleanups
bnellnm Apr 2, 2025
b52b50d
test improvements
bnellnm Apr 3, 2025
49a9d11
make modular triton classes, fix edge cases
bnellnm Apr 3, 2025
bcac19a
fix outplace bug
bnellnm Apr 3, 2025
e2ab4f5
refactor dispatch/combine stuff
bnellnm Apr 3, 2025
ecaca4e
initial pplx dispatch/combine class
bnellnm Apr 3, 2025
9bcbde0
merge triton dispatch into standard, add some comments
bnellnm Apr 3, 2025
73847e0
format
bnellnm Apr 3, 2025
b136032
comments
bnellnm Apr 3, 2025
62584bf
fix linter
bnellnm Apr 3, 2025
cbdc471
fix more linter stuff
bnellnm Apr 3, 2025
8e2c5b2
cleanup for review
bnellnm Apr 3, 2025
cef98ab
review comments
bnellnm Apr 4, 2025
13da7ea
forgot return
bnellnm Apr 4, 2025
fb39d50
add dp_rank_num_tokens to DPMetadata
bnellnm Apr 4, 2025
9bac87a
better check for fp8 in _fp8_permute
bnellnm Apr 4, 2025
9882f97
updates
bnellnm Apr 28, 2025
cfcdb70
fix merge issues
bnellnm Apr 29, 2025
4664e0f
fix lint
bnellnm Apr 29, 2025
42f12d7
add pplx tests
bnellnm Apr 29, 2025
dc0a640
lint
bnellnm Apr 29, 2025
64acde9
undo random lint changes
bnellnm Apr 29, 2025
17e6e00
more lint
bnellnm Apr 29, 2025
1039851
more lint nonsense
bnellnm Apr 29, 2025
89de35f
WIP torch while
tlrmchlsmth Mar 15, 2025
2c12392
wip
tlrmchlsmth Mar 25, 2025
8c19435
wip
tlrmchlsmth Mar 25, 2025
49d2658
wip
tlrmchlsmth Mar 27, 2025
36bb880
wip
tlrmchlsmth Mar 27, 2025
09a9813
WIP integration
tlrmchlsmth Mar 28, 2025
af8fd7c
Add test for deep gemm matmul
bnellnm Feb 26, 2025
3ab5443
fix matmul test
bnellnm Feb 27, 2025
187eadf
running
bnellnm Feb 27, 2025
45fd37f
wip
bnellnm Feb 27, 2025
1c54fa9
wip
bnellnm Feb 28, 2025
8752d63
debugging
bnellnm Feb 28, 2025
9ac9bfe
debugging
bnellnm Feb 28, 2025
2724f05
fix
bnellnm Feb 28, 2025
e86cf1d
update deep gemm
bnellnm Feb 28, 2025
353687d
update deep gemm + small test case
bnellnm Mar 1, 2025
228c054
wip
bnellnm Mar 2, 2025
4439c53
wip
bnellnm Mar 2, 2025
487e319
problem with scores
bnellnm Mar 2, 2025
8d89dc2
some passing tests
bnellnm Mar 3, 2025
abf6171
some passing tests
bnellnm Mar 3, 2025
c09f42f
topk > 1 doesn't work. prune oom-ing tests
bnellnm Mar 3, 2025
2ffac31
fix indices
bnellnm Mar 3, 2025
4e81605
enable more tests
bnellnm Mar 3, 2025
9f21aa2
format
bnellnm Mar 3, 2025
10ba95d
use fused_topk for unit test
bnellnm Mar 4, 2025
a46f3d4
every other block correct
bnellnm Mar 5, 2025
4cf7770
working
bnellnm Mar 5, 2025
65a3ef3
enable more tests
bnellnm Mar 5, 2025
65ce6e7
working tests w/permute
bnellnm Mar 5, 2025
75b376c
cleanups
bnellnm Mar 5, 2025
416dec4
wip
bnellnm Mar 6, 2025
55d9efa
not crashing
bnellnm Mar 6, 2025
ae402f5
baseline working integration
bnellnm Mar 6, 2025
6587ea1
add allow_deep_gemm flag
bnellnm Mar 6, 2025
cc7ec3f
wip
bnellnm Mar 7, 2025
da0fd3e
better
bnellnm Mar 7, 2025
6b08ac7
fix some stuff
bnellnm Mar 8, 2025
caa58c0
fix more stuff
bnellnm Mar 8, 2025
78034ff
cleanups
bnellnm Mar 8, 2025
0549dc2
some integration tests working
bnellnm Mar 8, 2025
14d0569
almost all tests passing
bnellnm Mar 10, 2025
ac2a339
cleanup temp construction a bit
bnellnm Mar 10, 2025
d87b305
fix rest of tests
bnellnm Mar 10, 2025
7fcdd1c
cleanups + format
bnellnm Mar 10, 2025
ed3610e
do more of output computation in place
bnellnm Mar 10, 2025
e39f8c8
add env var
bnellnm Mar 10, 2025
adf85f1
formatting, remove some blocking restrictions
bnellnm Mar 12, 2025
8e93160
wip
bnellnm Mar 12, 2025
d81062b
fix resizing of output
bnellnm Mar 12, 2025
b2ea85c
fix resizing of output
bnellnm Mar 12, 2025
37053bd
fixes
bnellnm Mar 12, 2025
bcb245a
aligned chunking working for deep gemm
bnellnm Mar 12, 2025
f585c5d
unaligned chunking for deep gemm
bnellnm Mar 13, 2025
6dd17e5
cleanup wip
bnellnm Mar 13, 2025
e150caa
clean up some blocking stuff
bnellnm Mar 13, 2025
f4d5441
clean up some blocking stuff
bnellnm Mar 13, 2025
3b5f459
tweaks
bnellnm Mar 14, 2025
d8771fa
fix rebase
bnellnm Mar 15, 2025
00ad23a
rebase
bnellnm Mar 17, 2025
833182f
refactoring + minor perf improvements
bnellnm Mar 21, 2025
29add30
refactoring + perf tweaks
bnellnm Mar 22, 2025
b1f5fcf
remove debugging cruft
bnellnm Mar 24, 2025
2e19622
cache resize refactoring
bnellnm Mar 24, 2025
5d97022
cleanups
bnellnm Mar 25, 2025
0c343cf
format
bnellnm Mar 25, 2025
f60b4b3
revert test.txt, fix mypy errors
bnellnm Mar 25, 2025
856046b
review comments
bnellnm Mar 26, 2025
c7f3ddb
review comments
bnellnm Mar 27, 2025
f653358
clean up use_dg flags
bnellnm Mar 27, 2025
9391c66
remove check for aligned M
bnellnm Mar 27, 2025
2351edf
rebase + clean up test
bnellnm Mar 28, 2025
d0e81cc
fix format
bnellnm Mar 28, 2025
b5fb80c
Clean up diff
tlrmchlsmth Mar 31, 2025
204c4d5
[Distributed] Add custom allreduce support for ROCM (#14125)
ilmarkov Apr 1, 2025
ad77c5f
[Bugfix][Model] fix mllama multi-image (#14883)
yma11 Apr 1, 2025
84782a1
module deepgemm moe working
bnellnm Apr 1, 2025
d88baaa
working deep gemm, wip cutlass
bnellnm Apr 2, 2025
bf9a833
working cutlass
bnellnm Apr 2, 2025
ab7ff87
deepgemm working again
bnellnm Apr 2, 2025
b1f59a8
fix inplace, format and name cleanups
bnellnm Apr 2, 2025
b9542bc
test improvements
bnellnm Apr 3, 2025
e974b59
make modular triton classes, fix edge cases
bnellnm Apr 3, 2025
1a7bdbd
refactor dispatch/combine stuff
bnellnm Apr 3, 2025
ca50521
initial pplx dispatch/combine class
bnellnm Apr 3, 2025
a5c8907
merge triton dispatch into standard, add some comments
bnellnm Apr 3, 2025
939ef2f
format
bnellnm Apr 3, 2025
65f4b55
cleanup for review
bnellnm Apr 3, 2025
2672f68
hacking
bnellnm Apr 4, 2025
a6df5b7
hacking
bnellnm Apr 7, 2025
bddffe7
init stuff
bnellnm Apr 7, 2025
1813ae4
call super ctor + fix random stuff
bnellnm Apr 7, 2025
d50afb6
fix use_ep bug
tlrmchlsmth Apr 7, 2025
207a373
Fix dp_size
tlrmchlsmth Apr 7, 2025
ea821e3
add comment
tlrmchlsmth Apr 7, 2025
e4acd18
fixes
tlrmchlsmth Apr 7, 2025
353151e
get a bit further
bnellnm Apr 7, 2025
70fc2a8
hacking in dispatch_combine
bnellnm Apr 9, 2025
3b319a1
hook up some wires
bnellnm Apr 10, 2025
792d751
seems to be working
bnellnm Apr 10, 2025
be24517
wip
bnellnm Apr 11, 2025
16092a5
batched moe test
bnellnm Apr 14, 2025
1d98c32
simple test
bnellnm Apr 15, 2025
0dfd27e
cleanup
bnellnm Apr 15, 2025
f6acee6
test pplx w/naive implementation
bnellnm Apr 15, 2025
c69354d
test pplx w/naive implementation
bnellnm Apr 15, 2025
4971b43
hack fix for chunking loop
bnellnm Apr 15, 2025
fedb2d2
wip. add pplx unit test
bnellnm Apr 16, 2025
46d09b7
work on unit test
bnellnm Apr 17, 2025
7db0061
dispatch/combine unit test
bnellnm Apr 17, 2025
cb7320d
forgot file
bnellnm Apr 17, 2025
fe1974a
somewhat working unit test
bnellnm Apr 18, 2025
86c2055
wip
bnellnm Apr 18, 2025
58fe406
fix test
bnellnm Apr 18, 2025
4fb31ef
some cleanup
bnellnm Apr 19, 2025
e0560d5
wip
bnellnm Apr 19, 2025
a876454
wip
bnellnm Apr 29, 2025
9396364
undo random changes
bnellnm Apr 29, 2025
47f32c7
merge
bnellnm Apr 29, 2025
00f8fb2
tweak
bnellnm Apr 29, 2025
fd4805f
revert hack
bnellnm Apr 29, 2025
be22c57
fixes
bnellnm Apr 29, 2025
9018df8
pplx update
bnellnm Apr 29, 2025
3433b73
varun's fixes
bnellnm Apr 29, 2025
800dde1
varun's fixes
bnellnm Apr 29, 2025
918e62b
tweak bound_m
bnellnm Apr 29, 2025
b6ae861
run linter
bnellnm Apr 29, 2025
448658a
more lint stuff
bnellnm Apr 29, 2025
c7ddca4
add guards for pplx import
bnellnm Apr 30, 2025
22b988a
fix forward_chunked
Apr 30, 2025
c09cefd
fix more lint
bnellnm Apr 30, 2025
938c516
cleanups
bnellnm Apr 30, 2025
c0fc027
cleanups + lint, layer.py wip
bnellnm Apr 30, 2025
f74ab61
fix parallel_state lint
bnellnm Apr 30, 2025
3e8a0e3
fix M=1 pplx test
bnellnm May 1, 2025
886045e
fix M=1 pplx test
bnellnm May 1, 2025
5d960df
fix M=1 pplx test
bnellnm May 1, 2025
1014679
fixes
May 1, 2025
ba8f478
zero out attn outputs during profile run
May 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions tests/kernels/moe/test_batched_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass

import pytest
import torch
import triton.language as tl

from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
invoke_batched_silu_and_mul, invoke_moe_batched_triton_kernel)


@dataclass
class BatchedMMConfig:
dtype: torch.dtype
num_experts: int
max_tokens_per_expert: int
K: int
N: int


@dataclass
class BatchedMMTensors:
A: torch.Tensor # [E, max_tokens, K]
B: torch.Tensor # [E, K, N] - column major
C: torch.Tensor # [E, max_tokens, N]
num_expert_tokens: torch.Tensor # [E]

@staticmethod
def make_tensors(config: BatchedMMConfig):
A = torch.randn(
(config.num_experts, config.max_tokens_per_expert, config.K),
device="cuda",
dtype=config.dtype) / 50.0
B = torch.randn((config.num_experts, config.N, config.K),
device="cuda",
dtype=config.dtype) / 50.0
C = torch.zeros(
(config.num_experts, config.max_tokens_per_expert, config.N),
device="cuda",
dtype=config.dtype)
num_expert_tokens = torch.randint(low=0,
high=config.max_tokens_per_expert,
size=(config.num_experts, ),
device="cuda",
dtype=torch.int32)
return BatchedMMTensors(A, B, C, num_expert_tokens)


def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
num_expert_tokens: torch.Tensor) -> torch.Tensor:

num_expert_tokens_cpu = num_expert_tokens.clone()
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
num_experts = num_expert_tokens.size(0)

for e in range(num_experts):
num_tokens = num_expert_tokens_cpu[e]
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)

return C


@pytest.mark.parametrize("num_experts", [16, 32])
@pytest.mark.parametrize("max_tokens_per_expert", [512])
@pytest.mark.parametrize("K", [256])
@pytest.mark.parametrize("N", [512])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
N: int, dtype: torch.dtype):

config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
tensors = BatchedMMTensors.make_tensors(config)

test_output = tensors.C
ref_output = test_output.clone()

compute_tl_dtype = {
torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16,
torch.float32: tl.float32
}[test_output.dtype]
invoke_moe_batched_triton_kernel(
tensors.A,
tensors.B,
test_output,
tensors.num_expert_tokens,
compute_tl_dtype,
# Quantization data
None,
None,
None,
# Quantization schemes
False,
False,
False,
config={
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 16
})

ref_output = ref_impl(tensors.A, tensors.B, ref_output,
tensors.num_expert_tokens)
#torch.cuda.synchronize()
#print (f"ref output {ref_output}")
#print (f"test output {test_output}")

torch.testing.assert_close(test_output, ref_output, atol=1e-3, rtol=1e-3)


@dataclass
class BatchedSiluMulConfig:
dtype: torch.dtype
num_experts: int
max_tokens_per_expert: int
D: int


@dataclass
class BatchedSiluMulTensors:
input: torch.Tensor
output: torch.Tensor
expert_num_tokens: torch.Tensor

@staticmethod
def make_tensors(config: BatchedSiluMulConfig):
input = torch.randn(
(config.num_experts, config.max_tokens_per_expert, config.D * 2),
device="cuda",
dtype=config.dtype) / 50.0
output = torch.zeros(
(config.num_experts, config.max_tokens_per_expert, config.D),
device="cuda",
dtype=config.dtype)
num_expert_tokens = torch.randint(low=0,
high=config.max_tokens_per_expert,
size=(config.num_experts, ),
device="cuda",
dtype=torch.int32)
return BatchedSiluMulTensors(input, output, num_expert_tokens)


def ref_batched_silu_mul(output: torch.Tensor, input: torch.Tensor,
num_expert_tokens: torch.Tensor) -> torch.Tensor:

num_expert_tokens_cpu = num_expert_tokens.clone()
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
num_experts = num_expert_tokens.size(0)

for e in range(num_experts):
num_tokens = num_expert_tokens_cpu[e].item()
out_part = output[e, :num_tokens, :]
in_part = input[e, :num_tokens, :]
torch.ops._C.silu_and_mul(out_part, in_part)


@pytest.mark.parametrize("num_experts", [16, 32])
@pytest.mark.parametrize("max_tokens_per_expert", [128])
@pytest.mark.parametrize("D", [128, 256])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_batched_silu_mul(num_experts: int, max_tokens_per_expert: int, D: int,
dtype: torch.dtype):

config = BatchedSiluMulConfig(dtype, num_experts, max_tokens_per_expert, D)
tensors = BatchedSiluMulTensors.make_tensors(config)

test_out = tensors.output
ref_out = torch.zeros_like(test_out)

ref_batched_silu_mul(ref_out, tensors.input, tensors.expert_num_tokens)

invoke_batched_silu_and_mul(test_out, tensors.input,
tensors.expert_num_tokens)

torch.testing.assert_close(test_out, ref_out)
22 changes: 9 additions & 13 deletions tests/kernels/moe/test_cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@
(224, 3072, 1536),
]

vllm_config = VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192


@dataclasses.dataclass
class MOETensors:
Expand Down Expand Up @@ -190,7 +195,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr]
'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr]
'topk_weights': topk_weights,
'topk_ids_': topk_ids,
'topk_ids': topk_ids,
'ab_strides1': moe_tensors.ab_strides1,
'c_strides1': moe_tensors.c_strides1,
'ab_strides2': moe_tensors.ab_strides2,
Expand Down Expand Up @@ -231,10 +236,7 @@ def test_cutlass_moe_8_bit_no_graph(
per_out_ch: bool,
):
current_platform.seed_everything(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):

with set_current_vllm_config(vllm_config):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_ch)

Expand Down Expand Up @@ -276,10 +278,7 @@ def test_cutlass_moe_8_bit_cuda_graph(
per_out_ch: bool,
):
current_platform.seed_everything(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):

with set_current_vllm_config(vllm_config):
dtype = torch.half

mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
Expand Down Expand Up @@ -334,10 +333,7 @@ def test_cutlass_moe_8_bit_EP(
ep_size: int,
):
current_platform.seed_everything(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):

with set_current_vllm_config(vllm_config):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_channel)

Expand Down
Loading