Skip to content
This repository was archived by the owner on Sep 26, 2025. It is now read-only.

Commit a5c7422

Browse files
author
Laurent
committed
add torch.Generator to MultiUpscaler.upscale + make MultiUpscaler.diffuse_targets "stateless"
1 parent 883a212 commit a5c7422

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

src/refiners/foundationals/latent_diffusion/stable_diffusion_1/multi_upscaler.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch import Tensor
99
from typing_extensions import TypeVar
1010

11-
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed, no_grad
11+
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, no_grad
1212
from refiners.foundationals.clip.concepts import ConceptExtender
1313
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
1414
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion, Size
@@ -217,13 +217,12 @@ def compute_upscaler_targets(
217217

218218
def diffuse_targets(
219219
self,
220+
noise: torch.Tensor,
220221
targets: Sequence[T],
221222
image: Image.Image,
222-
latent_size: Size,
223223
first_step: int,
224224
autoencoder_tile_length: int,
225225
) -> Image.Image:
226-
noise = torch.randn(size=(1, 4, *latent_size), device=self.device, dtype=self.dtype)
227226
with self.sd.lda.tiled_inference(image, (autoencoder_tile_length, autoencoder_tile_length)):
228227
latents = self.sd.lda.tiled_image_to_latents(image)
229228
x = self.sd.solver.add_noise(x=latents, noise=noise, step=first_step)
@@ -249,7 +248,7 @@ def upscale(
249248
solver_type: type[Solver] = DPMSolver,
250249
num_inference_steps: int = 18,
251250
autoencoder_tile_length: int = 1024,
252-
seed: int = 37,
251+
generator: torch.Generator | None = None,
253252
) -> Image.Image:
254253
"""
255254
Upscale an image using the multi upscaler.
@@ -280,10 +279,8 @@ def upscale(
280279
between quality and speed.
281280
autoencoder_tile_length: The length of the autoencoder tiles. It shouldn't affect the end result, but
282281
lowering it can reduce GPU memory usage (but increase computation time).
283-
seed: The seed to use for the random number generator.
282+
generator: The random number generator to use for sampling noise.
284283
"""
285-
manual_seed(seed)
286-
287284
# update controlnet scale
288285
self.controlnet.scale = controlnet_scale
289286
self.controlnet.scale_decay = controlnet_scale_decay
@@ -323,11 +320,19 @@ def upscale(
323320
clip_text_embedding=clip_text_embedding,
324321
)
325322

323+
# initialize the noise
324+
noise = torch.randn(
325+
size=(1, 4, *latent_size),
326+
device=self.device,
327+
dtype=self.dtype,
328+
generator=generator,
329+
)
330+
326331
# diffuse the tiles
327332
return self.diffuse_targets(
333+
noise=noise,
328334
targets=targets,
329335
image=image,
330-
latent_size=latent_size,
331336
first_step=first_step,
332337
autoencoder_tile_length=autoencoder_tile_length,
333338
)

tests/e2e/test_diffusion.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2669,7 +2669,9 @@ def test_multi_upscaler(
26692669
clarity_example: Image.Image,
26702670
expected_multi_upscaler: Image.Image,
26712671
) -> None:
2672-
predicted_image = multi_upscaler.upscale(clarity_example)
2672+
generator = torch.Generator(device=multi_upscaler.device)
2673+
generator.manual_seed(37)
2674+
predicted_image = multi_upscaler.upscale(clarity_example, generator=generator)
26732675
ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=35, min_ssim=0.99)
26742676

26752677

0 commit comments

Comments
 (0)