Skip to content

Commit 53b85cb

Browse files
authored
fix EMA (#11177)
1 parent 9904cfc commit 53b85cb

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

paddlenlp/trainer/utils/zero_cost_checkpoint.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,11 @@ def _load(self, resume_from_checkpoint):
10421042
if not os.path.exists(ema_path):
10431043
return
10441044

1045+
success, err_msg = self.sharding_io.check_same_strategy(resume_from_checkpoint)
1046+
if not success:
1047+
logger.info(f"Cannot load EMA because: {err_msg}")
1048+
return
1049+
10451050
logger.info(f"Loading EMA checkpoint from {resume_from_checkpoint} ...")
10461051
with device_guard("cpu"):
10471052
ema_state_dict = paddle.load(ema_path)

0 commit comments

Comments
 (0)