Skip to content

Commit 83b9312

Browse files
committed
fix DPM-Solver with bfloat16
1 parent 283bf45 commit 83b9312

File tree

2 files changed

+50
-48
lines changed
  • src/refiners/foundationals/latent_diffusion/solvers
  • tests/foundationals/latent_diffusion

2 files changed

+50
-48
lines changed

src/refiners/foundationals/latent_diffusion/solvers/dpm.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,6 @@ def _generate_timesteps(self) -> torch.Tensor:
121121
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
122122
return torch.tensor(np_space).flip(0)
123123

124-
def _generate_sigmas(self) -> tuple[torch.Tensor, torch.Tensor]:
125-
"""Generate the sigmas used by the solver."""
126-
assert self.params.sigma_schedule is not None, "sigma_schedule must be set for the DPM solver"
127-
sigmas = self.noise_std / self.cumulative_scale_factors
128-
sigmas = sigmas.flip(0)
129-
rescaled_sigmas = self._rescale_sigmas(sigmas, self.params.sigma_schedule)
130-
rescaled_sigmas = torch.cat([rescaled_sigmas, torch.tensor([0.0])])
131-
return sigmas, rescaled_sigmas
132-
133124
def _rescale_sigmas(self, sigmas: torch.Tensor, sigma_schedule: NoiseSchedule | None) -> torch.Tensor:
134125
"""Rescale the sigmas according to the sigma schedule."""
135126
match sigma_schedule:
@@ -140,9 +131,12 @@ def _rescale_sigmas(self, sigmas: torch.Tensor, sigma_schedule: NoiseSchedule |
140131
case NoiseSchedule.KARRAS:
141132
rho = 7
142133
case None:
134+
if sigmas.dtype == torch.bfloat16:
135+
sigmas = sigmas.to(torch.float32)
143136
return torch.tensor(
144137
np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()),
145138
device=self.device,
139+
dtype=self.dtype,
146140
)
147141

148142
linear_schedule = torch.linspace(0, 1, steps=self.num_inference_steps, device=self.device)

tests/foundationals/latent_diffusion/test_solvers.py

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from warnings import warn
44

55
import pytest
6-
from torch import Generator, Tensor, allclose, device as Device, equal, isclose, randn, tensor
6+
import torch
7+
from torch import Tensor, device as Device
78

89
from refiners.fluxion import manual_seed
910
from refiners.foundationals.latent_diffusion.solvers import (
@@ -27,7 +28,7 @@ def test_ddpm_diffusers():
2728
diffusers_scheduler = DDPMScheduler(beta_schedule="scaled_linear", beta_start=0.00085, beta_end=0.012)
2829
diffusers_scheduler.set_timesteps(1000)
2930
solver = DDPM(num_inference_steps=1000)
30-
assert equal(diffusers_scheduler.timesteps, solver.timesteps)
31+
assert torch.equal(diffusers_scheduler.timesteps, solver.timesteps)
3132

3233

3334
@pytest.mark.parametrize(
@@ -58,10 +59,10 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool, sde_var
5859
sigma_schedule=NoiseSchedule.KARRAS if use_karras_sigmas else None,
5960
),
6061
)
61-
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
62+
assert torch.equal(solver.timesteps, diffusers_scheduler.timesteps)
6263

63-
sample = randn(1, 3, 32, 32)
64-
predicted_noise = randn(1, 3, 32, 32)
64+
sample = torch.randn(1, 3, 32, 32)
65+
predicted_noise = torch.randn(1, 3, 32, 32)
6566

6667
manual_seed(37)
6768
diffusers_outputs: list[Tensor] = [
@@ -74,7 +75,7 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool, sde_var
7475

7576
atol = 1e-4 if use_karras_sigmas else 1e-6
7677
for step, (diffusers_output, refiners_output) in enumerate(zip(diffusers_outputs, refiners_outputs)):
77-
assert allclose(diffusers_output, refiners_output, rtol=0.01, atol=atol), f"outputs differ at step {step}"
78+
assert torch.allclose(diffusers_output, refiners_output, rtol=0.01, atol=atol), f"outputs differ at step {step}"
7879

7980

8081
def test_ddim_diffusers():
@@ -92,16 +93,16 @@ def test_ddim_diffusers():
9293
)
9394
diffusers_scheduler.set_timesteps(30)
9495
solver = DDIM(num_inference_steps=30)
95-
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
96+
assert torch.equal(solver.timesteps, diffusers_scheduler.timesteps)
9697

97-
sample = randn(1, 4, 32, 32)
98-
predicted_noise = randn(1, 4, 32, 32)
98+
sample = torch.randn(1, 4, 32, 32)
99+
predicted_noise = torch.randn(1, 4, 32, 32)
99100

