|
| 1 | +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 3 | +# Copyright (c) 2025 DeepSeek |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | + |
| 17 | +try: |
| 18 | + import paddle.distributed.communication.deep_ep as deep_ep |
| 19 | + |
| 20 | + HAVE_DEEP_EP = True |
| 21 | +except ImportError: |
| 22 | + HAVE_DEEP_EP = False |
| 23 | + |
| 24 | +import paddle |
| 25 | +from paddle.autograd import PyLayer |
| 26 | +from paddle.distributed.communication.group import Group |
| 27 | + |
| 28 | +_buffer = None |
| 29 | + |
| 30 | + |
| 31 | +def get_hidden_bytes(x: paddle.Tensor) -> int: |
| 32 | + """Calculate the number of hidden bytes for a tensor. |
| 33 | +
|
| 34 | + Args: |
| 35 | + x (paddle.Tensor): Input tensor |
| 36 | +
|
| 37 | + Returns: |
| 38 | + int: Number of hidden bytes |
| 39 | + """ |
| 40 | + return x.shape[1] * max(x.element_size(), 2) |
| 41 | + |
| 42 | + |
| 43 | +def get_buffer(group: Group, hidden_bytes: int): |
| 44 | + """Get or create a buffer for all-to-all communication. |
| 45 | +
|
| 46 | + Args: |
| 47 | + group (paddle.distributed.ProcessGroup): Process group for communication |
| 48 | + hidden_bytes (int): Number of hidden bytes needed |
| 49 | +
|
| 50 | + Returns: |
| 51 | + Buffer: Communication buffer |
| 52 | + """ |
| 53 | + global _buffer |
| 54 | + num_nvl_bytes, num_rdma_bytes = 0, 0 |
| 55 | + for config in ( |
| 56 | + deep_ep.Buffer.get_dispatch_config(group.world_size), |
| 57 | + deep_ep.Buffer.get_combine_config(group.world_size), |
| 58 | + ): |
| 59 | + # Split long line for PEP8 compliance |
| 60 | + num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.world_size), num_nvl_bytes) |
| 61 | + # TODO(umiswing): support internode |
| 62 | + # num_rdma_bytes = max( |
| 63 | + # config.get_rdma_buffer_size_hint(hidden_bytes, group.world_size), num_rdma_bytes |
| 64 | + # ) |
| 65 | + |
| 66 | + # Allocate buffer if not existed or not enough buffer |
| 67 | + # NOTES: the adaptive routing configuration of the network **must be off** |
| 68 | + if ( |
| 69 | + _buffer is None |
| 70 | + or _buffer.group != group |
| 71 | + or _buffer.num_nvl_bytes < num_nvl_bytes |
| 72 | + or _buffer.num_rdma_bytes < num_rdma_bytes |
| 73 | + ): |
| 74 | + _buffer = deep_ep.Buffer(group, num_nvl_bytes, num_rdma_bytes) |
| 75 | + return _buffer |
| 76 | + |
| 77 | + |
| 78 | +class FusedDispatch(PyLayer): |
| 79 | + """Fused dispatch operation for MoE routing combining computation and communication.""" |
| 80 | + |
| 81 | + @staticmethod |
| 82 | + def forward(ctx, x, token_indices, token_probs, num_experts, group, previous_event=None): |
| 83 | + """Forward pass of fused dispatch.""" |
| 84 | + # Calculate layout before actual dispatch |
| 85 | + buffer = get_buffer(group, get_hidden_bytes(x)) |
| 86 | + ( |
| 87 | + num_tokens_per_rank, |
| 88 | + num_tokens_per_rdma_rank, |
| 89 | + num_tokens_per_expert, |
| 90 | + is_token_in_rank, |
| 91 | + previous_event, |
| 92 | + ) = buffer.get_dispatch_layout( |
| 93 | + token_indices, |
| 94 | + num_experts, |
| 95 | + previous_event=None, |
| 96 | + async_finish=False, |
| 97 | + allocate_on_comm_stream=False, |
| 98 | + ) |
| 99 | + |
| 100 | + # Do MoE dispatch |
| 101 | + # NOTES: the CPU will wait for GPU's signal to arrive, |
| 102 | + # so this is not compatible with CUDA graph |
| 103 | + ( |
| 104 | + recv_x, |
| 105 | + recv_token_indices, |
| 106 | + recv_token_probs, |
| 107 | + num_recv_tokens_per_expert_list, |
| 108 | + handle, |
| 109 | + event, |
| 110 | + ) = buffer.dispatch( |
| 111 | + x, |
| 112 | + topk_idx=token_indices, |
| 113 | + topk_weights=token_probs.cast(paddle.float32), |
| 114 | + num_tokens_per_rank=num_tokens_per_rank, |
| 115 | + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, |
| 116 | + is_token_in_rank=is_token_in_rank, |
| 117 | + num_tokens_per_expert=num_tokens_per_expert, |
| 118 | + previous_event=None, |
| 119 | + async_finish=False, |
| 120 | + allocate_on_comm_stream=False, |
| 121 | + ) |
| 122 | + |
| 123 | + ctx.group = group |
| 124 | + ctx.handle = handle |
| 125 | + ctx.event = event |
| 126 | + tokens_per_expert = paddle.to_tensor(num_recv_tokens_per_expert_list) |
| 127 | + |
| 128 | + states = dict() |
| 129 | + states["dispatched_indices"] = recv_token_indices |
| 130 | + states["tokens_per_expert"] = tokens_per_expert |
| 131 | + states["handle"] = handle |
| 132 | + |
| 133 | + return recv_x, recv_token_probs, states |
| 134 | + |
| 135 | + @staticmethod |
| 136 | + def backward(ctx, grad_output, grad_token_probs): |
| 137 | + """Backward pass of fused dispatch.""" |
| 138 | + buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output)) |
| 139 | + handle = ctx.handle |
| 140 | + |
| 141 | + grad_x, grad_token_probs, event = buffer.combine( |
| 142 | + grad_output.contiguous(), |
| 143 | + handle, |
| 144 | + topk_weights=grad_token_probs.cast(paddle.float32), |
| 145 | + previous_event=None, |
| 146 | + async_finish=False, |
| 147 | + allocate_on_comm_stream=False, |
| 148 | + ) |
| 149 | + return grad_x, None, grad_token_probs |
| 150 | + |
| 151 | + |
| 152 | +class FusedCombine(PyLayer): |
| 153 | + """Fused combine operation for MoE output combining computation and communication.""" |
| 154 | + |
| 155 | + @staticmethod |
| 156 | + def forward(ctx, x, group, states, previous_event=None): |
| 157 | + """Forward pass of fused combine.""" |
| 158 | + handle = states["handle"] |
| 159 | + buffer = get_buffer(group, get_hidden_bytes(x)) |
| 160 | + combined_x, _, event = buffer.combine( |
| 161 | + x, handle=handle, async_finish=False, previous_event=None, allocate_on_comm_stream=False |
| 162 | + ) |
| 163 | + ctx.handle = handle |
| 164 | + ctx.group = group |
| 165 | + ctx.previous_event = previous_event |
| 166 | + |
| 167 | + return combined_x |
| 168 | + |
| 169 | + @staticmethod |
| 170 | + def backward(ctx, grad_output): |
| 171 | + """Backward pass of fused combine.""" |
| 172 | + buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output)) |
| 173 | + grad_x, _, _, _, _, event = buffer.dispatch( |
| 174 | + grad_output.contiguous(), |
| 175 | + handle=ctx.handle, |
| 176 | + previous_event=ctx.previous_event, |
| 177 | + async_finish=False, |
| 178 | + allocate_on_comm_stream=False, |
| 179 | + ) |
| 180 | + return grad_x |
| 181 | + |
| 182 | + |
| 183 | +if HAVE_DEEP_EP: |
| 184 | + |
| 185 | + def fused_dispatch(x, token_indices, token_probs, num_experts, group: Group, previous_event=None): |
| 186 | + """Perform fused dispatch operation if deep_ep is available. |
| 187 | +
|
| 188 | + Args: |
| 189 | + x: Input tensor [num_tokens, hidden_size] |
| 190 | + token_indices: Token routing indices [num_tokens, topk] |
| 191 | + token_probs: Token routing probabilities [num_tokens, topk] |
| 192 | + num_experts: Number of experts |
| 193 | + group: Process group |
| 194 | + previous_event: Previous CUDA event |
| 195 | +
|
| 196 | + Returns: |
| 197 | + Result of FusedDispatch |
| 198 | + """ |
| 199 | + return FusedDispatch.apply(x.contiguous(), token_indices, token_probs, num_experts, group, previous_event) |
| 200 | + |
| 201 | + def fused_combine(x, group, handle, previous_event=None): |
| 202 | + """Perform fused combine operation if deep_ep is available. |
| 203 | +
|
| 204 | + Args: |
| 205 | + x: Input tensor |
| 206 | + group: Process group |
| 207 | + handle: Communication handle |
| 208 | + previous_event: Previous CUDA event |
| 209 | +
|
| 210 | + Returns: |
| 211 | + Result of FusedCombine |
| 212 | + """ |
| 213 | + states = dict() |
| 214 | + states["handle"] = handle |
| 215 | + return FusedCombine.apply(x, group, states, previous_event) |
| 216 | + |
| 217 | +else: |
| 218 | + fused_dispatch = None |
| 219 | + fused_combine = None |
0 commit comments