Skip to content
This repository was archived by the owner on Sep 26, 2025. It is now read-only.

Commit 09a9dfd

Browse files
Add stochastic sampling to DPM solver (SDE)
1 parent daee772 commit 09a9dfd

File tree

8 files changed

+188
-15
lines changed

8 files changed

+188
-15
lines changed

src/refiners/foundationals/latent_diffusion/solvers/ddim.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def __init__(
4141
"""
4242
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
4343
raise NotImplementedError
44+
if params and params.sde_variance != 0.0:
45+
raise NotImplementedError("DDIM does not support sde_variance != 0.0 yet")
4446

4547
super().__init__(
4648
num_inference_steps=num_inference_steps,

src/refiners/foundationals/latent_diffusion/solvers/dpm.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from collections import deque
33

44
import numpy as np
5-
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
5+
import torch
6+
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
67

78
from refiners.foundationals.latent_diffusion.solvers.solver import (
89
BaseSolverParams,
@@ -51,6 +52,8 @@ def __init__(
5152
"""
5253
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
5354
raise NotImplementedError
55+
if params and params.sde_variance not in (0.0, 1.0):
56+
raise NotImplementedError("DPMSolver only supports sde_variance=0.0 or 1.0")
5457

5558
super().__init__(
5659
num_inference_steps=num_inference_steps,
@@ -93,7 +96,9 @@ def _generate_timesteps(self) -> Tensor:
9396
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
9497
return tensor(np_space).flip(0)
9598

96-
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
99+
def dpm_solver_first_order_update(
100+
self, x: Tensor, noise: Tensor, step: int, sde_noise: Tensor | None = None
101+
) -> Tensor:
97102
"""Applies a first-order backward Euler update to the input data `x`.
98103
99104
Args:
@@ -115,11 +120,21 @@ def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) ->
115120
previous_noise_std = self.noise_std[previous_timestep]
116121
current_noise_std = self.noise_std[current_timestep]
117122

118-
factor = exp(-(previous_ratio - current_ratio)) - 1.0
119-
denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise
120-
return denoised_x
123+
ratio_delta = current_ratio - previous_ratio
121124

122-
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
125+
if sde_noise is None:
126+
return (previous_noise_std / current_noise_std) * x + (
127+
1.0 - torch.exp(ratio_delta)
128+
) * previous_scale_factor * noise
129+
130+
factor = 1.0 - torch.exp(2.0 * ratio_delta)
131+
return (
132+
(previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
133+
+ previous_scale_factor * factor * noise
134+
+ previous_noise_std * torch.sqrt(factor) * sde_noise
135+
)
136+
137+
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int, sde_noise: Tensor | None = None) -> Tensor:
123138
"""Applies a second-order backward Euler update to the input data `x`.
124139
125140
Args:
@@ -147,13 +162,23 @@ def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tens
147162
estimation_delta = (current_data_estimation - next_data_estimation) / (
148163
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
149164
)
150-
factor = exp(-(previous_ratio - current_ratio)) - 1.0
151-
denoised_x = (
152-
(previous_noise_std / current_noise_std) * x
153-
- (factor * previous_scale_factor) * current_data_estimation
154-
- 0.5 * (factor * previous_scale_factor) * estimation_delta
165+
ratio_delta = current_ratio - previous_ratio
166+
167+
if sde_noise is None:
168+
factor = 1.0 - torch.exp(ratio_delta)
169+
return (
170+
(previous_noise_std / current_noise_std) * x
171+
+ previous_scale_factor * factor * current_data_estimation
172+
+ 0.5 * previous_scale_factor * factor * estimation_delta
173+
)
174+
175+
factor = 1.0 - torch.exp(2.0 * ratio_delta)
176+
return (
177+
(previous_noise_std / current_noise_std) * torch.exp(ratio_delta) * x
178+
+ previous_scale_factor * factor * current_data_estimation
179+
+ 0.5 * previous_scale_factor * factor * estimation_delta
180+
+ previous_noise_std * torch.sqrt(factor) * sde_noise
155181
)
156-
return denoised_x
157182

158183
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
159184
"""Apply one step of the backward diffusion process.
@@ -175,11 +200,20 @@ def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Gen
175200
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
176201

177202
current_timestep = self.timesteps[step]
178-
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
203+
scale_factor = self.cumulative_scale_factors[current_timestep]
204+
noise_ratio = self.noise_std[current_timestep]
179205
estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor
180206
self.estimated_data.append(estimated_denoised_data)
207+
variance = self.params.sde_variance
208+
sde_noise = (
209+
torch.randn(x.shape, generator=generator, device=x.device, dtype=x.dtype) * variance
210+
if variance > 0.0
211+
else None
212+
)
181213

182214
if step == self.first_inference_step or (self.last_step_first_order and step == self.num_inference_steps - 1):
183-
return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
215+
return self.dpm_solver_first_order_update(
216+
x=x, noise=estimated_denoised_data, step=step, sde_noise=sde_noise
217+
)
184218

185-
return self.multistep_dpm_solver_second_order_update(x=x, step=step)
219+
return self.multistep_dpm_solver_second_order_update(x=x, step=step, sde_noise=sde_noise)

src/refiners/foundationals/latent_diffusion/solvers/euler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def __init__(
3636
"""
3737
if params and params.noise_schedule not in (NoiseSchedule.QUADRATIC, None):
3838
raise NotImplementedError
39+
if params and params.sde_variance != 0.0:
40+
raise NotImplementedError("Euler does not support sde_variance != 0.0 yet")
3941

4042
super().__init__(
4143
num_inference_steps=num_inference_steps,

src/refiners/foundationals/latent_diffusion/solvers/solver.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class BaseSolverParams:
7979
final_diffusion_rate: float | None
8080
noise_schedule: NoiseSchedule | None
8181
model_prediction_type: ModelPredictionType | None
82+
sde_variance: float
8283

8384

8485
@dataclasses.dataclass(kw_only=True, frozen=True)
@@ -102,6 +103,7 @@ class SolverParams(BaseSolverParams):
102103
final_diffusion_rate: float | None = None
103104
noise_schedule: NoiseSchedule | None = None
104105
model_prediction_type: ModelPredictionType | None = None
106+
sde_variance: float = 0.0
105107

106108

107109
@dataclasses.dataclass(kw_only=True, frozen=True)
@@ -113,6 +115,7 @@ class ResolvedSolverParams(BaseSolverParams):
113115
final_diffusion_rate: float
114116
noise_schedule: NoiseSchedule
115117
model_prediction_type: ModelPredictionType
118+
sde_variance: float
116119

117120

118121
class Solver(fl.Module, ABC):
@@ -123,6 +126,19 @@ class Solver(fl.Module, ABC):
123126
124127
This process is described using several parameters such as initial and final diffusion rates,
125128
and is encapsulated into a `__call__` method that applies a step of the diffusion process.
129+
130+
Attributes:
131+
params: The common parameters for solvers. See `SolverParams`.
132+
num_inference_steps: The number of inference steps to perform.
133+
first_inference_step: The step to start the inference process from.
134+
scale_factors: The scale factors used to denoise the input. These are called "betas" in other implementations,
135+
and `1 - scale_factors` is called "alphas".
136+
cumulative_scale_factors: The cumulative scale factors used to denoise the input. These are called "alpha_t" in
137+
other implementations.
138+
noise_std: The standard deviation of the noise used to denoise the input. This is called "sigma_t" in other
139+
implementations.
140+
signal_to_noise_ratios: The signal-to-noise ratios used to denoise the input. This is called "lambda_t" in other
141+
implementations.
126142
"""
127143

128144
timesteps: Tensor
@@ -136,6 +152,7 @@ class Solver(fl.Module, ABC):
136152
final_diffusion_rate=1.2e-2,
137153
noise_schedule=NoiseSchedule.QUADRATIC,
138154
model_prediction_type=ModelPredictionType.NOISE,
155+
sde_variance=0.0,
139156
)
140157

141158
def __init__(

tests/e2e/test_diffusion.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ def expected_image_std_random_init(ref_path: Path) -> Image.Image:
8888
return _img_open(ref_path / "expected_std_random_init.png").convert("RGB")
8989

9090

91+
@pytest.fixture
92+
def expected_image_std_sde_random_init(ref_path: Path) -> Image.Image:
93+
return _img_open(ref_path / "expected_std_sde_random_init.png").convert("RGB")
94+
95+
9196
@pytest.fixture
9297
def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image:
9398
return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB")
@@ -560,6 +565,24 @@ def sd15_std(
560565
return sd15
561566

562567

568+
@pytest.fixture
569+
def sd15_std_sde(
570+
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
571+
) -> StableDiffusion_1:
572+
if test_device.type == "cpu":
573+
warn("not running on CPU, skipping")
574+
pytest.skip()
575+
576+
sde_solver = DPMSolver(num_inference_steps=30, last_step_first_order=True, params=SolverParams(sde_variance=1.0))
577+
sd15 = StableDiffusion_1(device=test_device, solver=sde_solver)
578+
579+
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
580+
sd15.lda.load_from_safetensors(lda_weights)
581+
sd15.unet.load_from_safetensors(unet_weights_std)
582+
583+
return sd15
584+
585+
563586
@pytest.fixture
564587
def sd15_std_float16(
565588
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
@@ -831,6 +854,33 @@ def test_diffusion_std_random_init(
831854
ensure_similar_images(predicted_image, expected_image_std_random_init)
832855

833856

857+
@no_grad()
858+
def test_diffusion_std_sde_random_init(
859+
sd15_std_sde: StableDiffusion_1, expected_image_std_sde_random_init: Image.Image, test_device: torch.device
860+
):
861+
sd15 = sd15_std_sde
862+
863+
prompt = "a cute cat, detailed high-quality professional image"
864+
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
865+
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
866+
867+
sd15.set_inference_steps(50)
868+
869+
manual_seed(2)
870+
x = torch.randn(1, 4, 64, 64, device=test_device)
871+
872+
for step in sd15.steps:
873+
x = sd15(
874+
x,
875+
step=step,
876+
clip_text_embedding=clip_text_embedding,
877+
condition_scale=7.5,
878+
)
879+
predicted_image = sd15.lda.latents_to_image(x)
880+
881+
ensure_similar_images(predicted_image, expected_image_std_sde_random_init)
882+
883+
834884
@no_grad()
835885
def test_diffusion_batch2(sd15_std: StableDiffusion_1):
836886
sd15 = sd15_std

tests/e2e/test_diffusion_ref/README.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,35 @@ Special cases:
6767

6868
- `kitchen_dog.png` is generated with the same Diffusers script and negative prompt, seed 12, positive prompt "a small brown dog, detailed high-quality professional image, sitting on a chair, in a kitchen".
6969

70+
- `expected_std_sde_random_init.png` is generated with the following code:
71+
72+
```python
73+
import torch
74+
from diffusers import StableDiffusionPipeline
75+
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
76+
77+
from refiners.fluxion.utils import manual_seed
78+
79+
diffusers_solver = DPMSolverMultistepScheduler.from_config( # type: ignore
80+
{
81+
"beta_end": 0.012,
82+
"beta_schedule": "scaled_linear",
83+
"beta_start": 0.00085,
84+
"algorithm_type": "sde-dpmsolver++",
85+
"use_karras_sigmas": False,
86+
"final_sigmas_type": "sigma_min",
87+
"euler_at_final": True,
88+
}
89+
)
90+
model_id = "runwayml/stable-diffusion-v1-5"
91+
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32, scheduler=diffusers_solver)
92+
pipe = pipe.to("cuda")
93+
prompt = "a cute cat, detailed high-quality professional image"
94+
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
95+
manual_seed(2)
96+
image = pipe(prompt, negative_prompt=negative_prompt, guidance_scale=7.5).images[0]
97+
```
98+
7099
- `kitchen_mask.png` is made manually.
71100

72101
- Controlnet guides have been manually generated (x) using open source software and models, namely:
349 KB
Loading

tests/foundationals/latent_diffusion/test_solvers.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,45 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
5959
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
6060

6161

62+
@pytest.mark.parametrize("n_steps, last_step_first_order", [(5, False), (5, True), (30, False), (30, True)])
63+
def test_dpm_solver_sde_diffusers(n_steps: int, last_step_first_order: bool):
64+
from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore
65+
66+
manual_seed(0)
67+
68+
diffusers_scheduler = DiffuserScheduler(
69+
beta_schedule="scaled_linear",
70+
beta_start=0.00085,
71+
beta_end=0.012,
72+
lower_order_final=False,
73+
euler_at_final=last_step_first_order,
74+
final_sigmas_type="sigma_min", # default before Diffusers 0.26.0
75+
algorithm_type="sde-dpmsolver++",
76+
)
77+
diffusers_scheduler.set_timesteps(n_steps)
78+
solver = DPMSolver(
79+
num_inference_steps=n_steps,
80+
last_step_first_order=last_step_first_order,
81+
params=SolverParams(sde_variance=1.0),
82+
)
83+
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
84+
85+
sample = randn(1, 3, 32, 32)
86+
predicted_noise = randn(1, 3, 32, 32)
87+
88+
manual_seed(37)
89+
diffusers_outputs: list[Tensor] = [
90+
cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
91+
for timestep in diffusers_scheduler.timesteps
92+
]
93+
94+
manual_seed(37)
95+
refiners_outputs = [solver(x=sample, predicted_noise=predicted_noise, step=step) for step in range(n_steps)]
96+
97+
for step, (diffusers_output, refiners_output) in enumerate(zip(diffusers_outputs, refiners_outputs)):
98+
assert allclose(diffusers_output, refiners_output, rtol=0.01, atol=1e-6), f"outputs differ at step {step}"
99+
100+
62101
def test_ddim_diffusers():
63102
from diffusers import DDIMScheduler # type: ignore
64103

0 commit comments

Comments
 (0)