Skip to content

Commit 5c7c1bd

Browse files
committed
Add sp_async_reduce_scatter in dygraph auto mode
1 parent 5c482b6 commit 5c7c1bd

File tree

3 files changed

+213
-4
lines changed

3 files changed

+213
-4
lines changed

llm/auto_parallel/llama/run_pretrain_auto.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ class PreTrainingArguments(AutoTrainingArguments):
8686
)
8787
sr: Optional[int] = field(default=0, metadata={"help": "The count of chunks without recompute."})
8888
virtual_pipeline_seg_method: str = field(
89-
default="LlamaDecoderLayerAuto", metadata={"help": "The seg method of splitting pp layer for virtual pipeline."}
89+
default="LlamaDecoderLayerAuto",
90+
metadata={"help": "The seg method of splitting pp layer for virtual pipeline."},
9091
)
9192
# NOTE(gongenlei): new add autotuner_benchmark
9293
autotuner_benchmark: bool = field(
@@ -449,8 +450,14 @@ def main():
449450
else:
450451
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
451452

452-
if training_args.enable_linear_fused_grad_add:
453-
from fused_layers import mock_layers
453+
do_enable_sp_async_reduce_scatter = (
454+
training_args.enable_auto_parallel
455+
and training_args.tensor_parallel_degree > 1
456+
and training_args.sequence_parallel
457+
and "enable_sp_async_reduce_scatter" in training_args.tensor_parallel_config
458+
)
459+
if training_args.enable_linear_fused_grad_add and not do_enable_sp_async_reduce_scatter:
460+
from llm.utils.fused_layers import mock_layers
454461

455462
mock_layers()
456463

@@ -557,7 +564,11 @@ def main():
557564

558565
print("Final pre-training config:", config)
559566

560-
if "replace_with_parallel_cross_entropy" in training_args.tensor_parallel_config and config.tensor_parallel_degree > 1 and config.to_static is False:
567+
if (
568+
"replace_with_parallel_cross_entropy" in training_args.tensor_parallel_config
569+
and config.tensor_parallel_degree > 1
570+
and config.to_static is False
571+
):
561572
from llm.utils.replace_ops import replace_cross_entropy
562573

563574
replace_cross_entropy()
@@ -582,6 +593,13 @@ def fn(layer):
582593

583594
model.apply(fn)
584595

596+
if do_enable_sp_async_reduce_scatter:
597+
from llm.utils.sp_async_reduce_scatter import (
598+
mock_layers_sp_async_reduce_scatter,
599+
)
600+
601+
mock_layers_sp_async_reduce_scatter(model)
602+
585603
# Create the learning_rate scheduler and optimizer
586604
if training_args.decay_steps is None:
587605
training_args.decay_steps = training_args.max_steps

llm/utils/sp_async_reduce_scatter.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import paddle
15+
import paddle.distributed as dist
16+
from paddle.distributed import fleet
17+
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
18+
_check_environment_for_overlap,
19+
)
20+
from paddle.framework import core
21+
22+
from paddlenlp.transformers.llama.modeling_auto import get_mesh
23+
24+
25+
def is_fused_matmul_bias_supported():
26+
if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm() or paddle.is_compiled_with_xpu():
27+
return hasattr(core.eager.ops.legacy, "fused_gemm_epilogue")
28+
else:
29+
return False
30+
31+
32+
ipp = None
33+
id2ipp = {}
34+
35+
paddle_nn_functional_linear = paddle.nn.functional.linear
36+
if is_fused_matmul_bias_supported():
37+
paddle_incubate_nn_functional_fused_linear = paddle.incubate.nn.functional.fused_linear
38+
39+
40+
# modify from Paddle/python/paddle/distributed/auto_parallel/moe_utils.py
41+
def _dist_reshape(
42+
dist_tensor,
43+
global_shape,
44+
mesh,
45+
placements,
46+
):
47+
local_tensor = dist_tensor._local_value()
48+
tgt_global_shape = [dist_tensor.shape[0] * dist_tensor.shape[1], dist_tensor.shape[2]]
49+
tgt_local_shape = [local_tensor.shape[0] * local_tensor.shape[1], local_tensor.shape[2]]
50+
51+
place = paddle.framework._current_expected_place()
52+
place = paddle.framework._get_paddle_place(place)
53+
54+
local_tensor = local_tensor.reshape(tgt_local_shape)
55+
56+
if placements[1].is_shard():
57+
new_placements = [dist.Shard(0), dist.Shard(1)]
58+
else:
59+
new_placements = [dist.Shard(0), dist.Replicate()]
60+
61+
out = paddle.Tensor(
62+
local_tensor,
63+
dims=tgt_global_shape,
64+
process_mesh=mesh,
65+
placements=new_placements,
66+
place=place,
67+
)
68+
out.stop_gradient = dist_tensor.stop_gradient
69+
return out
70+
71+
72+
if is_fused_matmul_bias_supported():
73+
origin_linear = paddle.incubate.nn.functional.fused_linear
74+
else:
75+
origin_linear = paddle.nn.functional.linear
76+
77+
78+
class FusedLinearWithReduceScatter(paddle.autograd.PyLayer):
79+
@staticmethod
80+
def forward(ctx, x, weight, bias=None, name=None):
81+
global ipp
82+
input_parallel = dist.reshard(
83+
x,
84+
get_mesh(ipp),
85+
[dist.Shard(1), dist.Replicate()],
86+
)
87+
y = origin_linear(input_parallel, weight, bias)
88+
ctx.save_for_backward(weight, bias, input_parallel)
89+
90+
return y
91+
92+
@staticmethod
93+
def backward(ctx, dy):
94+
weight, bias, input_parallel = ctx.saved_tensor()
95+
96+
# compute dx
97+
if dy.dtype == weight.dtype:
98+
dinput_parallel = paddle.matmul(dy, weight, transpose_y=True)
99+
else:
100+
dinput_parallel = paddle.matmul(dy, paddle.cast(weight, dtype=dy.dtype), transpose_y=True)
101+
102+
hcg = fleet.get_hybrid_communicate_group()
103+
model_parallel_group = hcg.get_model_parallel_group()
104+
parallelism = model_parallel_group.nranks
105+
106+
assert (
107+
dinput_parallel.shape[0] % parallelism == 0
108+
), f"Input sequence length {dinput_parallel.shape[0]} can't be divided exactly by sequence parallelism {parallelism}"
109+
110+
# reduce-scatter dx
111+
dx_global_shape = dinput_parallel.shape
112+
dx_global_shape[0] = dx_global_shape[0] // parallelism
113+
dinput_parallel_local = dinput_parallel._local_value()
114+
dx_local_shape = dinput_parallel_local.shape
115+
dx_local_shape[0] = dx_local_shape[0] // parallelism
116+
dx_local = paddle.empty(shape=dx_local_shape, dtype=dinput_parallel.dtype)
117+
task = dist.stream.reduce_scatter(
118+
dx_local,
119+
dinput_parallel_local,
120+
op=dist.ReduceOp.SUM,
121+
group=model_parallel_group,
122+
sync_op=False,
123+
)
124+
125+
# compute dw and dbias
126+
_check_environment_for_overlap()
127+
dy = _dist_reshape(dy, [-1, dy.shape[-1]], dy.process_mesh, dy.placements)
128+
input_parallel = _dist_reshape(
129+
input_parallel, [-1, input_parallel.shape[-1]], input_parallel.process_mesh, input_parallel.placements
130+
)
131+
dw = paddle.matmul(
132+
input_parallel,
133+
dy,
134+
transpose_x=True,
135+
)
136+
if bias is None:
137+
task.wait()
138+
place = paddle.framework._current_expected_place()
139+
place = paddle.framework._get_paddle_place(place)
140+
141+
dx = paddle.Tensor(
142+
dx_local,
143+
dims=dx_global_shape,
144+
process_mesh=dinput_parallel.process_mesh,
145+
placements=[dist.Shard(1), dist.Shard(0)],
146+
place=place,
147+
)
148+
dx.stop_gradient = dx.stop_gradient
149+
return dx, dw
150+
else:
151+
dbias = paddle.sum(dy, axis=0)
152+
task.wait()
153+
place = paddle.framework._current_expected_place()
154+
place = paddle.framework._get_paddle_place(place)
155+
156+
dx = paddle.Tensor(
157+
dx_local,
158+
dims=dx_global_shape,
159+
process_mesh=dinput_parallel.process_mesh,
160+
placements=[dist.Shard(1), dist.Shard(0)],
161+
place=place,
162+
)
163+
dx.stop_gradient = dx.stop_gradient
164+
return dx, dw, dbias
165+
166+
167+
def forward_pre_hook(layer, input):
168+
paddle.nn.functional.linear = FusedLinearWithReduceScatter.apply
169+
if is_fused_matmul_bias_supported():
170+
paddle.incubate.nn.functional.fused_linear = FusedLinearWithReduceScatter.apply
171+
global ipp, id2ipp
172+
ipp = id2ipp[id(layer)]
173+
174+
175+
def forward_post_hook(layer, input, ouput):
176+
paddle.nn.functional.linear = paddle_nn_functional_linear
177+
if is_fused_matmul_bias_supported():
178+
paddle.incubate.nn.functional.fused_linear = paddle_incubate_nn_functional_fused_linear
179+
180+
181+
def mock_layers_sp_async_reduce_scatter(model):
182+
global ipp, id2ipp
183+
for name, layer in model.named_sublayers():
184+
if name.endswith("self_attn") or name.startswith("mlp"):
185+
ipp = layer.ipp
186+
for n in ["qkv_proj", "q_proj", "k_proj", "v_proj", "gate_up_fused_proj", "gate_proj", "up_proj"]:
187+
if name.endswith(n):
188+
id2ipp[id(layer)] = ipp
189+
layer.register_forward_pre_hook(forward_pre_hook)
190+
layer.register_forward_post_hook(forward_post_hook)

paddlenlp/trainer/training_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,6 +1647,7 @@ def is_segment_parallel_supported():
16471647
"replace_with_c_embedding",
16481648
# "enable_mp_fused_linear_param_grad_add",
16491649
"replace_with_parallel_cross_entropy",
1650+
"enable_sp_async_reduce_scatter",
16501651
]:
16511652
raise ValueError(
16521653
f"Found unknown tensor parallell config {x}, "

0 commit comments

Comments
 (0)