|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +# DeepSpeed Team |
| 5 | + |
| 6 | +import torch |
| 7 | +import deepspeed.comm as dist |
| 8 | +from .sp_dp_registry import get_group, is_setup, sp_size |
| 9 | + |
| 10 | + |
| 11 | +@torch.library.custom_op("autosp::all_to_all", mutates_args=()) |
| 12 | +def all_to_all( |
| 13 | + input: torch.Tensor, |
| 14 | + scatter_idx: int, |
| 15 | + gather_idx: int, |
| 16 | + name: str, |
| 17 | +) -> torch.Tensor: |
| 18 | + """ |
| 19 | + All-to-all collective for SDPA tensors [B, N, S, H]. |
| 20 | +
|
| 21 | + For QKV (scatter_idx=1, gather_idx=2): |
| 22 | + [B, N, S/P, H] -> [B, N/P, S, H] |
| 23 | + For O (scatter_idx=2, gather_idx=1): |
| 24 | + [B, N/P, S, H] -> [B, N, S/P, H] |
| 25 | + """ |
| 26 | + assert is_setup(), 'Incorrect initialization of SP/DP mesh.' |
| 27 | + B, dim1, dim2, H = input.shape |
| 28 | + gid = dist.get_rank() // sp_size() |
| 29 | + group = get_group(gid) |
| 30 | + |
| 31 | + if scatter_idx == 1: |
| 32 | + N, local_S = dim1, dim2 |
| 33 | + input_t = input.reshape(B, sp_size(), N // sp_size(), local_S, H) |
| 34 | + input_t = input_t.permute(1, 0, 2, 3, 4).contiguous() |
| 35 | + |
| 36 | + output = torch.empty_like(input_t) |
| 37 | + dist.all_to_all_single(output, input_t, group=group) |
| 38 | + |
| 39 | + output = output.permute(1, 2, 0, 3, 4).contiguous() |
| 40 | + output = output.reshape(B, N // sp_size(), sp_size() * local_S, H) |
| 41 | + else: |
| 42 | + local_N, S = dim1, dim2 |
| 43 | + input_t = input.reshape(B, local_N, sp_size(), S // sp_size(), H) |
| 44 | + input_t = input_t.permute(2, 0, 1, 3, 4).contiguous() |
| 45 | + |
| 46 | + output = torch.empty_like(input_t) |
| 47 | + dist.all_to_all_single(output, input_t, group=group) |
| 48 | + |
| 49 | + output = output.permute(1, 0, 2, 3, 4).contiguous() |
| 50 | + output = output.reshape(B, sp_size() * local_N, S // sp_size(), H) |
| 51 | + |
| 52 | + return output |
| 53 | + |
| 54 | + |
| 55 | +@torch.library.register_fake("autosp::all_to_all") |
| 56 | +def all_to_all_fake(input: torch.Tensor, scatter_idx: int, gather_idx: int, name: str): |
| 57 | + B, dim1, dim2, H = input.shape |
| 58 | + if scatter_idx == 1: |
| 59 | + return input.new_empty(B, dim1 // sp_size(), dim2 * sp_size(), H) |
| 60 | + else: |
| 61 | + return input.new_empty(B, dim1 * sp_size(), dim2 // sp_size(), H) |
| 62 | + |
| 63 | + |
| 64 | +def _all_to_all_backward_setup(ctx, inputs, output): |
| 65 | + _, scatter_idx, gather_idx, name = inputs |
| 66 | + ctx.scatter_idx = gather_idx |
| 67 | + ctx.gather_idx = scatter_idx |
| 68 | + ctx.name = name + "_grad" |
| 69 | + |
| 70 | + |
| 71 | +def _all_to_all_backward(ctx, grad): |
| 72 | + return (all_to_all(grad, ctx.scatter_idx, ctx.gather_idx, ctx.name), None, None, None) |
| 73 | + |
| 74 | + |
| 75 | +torch.library.register_autograd("autosp::all_to_all", _all_to_all_backward, setup_context=_all_to_all_backward_setup) |
0 commit comments