|
| 1 | +from math import sqrt |
| 2 | +import torch |
| 3 | +from torch import nn, einsum |
| 4 | +import torch.nn.functional as F |
| 5 | + |
| 6 | +from tqdm import tqdm |
| 7 | +from einops import rearrange, repeat, reduce |
| 8 | + |
| 9 | +# helpers |
| 10 | + |
| 11 | +def exists(val): |
| 12 | + return val is not None |
| 13 | + |
| 14 | +def default(val, d): |
| 15 | + if exists(val): |
| 16 | + return val |
| 17 | + return d() if callable(d) else d |
| 18 | + |
| 19 | +# tensor helpers |
| 20 | + |
| 21 | +def log(t, eps = 1e-20): |
| 22 | + return torch.log(t.clamp(min = eps)) |
| 23 | + |
| 24 | +# normalization functions |
| 25 | + |
| 26 | +def normalize_to_neg_one_to_one(img): |
| 27 | + return img * 2 - 1 |
| 28 | + |
| 29 | +def unnormalize_to_zero_to_one(t): |
| 30 | + return (t + 1) * 0.5 |
| 31 | + |
| 32 | +# main class |
| 33 | + |
| 34 | +class ElucidatedDiffusion(nn.Module): |
| 35 | + def __init__( |
| 36 | + self, |
| 37 | + net, |
| 38 | + *, |
| 39 | + image_size, |
| 40 | + channels = 3, |
| 41 | + num_sample_steps = 32, # number of sampling steps |
| 42 | + sigma_min = 0.002, # min noise level |
| 43 | + sigma_max = 80, # max noise level |
| 44 | + sigma_data = 0.5, # standard deviation of data distribution |
| 45 | + rho = 7, # controls the sampling schedule |
| 46 | + P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training |
| 47 | + P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training |
| 48 | + S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper |
| 49 | + S_tmin = 0.05, |
| 50 | + S_tmax = 50, |
| 51 | + S_noise = 1.003, |
| 52 | + ): |
| 53 | + super().__init__() |
| 54 | + assert net.learned_sinusoidal_cond |
| 55 | + |
| 56 | + self.net = net |
| 57 | + |
| 58 | + # image dimensions |
| 59 | + |
| 60 | + self.channels = channels |
| 61 | + self.image_size = image_size |
| 62 | + |
| 63 | + # parameters |
| 64 | + |
| 65 | + self.sigma_min = sigma_min |
| 66 | + self.sigma_max = sigma_max |
| 67 | + self.sigma_data = sigma_data |
| 68 | + |
| 69 | + self.rho = rho |
| 70 | + |
| 71 | + self.P_mean = P_mean |
| 72 | + self.P_std = P_std |
| 73 | + |
| 74 | + self.num_sample_steps = num_sample_steps # otherwise known as N in the paper |
| 75 | + |
| 76 | + self.S_churn = S_churn |
| 77 | + self.S_tmin = S_tmin |
| 78 | + self.S_tmax = S_tmax |
| 79 | + self.S_noise = S_noise |
| 80 | + |
| 81 | + @property |
| 82 | + def device(self): |
| 83 | + return next(self.net.parameters()).device |
| 84 | + |
| 85 | + # derived preconditioning params - Table 1 |
| 86 | + |
| 87 | + def c_skip(self, sigma): |
| 88 | + return (self.sigma_data ** 2) / (sigma ** 2 + self.sigma_data ** 2) |
| 89 | + |
| 90 | + def c_out(self, sigma): |
| 91 | + return sigma * self.sigma_data * (self.sigma_data ** 2 + sigma ** 2) ** -0.5 |
| 92 | + |
| 93 | + def c_in(self, sigma): |
| 94 | + return 1 * (sigma ** 2 + self.sigma_data ** 2) ** -0.5 |
| 95 | + |
| 96 | + def c_noise(self, sigma): |
| 97 | + return log(sigma) * 0.25 |
| 98 | + |
| 99 | + # noise distribution |
| 100 | + |
| 101 | + def noise_distribution(self, batch_size): |
| 102 | + return (self.P_mean + self.P_std * torch.randn((batch_size,), device = self.device)).exp() |
| 103 | + |
| 104 | + def loss_weight(self, sigma): |
| 105 | + return (sigma ** 2 + self.sigma_data ** 2) * (sigma * self.sigma_data) ** -2 |
| 106 | + |
| 107 | + # sample schedule |
| 108 | + # equation (5) in the paper |
| 109 | + |
| 110 | + def sample_schedule(self, num_sample_steps = None): |
| 111 | + num_sample_steps = default(num_sample_steps, self.num_sample_steps) |
| 112 | + |
| 113 | + N = num_sample_steps |
| 114 | + inv_rho = 1 / self.rho |
| 115 | + |
| 116 | + steps = torch.arange(num_sample_steps, device = self.device, dtype = torch.float32) |
| 117 | + sigmas = (self.sigma_max ** inv_rho + steps / (N - 1) * (self.sigma_min ** inv_rho - self.sigma_max ** inv_rho)) ** self.rho |
| 118 | + |
| 119 | + sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0. |
| 120 | + return sigmas |
| 121 | + |
| 122 | + # preconditioned network output |
| 123 | + # equation (7) in the paper |
| 124 | + |
| 125 | + def preconditioned_network_forward(self, noised_images, sigma): |
| 126 | + batch, device = noised_images.shape[0], noised_images.device |
| 127 | + |
| 128 | + if isinstance(sigma, float): |
| 129 | + sigma = torch.full((batch,), sigma, device = device) |
| 130 | + |
| 131 | + padded_sigma = rearrange(sigma, 'b -> b 1 1 1') |
| 132 | + |
| 133 | + net_out = self.net( |
| 134 | + self.c_in(padded_sigma) * noised_images, |
| 135 | + self.c_noise(sigma) |
| 136 | + ) |
| 137 | + |
| 138 | + return self.c_skip(padded_sigma) * noised_images + self.c_out(padded_sigma) * net_out |
| 139 | + |
| 140 | + # sampling |
| 141 | + |
| 142 | + @torch.no_grad() |
| 143 | + def sample(self, batch_size = 16): |
| 144 | + shape = (batch_size, self.channels, self.image_size, self.image_size) |
| 145 | + |
| 146 | + # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma |
| 147 | + |
| 148 | + sigmas = self.sample_schedule() |
| 149 | + |
| 150 | + gammas = torch.where( |
| 151 | + (sigmas >= self.S_tmin) & (sigmas <= self.S_tmax), |
| 152 | + min(self.S_churn / self.num_sample_steps, sqrt(2) - 1), |
| 153 | + 0. |
| 154 | + ) |
| 155 | + |
| 156 | + sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1])) |
| 157 | + |
| 158 | + # images is noise at the beginning |
| 159 | + |
| 160 | + init_sigma = sigmas[0] |
| 161 | + |
| 162 | + images = init_sigma * torch.randn(shape, device = self.device) |
| 163 | + |
| 164 | + # gradually denoise |
| 165 | + |
| 166 | + for sigma, sigma_next, gamma in tqdm(sigmas_and_gammas, desc = 'sampling time step'): |
| 167 | + sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma)) |
| 168 | + |
| 169 | + eps = gamma * torch.randn(shape, device = self.device) |
| 170 | + |
| 171 | + sigma_hat = sigma + gamma * sigma |
| 172 | + images_hat = images + sqrt(sigma_hat ** 2 - sigma ** 2) * eps |
| 173 | + |
| 174 | + model_output = self.preconditioned_network_forward(images_hat, sigma_hat) |
| 175 | + denoised_over_sigma = (images_hat - model_output) / sigma_hat |
| 176 | + |
| 177 | + images_next = images_hat + (sigma_next - sigma_hat) * denoised_over_sigma |
| 178 | + |
| 179 | + # second order correction, if not the last timestep |
| 180 | + |
| 181 | + if sigma_next != 0: |
| 182 | + model_output_next = self.preconditioned_network_forward(images_next, sigma_next) |
| 183 | + denoised_prime_over_sigma = (images_next - model_output_next) / sigma_next |
| 184 | + images_next = images_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma) |
| 185 | + |
| 186 | + images = images_next |
| 187 | + |
| 188 | + images = images.clamp(-1., 1.) |
| 189 | + return unnormalize_to_zero_to_one(images) |
| 190 | + |
| 191 | + # training |
| 192 | + |
| 193 | + def forward(self, images): |
| 194 | + batch_size, c, h, w, device, image_size, channels = *images.shape, images.device, self.image_size, self.channels |
| 195 | + |
| 196 | + assert h == image_size and w == image_size, f'height and width of image must be {image_size}' |
| 197 | + assert c == channels, 'mismatch of image channels' |
| 198 | + |
| 199 | + images = normalize_to_neg_one_to_one(images) |
| 200 | + |
| 201 | + sigmas = self.noise_distribution(batch_size) |
| 202 | + padded_sigmas = rearrange(sigmas, 'b -> b 1 1 1') |
| 203 | + |
| 204 | + noise = torch.randn_like(images) |
| 205 | + |
| 206 | + noised_images = images + padded_sigmas * noise # alphas are 1. in the paper |
| 207 | + |
| 208 | + denoised = self.preconditioned_network_forward(noised_images, sigmas) |
| 209 | + |
| 210 | + losses = F.mse_loss(denoised, images, reduction = 'none') |
| 211 | + losses = reduce(losses, 'b ... -> b', 'mean') |
| 212 | + |
| 213 | + losses = losses * self.loss_weight(sigmas) |
| 214 | + |
| 215 | + return losses.mean() |
0 commit comments