Skip to content

Commit de39c5e

Browse files
authored
Modular backend - add rescale cfg (#6640)
## Summary Rescale CFG code from #6577. ## Related Issues / Discussions #6606 https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d ## QA Instructions Run with and without set `USE_MODULAR_DENOISE` environment. ~~Note: for some reasons there slightly different output from run to run, but I able sometimes to get same output on main and this branch.~~ Fix presented in #6641. ## Merge Plan ~~Nope.~~ Merge #6641 firstly, to be able see output difference properly. If you think that there should be some kind of tests - feel free to add. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_
2 parents 154e8f6 + d014dc9 commit de39c5e

File tree

5 files changed

+56
-15
lines changed

5 files changed

+56
-15
lines changed

invokeai/app/invocations/denoise_latents.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
6060
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
6161
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
62+
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
6263
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
6364
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
6465
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@@ -790,6 +791,10 @@ def step_callback(state: PipelineIntermediateState) -> None:
790791

791792
ext_manager.add_extension(PreviewExt(step_callback))
792793

794+
### cfg rescale
795+
if self.cfg_rescale_multiplier > 0:
796+
ext_manager.add_extension(RescaleCFGExt(self.cfg_rescale_multiplier))
797+
793798
# ext: t2i/ip adapter
794799
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
795800

invokeai/backend/stable_diffusion/denoise_context.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -83,47 +83,47 @@ class DenoiseContext:
8383
unet: Optional[UNet2DConditionModel] = None
8484

8585
# Current state of latent-space image in denoising process.
86-
# None until `pre_denoise_loop` callback.
86+
# None until `PRE_DENOISE_LOOP` callback.
8787
# Shape: [batch, channels, latent_height, latent_width]
8888
latents: Optional[torch.Tensor] = None
8989

9090
# Current denoising step index.
91-
# None until `pre_step` callback.
91+
# None until `PRE_STEP` callback.
9292
step_index: Optional[int] = None
9393

9494
# Current denoising step timestep.
95-
# None until `pre_step` callback.
95+
# None until `PRE_STEP` callback.
9696
timestep: Optional[torch.Tensor] = None
9797

9898
# Arguments which will be passed to UNet model.
99-
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
99+
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
100100
unet_kwargs: Optional[UNetKwargs] = None
101101

102102
# SchedulerOutput class returned from step function(normally, generated by scheduler).
103-
# Supposed to be used only in `post_step` callback, otherwise can be None.
103+
# Supposed to be used only in `POST_STEP` callback, otherwise can be None.
104104
step_output: Optional[SchedulerOutput] = None
105105

106106
# Scaled version of `latents`, which will be passed to unet_kwargs initialization.
107-
# Available in events inside step(between `pre_step` and `post_stop`).
107+
# Available in events inside step(between `PRE_STEP` and `POST_STEP`).
108108
# Shape: [batch, channels, latent_height, latent_width]
109109
latent_model_input: Optional[torch.Tensor] = None
110110

111111
# [TMP] Defines on which conditionings current unet call will be runned.
112-
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
112+
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
113113
conditioning_mode: Optional[ConditioningMode] = None
114114

115115
# [TMP] Noise predictions from negative conditioning.
116-
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
116+
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
117117
# Shape: [batch, channels, latent_height, latent_width]
118118
negative_noise_pred: Optional[torch.Tensor] = None
119119

120120
# [TMP] Noise predictions from positive conditioning.
121-
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
121+
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
122122
# Shape: [batch, channels, latent_height, latent_width]
123123
positive_noise_pred: Optional[torch.Tensor] = None
124124

125125
# Combined noise prediction from passed conditionings.
126-
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
126+
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
127127
# Shape: [batch, channels, latent_height, latent_width]
128128
noise_pred: Optional[torch.Tensor] = None
129129

invokeai/backend/stable_diffusion/diffusion_backend.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,12 @@ def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> Scheduler
7676
both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both)
7777
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
7878

79-
# ext: override apply_cfg
80-
ctx.noise_pred = self.apply_cfg(ctx)
79+
# ext: override combine_noise_preds
80+
ctx.noise_pred = self.combine_noise_preds(ctx)
8181

8282
# ext: cfg_rescale [modify_noise_prediction]
8383
# TODO: rename
84-
ext_manager.run_callback(ExtensionCallbackType.POST_APPLY_CFG, ctx)
84+
ext_manager.run_callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS, ctx)
8585

8686
# compute the previous noisy sample x_t -> x_t-1
8787
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
@@ -95,7 +95,7 @@ def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> Scheduler
9595
return step_output
9696

9797
@staticmethod
98-
def apply_cfg(ctx: DenoiseContext) -> torch.Tensor:
98+
def combine_noise_preds(ctx: DenoiseContext) -> torch.Tensor:
9999
guidance_scale = ctx.inputs.conditioning_data.guidance_scale
100100
if isinstance(guidance_scale, list):
101101
guidance_scale = guidance_scale[ctx.step_index]

invokeai/backend/stable_diffusion/extension_callback_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ class ExtensionCallbackType(Enum):
99
POST_STEP = "post_step"
1010
PRE_UNET = "pre_unet"
1111
POST_UNET = "post_unet"
12-
POST_APPLY_CFG = "post_apply_cfg"
12+
POST_COMBINE_NOISE_PREDS = "post_combine_noise_preds"
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import torch
6+
7+
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
8+
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
9+
10+
if TYPE_CHECKING:
11+
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
12+
13+
14+
class RescaleCFGExt(ExtensionBase):
15+
def __init__(self, rescale_multiplier: float):
16+
super().__init__()
17+
self._rescale_multiplier = rescale_multiplier
18+
19+
@staticmethod
20+
def _rescale_cfg(total_noise_pred: torch.Tensor, pos_noise_pred: torch.Tensor, multiplier: float = 0.7):
21+
"""Implementation of Algorithm 2 from https://arxiv.org/pdf/2305.08891.pdf."""
22+
ro_pos = torch.std(pos_noise_pred, dim=(1, 2, 3), keepdim=True)
23+
ro_cfg = torch.std(total_noise_pred, dim=(1, 2, 3), keepdim=True)
24+
25+
x_rescaled = total_noise_pred * (ro_pos / ro_cfg)
26+
x_final = multiplier * x_rescaled + (1.0 - multiplier) * total_noise_pred
27+
return x_final
28+
29+
@callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS)
30+
def rescale_noise_pred(self, ctx: DenoiseContext):
31+
if self._rescale_multiplier > 0:
32+
ctx.noise_pred = self._rescale_cfg(
33+
ctx.noise_pred,
34+
ctx.positive_noise_pred,
35+
self._rescale_multiplier,
36+
)

0 commit comments

Comments
 (0)