Skip to content

Commit 473f4cc

Browse files
authored
Base of modular backend (#6606)
## Summary Base code of new modular backend from #6577. Contains normal generation and regional prompts support. Also preview extension included to test if extensions logic works. ## Related Issues / Discussions https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d ## QA Instructions Run with and without set `USE_MODULAR_DENOISE` environment. Currently only normal and regional conditionings supported, so just generate some images and compare with main output. ## Merge Plan Discuss a bit more about injection point names? As if for example in future unet will be overridable, current `pre_unet`/`post_unet` assumes to name override as `unet` what feels a bit odd. Also `apply_cfg` - future implementation could ignore/not use cfg, so in this case `combine_noise_predictions`/`combine_noise` seems more suitable. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_
2 parents e16faa6 + 78d2b1b commit 473f4cc

File tree

13 files changed

+908
-25
lines changed

13 files changed

+908
-25
lines changed

invokeai/app/invocations/denoise_latents.py

Lines changed: 115 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
22
import inspect
3+
import os
34
from contextlib import ExitStack
45
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
56

@@ -39,6 +40,7 @@
3940
from invokeai.backend.model_manager import BaseModelType
4041
from invokeai.backend.model_patcher import ModelPatcher
4142
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
43+
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
4244
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
4345
ControlNetData,
4446
StableDiffusionGeneratorPipeline,
@@ -53,6 +55,11 @@
5355
TextConditioningData,
5456
TextConditioningRegions,
5557
)
58+
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
59+
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
60+
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
61+
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
62+
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
5663
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
5764
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
5865
from invokeai.backend.util.devices import TorchDevice
@@ -314,9 +321,10 @@ def get_conditioning_data(
314321
context: InvocationContext,
315322
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
316323
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
317-
unet: UNet2DConditionModel,
318324
latent_height: int,
319325
latent_width: int,
326+
device: torch.device,
327+
dtype: torch.dtype,
320328
cfg_scale: float | list[float],
321329
steps: int,
322330
cfg_rescale_multiplier: float,
@@ -330,25 +338,25 @@ def get_conditioning_data(
330338
uncond_list = [uncond_list]
331339

332340
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
333-
cond_list, context, unet.device, unet.dtype
341+
cond_list, context, device, dtype
334342
)
335343
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
336-
uncond_list, context, unet.device, unet.dtype
344+
uncond_list, context, device, dtype
337345
)
338346

339347
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
340348
text_conditionings=cond_text_embeddings,
341349
masks=cond_text_embedding_masks,
342350
latent_height=latent_height,
343351
latent_width=latent_width,
344-
dtype=unet.dtype,
352+
dtype=dtype,
345353
)
346354
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
347355
text_conditionings=uncond_text_embeddings,
348356
masks=uncond_text_embedding_masks,
349357
latent_height=latent_height,
350358
latent_width=latent_width,
351-
dtype=unet.dtype,
359+
dtype=dtype,
352360
)
353361

