Skip to content

Commit 7c975f0

Browse files
authored
Modular backend - add ControlNet (#6642)
## Summary ControlNet code from #6577. ## Related Issues / Discussions #6606 https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d ## QA Instructions Run with and without set `USE_MODULAR_DENOISE` environment. ## Merge Plan Merge #6641 firstly, to be able see output difference properly. If you think that there should be some kind of tests - feel free to add. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_
2 parents 7b8e25f + e2e47fd commit 7c975f0

File tree

4 files changed

+218
-19
lines changed

4 files changed

+218
-19
lines changed

invokeai/app/invocations/denoise_latents.py

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
5959
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
6060
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
61+
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
6162
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
6263
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
6364
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
@@ -465,6 +466,38 @@ def prep_control_data(
465466

466467
return controlnet_data
467468

469+
@staticmethod
470+
def parse_controlnet_field(
471+
exit_stack: ExitStack,
472+
context: InvocationContext,
473+
control_input: ControlField | list[ControlField] | None,
474+
ext_manager: ExtensionsManager,
475+
) -> None:
476+
# Normalize control_input to a list.
477+
control_list: list[ControlField]
478+
if isinstance(control_input, ControlField):
479+
control_list = [control_input]
480+
elif isinstance(control_input, list):
481+
control_list = control_input
482+
elif control_input is None:
483+
control_list = []
484+
else:
485+
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
486+
487+
for control_info in control_list:
488+
model = exit_stack.enter_context(context.models.load(control_info.control_model))
489+
ext_manager.add_extension(
490+
ControlNetExt(
491+
model=model,
492+
image=context.images.get_pil(control_info.image.image_name),
493+
weight=control_info.control_weight,
494+
begin_step_percent=control_info.begin_step_percent,
495+
end_step_percent=control_info.end_step_percent,
496+
control_mode=control_info.control_mode,
497+
resize_mode=control_info.resize_mode,
498+
)
499+
)
500+
468501
def prep_ip_adapter_image_prompts(
469502
self,
470503
context: InvocationContext,
@@ -800,22 +833,30 @@ def step_callback(state: PipelineIntermediateState) -> None:
800833
if self.unet.freeu_config:
801834
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
802835

803-
# ext: t2i/ip adapter
804-
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
805-
806-
unet_info = context.models.load(self.unet.unet)
807-
assert isinstance(unet_info.model, UNet2DConditionModel)
808-
with (
809-
unet_info.model_on_device() as (cached_weights, unet),
810-
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
811-
# ext: controlnet
812-
ext_manager.patch_extensions(unet),
813-
# ext: freeu, seamless, ip adapter, lora
814-
ext_manager.patch_unet(unet, cached_weights),
815-
):
816-
sd_backend = StableDiffusionBackend(unet, scheduler)
817-
denoise_ctx.unet = unet
818-
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
836+
# context for loading additional models
837+
with ExitStack() as exit_stack:
838+
# later should be smth like:
839+
# for extension_field in self.extensions:
840+
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
841+
# ext_manager.add_extension(ext)
842+
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
843+
844+
# ext: t2i/ip adapter
845+
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
846+
847+
unet_info = context.models.load(self.unet.unet)
848+
assert isinstance(unet_info.model, UNet2DConditionModel)
849+
with (
850+
unet_info.model_on_device() as (cached_weights, unet),
851+
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
852+
# ext: controlnet
853+
ext_manager.patch_extensions(denoise_ctx),
854+
# ext: freeu, seamless, ip adapter, lora
855+
ext_manager.patch_unet(unet, cached_weights),
856+
):
857+
sd_backend = StableDiffusionBackend(unet, scheduler)
858+
denoise_ctx.unet = unet
859+
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
819860

820861
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
821862
result_latents = result_latents.detach().to("cpu")

invokeai/backend/stable_diffusion/extensions/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def get_callbacks(self):
5252
return self._callbacks
5353

5454
@contextmanager
55-
def patch_extension(self, context: DenoiseContext):
55+
def patch_extension(self, ctx: DenoiseContext):
5656
yield None
5757

5858
@contextmanager
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
from __future__ import annotations
2+
3+
import math
4+
from contextlib import contextmanager
5+
from typing import TYPE_CHECKING, List, Optional, Union
6+
7+
import torch
8+
from PIL.Image import Image
9+
10+
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
11+
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
12+
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs
13+
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
14+
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
15+
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
16+
17+
if TYPE_CHECKING:
18+
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
19+
from invokeai.backend.util.hotfixes import ControlNetModel
20+
21+
22+
class ControlNetExt(ExtensionBase):
23+
def __init__(
24+
self,
25+
model: ControlNetModel,
26+
image: Image,
27+
weight: Union[float, List[float]],
28+
begin_step_percent: float,
29+
end_step_percent: float,
30+
control_mode: CONTROLNET_MODE_VALUES,
31+
resize_mode: CONTROLNET_RESIZE_VALUES,
32+
):
33+
super().__init__()
34+
self._model = model
35+
self._image = image
36+
self._weight = weight
37+
self._begin_step_percent = begin_step_percent
38+
self._end_step_percent = end_step_percent
39+
self._control_mode = control_mode
40+
self._resize_mode = resize_mode
41+
42+
self._image_tensor: Optional[torch.Tensor] = None
43+
44+
@contextmanager
45+
def patch_extension(self, ctx: DenoiseContext):
46+
original_processors = self._model.attn_processors
47+
try:
48+
self._model.set_attn_processor(ctx.inputs.attention_processor_cls())
49+
50+
yield None
51+
finally:
52+
self._model.set_attn_processor(original_processors)
53+
54+
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
55+
def resize_image(self, ctx: DenoiseContext):
56+
_, _, latent_height, latent_width = ctx.latents.shape
57+
image_height = latent_height * LATENT_SCALE_FACTOR
58+
image_width = latent_width * LATENT_SCALE_FACTOR
59+
60+
self._image_tensor = prepare_control_image(
61+
image=self._image,
62+
do_classifier_free_guidance=False,
63+
width=image_width,
64+
height=image_height,
65+
device=ctx.latents.device,
66+
dtype=ctx.latents.dtype,
67+
control_mode=self._control_mode,
68+
resize_mode=self._resize_mode,
69+
)
70+
71+
@callback(ExtensionCallbackType.PRE_UNET)
72+
def pre_unet_step(self, ctx: DenoiseContext):
73+
# skip if model not active in current step
74+
total_steps = len(ctx.inputs.timesteps)
75+
first_step = math.floor(self._begin_step_percent * total_steps)
76+
last_step = math.ceil(self._end_step_percent * total_steps)
77+
if ctx.step_index < first_step or ctx.step_index > last_step:
78+
return
79+
80+
# convert mode to internal flags
81+
soft_injection = self._control_mode in ["more_prompt", "more_control"]
82+
cfg_injection = self._control_mode in ["more_control", "unbalanced"]
83+
84+
# no negative conditioning in cfg_injection mode
85+
if cfg_injection:
86+
if ctx.conditioning_mode == ConditioningMode.Negative:
87+
return
88+
down_samples, mid_sample = self._run(ctx, soft_injection, ConditioningMode.Positive)
89+
90+
if ctx.conditioning_mode == ConditioningMode.Both:
91+
# add zeros as samples for negative conditioning
92+
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
93+
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
94+
95+
else:
96+
down_samples, mid_sample = self._run(ctx, soft_injection, ctx.conditioning_mode)
97+
98+
if (
99+
ctx.unet_kwargs.down_block_additional_residuals is None
100+
and ctx.unet_kwargs.mid_block_additional_residual is None
101+
):
102+
ctx.unet_kwargs.down_block_additional_residuals = down_samples
103+
ctx.unet_kwargs.mid_block_additional_residual = mid_sample
104+
else:
105+
# add controlnet outputs together if have multiple controlnets
106+
ctx.unet_kwargs.down_block_additional_residuals = [
107+
samples_prev + samples_curr
108+
for samples_prev, samples_curr in zip(
109+
ctx.unet_kwargs.down_block_additional_residuals, down_samples, strict=True
110+
)
111+
]
112+
ctx.unet_kwargs.mid_block_additional_residual += mid_sample
113+
114+
def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: ConditioningMode):
115+
total_steps = len(ctx.inputs.timesteps)
116+
117+
model_input = ctx.latent_model_input
118+
image_tensor = self._image_tensor
119+
if conditioning_mode == ConditioningMode.Both:
120+
model_input = torch.cat([model_input] * 2)
121+
image_tensor = torch.cat([image_tensor] * 2)
122+
123+
cn_unet_kwargs = UNetKwargs(
124+
sample=model_input,
125+
timestep=ctx.timestep,
126+
encoder_hidden_states=None, # set later by conditioning
127+
cross_attention_kwargs=dict( # noqa: C408
128+
percent_through=ctx.step_index / total_steps,
129+
),
130+
)
131+
132+
ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode)
133+
134+
# get static weight, or weight corresponding to current step
135+
weight = self._weight
136+
if isinstance(weight, list):
137+
weight = weight[ctx.step_index]
138+
139+
tmp_kwargs = vars(cn_unet_kwargs)
140+
141+
# Remove kwargs not related to ControlNet unet
142+
# ControlNet guidance fields
143+
del tmp_kwargs["down_block_additional_residuals"]
144+
del tmp_kwargs["mid_block_additional_residual"]
145+
146+
# T2i Adapter guidance fields
147+
del tmp_kwargs["down_intrablock_additional_residuals"]
148+
149+
# controlnet(s) inference
150+
down_samples, mid_sample = self._model(
151+
controlnet_cond=image_tensor,
152+
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
153+
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
154+
return_dict=False,
155+
**vars(cn_unet_kwargs),
156+
)
157+
158+
return down_samples, mid_sample

invokeai/backend/stable_diffusion/extensions_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,13 @@ def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext
5252
cb.function(ctx)
5353

5454
@contextmanager
55-
def patch_extensions(self, context: DenoiseContext):
55+
def patch_extensions(self, ctx: DenoiseContext):
5656
if self._is_canceled and self._is_canceled():
5757
raise CanceledException
5858

5959
with ExitStack() as exit_stack:
6060
for ext in self._extensions:
61-
exit_stack.enter_context(ext.patch_extension(context))
61+
exit_stack.enter_context(ext.patch_extension(ctx))
6262

6363
yield None
6464

0 commit comments

Comments
 (0)