Skip to content

Commit f5404e4

Browse files
add karras sigmas to dpm solver
1 parent 5aef140 commit f5404e4

File tree

6 files changed

+215
-88
lines changed

6 files changed

+215
-88
lines changed

src/refiners/foundationals/latent_diffusion/solvers/dpm.py

Lines changed: 133 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,35 @@
11
import dataclasses
22
from collections import deque
3+
from typing import NamedTuple
34

45
import numpy as np
56
import torch
6-
from torch import Generator, Tensor, device as Device, dtype as Dtype
77

88
from refiners.foundationals.latent_diffusion.solvers.solver import (
99
BaseSolverParams,
1010
ModelPredictionType,
11+
NoiseSchedule,
1112
Solver,
1213
TimestepSpacing,
1314
)
1415

1516

17+
def safe_log(x: torch.Tensor, lower_bound: float = 1e-6) -> torch.Tensor:
18+
"""Compute the log of a tensor with a lower bound."""
19+
return torch.log(torch.maximum(x, torch.tensor(lower_bound)))
20+
21+
22+
def safe_sqrt(x: torch.Tensor) -> torch.Tensor:
23+
"""Compute the square root of a tensor ensuring that the input is non-negative"""
24+
return torch.sqrt(torch.maximum(x, torch.tensor(0)))
25+
26+
27+
class SolverTensors(NamedTuple):
28+
cumulative_scale_factors: torch.Tensor
29+
noise_std: torch.Tensor
30+
signal_to_noise_ratios: torch.Tensor
31+
32+
1633
class DPMSolver(Solver):
1734
"""Diffusion probabilistic models (DPMs) solver.
1835
@@ -37,9 +54,9 @@ def __init__(
3754
first_inference_step: int = 0,
3855
params: BaseSolverParams | None = None,
3956
last_step_first_order: bool = False,
40-
device: Device | str = "cpu",
41-
dtype: Dtype = torch.float32,
42-
):
57+
device: torch.device | str = "cpu",
58+
dtype: torch.dtype = torch.float32,
59+
) -> None:
4360
"""Initializes a new DPM solver.
4461
4562
Args:
@@ -64,6 +81,14 @@ def __init__(
6481
)
6582
self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2)
6683
self.last_step_first_order = last_step_first_order
84+
sigmas = self.noise_std / self.cumulative_scale_factors
85+
self.sigmas = self._rescale_sigmas(sigmas, self.params.sigma_schedule)
86+
sigma_min = sigmas[0:1] # corresponds to `final_sigmas_type="sigma_min" in diffusers`
87+
self.sigmas = torch.cat([self.sigmas, sigma_min])
88+
self.cumulative_scale_factors, self.noise_std, self.signal_to_noise_ratios = self._solver_tensors_from_sigmas(
89+
self.sigmas
90+
)
91+
self.timesteps = self._timesteps_from_sigmas(sigmas)
6792

