1
1
import dataclasses
2
2
from collections import deque
3
+ from typing import NamedTuple
3
4
4
5
import numpy as np
5
6
import torch
6
- from torch import Generator , Tensor , device as Device , dtype as Dtype
7
7
8
8
from refiners .foundationals .latent_diffusion .solvers .solver import (
9
9
BaseSolverParams ,
10
10
ModelPredictionType ,
11
+ NoiseSchedule ,
11
12
Solver ,
12
13
TimestepSpacing ,
13
14
)
14
15
15
16
17
+ def safe_log (x : torch .Tensor , lower_bound : float = 1e-6 ) -> torch .Tensor :
18
+ """Compute the log of a tensor with a lower bound."""
19
+ return torch .log (torch .maximum (x , torch .tensor (lower_bound )))
20
+
21
+
22
+ def safe_sqrt (x : torch .Tensor ) -> torch .Tensor :
23
+ """Compute the square root of a tensor ensuring that the input is non-negative"""
24
+ return torch .sqrt (torch .maximum (x , torch .tensor (0 )))
25
+
26
+
27
+ class SolverTensors (NamedTuple ):
28
+ cumulative_scale_factors : torch .Tensor
29
+ noise_std : torch .Tensor
30
+ signal_to_noise_ratios : torch .Tensor
31
+
32
+
16
33
class DPMSolver (Solver ):
17
34
"""Diffusion probabilistic models (DPMs) solver.
18
35
@@ -37,9 +54,9 @@ def __init__(
37
54
first_inference_step : int = 0 ,
38
55
params : BaseSolverParams | None = None ,
39
56
last_step_first_order : bool = False ,
40
- device : Device | str = "cpu" ,
41
- dtype : Dtype = torch .float32 ,
42
- ):
57
+ device : torch . device | str = "cpu" ,
58
+ dtype : torch . dtype = torch .float32 ,
59
+ ) -> None :
43
60
"""Initializes a new DPM solver.
44
61
45
62
Args:
@@ -64,6 +81,14 @@ def __init__(
64
81
)
65
82
self .estimated_data = deque ([torch .tensor ([])] * 2 , maxlen = 2 )
66
83
self .last_step_first_order = last_step_first_order
84
+ sigmas = self .noise_std / self .cumulative_scale_factors
85
+ self .sigmas = self ._rescale_sigmas (sigmas , self .params .sigma_schedule )
86
+ sigma_min = sigmas [0 :1 ] # corresponds to `final_sigmas_type="sigma_min" in diffusers`
87
+ self .sigmas = torch .cat ([self .sigmas , sigma_min ])
88
+ self .cumulative_scale_factors , self .noise_std , self .signal_to_noise_ratios = self ._solver_tensors_from_sigmas (
89
+ self .sigmas
90
+ )
91
+ self .timesteps = self ._timesteps_from_sigmas (sigmas )
67
92
68
93
def rebuild (
69
94
self : "DPMSolver" ,
@@ -83,7 +108,7 @@ def rebuild(
83
108
r .last_step_first_order = self .last_step_first_order
84
109
return r
85
110
86
- def _generate_timesteps (self ) -> Tensor :
111
+ def _generate_timesteps (self ) -> torch . Tensor :
87
112
if self .params .timesteps_spacing != TimestepSpacing .CUSTOM :
88
113
return super ()._generate_timesteps ()
89
114
@@ -96,9 +121,75 @@ def _generate_timesteps(self) -> Tensor:
96
121
np_space = np .linspace (offset , max_timestep , self .num_inference_steps + 1 ).round ().astype (int )[1 :]
97
122
return torch .tensor (np_space ).flip (0 )
98
123
124
+ def _generate_sigmas (self ) -> tuple [torch .Tensor , torch .Tensor ]:
125
+ """Generate the sigmas used by the solver."""
126
+ assert self .params .sigma_schedule is not None , "sigma_schedule must be set for the DPM solver"
127
+ sigmas = self .noise_std / self .cumulative_scale_factors
128
+ sigmas = sigmas .flip (0 )
129
+ rescaled_sigmas = self ._rescale_sigmas (sigmas , self .params .sigma_schedule )
130
+ rescaled_sigmas = torch .cat ([rescaled_sigmas , torch .tensor ([0.0 ])])
131
+ return sigmas , rescaled_sigmas
132
+
133
+ def _rescale_sigmas (self , sigmas : torch .Tensor , sigma_schedule : NoiseSchedule | None ) -> torch .Tensor :
134
+ """Rescale the sigmas according to the sigma schedule."""
135
+ match sigma_schedule :
136
+ case NoiseSchedule .UNIFORM :
137
+ rho = 1
138
+ case NoiseSchedule .QUADRATIC :
139
+ rho = 2
140
+ case NoiseSchedule .KARRAS :
141
+ rho = 7
142
+ case None :
143
+ return torch .tensor (
144
+ np .interp (self .timesteps .cpu (), np .arange (0 , len (sigmas )), sigmas .cpu ()),
145
+ device = self .device ,
146
+ )
147
+
148
+ linear_schedule = torch .linspace (0 , 1 , steps = self .num_inference_steps , device = self .device )
149
+ first_sigma = sigmas [0 ]
150
+ last_sigma = sigmas [- 1 ]
151
+ rescaled_sigmas = (
152
+ first_sigma ** (1 / rho ) + linear_schedule * (last_sigma ** (1 / rho ) - first_sigma ** (1 / rho ))
153
+ ) ** rho
154
+ return rescaled_sigmas .flip (0 )
155
+
156
+ def _timesteps_from_sigmas (self , sigmas : torch .Tensor ) -> torch .Tensor :
157
+ """Generate the timesteps from the sigmas."""
158
+ log_sigmas = safe_log (sigmas )
159
+ timesteps : list [torch .Tensor ] = []
160
+ for sigma in self .sigmas [:- 1 ]:
161
+ log_sigma = safe_log (sigma )
162
+ distance_matrix = log_sigma - log_sigmas .unsqueeze (1 )
163
+
164
+ # Determine the range of sigma indices
165
+ low_indices = (distance_matrix >= 0 ).cumsum (dim = 0 ).argmax (dim = 0 ).clip (max = sigmas .size (0 ) - 2 )
166
+ high_indices = low_indices + 1
167
+
168
+ low_log_sigma = log_sigmas [low_indices ]
169
+ high_log_sigma = log_sigmas [high_indices ]
170
+
171
+ # Interpolate sigma values
172
+ interpolation_weights = (low_log_sigma - log_sigma ) / (low_log_sigma - high_log_sigma )
173
+ interpolation_weights = torch .clamp (interpolation_weights , 0 , 1 )
174
+ timestep = (1 - interpolation_weights ) * low_indices + interpolation_weights * high_indices
175
+ timesteps .append (timestep )
176
+
177
+ return torch .cat (timesteps ).round ()
178
+
179
+ def _solver_tensors_from_sigmas (self , sigmas : torch .Tensor ) -> SolverTensors :
180
+ """Generate the tensors from the sigmas."""
181
+ cumulative_scale_factors = 1 / torch .sqrt (sigmas ** 2 + 1 )
182
+ noise_std = sigmas * cumulative_scale_factors
183
+ signal_to_noise_ratios = safe_log (cumulative_scale_factors ) - safe_log (noise_std )
184
+ return SolverTensors (
185
+ cumulative_scale_factors = cumulative_scale_factors ,
186
+ noise_std = noise_std ,
187
+ signal_to_noise_ratios = signal_to_noise_ratios ,
188
+ )
189
+
99
190
def dpm_solver_first_order_update (
100
- self , x : Tensor , noise : Tensor , step : int , sde_noise : Tensor | None = None
101
- ) -> Tensor :
191
+ self , x : torch . Tensor , noise : torch . Tensor , step : int , sde_noise : torch . Tensor | None = None
192
+ ) -> torch . Tensor :
102
193
"""Applies a first-order backward Euler update to the input data `x`.
103
194
104
195
Args:
@@ -109,32 +200,29 @@ def dpm_solver_first_order_update(
109
200
Returns:
110
201
The denoised version of the input data `x`.
111
202
"""
112
- current_timestep = self .timesteps [step ]
113
- previous_timestep = self .timesteps [step + 1 ] if step < self . num_inference_steps - 1 else torch . tensor ([ 0 ])
203
+ current_ratio = self .signal_to_noise_ratios [step ]
204
+ next_ratio = self .signal_to_noise_ratios [step + 1 ]
114
205
115
- previous_ratio = self .signal_to_noise_ratios [previous_timestep ]
116
- current_ratio = self .signal_to_noise_ratios [current_timestep ]
206
+ next_scale_factor = self .cumulative_scale_factors [step + 1 ]
117
207
118
- previous_scale_factor = self .cumulative_scale_factors [previous_timestep ]
208
+ next_noise_std = self .noise_std [step + 1 ]
209
+ current_noise_std = self .noise_std [step ]
119
210
120
- previous_noise_std = self .noise_std [previous_timestep ]
121
- current_noise_std = self .noise_std [current_timestep ]
122
-
123
- ratio_delta = current_ratio - previous_ratio
211
+ ratio_delta = current_ratio - next_ratio
124
212
125
213
if sde_noise is None :
126
- return (previous_noise_std / current_noise_std ) * x + (
127
- 1.0 - torch .exp (ratio_delta )
128
- ) * previous_scale_factor * noise
214
+ return (next_noise_std / current_noise_std ) * x + (1.0 - torch .exp (ratio_delta )) * next_scale_factor * noise
129
215
130
216
factor = 1.0 - torch .exp (2.0 * ratio_delta )
131
217
return (
132
- (previous_noise_std / current_noise_std ) * torch .exp (ratio_delta ) * x
133
- + previous_scale_factor * factor * noise
134
- + previous_noise_std * torch . sqrt (factor ) * sde_noise
218
+ (next_noise_std / current_noise_std ) * torch .exp (ratio_delta ) * x
219
+ + next_scale_factor * factor * noise
220
+ + next_noise_std * safe_sqrt (factor ) * sde_noise
135
221
)
136
222
137
- def multistep_dpm_solver_second_order_update (self , x : Tensor , step : int , sde_noise : Tensor | None = None ) -> Tensor :
223
+ def multistep_dpm_solver_second_order_update (
224
+ self , x : torch .Tensor , step : int , sde_noise : torch .Tensor | None = None
225
+ ) -> torch .Tensor :
138
226
"""Applies a second-order backward Euler update to the input data `x`.
139
227
140
228
Args:
@@ -144,43 +232,41 @@ def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int, sde_noi
144
232
Returns:
145
233
The denoised version of the input data `x`.
146
234
"""
147
- previous_timestep = self .timesteps [step + 1 ] if step < self .num_inference_steps - 1 else torch .tensor ([0 ])
148
- current_timestep = self .timesteps [step ]
149
- next_timestep = self .timesteps [step - 1 ]
150
-
151
235
current_data_estimation = self .estimated_data [- 1 ]
152
- next_data_estimation = self .estimated_data [- 2 ]
236
+ previous_data_estimation = self .estimated_data [- 2 ]
153
237
154
- previous_ratio = self .signal_to_noise_ratios [previous_timestep ]
155
- current_ratio = self .signal_to_noise_ratios [current_timestep ]
156
- next_ratio = self .signal_to_noise_ratios [next_timestep ]
238
+ next_ratio = self .signal_to_noise_ratios [step + 1 ]
239
+ current_ratio = self .signal_to_noise_ratios [step ]
240
+ previous_ratio = self .signal_to_noise_ratios [step - 1 ]
157
241
158
- previous_scale_factor = self .cumulative_scale_factors [previous_timestep ]
159
- previous_noise_std = self .noise_std [previous_timestep ]
160
- current_noise_std = self .noise_std [current_timestep ]
242
+ next_scale_factor = self .cumulative_scale_factors [step + 1 ]
243
+ next_noise_std = self .noise_std [step + 1 ]
244
+ current_noise_std = self .noise_std [step ]
161
245
162
- estimation_delta = (current_data_estimation - next_data_estimation ) / (
163
- (current_ratio - next_ratio ) / (previous_ratio - current_ratio )
246
+ estimation_delta = (current_data_estimation - previous_data_estimation ) / (
247
+ (current_ratio - previous_ratio ) / (next_ratio - current_ratio )
164
248
)
165
- ratio_delta = current_ratio - previous_ratio
249
+ ratio_delta = current_ratio - next_ratio
166
250
167
251
if sde_noise is None :
168
252
factor = 1.0 - torch .exp (ratio_delta )
169
253
return (
170
- (previous_noise_std / current_noise_std ) * x
171
- + previous_scale_factor * factor * current_data_estimation
172
- + 0.5 * previous_scale_factor * factor * estimation_delta
254
+ (next_noise_std / current_noise_std ) * x
255
+ + next_scale_factor * factor * current_data_estimation
256
+ + 0.5 * next_scale_factor * factor * estimation_delta
173
257
)
174
258
175
259
factor = 1.0 - torch .exp (2.0 * ratio_delta )
176
260
return (
177
- (previous_noise_std / current_noise_std ) * torch .exp (ratio_delta ) * x
178
- + previous_scale_factor * factor * current_data_estimation
179
- + 0.5 * previous_scale_factor * factor * estimation_delta
180
- + previous_noise_std * torch . sqrt (factor ) * sde_noise
261
+ (next_noise_std / current_noise_std ) * torch .exp (ratio_delta ) * x
262
+ + next_scale_factor * factor * current_data_estimation
263
+ + 0.5 * next_scale_factor * factor * estimation_delta
264
+ + next_noise_std * safe_sqrt (factor ) * sde_noise
181
265
)
182
266
183
- def __call__ (self , x : Tensor , predicted_noise : Tensor , step : int , generator : Generator | None = None ) -> Tensor :
267
+ def __call__ (
268
+ self , x : torch .Tensor , predicted_noise : torch .Tensor , step : int , generator : torch .Generator | None = None
269
+ ) -> torch .Tensor :
184
270
"""Apply one step of the backward diffusion process.
185
271
186
272
Note:
@@ -199,9 +285,8 @@ def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Gen
199
285
"""
200
286
assert self .first_inference_step <= step < self .num_inference_steps , "invalid step {step}"
201
287
202
- current_timestep = self .timesteps [step ]
203
- scale_factor = self .cumulative_scale_factors [current_timestep ]
204
- noise_ratio = self .noise_std [current_timestep ]
288
+ scale_factor = self .cumulative_scale_factors [step ]
289
+ noise_ratio = self .noise_std [step ]
205
290
estimated_denoised_data = (x - noise_ratio * predicted_noise ) / scale_factor
206
291
self .estimated_data .append (estimated_denoised_data )
207
292
variance = self .params .sde_variance
0 commit comments