100101
for step, timestep in enumerate(diffusers_scheduler.timesteps):
101102
diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
102103
refiners_output = solver(x=sample, predicted_noise=predicted_noise, step=step)
103104

104-
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
105+
assert torch.allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
105106

106107

107108
@pytest.mark.parametrize("model_prediction_type", [ModelPredictionType.NOISE, ModelPredictionType.SAMPLE])
@@ -122,20 +123,20 @@ def test_euler_diffusers(model_prediction_type: ModelPredictionType):
122123
)
123124
diffusers_scheduler.set_timesteps(30)
124125
solver = Euler(num_inference_steps=30, params=SolverParams(model_prediction_type=model_prediction_type))
125-
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
126+
assert torch.equal(solver.timesteps, diffusers_scheduler.timesteps)
126127

127-
sample = randn(1, 4, 32, 32)
128-
predicted_noise = randn(1, 4, 32, 32)
128+
sample = torch.randn(1, 4, 32, 32)
129+
predicted_noise = torch.randn(1, 4, 32, 32)
129130

130131
ref_init_noise_sigma = diffusers_scheduler.init_noise_sigma # type: ignore
131132
assert isinstance(ref_init_noise_sigma, Tensor)
132-
assert isclose(ref_init_noise_sigma, solver.init_noise_sigma), "init_noise_sigma differ"
133+
assert torch.isclose(ref_init_noise_sigma, solver.init_noise_sigma), "init_noise_sigma differ"
133134

134135
for step, timestep in enumerate(diffusers_scheduler.timesteps):
135136
diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
136137
refiners_output = solver(x=sample, predicted_noise=predicted_noise, step=step)
137138

138-
assert allclose(diffusers_output, refiners_output, rtol=0.02), f"outputs differ at step {step}"
139+
assert torch.allclose(diffusers_output, refiners_output, rtol=0.02), f"outputs differ at step {step}"
139140

140141

141142
def test_franken_diffusers():
@@ -157,21 +158,21 @@ def test_franken_diffusers():
157158

158159
diffusers_scheduler_2 = EulerDiscreteScheduler(**params) # type: ignore
159160
solver = FrankenSolver(lambda: diffusers_scheduler_2, num_inference_steps=30)
160-
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
161+
assert torch.equal(solver.timesteps, diffusers_scheduler.timesteps)
161162

162-
sample = randn(1, 4, 32, 32)
163-
predicted_noise = randn(1, 4, 32, 32)
163+
sample = torch.randn(1, 4, 32, 32)
164+
predicted_noise = torch.randn(1, 4, 32, 32)
164165

165166
ref_init_noise_sigma = diffusers_scheduler.init_noise_sigma # type: ignore
166167
assert isinstance(ref_init_noise_sigma, Tensor)
167-
init_noise_sigma = solver.scale_model_input(tensor(1), step=-1)
168-
assert equal(ref_init_noise_sigma, init_noise_sigma), "init_noise_sigma differ"
168+
init_noise_sigma = solver.scale_model_input(torch.tensor(1), step=-1)
169+
assert torch.equal(ref_init_noise_sigma, init_noise_sigma), "init_noise_sigma differ"
169170

170171
for step, timestep in enumerate(diffusers_scheduler.timesteps):
171172
diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
172173
refiners_output = solver(x=sample, predicted_noise=predicted_noise, step=step)
173174

174-
assert equal(diffusers_output, refiners_output), f"outputs differ at step {step}"
175+
assert torch.equal(diffusers_output, refiners_output), f"outputs differ at step {step}"
175176

176177

177178
def test_lcm_diffusers():
@@ -180,16 +181,16 @@ def test_lcm_diffusers():
180181
manual_seed(0)
181182

182183
# LCMScheduler is stochastic, make sure we use identical generators
183-
diffusers_generator = Generator().manual_seed(42)
184-
refiners_generator = Generator().manual_seed(42)
184+
diffusers_generator = torch.Generator().manual_seed(42)
185+
refiners_generator = torch.Generator().manual_seed(42)
185186

186187
diffusers_scheduler = LCMScheduler()
187188
diffusers_scheduler.set_timesteps(4)
188189
solver = LCMSolver(num_inference_steps=4)
189-
assert equal(solver.timesteps, diffusers_scheduler.timesteps)
190+
assert torch.equal(solver.timesteps, diffusers_scheduler.timesteps)
190191

191-
sample = randn(1, 4, 32, 32)
192-
predicted_noise = randn(1, 4, 32, 32)
192+
sample = torch.randn(1, 4, 32, 32)
193+
predicted_noise = torch.randn(1, 4, 32, 32)
193194

