@@ -1178,26 +1178,24 @@ def forward(self, prediction_scores, masked_lm_labels):
1178
1178
masked_lm_labels .unsqueeze (2 ),
1179
1179
)
1180
1180
1181
- # XPU dose not support allgather mask with bool dtype, so we use LocalLayer here.
1181
+ # XPU dose not support allgather mask with bool dtype, so we use local_map here.
1182
1182
if get_env_device () == "xpu" :
1183
1183
1184
- class LocalLossLayer (paddle .distributed .LocalLayer ):
1185
- def __init__ (self , out_dist_attrs , grad_dist_attrs ):
1186
- super ().__init__ (out_dist_attrs , grad_dist_attrs )
1187
-
1188
- def forward (self , x , mask ):
1189
- masked_lm_loss = paddle .masked_select (x , mask ).astype ("float32" )
1190
- loss = paddle .mean (masked_lm_loss ).unsqueeze (0 )
1191
- return loss .unsqueeze (0 )
1184
+ def coculate_loss (x , mask ):
1185
+ masked_lm_loss = paddle .masked_select (x , mask ).astype ("float32" )
1186
+ loss = paddle .mean (masked_lm_loss ).unsqueeze (0 )
1187
+ return loss .unsqueeze (0 )
1192
1188
1193
1189
out_dist_attrs = [
1194
- ( masked_lm_loss . process_mesh , [dist .Shard (0 ), dist .Replicate ()]) ,
1190
+ [dist .Shard (0 ), dist .Replicate ()],
1195
1191
]
1196
1192
grad_dist_attrs = [
1197
- ( masked_lm_loss . process_mesh , [dist .Shard (0 ), dist .Replicate ()]) ,
1193
+ [dist .Shard (0 ), dist .Replicate ()],
1198
1194
None ,
1199
1195
]
1200
- loss_func = LocalLossLayer (out_dist_attrs , grad_dist_attrs )
1196
+ loss_func = dist .local_map (
1197
+ coculate_loss , out_dist_attrs , grad_dist_attrs , masked_lm_loss .process_mesh , reshard_inputs = True
1198
+ )
1201
1199
1202
1200
loss = loss_func (masked_lm_loss , masked_lm_loss > 0 )
1203
1201
loss = loss .mean ()
0 commit comments