Skip to content

[Perf] Reduce memory usage by splitting tokens in fused_experts #1729

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/user_guide/configuration/additional_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ The following table lists the additional configuration options available in vLLM
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
| `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. |
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
| `fused_moe_max_chunk_size` | int | `max_num_batched_tokens * data_parallel_size` | The maximum token chunk size for the fused MoE operation. Input exceeding this size is split into multiple chunks for processing. |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need test different case of data_parallel_size to make sure this change works as expect


The details of each config option are as follows:

Expand Down Expand Up @@ -76,6 +77,7 @@ An example of additional configuration is as follows:
"enable_chunked_prefill": True,
},
"expert_tensor_parallel_size": 1,
"fused_moe_max_chunk_size": 8192,
"refresh": False,
}
```
68 changes: 68 additions & 0 deletions tests/e2e/singlecard/ops/test_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,71 @@ def test_fused_experts(
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)
torch.npu.empty_cache()


@pytest.mark.parametrize("chunk_size", [256, 4096])
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("device", DEVICE)
def test_split_fused_experts(
chunk_size: int,
m: int,
n: int,
k: int,
e: int,
topk: int,
ep_size: int,
dtype: torch.dtype,
device: str,
):
a = torch.randn((m, k), device=device, dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10

score = torch.randn((m, e), device=device, dtype=dtype)

if ep_size > 1:
local_e = e // ep_size
e_ids = torch.randint(0,
e, (local_e, ),
device=device,
dtype=torch.int32)
e_map = torch.full((e, ), -1, device=device, dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device=device, dtype=torch.int32)
w1 = w1[e_ids]
w2 = w2[e_ids]
else:
e_map = None

score = torch.softmax(score, dim=-1, dtype=dtype)
topk_weights, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.to(torch.int32)

a_clone = a.clone()
a_list = a_clone.split(chunk_size)
topk_weights_list = topk_weights.split(chunk_size)
topk_ids_list = topk_ids.split(chunk_size)
num_chunks = len(a_list)
assert num_chunks == len(topk_weights_list) == len(topk_ids_list)
output = a_clone
for i in range(len(a_list)):
h_chunk = fused_experts(a_list[i], w1, w2, topk_weights_list[i],
topk_ids_list[i], topk, e_map)
if num_chunks > 1:
# use inplace copy to save memory
a_list[i].copy_(h_chunk)
del h_chunk
else:
# num_chunks == 1, return the result directly
output = h_chunk
break

torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, e_map)
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)
torch.npu.empty_cache()
5 changes: 5 additions & 0 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def __init__(self, vllm_config):
self.expert_map_path = additional_config.get("expert_map_path", None)
self.chunked_prefill_for_mla = additional_config.get(
"chunked_prefill_for_mla", False)
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
dp_size = vllm_config.parallel_config.data_parallel_size
self.fused_moe_max_chunk_size = int(
additional_config.get("fused_moe_max_chunk_size",
max_num_tokens * dp_size))


class TorchairGraphConfig:
Expand Down
32 changes: 24 additions & 8 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@
self.max_model_len = vllm_config.model_config.max_model_len

ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.fused_moe_max_chunk_size = ascend_config.fused_moe_max_chunk_size

Check warning on line 956 in vllm_ascend/ops/fused_moe.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/ops/fused_moe.py#L956

Added line #L956 was not covered by tests

try:
device_group = self.ep_group.device_group
Expand Down Expand Up @@ -1049,13 +1049,29 @@
elif fused_moe_state in [
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
]:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map)
max_chunk_size = self.fused_moe_max_chunk_size
x_list = x.split(max_chunk_size)
topk_weights_list = topk_weights.split(max_chunk_size)
topk_ids_list = topk_ids.split(max_chunk_size)
num_chunks = len(x_list)
assert num_chunks == len(topk_weights_list) == len(topk_ids_list)
for i in range(len(x_list)):
hidden_states_chunk = fused_experts(

Check warning on line 1059 in vllm_ascend/ops/fused_moe.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/ops/fused_moe.py#L1052-L1059

Added lines #L1052 - L1059 were not covered by tests
hidden_states=x_list[i],
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights_list[i],
topk_ids=topk_ids_list[i],
top_k=top_k,
expert_map=expert_map)
if num_chunks > 1:

Check warning on line 1067 in vllm_ascend/ops/fused_moe.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/ops/fused_moe.py#L1067

Added line #L1067 was not covered by tests
# use inplace copy to save memory
x_list[i].copy_(hidden_states_chunk)
del hidden_states_chunk

Check warning on line 1070 in vllm_ascend/ops/fused_moe.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/ops/fused_moe.py#L1069-L1070

Added lines #L1069 - L1070 were not covered by tests
else:
# num_chunks == 1, return the result directly
return hidden_states_chunk
return x

Check warning on line 1074 in vllm_ascend/ops/fused_moe.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/ops/fused_moe.py#L1073-L1074

Added lines #L1073 - L1074 were not covered by tests
elif MOE_ALL2ALL_BUFFER:
return fused_experts_with_all2all_buffer(
hidden_states=x,
Expand Down
36 changes: 26 additions & 10 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@
self.ep_group = get_ep_group()

ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.fused_moe_max_chunk_size = ascend_config.fused_moe_max_chunk_size

Check warning on line 637 in vllm_ascend/quantization/w8a8_dynamic.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/quantization/w8a8_dynamic.py#L637

Added line #L637 was not covered by tests

try:
device_group = self.ep_group.device_group
Expand Down Expand Up @@ -783,15 +783,31 @@
elif fused_moe_state in [
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
]:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map)
max_chunk_size = self.fused_moe_max_chunk_size
x_list = x.split(max_chunk_size)
topk_weights_list = topk_weights.split(max_chunk_size)
topk_ids_list = topk_ids.split(max_chunk_size)
num_chunks = len(x_list)
assert num_chunks == len(topk_weights_list) == len(topk_ids_list)
for i in range(num_chunks):
hidden_states_chunk = fused_experts(

Check warning on line 793 in vllm_ascend/quantization/w8a8_dynamic.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/quantization/w8a8_dynamic.py#L786-L793

Added lines #L786 - L793 were not covered by tests
hidden_states=x_list[i],
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights_list[i],
topk_ids=topk_ids_list[i],
top_k=top_k,
expert_map=expert_map)
if num_chunks > 1:

Check warning on line 803 in vllm_ascend/quantization/w8a8_dynamic.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/quantization/w8a8_dynamic.py#L803

Added line #L803 was not covered by tests
# use inplace copy to save memory
x_list[i].copy_(hidden_states_chunk)
del hidden_states_chunk

Check warning on line 806 in vllm_ascend/quantization/w8a8_dynamic.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/quantization/w8a8_dynamic.py#L805-L806

Added lines #L805 - L806 were not covered by tests
else:
# num_chunks == 1, return the result directly
return hidden_states_chunk
return x

Check warning on line 810 in vllm_ascend/quantization/w8a8_dynamic.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/quantization/w8a8_dynamic.py#L809-L810

Added lines #L809 - L810 were not covered by tests
else:
# The current implementation of deepseek moe splits hidden_states
# according to tp_size before they are feed into fused_moe module.
Expand Down
Loading