|
| 1 | +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +import paddle |
| 15 | +import paddle.distributed as dist |
| 16 | +from paddle.distributed import fleet |
| 17 | +from paddle.distributed.fleet.utils.sequence_parallel_utils import ( |
| 18 | + _check_environment_for_overlap, |
| 19 | +) |
| 20 | +from paddle.framework import core |
| 21 | + |
| 22 | +from paddlenlp.transformers.llama.modeling_auto import get_mesh |
| 23 | + |
| 24 | + |
| 25 | +def is_fused_matmul_bias_supported(): |
| 26 | + if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm() or paddle.is_compiled_with_xpu(): |
| 27 | + return hasattr(core.eager.ops.legacy, "fused_gemm_epilogue") |
| 28 | + else: |
| 29 | + return False |
| 30 | + |
| 31 | + |
| 32 | +ipp = None |
| 33 | +id2ipp = {} |
| 34 | + |
| 35 | +paddle_nn_functional_linear = paddle.nn.functional.linear |
| 36 | +if is_fused_matmul_bias_supported(): |
| 37 | + paddle_incubate_nn_functional_fused_linear = paddle.incubate.nn.functional.fused_linear |
| 38 | + |
| 39 | + |
| 40 | +# modify from Paddle/python/paddle/distributed/auto_parallel/moe_utils.py |
| 41 | +def _dist_reshape( |
| 42 | + dist_tensor, |
| 43 | + global_shape, |
| 44 | + mesh, |
| 45 | + placements, |
| 46 | +): |
| 47 | + local_tensor = dist_tensor._local_value() |
| 48 | + tgt_global_shape = [dist_tensor.shape[0] * dist_tensor.shape[1], dist_tensor.shape[2]] |
| 49 | + tgt_local_shape = [local_tensor.shape[0] * local_tensor.shape[1], local_tensor.shape[2]] |
| 50 | + |
| 51 | + place = paddle.framework._current_expected_place() |
| 52 | + place = paddle.framework._get_paddle_place(place) |
| 53 | + |
| 54 | + local_tensor = local_tensor.reshape(tgt_local_shape) |
| 55 | + |
| 56 | + if placements[1].is_shard(): |
| 57 | + new_placements = [dist.Shard(0), dist.Shard(1)] |
| 58 | + else: |
| 59 | + new_placements = [dist.Shard(0), dist.Replicate()] |
| 60 | + |
| 61 | + out = paddle.Tensor( |
| 62 | + local_tensor, |
| 63 | + dims=tgt_global_shape, |
| 64 | + process_mesh=mesh, |
| 65 | + placements=new_placements, |
| 66 | + place=place, |
| 67 | + ) |
| 68 | + out.stop_gradient = dist_tensor.stop_gradient |
| 69 | + return out |
| 70 | + |
| 71 | + |
| 72 | +if is_fused_matmul_bias_supported(): |
| 73 | + origin_linear = paddle.incubate.nn.functional.fused_linear |
| 74 | +else: |
| 75 | + origin_linear = paddle.nn.functional.linear |
| 76 | + |
| 77 | + |
| 78 | +class FusedLinearWithReduceScatter(paddle.autograd.PyLayer): |
| 79 | + @staticmethod |
| 80 | + def forward(ctx, x, weight, bias=None, name=None): |
| 81 | + global ipp |
| 82 | + input_parallel = dist.reshard( |
| 83 | + x, |
| 84 | + get_mesh(ipp), |
| 85 | + [dist.Shard(1), dist.Replicate()], |
| 86 | + ) |
| 87 | + y = origin_linear(input_parallel, weight, bias) |
| 88 | + ctx.save_for_backward(weight, bias, input_parallel) |
| 89 | + |
| 90 | + return y |
| 91 | + |
| 92 | + @staticmethod |
| 93 | + def backward(ctx, dy): |
| 94 | + weight, bias, input_parallel = ctx.saved_tensor() |
| 95 | + |
| 96 | + # compute dx |
| 97 | + if dy.dtype == weight.dtype: |
| 98 | + dinput_parallel = paddle.matmul(dy, weight, transpose_y=True) |
| 99 | + else: |
| 100 | + dinput_parallel = paddle.matmul(dy, paddle.cast(weight, dtype=dy.dtype), transpose_y=True) |
| 101 | + |
| 102 | + hcg = fleet.get_hybrid_communicate_group() |
| 103 | + model_parallel_group = hcg.get_model_parallel_group() |
| 104 | + parallelism = model_parallel_group.nranks |
| 105 | + |
| 106 | + assert ( |
| 107 | + dinput_parallel.shape[0] % parallelism == 0 |
| 108 | + ), f"Input sequence length {dinput_parallel.shape[0]} can't be divided exactly by sequence parallelism {parallelism}" |
| 109 | + |
| 110 | + # reduce-scatter dx |
| 111 | + dx_global_shape = dinput_parallel.shape |
| 112 | + dx_global_shape[0] = dx_global_shape[0] // parallelism |
| 113 | + dinput_parallel_local = dinput_parallel._local_value() |
| 114 | + dx_local_shape = dinput_parallel_local.shape |
| 115 | + dx_local_shape[0] = dx_local_shape[0] // parallelism |
| 116 | + dx_local = paddle.empty(shape=dx_local_shape, dtype=dinput_parallel.dtype) |
| 117 | + task = dist.stream.reduce_scatter( |
| 118 | + dx_local, |
| 119 | + dinput_parallel_local, |
| 120 | + op=dist.ReduceOp.SUM, |
| 121 | + group=model_parallel_group, |
| 122 | + sync_op=False, |
| 123 | + ) |
| 124 | + |
| 125 | + # compute dw and dbias |
| 126 | + _check_environment_for_overlap() |
| 127 | + dy = _dist_reshape(dy, [-1, dy.shape[-1]], dy.process_mesh, dy.placements) |
| 128 | + input_parallel = _dist_reshape( |
| 129 | + input_parallel, [-1, input_parallel.shape[-1]], input_parallel.process_mesh, input_parallel.placements |
| 130 | + ) |
| 131 | + dw = paddle.matmul( |
| 132 | + input_parallel, |
| 133 | + dy, |
| 134 | + transpose_x=True, |
| 135 | + ) |
| 136 | + if bias is None: |
| 137 | + task.wait() |
| 138 | + place = paddle.framework._current_expected_place() |
| 139 | + place = paddle.framework._get_paddle_place(place) |
| 140 | + |
| 141 | + dx = paddle.Tensor( |
| 142 | + dx_local, |
| 143 | + dims=dx_global_shape, |
| 144 | + process_mesh=dinput_parallel.process_mesh, |
| 145 | + placements=[dist.Shard(1), dist.Shard(0)], |
| 146 | + place=place, |
| 147 | + ) |
| 148 | + dx.stop_gradient = dx.stop_gradient |
| 149 | + return dx, dw |
| 150 | + else: |
| 151 | + dbias = paddle.sum(dy, axis=0) |
| 152 | + task.wait() |
| 153 | + place = paddle.framework._current_expected_place() |
| 154 | + place = paddle.framework._get_paddle_place(place) |
| 155 | + |
| 156 | + dx = paddle.Tensor( |
| 157 | + dx_local, |
| 158 | + dims=dx_global_shape, |
| 159 | + process_mesh=dinput_parallel.process_mesh, |
| 160 | + placements=[dist.Shard(1), dist.Shard(0)], |
| 161 | + place=place, |
| 162 | + ) |
| 163 | + dx.stop_gradient = dx.stop_gradient |
| 164 | + return dx, dw, dbias |
| 165 | + |
| 166 | + |
| 167 | +def forward_pre_hook(layer, input): |
| 168 | + paddle.nn.functional.linear = FusedLinearWithReduceScatter.apply |
| 169 | + if is_fused_matmul_bias_supported(): |
| 170 | + paddle.incubate.nn.functional.fused_linear = FusedLinearWithReduceScatter.apply |
| 171 | + global ipp, id2ipp |
| 172 | + ipp = id2ipp[id(layer)] |
| 173 | + |
| 174 | + |
| 175 | +def forward_post_hook(layer, input, ouput): |
| 176 | + paddle.nn.functional.linear = paddle_nn_functional_linear |
| 177 | + if is_fused_matmul_bias_supported(): |
| 178 | + paddle.incubate.nn.functional.fused_linear = paddle_incubate_nn_functional_fused_linear |
| 179 | + |
| 180 | + |
| 181 | +def mock_layers_sp_async_reduce_scatter(model): |
| 182 | + global ipp, id2ipp |
| 183 | + for name, layer in model.named_sublayers(): |
| 184 | + if name.endswith("self_attn") or name.startswith("mlp"): |
| 185 | + ipp = layer.ipp |
| 186 | + for n in ["qkv_proj", "q_proj", "k_proj", "v_proj", "gate_up_fused_proj", "gate_proj", "up_proj"]: |
| 187 | + if name.endswith(n): |
| 188 | + id2ipp[id(layer)] = ipp |
| 189 | + layer.register_forward_pre_hook(forward_pre_hook) |
| 190 | + layer.register_forward_post_hook(forward_post_hook) |
0 commit comments