Skip to content
This repository was archived by the owner on Sep 26, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch import Tensor
from typing_extensions import TypeVar

from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed, no_grad
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, no_grad
from refiners.foundationals.clip.concepts import ConceptExtender
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion, Size
Expand Down Expand Up @@ -217,13 +217,12 @@ def compute_upscaler_targets(

def diffuse_targets(
self,
noise: torch.Tensor,
targets: Sequence[T],
image: Image.Image,
latent_size: Size,
first_step: int,
autoencoder_tile_length: int,
) -> Image.Image:
noise = torch.randn(size=(1, 4, *latent_size), device=self.device, dtype=self.dtype)
with self.sd.lda.tiled_inference(image, (autoencoder_tile_length, autoencoder_tile_length)):
latents = self.sd.lda.tiled_image_to_latents(image)
x = self.sd.solver.add_noise(x=latents, noise=noise, step=first_step)
Expand All @@ -249,7 +248,7 @@ def upscale(
solver_type: type[Solver] = DPMSolver,
num_inference_steps: int = 18,
autoencoder_tile_length: int = 1024,
seed: int = 37,
generator: torch.Generator | None = None,
) -> Image.Image:
"""
Upscale an image using the multi upscaler.
Expand Down Expand Up @@ -280,10 +279,8 @@ def upscale(
between quality and speed.
autoencoder_tile_length: The length of the autoencoder tiles. It shouldn't affect the end result, but
lowering it can reduce GPU memory usage (but increase computation time).
seed: The seed to use for the random number generator.
generator: The random number generator to use for sampling noise.
"""
manual_seed(seed)

# update controlnet scale
self.controlnet.scale = controlnet_scale
self.controlnet.scale_decay = controlnet_scale_decay
Expand Down Expand Up @@ -323,11 +320,19 @@ def upscale(
clip_text_embedding=clip_text_embedding,
)

# initialize the noise
noise = torch.randn(
size=(1, 4, *latent_size),
device=self.device,
dtype=self.dtype,
generator=generator,
)

# diffuse the tiles
return self.diffuse_targets(
noise=noise,
targets=targets,
image=image,
latent_size=latent_size,
first_step=first_step,
autoencoder_tile_length=autoencoder_tile_length,
)
Expand Down
4 changes: 3 additions & 1 deletion tests/e2e/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2669,7 +2669,9 @@ def test_multi_upscaler(
clarity_example: Image.Image,
expected_multi_upscaler: Image.Image,
) -> None:
predicted_image = multi_upscaler.upscale(clarity_example)
generator = torch.Generator(device=multi_upscaler.device)
generator.manual_seed(37)
predicted_image = multi_upscaler.upscale(clarity_example, generator=generator)
ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=35, min_ssim=0.99)


Expand Down