6893
def rebuild(
6994
self: "DPMSolver",
@@ -83,7 +108,7 @@ def rebuild(
83108
r.last_step_first_order = self.last_step_first_order
84109
return r
85110

86-
def _generate_timesteps(self) -> Tensor:
111+
def _generate_timesteps(self) -> torch.Tensor:
87112
if self.params.timesteps_spacing != TimestepSpacing.CUSTOM:
88113
return super()._generate_timesteps()
89114

@@ -96,9 +121,75 @@ def _generate_timesteps(self) -> Tensor:
96121
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
97122
return torch.tensor(np_space).flip(0)
98123

124+
def _generate_sigmas(self) -> tuple[torch.Tensor, torch.Tensor]:
125+
"""Generate the sigmas used by the solver."""
126+
assert self.params.sigma_schedule is not None, "sigma_schedule must be set for the DPM solver"
127+
sigmas = self.noise_std / self.cumulative_scale_factors
128+
sigmas = sigmas.flip(0)
129+
rescaled_sigmas = self._rescale_sigmas(sigmas, self.params.sigma_schedule)
130+
rescaled_sigmas = torch.cat([rescaled_sigmas, torch.tensor([0.0])])
131+
return sigmas, rescaled_sigmas
132+
133+
def _rescale_sigmas(self, sigmas: torch.Tensor, sigma_schedule: NoiseSchedule | None) -> torch.Tensor:
134+
"""Rescale the sigmas according to the sigma schedule."""
135+
match sigma_schedule:
136+
case NoiseSchedule.UNIFORM:
137+
rho = 1
138+
case NoiseSchedule.QUADRATIC:
139+
rho = 2
140+
case NoiseSchedule.KARRAS:
141+
rho = 7
142+
case None:
143+
return torch.tensor(
144+
np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()),
145+
device=self.device,
146+
)
147+
148+
linear_schedule = torch.linspace(0, 1, steps=self.num_inference_steps, device=self.device)
149+
first_sigma = sigmas[0]
150+
last_sigma = sigmas[-1]
151+
rescaled_sigmas = (
152+
first_sigma ** (1 / rho) + linear_schedule * (last_sigma ** (1 / rho) - first_sigma ** (1 / rho))
153+
) ** rho
154+
return rescaled_sigmas.flip(0)
155+
156+
def _timesteps_from_sigmas(self, sigmas: torch.Tensor) -> torch.Tensor:
157+
"""Generate the timesteps from the sigmas."""
158+
log_sigmas = safe_log(sigmas)
159+
timesteps: list[torch.Tensor] = []
160+
for sigma in self.sigmas[:-1]:
161+
log_sigma = safe_log(sigma)
162+
distance_matrix = log_sigma - log_sigmas.unsqueeze(1)
163+
164+
# Determine the range of sigma indices
165+
low_indices = (distance_matrix >= 0).cumsum(dim=0).argmax(dim=0).clip(max=sigmas.size(0) - 2)
166+
high_indices = low_indices + 1
167+
168+
low_log_sigma = log_sigmas[low_indices]
169+
high_log_sigma = log_sigmas[high_indices]
170+
171+
# Interpolate sigma values
172+
interpolation_weights = (low_log_sigma - log_sigma) / (low_log_sigma - high_log_sigma)
173+
interpolation_weights = torch.clamp(interpolation_weights, 0, 1)
174+
timestep = (1 - interpolation_weights) * low_indices + interpolation_weights * high_indices
175+
timesteps.append(timestep)
176+
177+
return torch.cat(timesteps).round()
178+
179+
def _solver_tensors_from_sigmas(self, sigmas: torch.Tensor) -> SolverTensors:
180+
"""Generate the tensors from the sigmas."""
181+
cumulative_scale_factors = 1 / torch.sqrt(sigmas**2 + 1)
182+
noise_std = sigmas * cumulative_scale_factors
183+
signal_to_noise_ratios = safe_log(cumulative_scale_factors) - safe_log(noise_std)
184+
return SolverTensors(
185+
cumulative_scale_factors=cumulative_scale_factors,
186+
noise_std=noise_std,
187+
signal_to_noise_ratios=signal_to_noise_ratios,
188+
)
189+
99190
def dpm_solver_first_order_update(
100-
self, x: Tensor, noise: Tensor, step: int, sde_noise: Tensor | None = None
101-
) -> Tensor:
191+
self, x: torch.Tensor, noise: torch.Tensor, step: int, sde_noise: torch.Tensor | None = None
192+
) -> torch.Tensor:
102193
"""Applies a first-order backward Euler update to the input data `x`.
103194
104195
Args:
@@ -109,32 +200,29 @@ def dpm_solver_first_order_update(
109200
Returns:
110201
The denoised version of the input data `x`.
111202
"""
112-
current_timestep = self.timesteps[step]
113-
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0])
203+
current_ratio = self.signal_to_noise_ratios[step]
204+
next_ratio = self.signal_to_noise_ratios[step + 1]
114205

115-
previous_ratio = self.signal_to_noise_ratios[previous_timestep]
116-
current_ratio = self.signal_to_noise_ratios[current_timestep]
206+
next_scale_factor = self.cumulative_scale_factors[step + 1]
117207

118-
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
208+
next_noise_std = self.noise_std[step + 1]
209+
current_noise_std = self.noise_std[step]
119210

120-
previous_noise_std = self.noise_std[previous_timestep]
121-
current_noise_std = self.noise_std[current_timestep]
122-
123-
ratio_delta = current_ratio - previous_ratio
211+
ratio_delta = current_ratio - next_ratio
124212

125213
if sde_noise is None:
126-
return (previous_noise_std / current_noise_std) * x + (
127-
1.0 - torch.exp(ratio_delta)
128-
) * previous_scale_factor * noise
214+
return (next_noise_std / current_noise_std) * x + (1.0 - torch.exp(ratio_delta)) * next_scale_factor * noise
129215

