From 1545155743e1e6053ac9e71bdcfca8dca5acfb53 Mon Sep 17 00:00:00 2001 From: qcloud Date: Sun, 26 Nov 2023 17:27:07 +0800 Subject: [PATCH] Fix a bug for DDIM sampling --- ldm/modules/diffusionmodules/util.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py index 637363dfe3..4cfe01a798 100644 --- a/ldm/modules/diffusionmodules/util.py +++ b/ldm/modules/diffusionmodules/util.py @@ -63,7 +63,9 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): # select alphas for computing the variance schedule alphas = alphacums[ddim_timesteps] - alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # alphas_prev should start with 1. + alphas_prev = np.append(1., alphas[:-1]) # according the the formula provided in https://arxiv.org/abs/2010.02502 sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) @@ -267,4 +269,4 @@ def forward(self, c_concat, c_crossattn): def noise_like(shape, device, repeat=False): repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) noise = lambda: torch.randn(shape, device=device) - return repeat_noise() if repeat else noise() \ No newline at end of file + return repeat_noise() if repeat else noise()