8
8
from torch import Tensor
9
9
from typing_extensions import TypeVar
10
10
11
- from refiners .fluxion .utils import image_to_tensor , load_from_safetensors , manual_seed , no_grad
11
+ from refiners .fluxion .utils import image_to_tensor , load_from_safetensors , no_grad
12
12
from refiners .foundationals .clip .concepts import ConceptExtender
13
13
from refiners .foundationals .latent_diffusion .lora import SDLoraManager
14
14
from refiners .foundationals .latent_diffusion .multi_diffusion import DiffusionTarget , MultiDiffusion , Size
@@ -217,13 +217,12 @@ def compute_upscaler_targets(
217
217
218
218
def diffuse_targets (
219
219
self ,
220
+ noise : torch .Tensor ,
220
221
targets : Sequence [T ],
221
222
image : Image .Image ,
222
- latent_size : Size ,
223
223
first_step : int ,
224
224
autoencoder_tile_length : int ,
225
225
) -> Image .Image :
226
- noise = torch .randn (size = (1 , 4 , * latent_size ), device = self .device , dtype = self .dtype )
227
226
with self .sd .lda .tiled_inference (image , (autoencoder_tile_length , autoencoder_tile_length )):
228
227
latents = self .sd .lda .tiled_image_to_latents (image )
229
228
x = self .sd .solver .add_noise (x = latents , noise = noise , step = first_step )
@@ -249,7 +248,7 @@ def upscale(
249
248
solver_type : type [Solver ] = DPMSolver ,
250
249
num_inference_steps : int = 18 ,
251
250
autoencoder_tile_length : int = 1024 ,
252
- seed : int = 37 ,
251
+ generator : torch . Generator | None = None ,
253
252
) -> Image .Image :
254
253
"""
255
254
Upscale an image using the multi upscaler.
@@ -280,10 +279,8 @@ def upscale(
280
279
between quality and speed.
281
280
autoencoder_tile_length: The length of the autoencoder tiles. It shouldn't affect the end result, but
282
281
lowering it can reduce GPU memory usage (but increase computation time).
283
- seed : The seed to use for the random number generator .
282
+ generator : The random number generator to use for sampling noise .
284
283
"""
285
- manual_seed (seed )
286
-
287
284
# update controlnet scale
288
285
self .controlnet .scale = controlnet_scale
289
286
self .controlnet .scale_decay = controlnet_scale_decay
@@ -323,11 +320,19 @@ def upscale(
323
320
clip_text_embedding = clip_text_embedding ,
324
321
)
325
322
323
+ # initialize the noise
324
+ noise = torch .randn (
325
+ size = (1 , 4 , * latent_size ),
326
+ device = self .device ,
327
+ dtype = self .dtype ,
328
+ generator = generator ,
329
+ )
330
+
326
331
# diffuse the tiles
327
332
return self .diffuse_targets (
333
+ noise = noise ,
328
334
targets = targets ,
329
335
image = image ,
330
- latent_size = latent_size ,
331
336
first_step = first_step ,
332
337
autoencoder_tile_length = autoencoder_tile_length ,
333
338
)
0 commit comments