Skip to content

Commit 0b8cdb4

Browse files
committed
remove outdated apex in favor of native pytorch AMP
1 parent e504e0e commit 0b8cdb4

File tree

3 files changed

+17
-27
lines changed

3 files changed

+17
-27
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ trainer = Trainer(
6868
train_num_steps = 700000, # total training steps
6969
gradient_accumulate_every = 2, # gradient accumulation steps
7070
ema_decay = 0.995, # exponential moving average decay
71-
fp16 = True # turn on mixed precision training with apex
71+
amp = True # turn on mixed precision
7272
)
7373

7474
trainer.train()

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from functools import partial
88

99
from torch.utils import data
10+
from torch.cuda.amp import autocast, GradScaler
11+
1012
from pathlib import Path
1113
from torch.optim import Adam
1214
from torchvision import transforms, utils
@@ -15,12 +17,6 @@
1517
from tqdm import tqdm
1618
from einops import rearrange
1719

18-
try:
19-
from apex import amp
20-
APEX_AVAILABLE = True
21-
except:
22-
APEX_AVAILABLE = False
23-
2420
# helpers functions
2521

2622
def exists(x):
@@ -44,13 +40,6 @@ def num_to_groups(num, divisor):
4440
arr.append(remainder)
4541
return arr
4642

47-
def loss_backwards(fp16, loss, optimizer, **kwargs):
48-
if fp16:
49-
with amp.scale_loss(loss, optimizer) as scaled_loss:
50-
scaled_loss.backward(**kwargs)
51-
else:
52-
loss.backward(**kwargs)
53-
5443
# small helper modules
5544

5645
class EMA():
@@ -502,7 +491,7 @@ def __init__(
502491
train_lr = 2e-5,
503492
train_num_steps = 100000,
504493
gradient_accumulate_every = 2,
505-
fp16 = False,
494+
amp = False,
506495
step_start_ema = 2000,
507496
update_ema_every = 10,
508497
save_and_sample_every = 1000,
@@ -528,11 +517,8 @@ def __init__(
528517

529518
self.step = 0
530519

531-
assert not fp16 or fp16 and APEX_AVAILABLE, 'Apex must be installed in order for mixed precision training to be turned on'
532-
533-
self.fp16 = fp16
534-
if fp16:
535-
(self.model, self.ema_model), self.opt = amp.initialize([self.model, self.ema_model], self.opt, opt_level='O1')
520+
self.amp = amp
521+
self.scaler = GradScaler(enabled = amp)
536522

537523
self.results_folder = Path(results_folder)
538524
self.results_folder.mkdir(exist_ok = True)
@@ -552,7 +538,8 @@ def save(self, milestone):
552538
data = {
553539
'step': self.step,
554540
'model': self.model.state_dict(),
555-
'ema': self.ema_model.state_dict()
541+
'ema': self.ema_model.state_dict(),
542+
'scaler': self.scaler.state_dict()
556543
}
557544
torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))
558545

@@ -562,18 +549,21 @@ def load(self, milestone):
562549
self.step = data['step']
563550
self.model.load_state_dict(data['model'])
564551
self.ema_model.load_state_dict(data['ema'])
552+
self.scaler.load_state_dict(data['scaler'])
565553

566554
def train(self):
567-
backwards = partial(loss_backwards, self.fp16)
568-
569555
while self.step < self.train_num_steps:
570556
for i in range(self.gradient_accumulate_every):
571557
data = next(self.dl).cuda()
572-
loss = self.model(data)
558+
559+
with autocast(enabled = self.amp):
560+
loss = self.model(data)
561+
self.scaler.scale(loss / self.gradient_accumulate_every).backward()
562+
573563
print(f'{self.step}: {loss.item()}')
574-
backwards(loss / self.gradient_accumulate_every, self.opt)
575564

576-
self.opt.step()
565+
self.scaler.step(self.opt)
566+
self.scaler.update()
577567
self.opt.zero_grad()
578568

579569
if self.step % self.update_ema_every == 0:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'denoising-diffusion-pytorch',
55
packages = find_packages(),
6-
version = '0.8.1',
6+
version = '0.9.0',
77
license='MIT',
88
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)