130216
factor = 1.0 - torch.exp(2.0 * ratio_delta)
131217
return (
132-
(previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
133-
+ previous_scale_factor * factor * noise
134-
+ previous_noise_std * torch.sqrt(factor) * sde_noise
218+
(next_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
219+
+ next_scale_factor * factor * noise
220+
+ next_noise_std * safe_sqrt(factor) * sde_noise
135221
)
136222

137-
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int, sde_noise: Tensor | None = None) -> Tensor:
223+
def multistep_dpm_solver_second_order_update(
224+
self, x: torch.Tensor, step: int, sde_noise: torch.Tensor | None = None
225+
) -> torch.Tensor:
138226
"""Applies a second-order backward Euler update to the input data `x`.
139227
140228
Args:
@@ -144,43 +232,41 @@ def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int, sde_noi
144232
Returns:
145233
The denoised version of the input data `x`.
146234
"""
147-
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else torch.tensor([0])
148-
current_timestep = self.timesteps[step]
149-
next_timestep = self.timesteps[step - 1]
150-
151235
current_data_estimation = self.estimated_data[-1]
152-
next_data_estimation = self.estimated_data[-2]
236+
previous_data_estimation = self.estimated_data[-2]
153237

154-
previous_ratio = self.signal_to_noise_ratios[previous_timestep]
155-
current_ratio = self.signal_to_noise_ratios[current_timestep]
156-
next_ratio = self.signal_to_noise_ratios[next_timestep]
238+
next_ratio = self.signal_to_noise_ratios[step + 1]
239+
current_ratio = self.signal_to_noise_ratios[step]
240+
previous_ratio = self.signal_to_noise_ratios[step - 1]
157241

158-
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
159-
previous_noise_std = self.noise_std[previous_timestep]
160-
current_noise_std = self.noise_std[current_timestep]
242+
next_scale_factor = self.cumulative_scale_factors[step + 1]
243+
next_noise_std = self.noise_std[step + 1]
244+
current_noise_std = self.noise_std[step]
161245

162-
estimation_delta = (current_data_estimation - next_data_estimation) / (
163-
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
246+
estimation_delta = (current_data_estimation - previous_data_estimation) / (
247+
(current_ratio - previous_ratio) / (next_ratio - current_ratio)
164248
)
165-
ratio_delta = current_ratio - previous_ratio
249+
ratio_delta = current_ratio - next_ratio
166250

167251
if sde_noise is None:
168252
factor = 1.0 - torch.exp(ratio_delta)
169253
return (
170-
(previous_noise_std / current_noise_std) * x
171-
+ previous_scale_factor * factor * current_data_estimation
172-
+ 0.5 * previous_scale_factor * factor * estimation_delta
254+
(next_noise_std / current_noise_std) * x
255+
+ next_scale_factor * factor * current_data_estimation
256+
+ 0.5 * next_scale_factor * factor * estimation_delta
173257
)
174258

175259
factor = 1.0 - torch.exp(2.0 * ratio_delta)
176260
return (
177-
(previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
178-
+ previous_scale_factor * factor * current_data_estimation
179-
+ 0.5 * previous_scale_factor * factor * estimation_delta
180-
+ previous_noise_std * torch.sqrt(factor) * sde_noise
261+
(next_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
262+
+ next_scale_factor * factor * current_data_estimation
263+
+ 0.5 * next_scale_factor * factor * estimation_delta
264+
+ next_noise_std * safe_sqrt(factor) * sde_noise
181265
)
182266

183-
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
267+
def __call__(
268+
self, x: torch.Tensor, predicted_noise: torch.Tensor, step: int, generator: torch.Generator | None = None
269+
) -> torch.Tensor:
184270
"""Apply one step of the backward diffusion process.
185271
186272
Note:
@@ -199,9 +285,8 @@ def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Gen
199285
"""
200286
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
201287

202-
current_timestep = self.timesteps[step]
203-
scale_factor = self.cumulative_scale_factors[current_timestep]
204-
noise_ratio = self.noise_std[current_timestep]
288+
scale_factor = self.cumulative_scale_factors[step]
289+
noise_ratio = self.noise_std[step]
205290
estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor
206291
self.estimated_data.append(estimated_denoised_data)
207292
variance = self.params.sde_variance

src/refiners/foundationals/latent_diffusion/solvers/solver.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class BaseSolverParams:
6767
initial_diffusion_rate: float | None
6868
final_diffusion_rate: float | None
6969
noise_schedule: NoiseSchedule | None
70+
sigma_schedule: NoiseSchedule | None
7071
model_prediction_type: ModelPredictionType | None
7172
sde_variance: float
7273

@@ -91,6 +92,7 @@ class SolverParams(BaseSolverParams):
9192
initial_diffusion_rate: float | None = None
9293
final_diffusion_rate: float | None = None
9394
noise_schedule: NoiseSchedule | None = None
95+
sigma_schedule: NoiseSchedule | None = None
9496
model_prediction_type: ModelPredictionType | None = None
9597
sde_variance: float = 0.0
9698

@@ -103,6 +105,7 @@ class ResolvedSolverParams(BaseSolverParams):
103105
initial_diffusion_rate: float
104106
final_diffusion_rate: float
105107
noise_schedule: NoiseSchedule
108+
sigma_schedule: NoiseSchedule | None
106109
model_prediction_type: ModelPredictionType
107110
sde_variance: float
108111

@@ -140,6 +143,7 @@ class Solver(fl.Module, ABC):
140143
initial_diffusion_rate=8.5e-4,
141144
final_diffusion_rate=1.2e-2,
142145
noise_schedule=NoiseSchedule.QUADRATIC,
146+
sigma_schedule=None,
143147
model_prediction_type=ModelPredictionType.NOISE,
144148
sde_variance=0.0,
145149
)
@@ -404,14 +408,12 @@ def sample_noise_schedule(self) -> Tensor:
404408
A tensor representing the noise schedule.
405409
"""
406410
match self.params.noise_schedule:
407-
case "uniform":
411+
case NoiseSchedule.UNIFORM:
408412
return 1 - self.sample_power_distribution(1)
409-
case "quadratic":
413+
case NoiseSchedule.QUADRATIC:
410414
return 1 - self.sample_power_distribution(2)
411-
case "karras":
415+
case NoiseSchedule.KARRAS:
412416
return 1 - self.sample_power_distribution(7)
413-
case _:
414-
raise ValueError(f"Unknown noise schedule: {self.params.noise_schedule}")
415417

416418
def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver":
417419
"""Move the solver to the specified device and data type.

tests/e2e/test_diffusion.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ def expected_image_std_sde_random_init(ref_path: Path) -> Image.Image:
9797
return _img_open(ref_path / "expected_std_sde_random_init.png").convert("RGB")
9898

9999

100+
@pytest.fixture
101+
def expected_image_std_sde_karras_random_init(ref_path: Path) -> Image.Image:
102+
return _img_open(ref_path / "expected_std_sde_karras_random_init.png").convert("RGB")
103+
104+
100105
@pytest.fixture
101106
def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image:
102107
return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB")
@@ -913,6 +918,39 @@ def test_diffusion_std_sde_random_init(
913918
ensure_similar_images(predicted_image, expected_image_std_sde_random_init)
914919

915920

921+
@no_grad()
922+
def test_diffusion_std_sde_karras_random_init(
923+
sd15_std_sde: StableDiffusion_1, expected_image_std_sde_karras_random_init: Image.Image, test_device: torch.device
924+
):
925+
sd15 = sd15_std_sde
926+
927+
prompt = "a cute cat, detailed high-quality professional image"
928+
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
929+
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
930+
931+
sd15.solver = DPMSolver(
932+
num_inference_steps=18,
933+
last_step_first_order=True,
934+
params=SolverParams(sde_variance=1.0, sigma_schedule=NoiseSchedule.KARRAS),
935+
device=test_device,
936+
)
937+
938+
manual_seed(2)
939+
x = sd15.init_latents((512, 512))
940+
941+
for step in sd15.steps:
942+
x = sd15(
943+
x,
944+
step=step,
945+
clip_text_embedding=clip_text_embedding,
946+
condition_scale=7.5,
947+
)
948+
949+
predicted_image = sd15.lda.latents_to_image(x)
950+
951+
ensure_similar_images(predicted_image, expected_image_std_sde_karras_random_init)
952+
953+
916954
@no_grad()
917955
def test_diffusion_batch2(sd15_std: StableDiffusion_1):
918956
sd15 = sd15_std

0 commit comments

Comments
 (0)