Currently, when resuming from a previous run that utilizes a learning rate scheduler, we do NOT load a state dict from the scheduler.
But wait, does that mean our code is BROKEN?
Actually, b/c we save the state dict for the optimizer that initializes the learning rate scheduler AND we track the global steps taken which informs the learning rate scheduler where in it's schedule it is, the behavior is largely the same. However, if you were to inspect the learning rate scheduler state dict before and after, they will not be the same due to the parameter _step_count not being updated.
Does this matter?
_step_count looks like it's mostly there for debugging purposes. It's only actually used in one standard PyTorch learning rate scheduler: CosineAnnealingLR here. I'm opening this issue b/c even though our training code works fine for 99% of use cases, we really should utilize the state dict to cover all the cases.
Goal
Update our recipes to save the lr_scheduler.state_dict() to the intermediate state dict and upon resuming from checkpoint, we should call lr_scheduler.load_state_dict() on the learning rate scheduler.
Currently, when resuming from a previous run that utilizes a learning rate scheduler, we do NOT load a state dict from the scheduler.
But wait, does that mean our code is BROKEN?
Actually, b/c we save the state dict for the optimizer that initializes the learning rate scheduler AND we track the global steps taken which informs the learning rate scheduler where in it's schedule it is, the behavior is largely the same. However, if you were to inspect the learning rate scheduler state dict before and after, they will not be the same due to the parameter
_step_countnot being updated.Does this matter?
_step_countlooks like it's mostly there for debugging purposes. It's only actually used in one standard PyTorch learning rate scheduler:CosineAnnealingLRhere. I'm opening this issue b/c even though our training code works fine for 99% of use cases, we really should utilize the state dict to cover all the cases.Goal
Update our recipes to save the
lr_scheduler.state_dict()to the intermediate state dict and upon resuming from checkpoint, we should calllr_scheduler.load_state_dict()on the learning rate scheduler.