Skip to content

Commit 6ede865

Browse files
authored
fix zcc ema GPU alloc bug (#11165)
1 parent 29f97bb commit 6ede865

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

paddlenlp/trainer/utils/zero_cost_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def ema_accumulate(self, global_step, loss, zcc_ema_loss_threshold):
190190
self.ema_buffer = self.ema_coef * self.ema_buffer + (1 - self.ema_coef) * cpu_master_weights
191191
for index, ema_buf in self.ema_buffer_model_params.items():
192192
_, cpu_buf = self.param_fusion_storage_helper.inited_buffers[index]
193-
updated_ema = self.ema_coef * ema_buf + (1 - self.ema_coef) * cpu_buf
193+
updated_ema = self.ema_coef * ema_buf + (1 - self.ema_coef) * cpu_buf.cpu()
194194
self.ema_buffer_model_params[index] = updated_ema
195195
logger.info(
196196
f"[ZCC EMA] accmulating, buffer type:{self.ema_buffer.place} {self.ema_buffer.dtype}, done"

0 commit comments

Comments
 (0)