Skip to content

Commit c3e25ca

Browse files
authored
[AutoParallel] Using local_map replace LocalLayer (#10309)
1 parent 99e1c10 commit c3e25ca

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

paddlenlp/transformers/llama/modeling_auto.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,26 +1178,24 @@ def forward(self, prediction_scores, masked_lm_labels):
11781178
masked_lm_labels.unsqueeze(2),
11791179
)
11801180

1181-
# XPU dose not support allgather mask with bool dtype, so we use LocalLayer here.
1181+
# XPU dose not support allgather mask with bool dtype, so we use local_map here.
11821182
if get_env_device() == "xpu":
11831183

1184-
class LocalLossLayer(paddle.distributed.LocalLayer):
1185-
def __init__(self, out_dist_attrs, grad_dist_attrs):
1186-
super().__init__(out_dist_attrs, grad_dist_attrs)
1187-
1188-
def forward(self, x, mask):
1189-
masked_lm_loss = paddle.masked_select(x, mask).astype("float32")
1190-
loss = paddle.mean(masked_lm_loss).unsqueeze(0)
1191-
return loss.unsqueeze(0)
1184+
def coculate_loss(x, mask):
1185+
masked_lm_loss = paddle.masked_select(x, mask).astype("float32")
1186+
loss = paddle.mean(masked_lm_loss).unsqueeze(0)
1187+
return loss.unsqueeze(0)
11921188

11931189
out_dist_attrs = [
1194-
(masked_lm_loss.process_mesh, [dist.Shard(0), dist.Replicate()]),
1190+
[dist.Shard(0), dist.Replicate()],
11951191
]
11961192
grad_dist_attrs = [
1197-
(masked_lm_loss.process_mesh, [dist.Shard(0), dist.Replicate()]),
1193+
[dist.Shard(0), dist.Replicate()],
11981194
None,
11991195
]
1200-
loss_func = LocalLossLayer(out_dist_attrs, grad_dist_attrs)
1196+
loss_func = dist.local_map(
1197+
coculate_loss, out_dist_attrs, grad_dist_attrs, masked_lm_loss.process_mesh, reshard_inputs=True
1198+
)
12011199

12021200
loss = loss_func(masked_lm_loss, masked_lm_loss > 0)
12031201
loss = loss.mean()

0 commit comments

Comments
 (0)