@@ -59,7 +59,6 @@ def __init__(
5959 * ,
6060 image_size ,
6161 channels = 3 ,
62- cond_scale = 500 ,
6362 loss_type = 'l1' ,
6463 noise_schedule = 'linear' ,
6564 num_sample_steps = 500
@@ -76,7 +75,6 @@ def __init__(
7675
7776 # continuous noise schedule related stuff
7877
79- self .cond_scale = cond_scale # the log(snr) will be scaled by this value
8078 self .loss_type = loss_type
8179
8280 if noise_schedule == 'linear' :
@@ -108,18 +106,16 @@ def p_mean_variance(self, x, time, time_next):
108106 # todo - derive x_start from the posterior mean and do dynamic thresholding
109107 # assumed that is what is going on in Imagen
110108
111- batch = x .shape [0 ]
112- batch_time = repeat (time , ' -> b' , b = batch )
113-
114- pred_noise = self .denoise_fn (x , batch_time * self .cond_scale )
115-
116109 log_snr = self .log_snr (time )
117110 log_snr_next = self .log_snr (time_next )
118111 c = - expm1 (log_snr - log_snr_next )
119112
120113 squared_alpha , squared_alpha_next = log_snr .sigmoid (), log_snr_next .sigmoid ()
121114 squared_sigma , squared_sigma_next = (- log_snr ).sigmoid (), (- log_snr_next ).sigmoid ()
122115
116+ batch_log_snr = repeat (log_snr , ' -> b' , b = x .shape [0 ])
117+ pred_noise = self .denoise_fn (x , batch_log_snr )
118+
123119 model_mean = sqrt (squared_alpha_next / squared_alpha ) * (x - c * sqrt (squared_sigma ) * pred_noise )
124120 posterior_variance = squared_sigma_next * c
125121
@@ -151,6 +147,7 @@ def p_sample_loop(self, shape):
151147 times_next = steps [i + 1 ]
152148 img = self .p_sample (img , times , times_next )
153149
150+ img .clamp_ (- 1. , 1. )
154151 img = unnormalize_to_zero_to_one (img )
155152 return img
156153
@@ -180,7 +177,7 @@ def p_losses(self, x_start, times, noise = None):
180177
181178 x , log_snr = self .q_sample (x_start = x_start , times = times , noise = noise )
182179
183- model_out = self .denoise_fn (x , log_snr * self . cond_scale )
180+ model_out = self .denoise_fn (x , log_snr )
184181 return self .loss_fn (model_out , noise )
185182
186183 def forward (self , img , * args , ** kwargs ):
0 commit comments