354362
if isinstance(cfg_scale, list):
@@ -707,9 +715,108 @@ def prepare_noise_and_latents(
707715

708716
return seed, noise, latents
709717

718+
def invoke(self, context: InvocationContext) -> LatentsOutput:
719+
if os.environ.get("USE_MODULAR_DENOISE", False):
720+
return self._new_invoke(context)
721+
else:
722+
return self._old_invoke(context)
723+
710724
@torch.no_grad()
711725
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
712-
def invoke(self, context: InvocationContext) -> LatentsOutput:
726+
def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
727+
ext_manager = ExtensionsManager(is_canceled=context.util.is_canceled)
728+
729+
device = TorchDevice.choose_torch_device()
730+
dtype = TorchDevice.choose_torch_dtype()
731+
732+
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
733+
latents = latents.to(device=device, dtype=dtype)
734+
if noise is not None:
735+
noise = noise.to(device=device, dtype=dtype)
736+
737+
_, _, latent_height, latent_width = latents.shape
738+
739+
conditioning_data = self.get_conditioning_data(
740+
context=context,
741+
positive_conditioning_field=self.positive_conditioning,
742+
negative_conditioning_field=self.negative_conditioning,
743+
cfg_scale=self.cfg_scale,
744+
steps=self.steps,
745+
latent_height=latent_height,
746+
latent_width=latent_width,
747+
device=device,
748+
dtype=dtype,
749+
# TODO: old backend, remove
750+
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
751+
)
752+
753+
scheduler = get_scheduler(
754+
context=context,
755+
scheduler_info=self.unet.scheduler,
756+
scheduler_name=self.scheduler,
757+
seed=seed,
758+
)
759+
760+
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
761+
scheduler,
762+
seed=seed,
763+
device=device,
764+
steps=self.steps,
765+
denoising_start=self.denoising_start,
766+
denoising_end=self.denoising_end,
767+
)
768+
769+
denoise_ctx = DenoiseContext(
770+
inputs=DenoiseInputs(
771+
orig_latents=latents,
772+
timesteps=timesteps,
773+
init_timestep=init_timestep,
774+
noise=noise,
775+
seed=seed,
776+
scheduler_step_kwargs=scheduler_step_kwargs,
777+
conditioning_data=conditioning_data,
778+
attention_processor_cls=CustomAttnProcessor2_0,
779+
),
780+
unet=None,
781+
scheduler=scheduler,
782+
)
783+
784+
# get the unet's config so that we can pass the base to sd_step_callback()
785+
unet_config = context.models.get_config(self.unet.unet.key)
786+
787+
### preview
788+
def step_callback(state: PipelineIntermediateState) -> None:
789+
context.util.sd_step_callback(state, unet_config.base)
790+
791+
ext_manager.add_extension(PreviewExt(step_callback))
792+
793+
# ext: t2i/ip adapter
794+
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
795+
796+
unet_info = context.models.load(self.unet.unet)
797+
assert isinstance(unet_info.model, UNet2DConditionModel)
798+
with (
799+
unet_info.model_on_device() as (model_state_dict, unet),
800+
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
801+
# ext: controlnet
802+
ext_manager.patch_extensions(unet),
803+
# ext: freeu, seamless, ip adapter, lora
804+
ext_manager.patch_unet(model_state_dict, unet),
805+
):
806+
sd_backend = StableDiffusionBackend(unet, scheduler)
807+
denoise_ctx.unet = unet
808+
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
809+
810+
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
811+
result_latents = result_latents.detach().to("cpu")
812+
TorchDevice.empty_cache()
813+
814+
name = context.tensors.save(tensor=result_latents)
815+
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
816+
817+
@torch.no_grad()
818+
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
819+
def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
713820
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
714821

715822
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
@@ -788,7 +895,8 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
788895
context=context,
789896
positive_conditioning_field=self.positive_conditioning,
790897
negative_conditioning_field=self.negative_conditioning,
791-
unet=unet,
898+
device=unet.device,
899+
dtype=unet.dtype,
792900
latent_height=latent_height,
793901
latent_width=latent_width,
794902
cfg_scale=self.cfg_scale,

invokeai/backend/model_patcher.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pickle
77
from contextlib import contextmanager
8-
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
8+
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
99

