Skip to content

Commit a381674

Browse files
authored
Add moe flex dispatcher (#9977)
* add wip flex dispatcher * wip fix PyLayer segfault * remove redundant args * refine * add group as args * add license, fix a bug * refine codestyle * refine
1 parent 2c1b106 commit a381674

File tree

3 files changed

+603
-0
lines changed

3 files changed

+603
-0
lines changed

paddlenlp/transformers/fused_a2a.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
# Copyright (c) 2025 DeepSeek
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
try:
18+
import paddle.distributed.communication.deep_ep as deep_ep
19+
20+
HAVE_DEEP_EP = True
21+
except ImportError:
22+
HAVE_DEEP_EP = False
23+
24+
import paddle
25+
from paddle.autograd import PyLayer
26+
from paddle.distributed.communication.group import Group
27+
28+
_buffer = None
29+
30+
31+
def get_hidden_bytes(x: paddle.Tensor) -> int:
32+
"""Calculate the number of hidden bytes for a tensor.
33+
34+
Args:
35+
x (paddle.Tensor): Input tensor
36+
37+
Returns:
38+
int: Number of hidden bytes
39+
"""
40+
return x.shape[1] * max(x.element_size(), 2)
41+
42+
43+
def get_buffer(group: Group, hidden_bytes: int):
44+
"""Get or create a buffer for all-to-all communication.
45+
46+
Args:
47+
group (paddle.distributed.ProcessGroup): Process group for communication
48+
hidden_bytes (int): Number of hidden bytes needed
49+
50+
Returns:
51+
Buffer: Communication buffer
52+
"""
53+
global _buffer
54+
num_nvl_bytes, num_rdma_bytes = 0, 0
55+
for config in (
56+
deep_ep.Buffer.get_dispatch_config(group.world_size),
57+
deep_ep.Buffer.get_combine_config(group.world_size),
58+
):
59+
# Split long line for PEP8 compliance
60+
num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.world_size), num_nvl_bytes)
61+
# TODO(umiswing): support internode
62+
# num_rdma_bytes = max(
63+
# config.get_rdma_buffer_size_hint(hidden_bytes, group.world_size), num_rdma_bytes
64+
# )
65+
66+
# Allocate buffer if not existed or not enough buffer
67+
# NOTES: the adaptive routing configuration of the network **must be off**
68+
if (
69+
_buffer is None
70+
or _buffer.group != group
71+
or _buffer.num_nvl_bytes < num_nvl_bytes
72+
or _buffer.num_rdma_bytes < num_rdma_bytes
73+
):
74+
_buffer = deep_ep.Buffer(group, num_nvl_bytes, num_rdma_bytes)
75+
return _buffer
76+
77+
78+
class FusedDispatch(PyLayer):
79+
"""Fused dispatch operation for MoE routing combining computation and communication."""
80+
81+
@staticmethod
82+
def forward(ctx, x, token_indices, token_probs, num_experts, group, previous_event=None):
83+
"""Forward pass of fused dispatch."""
84+
# Calculate layout before actual dispatch
85+
buffer = get_buffer(group, get_hidden_bytes(x))
86+
(
87+
num_tokens_per_rank,
88+
num_tokens_per_rdma_rank,
89+
num_tokens_per_expert,
90+
is_token_in_rank,
91+
previous_event,
92+
) = buffer.get_dispatch_layout(
93+
token_indices,
94+
num_experts,
95+
previous_event=None,
96+
async_finish=False,
97+
allocate_on_comm_stream=False,
98+
)
99+
100+
# Do MoE dispatch
101+
# NOTES: the CPU will wait for GPU's signal to arrive,
102+
# so this is not compatible with CUDA graph
103+
(
104+
recv_x,
105+
recv_token_indices,
106+
recv_token_probs,
107+
num_recv_tokens_per_expert_list,
108+
handle,
109+
event,
110+
) = buffer.dispatch(
111+
x,
112+
topk_idx=token_indices,
113+
topk_weights=token_probs.cast(paddle.float32),
114+
num_tokens_per_rank=num_tokens_per_rank,
115+
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
116+
is_token_in_rank=is_token_in_rank,
117+
num_tokens_per_expert=num_tokens_per_expert,
118+
previous_event=None,
119+
async_finish=False,
120+
allocate_on_comm_stream=False,
121+
)
122+
123+
ctx.group = group
124+
ctx.handle = handle
125+
ctx.event = event
126+
tokens_per_expert = paddle.to_tensor(num_recv_tokens_per_expert_list)
127+
128+
states = dict()
129+
states["dispatched_indices"] = recv_token_indices
130+
states["tokens_per_expert"] = tokens_per_expert
131+
states["handle"] = handle
132+
133+
return recv_x, recv_token_probs, states
134+
135+
@staticmethod
136+
def backward(ctx, grad_output, grad_token_probs):
137+
"""Backward pass of fused dispatch."""
138+
buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output))
139+
handle = ctx.handle
140+
141+
grad_x, grad_token_probs, event = buffer.combine(
142+
grad_output.contiguous(),
143+
handle,
144+
topk_weights=grad_token_probs.cast(paddle.float32),
145+
previous_event=None,
146+
async_finish=False,
147+
allocate_on_comm_stream=False,
148+
)
149+
return grad_x, None, grad_token_probs
150+
151+
152+
class FusedCombine(PyLayer):
153+
"""Fused combine operation for MoE output combining computation and communication."""
154+
155+
@staticmethod
156+
def forward(ctx, x, group, states, previous_event=None):
157+
"""Forward pass of fused combine."""
158+
handle = states["handle"]
159+
buffer = get_buffer(group, get_hidden_bytes(x))
160+
combined_x, _, event = buffer.combine(
161+
x, handle=handle, async_finish=False, previous_event=None, allocate_on_comm_stream=False
162+
)
163+
ctx.handle = handle
164+
ctx.group = group
165+
ctx.previous_event = previous_event
166+
167+
return combined_x
168+
169+
@staticmethod
170+
def backward(ctx, grad_output):
171+
"""Backward pass of fused combine."""
172+
buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output))
173+
grad_x, _, _, _, _, event = buffer.dispatch(
174+
grad_output.contiguous(),
175+
handle=ctx.handle,
176+
previous_event=ctx.previous_event,
177+
async_finish=False,
178+
allocate_on_comm_stream=False,
179+
)
180+
return grad_x
181+
182+
183+
if HAVE_DEEP_EP:
184+
185+
def fused_dispatch(x, token_indices, token_probs, num_experts, group: Group, previous_event=None):
186+
"""Perform fused dispatch operation if deep_ep is available.
187+
188+
Args:
189+
x: Input tensor [num_tokens, hidden_size]
190+
token_indices: Token routing indices [num_tokens, topk]
191+
token_probs: Token routing probabilities [num_tokens, topk]
192+
num_experts: Number of experts
193+
group: Process group
194+
previous_event: Previous CUDA event
195+
196+
Returns:
197+
Result of FusedDispatch
198+
"""
199+
return FusedDispatch.apply(x.contiguous(), token_indices, token_probs, num_experts, group, previous_event)
200+
201+
def fused_combine(x, group, handle, previous_event=None):
202+
"""Perform fused combine operation if deep_ep is available.
203+
204+
Args:
205+
x: Input tensor
206+
group: Process group
207+
handle: Communication handle
208+
previous_event: Previous CUDA event
209+
210+
Returns:
211+
Result of FusedCombine
212+
"""
213+
states = dict()
214+
states["handle"] = handle
215+
return FusedCombine.apply(x, group, states, previous_event)
216+
217+
else:
218+
fused_dispatch = None
219+
fused_combine = None

