77from functools import partial
88
99from torch .utils import data
10+ from torch .cuda .amp import autocast , GradScaler
11+
1012from pathlib import Path
1113from torch .optim import Adam
1214from torchvision import transforms , utils
1517from tqdm import tqdm
1618from einops import rearrange
1719
18- try :
19- from apex import amp
20- APEX_AVAILABLE = True
21- except :
22- APEX_AVAILABLE = False
23-
2420# helpers functions
2521
2622def exists (x ):
@@ -44,13 +40,6 @@ def num_to_groups(num, divisor):
4440 arr .append (remainder )
4541 return arr
4642
47- def loss_backwards (fp16 , loss , optimizer , ** kwargs ):
48- if fp16 :
49- with amp .scale_loss (loss , optimizer ) as scaled_loss :
50- scaled_loss .backward (** kwargs )
51- else :
52- loss .backward (** kwargs )
53-
5443# small helper modules
5544
5645class EMA ():
@@ -502,7 +491,7 @@ def __init__(
502491 train_lr = 2e-5 ,
503492 train_num_steps = 100000 ,
504493 gradient_accumulate_every = 2 ,
505- fp16 = False ,
494+ amp = False ,
506495 step_start_ema = 2000 ,
507496 update_ema_every = 10 ,
508497 save_and_sample_every = 1000 ,
@@ -528,11 +517,8 @@ def __init__(
528517
529518 self .step = 0
530519
531- assert not fp16 or fp16 and APEX_AVAILABLE , 'Apex must be installed in order for mixed precision training to be turned on'
532-
533- self .fp16 = fp16
534- if fp16 :
535- (self .model , self .ema_model ), self .opt = amp .initialize ([self .model , self .ema_model ], self .opt , opt_level = 'O1' )
520+ self .amp = amp
521+ self .scaler = GradScaler (enabled = amp )
536522
537523 self .results_folder = Path (results_folder )
538524 self .results_folder .mkdir (exist_ok = True )
@@ -552,7 +538,8 @@ def save(self, milestone):
552538 data = {
553539 'step' : self .step ,
554540 'model' : self .model .state_dict (),
555- 'ema' : self .ema_model .state_dict ()
541+ 'ema' : self .ema_model .state_dict (),
542+ 'scaler' : self .scaler .state_dict ()
556543 }
557544 torch .save (data , str (self .results_folder / f'model-{ milestone } .pt' ))
558545
@@ -562,18 +549,21 @@ def load(self, milestone):
562549 self .step = data ['step' ]
563550 self .model .load_state_dict (data ['model' ])
564551 self .ema_model .load_state_dict (data ['ema' ])
552+ self .scaler .load_state_dict (data ['scaler' ])
565553
566554 def train (self ):
567- backwards = partial (loss_backwards , self .fp16 )
568-
569555 while self .step < self .train_num_steps :
570556 for i in range (self .gradient_accumulate_every ):
571557 data = next (self .dl ).cuda ()
572- loss = self .model (data )
558+
559+ with autocast (enabled = self .amp ):
560+ loss = self .model (data )
561+ self .scaler .scale (loss / self .gradient_accumulate_every ).backward ()
562+
573563 print (f'{ self .step } : { loss .item ()} ' )
574- backwards (loss / self .gradient_accumulate_every , self .opt )
575564
576- self .opt .step ()
565+ self .scaler .step (self .opt )
566+ self .scaler .update ()
577567 self .opt .zero_grad ()
578568
579569 if self .step % self .update_ema_every == 0 :
0 commit comments