Skip to content
This repository was archived by the owner on Sep 26, 2025. It is now read-only.

Commit 883a212

Browse files
committed
fix precision of DPM solver in bfloat16
1 parent ed7e2e5 commit 883a212

File tree

1 file changed

+2
-4
lines changed
  • src/refiners/foundationals/latent_diffusion/solvers

1 file changed

+2
-4
lines changed

src/refiners/foundationals/latent_diffusion/solvers/dpm.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(
7777
first_inference_step=first_inference_step,
7878
params=params,
7979
device=device,
80-
dtype=dtype,
80+
dtype=torch.float64, # compute constants precisely
8181
)
8282
self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2)
8383
self.last_step_first_order = last_step_first_order
@@ -89,6 +89,7 @@ def __init__(
8989
self.sigmas
9090
)
9191
self.timesteps = self._timesteps_from_sigmas(sigmas)
92+
self.to(dtype=dtype)
9293

9394
def rebuild(
9495
self: "DPMSolver",
@@ -131,12 +132,9 @@ def _rescale_sigmas(self, sigmas: torch.Tensor, sigma_schedule: NoiseSchedule |
131132
case NoiseSchedule.KARRAS:
132133
rho = 7
133134
case None:
134-
if sigmas.dtype == torch.bfloat16:
135-
sigmas = sigmas.to(torch.float32)
136135
return torch.tensor(
137136
np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()),
138137
device=self.device,
139-
dtype=self.dtype,
140138
)
141139

142140
linear_schedule = torch.linspace(0, 1, steps=self.num_inference_steps, device=self.device)

0 commit comments

Comments
 (0)