Skip to content

Wrap Flux's call with diffusion_model wrappers similar to the Unet model #7382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions comfy/ldm/flux/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias:
)
self.flipped_img_txt = flipped_img_txt

def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):

img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)

Expand Down Expand Up @@ -244,7 +245,7 @@ def __init__(
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)

def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
mod, _ = self.modulation(vec)
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)

Expand Down
18 changes: 13 additions & 5 deletions comfy/ldm/flux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.ldm.common_dit
import comfy.patcher_extension

from .layers import (
DoubleStreamBlock,
Expand Down Expand Up @@ -130,7 +131,7 @@ def block_wrap(args):
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
attn_mask=args.get("attn_mask"),transformer_options=transformer_options)
return out

out = blocks_replace[("double_block", i)]({"img": img,
Expand All @@ -146,7 +147,7 @@ def block_wrap(args):
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask)
attn_mask=attn_mask, transformer_options=transformer_options)

if control is not None: # Controlnet
control_i = control.get("input")
Expand All @@ -164,7 +165,7 @@ def block_wrap(args):
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
attn_mask=args.get("attn_mask"),transformer_options=transformer_options)
return out

out = blocks_replace[("single_block", i)]({"img": img,
Expand All @@ -174,7 +175,7 @@ def block_wrap(args):
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)

if control is not None: # Controlnet
control_o = control.get("output")
Expand All @@ -188,7 +189,7 @@ def block_wrap(args):
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img

def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs):
def _forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
Expand All @@ -205,3 +206,10 @@ def forward(self, x, timestep, context, y, guidance=None, control=None, transfor
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]

def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, guidance=guidance, control=control, transformer_options=transformer_options, **kwargs)