1
1
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
2
2
import inspect
3
+ import os
3
4
from contextlib import ExitStack
4
5
from typing import Any , Dict , Iterator , List , Optional , Tuple , Union
5
6
39
40
from invokeai .backend .model_manager import BaseModelType
40
41
from invokeai .backend .model_patcher import ModelPatcher
41
42
from invokeai .backend .stable_diffusion import PipelineIntermediateState , set_seamless
43
+ from invokeai .backend .stable_diffusion .denoise_context import DenoiseContext , DenoiseInputs
42
44
from invokeai .backend .stable_diffusion .diffusers_pipeline import (
43
45
ControlNetData ,
44
46
StableDiffusionGeneratorPipeline ,
53
55
TextConditioningData ,
54
56
TextConditioningRegions ,
55
57
)
58
+ from invokeai .backend .stable_diffusion .diffusion .custom_atttention import CustomAttnProcessor2_0
59
+ from invokeai .backend .stable_diffusion .diffusion_backend import StableDiffusionBackend
60
+ from invokeai .backend .stable_diffusion .extension_callback_type import ExtensionCallbackType
61
+ from invokeai .backend .stable_diffusion .extensions .preview import PreviewExt
62
+ from invokeai .backend .stable_diffusion .extensions_manager import ExtensionsManager
56
63
from invokeai .backend .stable_diffusion .schedulers import SCHEDULER_MAP
57
64
from invokeai .backend .stable_diffusion .schedulers .schedulers import SCHEDULER_NAME_VALUES
58
65
from invokeai .backend .util .devices import TorchDevice
@@ -314,9 +321,10 @@ def get_conditioning_data(
314
321
context : InvocationContext ,
315
322
positive_conditioning_field : Union [ConditioningField , list [ConditioningField ]],
316
323
negative_conditioning_field : Union [ConditioningField , list [ConditioningField ]],
317
- unet : UNet2DConditionModel ,
318
324
latent_height : int ,
319
325
latent_width : int ,
326
+ device : torch .device ,
327
+ dtype : torch .dtype ,
320
328
cfg_scale : float | list [float ],
321
329
steps : int ,
322
330
cfg_rescale_multiplier : float ,
@@ -330,25 +338,25 @@ def get_conditioning_data(
330
338
uncond_list = [uncond_list ]
331
339
332
340
cond_text_embeddings , cond_text_embedding_masks = DenoiseLatentsInvocation ._get_text_embeddings_and_masks (
333
- cond_list , context , unet . device , unet . dtype
341
+ cond_list , context , device , dtype
334
342
)
335
343
uncond_text_embeddings , uncond_text_embedding_masks = DenoiseLatentsInvocation ._get_text_embeddings_and_masks (
336
- uncond_list , context , unet . device , unet . dtype
344
+ uncond_list , context , device , dtype
337
345
)
338
346
339
347
cond_text_embedding , cond_regions = DenoiseLatentsInvocation ._concat_regional_text_embeddings (
340
348
text_conditionings = cond_text_embeddings ,
341
349
masks = cond_text_embedding_masks ,
342
350
latent_height = latent_height ,
343
351
latent_width = latent_width ,
344
- dtype = unet . dtype ,
352
+ dtype = dtype ,
345
353
)
346
354
uncond_text_embedding , uncond_regions = DenoiseLatentsInvocation ._concat_regional_text_embeddings (
347
355
text_conditionings = uncond_text_embeddings ,
348
356
masks = uncond_text_embedding_masks ,
349
357
latent_height = latent_height ,
350
358
latent_width = latent_width ,
351
- dtype = unet . dtype ,
359
+ dtype = dtype ,
352
360
)
353
361
354
362
if isinstance (cfg_scale , list ):
@@ -707,9 +715,108 @@ def prepare_noise_and_latents(
707
715
708
716
return seed , noise , latents
709
717
718
+ def invoke (self , context : InvocationContext ) -> LatentsOutput :
719
+ if os .environ .get ("USE_MODULAR_DENOISE" , False ):
720
+ return self ._new_invoke (context )
721
+ else :
722
+ return self ._old_invoke (context )
723
+
710
724
@torch .no_grad ()
711
725
@SilenceWarnings () # This quenches the NSFW nag from diffusers.
712
- def invoke (self , context : InvocationContext ) -> LatentsOutput :
726
+ def _new_invoke (self , context : InvocationContext ) -> LatentsOutput :
727
+ ext_manager = ExtensionsManager (is_canceled = context .util .is_canceled )
728
+
729
+ device = TorchDevice .choose_torch_device ()
730
+ dtype = TorchDevice .choose_torch_dtype ()
731
+
732
+ seed , noise , latents = self .prepare_noise_and_latents (context , self .noise , self .latents )
733
+ latents = latents .to (device = device , dtype = dtype )
734
+ if noise is not None :
735
+ noise = noise .to (device = device , dtype = dtype )
736
+
737
+ _ , _ , latent_height , latent_width = latents .shape
738
+
739
+ conditioning_data = self .get_conditioning_data (
740
+ context = context ,
741
+ positive_conditioning_field = self .positive_conditioning ,
742
+ negative_conditioning_field = self .negative_conditioning ,
743
+ cfg_scale = self .cfg_scale ,
744
+ steps = self .steps ,
745
+ latent_height = latent_height ,
746
+ latent_width = latent_width ,
747
+ device = device ,
748
+ dtype = dtype ,
749
+ # TODO: old backend, remove
750
+ cfg_rescale_multiplier = self .cfg_rescale_multiplier ,
751
+ )
752
+
753
+ scheduler = get_scheduler (
754
+ context = context ,
755
+ scheduler_info = self .unet .scheduler ,
756
+ scheduler_name = self .scheduler ,
757
+ seed = seed ,
758
+ )
759
+
760
+ timesteps , init_timestep , scheduler_step_kwargs = self .init_scheduler (
761
+ scheduler ,
762
+ seed = seed ,
763
+ device = device ,
764
+ steps = self .steps ,
765
+ denoising_start = self .denoising_start ,
766
+ denoising_end = self .denoising_end ,
767
+ )
768
+
769
+ denoise_ctx = DenoiseContext (
770
+ inputs = DenoiseInputs (
771
+ orig_latents = latents ,
772
+ timesteps = timesteps ,
773
+ init_timestep = init_timestep ,
774
+ noise = noise ,
775
+ seed = seed ,
776
+ scheduler_step_kwargs = scheduler_step_kwargs ,
777
+ conditioning_data = conditioning_data ,
778
+ attention_processor_cls = CustomAttnProcessor2_0 ,
779
+ ),
780
+ unet = None ,
781
+ scheduler = scheduler ,
782
+ )
783
+
784
+ # get the unet's config so that we can pass the base to sd_step_callback()
785
+ unet_config = context .models .get_config (self .unet .unet .key )
786
+
787
+ ### preview
788
+ def step_callback (state : PipelineIntermediateState ) -> None :
789
+ context .util .sd_step_callback (state , unet_config .base )
790
+
791
+ ext_manager .add_extension (PreviewExt (step_callback ))
792
+
793
+ # ext: t2i/ip adapter
794
+ ext_manager .run_callback (ExtensionCallbackType .SETUP , denoise_ctx )
795
+
796
+ unet_info = context .models .load (self .unet .unet )
797
+ assert isinstance (unet_info .model , UNet2DConditionModel )
798
+ with (
799
+ unet_info .model_on_device () as (model_state_dict , unet ),
800
+ ModelPatcher .patch_unet_attention_processor (unet , denoise_ctx .inputs .attention_processor_cls ),
801
+ # ext: controlnet
802
+ ext_manager .patch_extensions (unet ),
803
+ # ext: freeu, seamless, ip adapter, lora
804
+ ext_manager .patch_unet (model_state_dict , unet ),
805
+ ):
806
+ sd_backend = StableDiffusionBackend (unet , scheduler )
807
+ denoise_ctx .unet = unet
808
+ result_latents = sd_backend .latents_from_embeddings (denoise_ctx , ext_manager )
809
+
810
+ # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
811
+ result_latents = result_latents .detach ().to ("cpu" )
812
+ TorchDevice .empty_cache ()
813
+
814
+ name = context .tensors .save (tensor = result_latents )
815
+ return LatentsOutput .build (latents_name = name , latents = result_latents , seed = None )
816
+
817
+ @torch .no_grad ()
818
+ @SilenceWarnings () # This quenches the NSFW nag from diffusers.
819
+ def _old_invoke (self , context : InvocationContext ) -> LatentsOutput :
713
820
seed , noise , latents = self .prepare_noise_and_latents (context , self .noise , self .latents )
714
821
715
822
mask , masked_latents , gradient_mask = self .prep_inpaint_mask (context , latents )
@@ -788,7 +895,8 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
788
895
context = context ,
789
896
positive_conditioning_field = self .positive_conditioning ,
790
897
negative_conditioning_field = self .negative_conditioning ,
791
- unet = unet ,
898
+ device = unet .device ,
899
+ dtype = unet .dtype ,
792
900
latent_height = latent_height ,
793
901
latent_width = latent_width ,
794
902
cfg_scale = self .cfg_scale ,
0 commit comments