1010
import numpy as np
1111
import torch
@@ -32,8 +32,27 @@
3232
"""
3333

3434

35-
# TODO: rename smth like ModelPatcher and add TI method?
3635
class ModelPatcher:
36+
@staticmethod
37+
@contextmanager
38+
def patch_unet_attention_processor(unet: UNet2DConditionModel, processor_cls: Type[Any]):
39+
"""A context manager that patches `unet` with the provided attention processor.
40+
41+
Args:
42+
unet (UNet2DConditionModel): The UNet model to patch.
43+
processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...).
44+
"""
45+
unet_orig_processors = unet.attn_processors
46+
47+
# create separate instance for each attention, to be able modify each attention separately
48+
unet_new_processors = {key: processor_cls() for key in unet_orig_processors.keys()}
49+
try:
50+
unet.set_attn_processor(unet_new_processors)
51+
yield None
52+
53+
finally:
54+
unet.set_attn_processor(unet_orig_processors)
55+
3756
@staticmethod
3857
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
3958
assert "." not in lora_key
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass, field
4+
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union
5+
6+
import torch
7+
from diffusers import UNet2DConditionModel
8+
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
9+
10+
if TYPE_CHECKING:
11+
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode, TextConditioningData
12+
13+
14+
@dataclass
15+
class UNetKwargs:
16+
sample: torch.Tensor
17+
timestep: Union[torch.Tensor, float, int]
18+
encoder_hidden_states: torch.Tensor
19+
20+
class_labels: Optional[torch.Tensor] = None
21+
timestep_cond: Optional[torch.Tensor] = None
22+
attention_mask: Optional[torch.Tensor] = None
23+
cross_attention_kwargs: Optional[Dict[str, Any]] = None
24+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None
25+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None
26+
mid_block_additional_residual: Optional[torch.Tensor] = None
27+
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None
28+
encoder_attention_mask: Optional[torch.Tensor] = None
29+
# return_dict: bool = True
30+
31+
32+
@dataclass
33+
class DenoiseInputs:
34+
"""Initial variables passed to denoise. Supposed to be unchanged."""
35+
36+
# The latent-space image to denoise.
37+
# Shape: [batch, channels, latent_height, latent_width]
38+
# - If we are inpainting, this is the initial latent image before noise has been added.
39+
# - If we are generating a new image, this should be initialized to zeros.
40+
# - In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
41+
orig_latents: torch.Tensor
42+
43+
# kwargs forwarded to the scheduler.step() method.
44+
scheduler_step_kwargs: dict[str, Any]
45+
46+
# Text conditionging data.
47+
conditioning_data: TextConditioningData
48+
49+
# Noise used for two purposes:
50+
# 1. Used by the scheduler to noise the initial `latents` before denoising.
51+
# 2. Used to noise the `masked_latents` when inpainting.
52+
# `noise` should be None if the `latents` tensor has already been noised.
53+
# Shape: [1 or batch, channels, latent_height, latent_width]
54+
noise: Optional[torch.Tensor]
55+
56+
# The seed used to generate the noise for the denoising process.
57+
# HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
58+
# same noise used earlier in the pipeline. This should really be handled in a clearer way.
59+
seed: int
60+
61+
# The timestep schedule for the denoising process.
62+
timesteps: torch.Tensor
63+
64+
# The first timestep in the schedule. This is used to determine the initial noise level, so
65+
# should be populated if you want noise applied *even* if timesteps is empty.
66+
init_timestep: torch.Tensor
67+
68+
# Class of attention processor that is used.
69+
attention_processor_cls: Type[Any]
70+
71+
72+
@dataclass
73+
class DenoiseContext:
74+
"""Context with all variables in denoise"""
75+
76+
# Initial variables passed to denoise. Supposed to be unchanged.
77+
inputs: DenoiseInputs
78+
79+
# Scheduler which used to apply noise predictions.
80+
scheduler: SchedulerMixin
81+
82+
# UNet model.
83+
unet: Optional[UNet2DConditionModel] = None
84+
85+
# Current state of latent-space image in denoising process.
86+
# None until `pre_denoise_loop` callback.
87+
# Shape: [batch, channels, latent_height, latent_width]
88+
latents: Optional[torch.Tensor] = None
89+
90+
# Current denoising step index.
91+
# None until `pre_step` callback.
92+
step_index: Optional[int] = None
93+
94+
# Current denoising step timestep.
95+
# None until `pre_step` callback.
96+
timestep: Optional[torch.Tensor] = None
97+
98+
# Arguments which will be passed to UNet model.
99+
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
100+
unet_kwargs: Optional[UNetKwargs] = None
101+
102+
# SchedulerOutput class returned from step function(normally, generated by scheduler).
103+
# Supposed to be used only in `post_step` callback, otherwise can be None.
104+
step_output: Optional[SchedulerOutput] = None
105+
106+
# Scaled version of `latents`, which will be passed to unet_kwargs initialization.
107+
# Available in events inside step(between `pre_step` and `post_stop`).
108+
# Shape: [batch, channels, latent_height, latent_width]
109+
latent_model_input: Optional[torch.Tensor] = None
110+
111+
# [TMP] Defines on which conditionings current unet call will be runned.
112+
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
113+
conditioning_mode: Optional[ConditioningMode] = None
114+
115+
# [TMP] Noise predictions from negative conditioning.
116+
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
117+
# Shape: [batch, channels, latent_height, latent_width]
118+
negative_noise_pred: Optional[torch.Tensor] = None
119+
120+
# [TMP] Noise predictions from positive conditioning.
121+
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
122+
# Shape: [batch, channels, latent_height, latent_width]
123+
positive_noise_pred: Optional[torch.Tensor] = None
124+
125+
# Combined noise prediction from passed conditionings.
126+
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
127+
# Shape: [batch, channels, latent_height, latent_width]
128+
noise_pred: Optional[torch.Tensor] = None
129+
130+
# Dictionary for extensions to pass extra info about denoise process to other extensions.
131+
extra: dict = field(default_factory=dict)

invokeai/backend/stable_diffusion/diffusers_pipeline.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,12 @@
2323
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
2424
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
2525
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
26+
from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState
2627
from invokeai.backend.util.attention import auto_detect_slice_size
2728
from invokeai.backend.util.devices import TorchDevice
2829
from invokeai.backend.util.hotfixes import ControlNetModel
2930

3031

31-
@dataclass
32-
class PipelineIntermediateState:
33-
step: int
34-
order: int
35-
total_steps: int
36-
timestep: int
37-
latents: torch.Tensor
38-
predicted_original: Optional[torch.Tensor] = None
39-
40-
4132
@dataclass
4233
class AddsMaskGuidance:
4334
mask: torch.Tensor

0 commit comments

Comments
 (0)