paddlenlp/transformers/moe_utils.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
# Copyright (c) 2025 DeepSeek
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from typing import Optional
18+
19+
import paddle
20+
21+
22+
def permute(
23+
tokens,
24+
routing_map,
25+
num_out_tokens: Optional[int] = None,
26+
drop_and_pad: bool = False,
27+
):
28+
"""Permute the tokens and probs based on the mask.
29+
Tokens with the same designated expert will be grouped together.
30+
The shape of mask is [tokens, num_experts], it indicates which experts were selected
31+
by each token.
32+
33+
Args:
34+
tokens (paddle.Tensor): The input token tensor, [num_tokens, hidden].
35+
routing_map (paddle.Tensor): The sparse token to expert mapping, [num_tokens, num_experts].
36+
num_out_tokens (int, optional): The number of output tokens. If None, it's set to
37+
the number of input tokens.
38+
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
39+
and pads the number of tokens to the expert capacity.
40+
"""
41+
assert not drop_and_pad, "token-drop and pads is not supported"
42+
num_tokens, hidden = tokens.shape
43+
num_experts = routing_map.shape[1]
44+
45+
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
46+
routing_map = routing_map.cast(paddle.bool).T.contiguous()
47+
48+
# Create a dense expert-to-token mapping from the sparse token-to-expert mapping
49+
token_indices = paddle.arange(num_tokens).unsqueeze(0).expand([num_experts, -1])
50+
sorted_indices = token_indices.masked_select(routing_map)
51+
52+
# use the mapping to permute the tokens
53+
permuted_input = tokens.index_select(axis=0, index=sorted_indices)
54+
55+
return permuted_input, sorted_indices
56+
57+
58+
def unpermute(
59+
permuted_tokens: paddle.Tensor,
60+
sorted_indices: paddle.Tensor,
61+
restore_shape: paddle.shape,
62+
probs: paddle.Tensor = None,
63+
routing_map: paddle.Tensor = None,
64+
drop_and_pad: bool = False,
65+
):
66+
"""
67+
Restore the original order of tokens after permutation. If probs are provided, it
68+
will also apply them to the tokens before restoring the order.
69+
70+
Args:
71+
permuted_tokens (paddle.Tensor): The permuted token tensor.
72+
sorted_indices (paddle.Tensor): The indices used to sort the tokens.
73+
restore_shape (paddle.shape): The shape of the unpermuted tensor.
74+
probs (paddle.Tensor, optional): The unpermuted probs tensor,
75+
routing_map (paddle.Tensor, optional): Token to expert mapping, shape
76+
[num_tokens, num_experts].
77+
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
78+
and pads the number of tokens to the expert capacity.
79+
80+
Returns:
81+
paddle.Tensor: The tokens restored to their original order.
82+
"""
83+
assert not drop_and_pad, "token-drop and pads is not supported"
84+
_, hidden = restore_shape
85+
86+
if probs is not None:
87+
assert routing_map is not None, "Mask must be provided to permute the probs."
88+
permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous())
89+
permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1)
90+
91+
# Create an output tensor filled with zeros
92+
output_tokens = paddle.zeros(restore_shape, dtype=permuted_tokens.dtype)
93+
# Scatter add the permuted_input back to the original positions
94+
output_tokens.put_along_axis_(
95+
axis=0,
96+
indices=sorted_indices.unsqueeze(1).expand([-1, hidden]),
97+
values=permuted_tokens,
98+
reduce="add",
99+
include_self=True,
100+
)
101+
return output_tokens

0 commit comments

Comments
 (0)