Skip to content

Commit 4d35825

Browse files
[Performance Optimization]Replace cross_entropy_with_softmax to c_softmax_with_cross_entropy in dynamic auto mode (#10471)
* Replace cross_entropy_with_softmax to c_softmax_with_cross_entropy in dynamic auto mode * add sys_path * add TODO * add copyright
1 parent acb6e22 commit 4d35825

File tree

2 files changed

+246
-0
lines changed

2 files changed

+246
-0
lines changed

llm/auto_parallel/llama/run_pretrain_auto.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,7 @@ def main():
539539
config.tensor_parallel_degree = training_args.tensor_parallel_degree
540540
config.tensor_parallel_rank = training_args.tensor_parallel_rank
541541
config.sharding_parallel_degree = training_args.sharding_parallel_degree
542+
config.to_static = training_args.to_static
542543

543544
if training_args.strategy.pipeline.enable and config.virtual_pp_degree > 1:
544545
pipeline = training_args.strategy.pipeline
@@ -556,6 +557,11 @@ def main():
556557

557558
print("Final pre-training config:", config)
558559

560+
if "replace_with_parallel_cross_entropy" in training_args.tensor_parallel_config and config.tensor_parallel_degree > 1 and config.to_static is False:
561+
from llm.utils.replace_ops import replace_cross_entropy
562+
563+
replace_cross_entropy()
564+
559565
# # Set the dtype for loading model
560566
# dtype = "float32"
561567
# if training_args.fp16_opt_level == "O2":

llm/utils/replace_ops.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
from paddle import nn
17+
import functools
18+
import math
19+
import operator
20+
from typing import Literal, TypeAlias
21+
import paddle.distributed as dist
22+
23+
from paddle import Tensor
24+
from paddle import _C_ops, base, in_dynamic_mode
25+
from paddle.distributed.fleet.base import topology as tp
26+
from paddle.distributed import collective
27+
from paddle.tensor.manipulation import reshape
28+
from paddle.nn.layer.layers import Layer
29+
_ReduceMode: TypeAlias = Literal['mean', 'sum', 'none']
30+
31+
32+
# TODO: this function is rewrited from paddle.nn.functional.cross_entropy,
33+
# but better to merge into only one.
34+
def parallel_cross_entropy(
35+
input: Tensor,
36+
label: Tensor,
37+
weight: Tensor | None = None,
38+
ignore_index: int = -100,
39+
reduction: _ReduceMode = 'mean',
40+
soft_label: bool = False,
41+
axis: int = -1,
42+
use_softmax: bool = True,
43+
label_smoothing: float = 0.0,
44+
name: str | None = None,
45+
) -> Tensor:
46+
47+
if reduction not in ['sum', 'mean', 'none']:
48+
raise ValueError(
49+
"The value of 'reduction' in softmax_cross_entropy"
50+
f"should be 'sum', 'mean' or 'none', but received {reduction}, which is not allowed."
51+
)
52+
if ignore_index > 0 and soft_label:
53+
raise ValueError(
54+
"When soft_label == True, the value of 'ignore_index' in softmax_cross_entropy"
55+
f"should be '-100', but received {ignore_index}, which is not allowed."
56+
)
57+
58+
input_dims = len(list(input.shape))
59+
if input_dims == 0:
60+
raise ValueError('The dimension of input should be larger than zero!')
61+
62+
label_dims = len(list(label.shape))
63+
if input_dims - 1 == label_dims:
64+
label = paddle.unsqueeze(label, axis=axis)
65+
66+
if input_dims - 1 != label_dims and input_dims != label_dims:
67+
raise ValueError(
68+
f'Expected nput_dims - 1 = label_dims or input_dims == label_dims\
69+
(got nput_dims{input_dims}, label_dims{label_dims})'
70+
)
71+
72+
if label_smoothing > 0.0:
73+
soft_label = True
74+
# converting the label to one-hot encoding
75+
# for 1d case, converting label's shape from [N] to [N, C]
76+
# for 2d case, converting label's shape from [N, d_1, ..., d_k] to [N, d_1, ..., d_k, C]
77+
if input_dims - 1 == label_dims:
78+
label = paddle.squeeze(label, axis=axis)
79+
label = paddle.nn.functional.one_hot(label, input.shape[-1])
80+
81+
label = paddle.nn.functional.label_smooth(
82+
label, epsilon=label_smoothing
83+
)
84+
label = label.astype(input.dtype)
85+
label_dims = len(list(label.shape))
86+
87+
if not soft_label:
88+
valid_label = (
89+
paddle.cast(label != ignore_index, dtype=label.dtype) * label
90+
)
91+
92+
if soft_label == False and is_tensor_sharded(input):
93+
group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group()
94+
ring_id = group.id
95+
nranks = group.nranks
96+
global_rank = collective._get_global_env().rank
97+
rank = group.get_group_rank(global_rank)
98+
_, out = _C_ops.c_softmax_with_cross_entropy(
99+
input, label, ignore_index, ring_id, rank, nranks
100+
)
101+
else:
102+
from paddlenlp.utils.log import logger
103+
104+
logger.warning(
105+
"Failed to replace CrossEntropyLoss with ParallelCrossEntropyLoss. Please ensure: \n"
106+
"1. soft_label=False is set for parallel computation (current value: {}) \n"
107+
"2. Input tensor is properly sharded (current sharding status: {}) \n".format(
108+
soft_label,
109+
input_placement,
110+
)
111+
)
112+
113+
_, out = _C_ops.cross_entropy_with_softmax(
114+
input, label, soft_label, use_softmax, True, ignore_index, axis
115+
)
116+
117+
if weight is not None:
118+
# trans weight from class to sample, shape:N or [N,H,W] for 1d and 2d cases.
119+
if soft_label:
120+
# chajchaj:
121+
# weight's shape is C, where C is class num.
122+
# for 1d case: label's shape is [N,C], weight_gather's shape is N.
123+
# for 2d case: label's shape is [N,H,W,C], weight_gather's shape is [N,H,W].
124+
weight_gather = paddle.matmul(
125+
x=paddle.cast(label, weight.dtype),
126+
y=weight,
127+
transpose_x=False,
128+
transpose_y=True,
129+
)
130+
out_shape = list(out.shape)
131+
weight_gather_reshape = reshape(weight_gather, shape=out_shape)
132+
out = paddle.cast(out, weight_gather_reshape.dtype)
133+
134+
out = _C_ops.multiply(out, weight_gather_reshape)
135+
else:
136+
if input.shape[axis] != weight.shape[-1]:
137+
raise ValueError(
138+
f"input's class_dimension({input.shape[axis]}) must equal to "
139+
f"weight's class_dimension({weight.shape[-1]}) "
140+
"when weight is provided"
141+
)
142+
143+
ignore_weight_mask = paddle.cast(
144+
(label != ignore_index), out.dtype
145+
)
146+
if (
147+
ignore_weight_mask.ndim > 1
148+
and ignore_weight_mask.shape[axis] == 1
149+
):
150+
# TODO: Temporarily use squeeze instead of squeeze_
151+
ignore_weight_mask = paddle.squeeze(
152+
ignore_weight_mask, axis
153+
)
154+
if axis != -1 and axis != valid_label.ndim - 1:
155+
temp_perm = (
156+
list(range(axis % valid_label.ndim))
157+
+ list(
158+
range(
159+
(axis % valid_label.ndim + 1), valid_label.ndim
160+
)
161+
)
162+
+ [axis % valid_label.ndim]
163+
)
164+
weight_gather = _C_ops.gather_nd(
165+
weight, valid_label.transpose(temp_perm)
166+
)
167+
else:
168+
weight_gather = _C_ops.gather_nd(weight, valid_label)
169+
weight_gather = _C_ops.multiply(
170+
weight_gather, ignore_weight_mask
171+
)
172+
input_shape = list(label.shape)
173+
weight_gather_reshape = reshape(
174+
weight_gather, shape=input_shape
175+
)
176+
out = paddle.cast(out, weight_gather_reshape.dtype)
177+
out = _C_ops.multiply(out, weight_gather_reshape)
178+
179+
if reduction == "sum":
180+
# because of base_softmax_with_cross_entropy op's inner logic,
181+
# in the out tensor of this op, the loss of sample with class_index==ignore_index is 0
182+
# so, reduce_sum all directly is ok
183+
return _C_ops.sum(out, [], None, False)
184+
elif reduction == "mean":
185+
# 1. if weight==none,
186+
# numerator: reduce_sum all loss directly is ok causeof base_softmax_with_cross_entropy's inner logic
187+
# denominator: count sample num with class_index!=ignore_index
188+
# 2. else
189+
# numerator: loss's weighted sum
190+
# denominator: cal the sum of weight where the sample's class_index!=ignore_index
191+
if ignore_index >= 0: # ignore label
192+
out_sum = _C_ops.sum(out, [], None, False)
193+
# for each label[i],set 1 or 0, according to ignore_index
194+
# mask[i]=0, if label[i]==ignore_index
195+
# mask[i]=1, otherwise
196+
mask = label != ignore_index
197+
if weight is None:
198+
mask = paddle.cast(mask, dtype=out_sum.dtype)
199+
count = _C_ops.sum(mask, [], None, False)
200+
ret = out_sum / (count + (count == 0.0).astype(count.dtype))
201+
else:
202+
mask = paddle.cast(mask, weight_gather_reshape.dtype)
203+
weight_ignored = _C_ops.multiply(
204+
mask, weight_gather_reshape
205+
)
206+
weight_sum = _C_ops.sum(weight_ignored, [], None, False)
207+
ret = out_sum / (
208+
weight_sum
209+
+ (weight_sum == 0.0).astype(weight_sum.dtype)
210+
)
211+
return ret
212+
elif weight is not None:
213+
out_sum = _C_ops.sum(out, [], None, False)
214+
total_weight = _C_ops.sum(
215+
weight_gather_reshape, [], None, False
216+
)
217+
return out_sum / (
218+
total_weight
219+
+ (total_weight == 0.0).astype(total_weight.dtype)
220+
)
221+
else:
222+
return _C_ops.mean_all(out)
223+
224+
else:
225+
if input_dims - 1 == label_dims:
226+
out = paddle.squeeze(out, axis=axis)
227+
return out
228+
229+
230+
# TODO: placement[1] may not be mp axis.
231+
def is_tensor_sharded(tensor):
232+
if not tensor.is_dist():
233+
return False
234+
235+
placement = tensor.placements
236+
return placement[1].is_shard()
237+
238+
239+
def replace_cross_entropy():
240+
paddle.nn.functional.cross_entropy = parallel_cross_entropy

0 commit comments

Comments
 (0)