194195
for step, timestep in enumerate(diffusers_scheduler.timesteps):
195196
alpha_prod_t = diffusers_scheduler.alphas_cumprod[timestep]
@@ -212,7 +213,7 @@ def test_lcm_diffusers():
212213
generator=refiners_generator,
213214
)
214215

215-
assert allclose(refiners_output, diffusers_output, rtol=0.01), f"outputs differ at step {step}"
216+
assert torch.allclose(refiners_output, diffusers_output, rtol=0.01), f"outputs differ at step {step}"
216217

217218

218219
def test_solver_remove_noise():
@@ -231,14 +232,14 @@ def test_solver_remove_noise():
231232
diffusers_scheduler.set_timesteps(30)
232233
solver = DDIM(num_inference_steps=30)
233234

234-
sample = randn(1, 4, 32, 32)
235-
noise = randn(1, 4, 32, 32)
235+
sample = torch.randn(1, 4, 32, 32)
236+
noise = torch.randn(1, 4, 32, 32)
236237

237238
for step, timestep in enumerate(diffusers_scheduler.timesteps):
238239
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample) # type: ignore
239240
refiners_output = solver.remove_noise(x=sample, noise=noise, step=step)
240241

241-
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
242+
assert torch.allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
242243

243244

244245
def test_solver_device(test_device: Device):
@@ -247,16 +248,16 @@ def test_solver_device(test_device: Device):
247248
pytest.skip()
248249

249250
scheduler = DDIM(num_inference_steps=30, device=test_device)
250-
x = randn(1, 4, 32, 32, device=test_device)
251-
noise = randn(1, 4, 32, 32, device=test_device)
251+
x = torch.randn(1, 4, 32, 32, device=test_device)
252+
noise = torch.randn(1, 4, 32, 32, device=test_device)
252253
noised = scheduler.add_noise(x, noise, scheduler.first_inference_step)
253254
assert noised.device == test_device
254255

255256

256257
def test_solver_add_noise(test_device: Device):
257258
scheduler = DDIM(num_inference_steps=30, device=test_device)
258-
latent = randn(1, 4, 32, 32, device=test_device)
259-
noise = randn(1, 4, 32, 32, device=test_device)
259+
latent = torch.randn(1, 4, 32, 32, device=test_device)
260+
noise = torch.randn(1, 4, 32, 32, device=test_device)
260261
noised = scheduler.add_noise(
261262
x=latent,
262263
noise=noise,
@@ -267,8 +268,8 @@ def test_solver_add_noise(test_device: Device):
267268
noise=noise.repeat(2, 1, 1, 1),
268269
step=[0, 0],
269270
)
270-
assert allclose(noised, noised_double[0])
271-
assert allclose(noised, noised_double[1])
271+
assert torch.allclose(noised, noised_double[0])
272+
assert torch.allclose(noised, noised_double[1])
272273

273274

274275
@pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS])
@@ -291,20 +292,27 @@ def test_solver_timestep_spacing():
291292
num_train_timesteps=1000,
292293
offset=1,
293294
)
294-
assert equal(linspace_int, tensor([1000, 889, 778, 667, 556, 445, 334, 223, 112, 1]))
295+
assert torch.equal(linspace_int, torch.tensor([1000, 889, 778, 667, 556, 445, 334, 223, 112, 1]))
295296

296297
leading = Solver.generate_timesteps(
297298
spacing=TimestepSpacing.LEADING,
298299
num_inference_steps=10,
299300
num_train_timesteps=1000,
300301
offset=1,
301302
)
302-
assert equal(leading, tensor([901, 801, 701, 601, 501, 401, 301, 201, 101, 1]))
303+
assert torch.equal(leading, torch.tensor([901, 801, 701, 601, 501, 401, 301, 201, 101, 1]))
303304

304305
trailing = Solver.generate_timesteps(
305306
spacing=TimestepSpacing.TRAILING,
306307
num_inference_steps=10,
307308
num_train_timesteps=1000,
308309
offset=1,
309310
)
310-
assert equal(trailing, tensor([1000, 900, 800, 700, 600, 500, 400, 300, 200, 100]))
311+
assert torch.equal(trailing, torch.tensor([1000, 900, 800, 700, 600, 500, 400, 300, 200, 100]))
312+
313+
314+
def test_dpm_bfloat16(test_device: Device):
315+
if test_device.type == "cpu":
316+
warn("not running on CPU, skipping")
317+
pytest.skip()
318+
DPMSolver(num_inference_steps=5, dtype=torch.bfloat16) # should not raise

0 commit comments

Comments
 (0)