You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm following link
It works well but cannot resume from the checkpoints.
Here is my short version.
class ExpotentialMovingAveraging(Callback):
def __init__(self, decay: float = 0.999):
self.decay = decay
self.original_state_dict = {}
self.ema_state_dict = {}
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
# copy the model before moving it to accelerator device.
if self.ema_state_dict == {}:
with pl_module._prevent_trainer_and_dataloaders_deepcopy():
self.ema_state_dict = deepcopy(pl_module.state_dict())
cuda_available = True if torch.cuda.is_available() else False
if cuda_available:
self.ema_state_dict = {k: v.pin_memory() for k, v in self.ema_state_dict.items()}
def on_fit_start(self, trainer: "pl.Trainer", pl_module: pl.LightningModule) -> None:
self.ema_state_dict = {k: v.to(device=pl_module.device) for k, v in self.ema_state_dict.items()}
def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: pl.LightningModule, *args, **kwargs) -> None:
# Update EMA weights
with torch.no_grad():
for key, value in pl_module.state_dict().items():
ema_value = self.ema_state_dict[key]
ema_value.copy_(self.decay * ema_value + (1. - self.decay) * value, non_blocking=True)
def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
self.original_state_dict = deepcopy(pl_module.state_dict())
pl_module.load_state_dict(self.ema_state_dict, strict=True)
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
# Replace EMA weights with training weights
pl_module.load_state_dict(self.original_state_dict, strict=True)
def on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
self.original_state_dict = deepcopy(pl_module.state_dict())
pl_module.load_state_dict(self.ema_state_dict, strict=True)
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
# Replace EMA weights with training weights
pl_module.load_state_dict(self.original_state_dict, strict=True)
def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
self.original_state_dict = deepcopy(pl_module.state_dict())
pl_module.load_state_dict(self.ema_state_dict, strict=True)
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
# Replace EMA weights with training weights
pl_module.load_state_dict(self.original_state_dict, strict=True)
def state_dict(self) -> Dict[str, Any]:
return {'ema_state_dict': self.ema_state_dict}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.ema_state_dict = state_dict['ema_state_dict']
I hope to resume from the checkpoints by loading original_state_dict and ema_state_dict from the state_dict of checkpoint and callback, respectively.
The problem occurs when using trainer.validate.
And, the __init__ resets the self.original_state_dict and self.ema_state_dict as {} instead loading from the checkpoint files.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm following link
It works well but cannot resume from the checkpoints.
Here is my short version.
I hope to resume from the checkpoints by loading
original_state_dict
andema_state_dict
from thestate_dict
of checkpoint and callback, respectively.The problem occurs when using
trainer.validate
.And, the
__init__
resets theself.original_state_dict
andself.ema_state_dict
as{}
instead loading from the checkpoint files.Beta Was this translation helpful? Give feedback.
All reactions