1
+ # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import paddle
16
+ from paddle import nn
17
+ import functools
18
+ import math
19
+ import operator
20
+ from typing import Literal , TypeAlias
21
+ import paddle .distributed as dist
22
+
23
+ from paddle import Tensor
24
+ from paddle import _C_ops , base , in_dynamic_mode
25
+ from paddle .distributed .fleet .base import topology as tp
26
+ from paddle .distributed import collective
27
+ from paddle .tensor .manipulation import reshape
28
+ from paddle .nn .layer .layers import Layer
29
+ _ReduceMode : TypeAlias = Literal ['mean' , 'sum' , 'none' ]
30
+
31
+
32
+ # TODO: this function is rewrited from paddle.nn.functional.cross_entropy,
33
+ # but better to merge into only one.
34
+ def parallel_cross_entropy (
35
+ input : Tensor ,
36
+ label : Tensor ,
37
+ weight : Tensor | None = None ,
38
+ ignore_index : int = - 100 ,
39
+ reduction : _ReduceMode = 'mean' ,
40
+ soft_label : bool = False ,
41
+ axis : int = - 1 ,
42
+ use_softmax : bool = True ,
43
+ label_smoothing : float = 0.0 ,
44
+ name : str | None = None ,
45
+ ) -> Tensor :
46
+
47
+ if reduction not in ['sum' , 'mean' , 'none' ]:
48
+ raise ValueError (
49
+ "The value of 'reduction' in softmax_cross_entropy"
50
+ f"should be 'sum', 'mean' or 'none', but received { reduction } , which is not allowed."
51
+ )
52
+ if ignore_index > 0 and soft_label :
53
+ raise ValueError (
54
+ "When soft_label == True, the value of 'ignore_index' in softmax_cross_entropy"
55
+ f"should be '-100', but received { ignore_index } , which is not allowed."
56
+ )
57
+
58
+ input_dims = len (list (input .shape ))
59
+ if input_dims == 0 :
60
+ raise ValueError ('The dimension of input should be larger than zero!' )
61
+
62
+ label_dims = len (list (label .shape ))
63
+ if input_dims - 1 == label_dims :
64
+ label = paddle .unsqueeze (label , axis = axis )
65
+
66
+ if input_dims - 1 != label_dims and input_dims != label_dims :
67
+ raise ValueError (
68
+ f'Expected nput_dims - 1 = label_dims or input_dims == label_dims\
69
+ (got nput_dims{ input_dims } , label_dims{ label_dims } )'
70
+ )
71
+
72
+ if label_smoothing > 0.0 :
73
+ soft_label = True
74
+ # converting the label to one-hot encoding
75
+ # for 1d case, converting label's shape from [N] to [N, C]
76
+ # for 2d case, converting label's shape from [N, d_1, ..., d_k] to [N, d_1, ..., d_k, C]
77
+ if input_dims - 1 == label_dims :
78
+ label = paddle .squeeze (label , axis = axis )
79
+ label = paddle .nn .functional .one_hot (label , input .shape [- 1 ])
80
+
81
+ label = paddle .nn .functional .label_smooth (
82
+ label , epsilon = label_smoothing
83
+ )
84
+ label = label .astype (input .dtype )
85
+ label_dims = len (list (label .shape ))
86
+
87
+ if not soft_label :
88
+ valid_label = (
89
+ paddle .cast (label != ignore_index , dtype = label .dtype ) * label
90
+ )
91
+
92
+ if soft_label == False and is_tensor_sharded (input ):
93
+ group = tp ._HYBRID_PARALLEL_GROUP .get_model_parallel_group ()
94
+ ring_id = group .id
95
+ nranks = group .nranks
96
+ global_rank = collective ._get_global_env ().rank
97
+ rank = group .get_group_rank (global_rank )
98
+ _ , out = _C_ops .c_softmax_with_cross_entropy (
99
+ input , label , ignore_index , ring_id , rank , nranks
100
+ )
101
+ else :
102
+ from paddlenlp .utils .log import logger
103
+
104
+ logger .warning (
105
+ "Failed to replace CrossEntropyLoss with ParallelCrossEntropyLoss. Please ensure: \n "
106
+ "1. soft_label=False is set for parallel computation (current value: {}) \n "
107
+ "2. Input tensor is properly sharded (current sharding status: {}) \n " .format (
108
+ soft_label ,
109
+ input_placement ,
110
+ )
111
+ )
112
+
113
+ _ , out = _C_ops .cross_entropy_with_softmax (
114
+ input , label , soft_label , use_softmax , True , ignore_index , axis
115
+ )
116
+
117
+ if weight is not None :
118
+ # trans weight from class to sample, shape:N or [N,H,W] for 1d and 2d cases.
119
+ if soft_label :
120
+ # chajchaj:
121
+ # weight's shape is C, where C is class num.
122
+ # for 1d case: label's shape is [N,C], weight_gather's shape is N.
123
+ # for 2d case: label's shape is [N,H,W,C], weight_gather's shape is [N,H,W].
124
+ weight_gather = paddle .matmul (
125
+ x = paddle .cast (label , weight .dtype ),
126
+ y = weight ,
127
+ transpose_x = False ,
128
+ transpose_y = True ,
129
+ )
130
+ out_shape = list (out .shape )
131
+ weight_gather_reshape = reshape (weight_gather , shape = out_shape )
132
+ out = paddle .cast (out , weight_gather_reshape .dtype )
133
+
134
+ out = _C_ops .multiply (out , weight_gather_reshape )
135
+ else :
136
+ if input .shape [axis ] != weight .shape [- 1 ]:
137
+ raise ValueError (
138
+ f"input's class_dimension({ input .shape [axis ]} ) must equal to "
139
+ f"weight's class_dimension({ weight .shape [- 1 ]} ) "
140
+ "when weight is provided"
141
+ )
142
+
143
+ ignore_weight_mask = paddle .cast (
144
+ (label != ignore_index ), out .dtype
145
+ )
146
+ if (
147
+ ignore_weight_mask .ndim > 1
148
+ and ignore_weight_mask .shape [axis ] == 1
149
+ ):
150
+ # TODO: Temporarily use squeeze instead of squeeze_
151
+ ignore_weight_mask = paddle .squeeze (
152
+ ignore_weight_mask , axis
153
+ )
154
+ if axis != - 1 and axis != valid_label .ndim - 1 :
155
+ temp_perm = (
156
+ list (range (axis % valid_label .ndim ))
157
+ + list (
158
+ range (
159
+ (axis % valid_label .ndim + 1 ), valid_label .ndim
160
+ )
161
+ )
162
+ + [axis % valid_label .ndim ]
163
+ )
164
+ weight_gather = _C_ops .gather_nd (
165
+ weight , valid_label .transpose (temp_perm )
166
+ )
167
+ else :
168
+ weight_gather = _C_ops .gather_nd (weight , valid_label )
169
+ weight_gather = _C_ops .multiply (
170
+ weight_gather , ignore_weight_mask
171
+ )
172
+ input_shape = list (label .shape )
173
+ weight_gather_reshape = reshape (
174
+ weight_gather , shape = input_shape
175
+ )
176
+ out = paddle .cast (out , weight_gather_reshape .dtype )
177
+ out = _C_ops .multiply (out , weight_gather_reshape )
178
+
179
+ if reduction == "sum" :
180
+ # because of base_softmax_with_cross_entropy op's inner logic,
181
+ # in the out tensor of this op, the loss of sample with class_index==ignore_index is 0
182
+ # so, reduce_sum all directly is ok
183
+ return _C_ops .sum (out , [], None , False )
184
+ elif reduction == "mean" :
185
+ # 1. if weight==none,
186
+ # numerator: reduce_sum all loss directly is ok causeof base_softmax_with_cross_entropy's inner logic
187
+ # denominator: count sample num with class_index!=ignore_index
188
+ # 2. else
189
+ # numerator: loss's weighted sum
190
+ # denominator: cal the sum of weight where the sample's class_index!=ignore_index
191
+ if ignore_index >= 0 : # ignore label
192
+ out_sum = _C_ops .sum (out , [], None , False )
193
+ # for each label[i],set 1 or 0, according to ignore_index
194
+ # mask[i]=0, if label[i]==ignore_index
195
+ # mask[i]=1, otherwise
196
+ mask = label != ignore_index
197
+ if weight is None :
198
+ mask = paddle .cast (mask , dtype = out_sum .dtype )
199
+ count = _C_ops .sum (mask , [], None , False )
200
+ ret = out_sum / (count + (count == 0.0 ).astype (count .dtype ))
201
+ else :
202
+ mask = paddle .cast (mask , weight_gather_reshape .dtype )
203
+ weight_ignored = _C_ops .multiply (
204
+ mask , weight_gather_reshape
205
+ )
206
+ weight_sum = _C_ops .sum (weight_ignored , [], None , False )
207
+ ret = out_sum / (
208
+ weight_sum
209
+ + (weight_sum == 0.0 ).astype (weight_sum .dtype )
210
+ )
211
+ return ret
212
+ elif weight is not None :
213
+ out_sum = _C_ops .sum (out , [], None , False )
214
+ total_weight = _C_ops .sum (
215
+ weight_gather_reshape , [], None , False
216
+ )
217
+ return out_sum / (
218
+ total_weight
219
+ + (total_weight == 0.0 ).astype (total_weight .dtype )
220
+ )
221
+ else :
222
+ return _C_ops .mean_all (out )
223
+
224
+ else :
225
+ if input_dims - 1 == label_dims :
226
+ out = paddle .squeeze (out , axis = axis )
227
+ return out
228
+
229
+
230
+ # TODO: placement[1] may not be mp axis.
231
+ def is_tensor_sharded (tensor ):
232
+ if not tensor .is_dist ():
233
+ return False
234
+
235
+ placement = tensor .placements
236
+ return placement [1 ].is_shard ()
237
+
238
+
239
+ def replace_cross_entropy ():
240
+ paddle .nn .functional .cross_entropy = parallel_cross_entropy
0 commit comments