-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Modular backend - add ControlNet #6642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
RyanJDick
merged 5 commits into
invoke-ai:main
from
StAlKeR7779:stalker7779/modular_controlnet
Jul 23, 2024
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
42356ec
Add ControlNet support to denoise
StAlKeR7779 3cb13d6
Rename as suggested in other PRs
StAlKeR7779 4e8dcb7
Suggested changes
StAlKeR7779 39e804d
Use consistent param names in patch_extension(...) functions: context…
RyanJDick e2e47fd
Merge branch 'main' into stalker-modular_controlnet
RyanJDick File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
158 changes: 158 additions & 0 deletions
158
invokeai/backend/stable_diffusion/extensions/controlnet.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
from __future__ import annotations | ||
|
||
import math | ||
from contextlib import contextmanager | ||
from typing import TYPE_CHECKING, List, Optional, Union | ||
|
||
import torch | ||
from PIL.Image import Image | ||
|
||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR | ||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image | ||
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs | ||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode | ||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType | ||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback | ||
|
||
if TYPE_CHECKING: | ||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext | ||
from invokeai.backend.util.hotfixes import ControlNetModel | ||
|
||
|
||
class ControlNetExt(ExtensionBase): | ||
def __init__( | ||
self, | ||
model: ControlNetModel, | ||
image: Image, | ||
weight: Union[float, List[float]], | ||
begin_step_percent: float, | ||
end_step_percent: float, | ||
control_mode: CONTROLNET_MODE_VALUES, | ||
resize_mode: CONTROLNET_RESIZE_VALUES, | ||
): | ||
super().__init__() | ||
self._model = model | ||
self._image = image | ||
self._weight = weight | ||
self._begin_step_percent = begin_step_percent | ||
self._end_step_percent = end_step_percent | ||
self._control_mode = control_mode | ||
self._resize_mode = resize_mode | ||
|
||
self._image_tensor: Optional[torch.Tensor] = None | ||
|
||
@contextmanager | ||
def patch_extension(self, ctx: DenoiseContext): | ||
original_processors = self._model.attn_processors | ||
try: | ||
self._model.set_attn_processor(ctx.inputs.attention_processor_cls()) | ||
|
||
yield None | ||
finally: | ||
self._model.set_attn_processor(original_processors) | ||
|
||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP) | ||
def resize_image(self, ctx: DenoiseContext): | ||
_, _, latent_height, latent_width = ctx.latents.shape | ||
image_height = latent_height * LATENT_SCALE_FACTOR | ||
image_width = latent_width * LATENT_SCALE_FACTOR | ||
|
||
self._image_tensor = prepare_control_image( | ||
image=self._image, | ||
do_classifier_free_guidance=False, | ||
width=image_width, | ||
height=image_height, | ||
device=ctx.latents.device, | ||
dtype=ctx.latents.dtype, | ||
control_mode=self._control_mode, | ||
resize_mode=self._resize_mode, | ||
) | ||
|
||
@callback(ExtensionCallbackType.PRE_UNET) | ||
def pre_unet_step(self, ctx: DenoiseContext): | ||
# skip if model not active in current step | ||
total_steps = len(ctx.inputs.timesteps) | ||
first_step = math.floor(self._begin_step_percent * total_steps) | ||
last_step = math.ceil(self._end_step_percent * total_steps) | ||
if ctx.step_index < first_step or ctx.step_index > last_step: | ||
return | ||
|
||
# convert mode to internal flags | ||
soft_injection = self._control_mode in ["more_prompt", "more_control"] | ||
cfg_injection = self._control_mode in ["more_control", "unbalanced"] | ||
|
||
# no negative conditioning in cfg_injection mode | ||
if cfg_injection: | ||
if ctx.conditioning_mode == ConditioningMode.Negative: | ||
return | ||
down_samples, mid_sample = self._run(ctx, soft_injection, ConditioningMode.Positive) | ||
|
||
if ctx.conditioning_mode == ConditioningMode.Both: | ||
# add zeros as samples for negative conditioning | ||
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples] | ||
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample]) | ||
|
||
else: | ||
down_samples, mid_sample = self._run(ctx, soft_injection, ctx.conditioning_mode) | ||
|
||
if ( | ||
ctx.unet_kwargs.down_block_additional_residuals is None | ||
and ctx.unet_kwargs.mid_block_additional_residual is None | ||
): | ||
ctx.unet_kwargs.down_block_additional_residuals = down_samples | ||
ctx.unet_kwargs.mid_block_additional_residual = mid_sample | ||
else: | ||
# add controlnet outputs together if have multiple controlnets | ||
ctx.unet_kwargs.down_block_additional_residuals = [ | ||
samples_prev + samples_curr | ||
for samples_prev, samples_curr in zip( | ||
ctx.unet_kwargs.down_block_additional_residuals, down_samples, strict=True | ||
) | ||
] | ||
ctx.unet_kwargs.mid_block_additional_residual += mid_sample | ||
|
||
def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: ConditioningMode): | ||
total_steps = len(ctx.inputs.timesteps) | ||
|
||
model_input = ctx.latent_model_input | ||
image_tensor = self._image_tensor | ||
if conditioning_mode == ConditioningMode.Both: | ||
model_input = torch.cat([model_input] * 2) | ||
image_tensor = torch.cat([image_tensor] * 2) | ||
|
||
cn_unet_kwargs = UNetKwargs( | ||
sample=model_input, | ||
timestep=ctx.timestep, | ||
encoder_hidden_states=None, # set later by conditioning | ||
cross_attention_kwargs=dict( # noqa: C408 | ||
percent_through=ctx.step_index / total_steps, | ||
), | ||
StAlKeR7779 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode) | ||
|
||
# get static weight, or weight corresponding to current step | ||
weight = self._weight | ||
if isinstance(weight, list): | ||
weight = weight[ctx.step_index] | ||
|
||
tmp_kwargs = vars(cn_unet_kwargs) | ||
|
||
# Remove kwargs not related to ControlNet unet | ||
# ControlNet guidance fields | ||
del tmp_kwargs["down_block_additional_residuals"] | ||
del tmp_kwargs["mid_block_additional_residual"] | ||
|
||
# T2i Adapter guidance fields | ||
del tmp_kwargs["down_intrablock_additional_residuals"] | ||
|
||
# controlnet(s) inference | ||
down_samples, mid_sample = self._model( | ||
controlnet_cond=image_tensor, | ||
conditioning_scale=weight, # controlnet specific, NOT the guidance scale | ||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel | ||
return_dict=False, | ||
**vars(cn_unet_kwargs), | ||
) | ||
|
||
return down_samples, mid_sample |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.