Skip to content
This repository was archived by the owner on Sep 26, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions scripts/conversion/convert_diffusers_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class Args(argparse.Namespace):
half: bool
verbose: bool
skip_init_check: bool
override_weights: str | None


def setup_converter(args: Args) -> ModelConverter:
Expand Down
89 changes: 89 additions & 0 deletions scripts/conversion/convert_ic_light.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import argparse
from pathlib import Path

from convert_diffusers_unet import Args as UNetArgs, setup_converter as setup_unet_converter
from huggingface_hub import hf_hub_download # type: ignore

from refiners.fluxion.utils import load_from_safetensors, save_to_safetensors


class Args(argparse.Namespace):
source_path: str
output_path: str | None
subfolder: str
half: bool
verbose: bool
reference_unet_path: str


def main() -> None:
parser = argparse.ArgumentParser(description="Converts IC-Light patch weights to work with Refiners")
parser.add_argument(
"--from",
type=str,
dest="source_path",
default="lllyasviel/ic-light",
help=(
"Can be a path to a .bin file, a .safetensors file or a model name from the Hugging Face Hub. Default:"
" lllyasviel/ic-light"
),
)
parser.add_argument("--filename", type=str, default="iclight_sd15_fc.safetensors", help="Filename inside the hub.")
parser.add_argument(
"--to",
type=str,
dest="output_path",
default=None,
help=(
"Output path (.safetensors) for converted model. If not provided, the output path will be the same as the"
" source path."
),
)
parser.add_argument(
"--verbose",
action="store_true",
default=False,
help="Prints additional information during conversion. Default: False",
)
parser.add_argument(
"--reference-unet-path",
type=str,
dest="reference_unet_path",
default="runwayml/stable-diffusion-v1-5",
help="Path to the reference UNet weights.",
)
args = parser.parse_args(namespace=Args())
if args.output_path is None:
args.output_path = f"{Path(args.filename).stem}-refiners.safetensors"

patch_file = (
Path(args.source_path)
if args.source_path.endswith(".safetensors")
else Path(
hf_hub_download(
repo_id=args.source_path,
filename=args.filename,
)
)
)
patch_weights = load_from_safetensors(patch_file)

unet_args = UNetArgs(
source_path=args.reference_unet_path,
subfolder="unet",
half=False,
verbose=False,
skip_init_check=True,
override_weights=None,
)
converter = setup_unet_converter(args=unet_args)
result = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage]
source_state_dict=patch_weights,
target_state_dict=converter.target_model.state_dict(),
state_dict_mapping=converter.get_mapping(),
)
save_to_safetensors(path=args.output_path, tensors=result)


if __name__ == "__main__":
main()
20 changes: 20 additions & 0 deletions scripts/prepare_test_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,14 @@ def download_sdxl_lightning_lora():
)


def download_ic_light():
download_file(
"https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors",
dest_folder=test_weights_dir,
expected_hash="bce70123",
)


def printg(msg: str):
"""print in green color"""
print("\033[92m" + msg + "\033[0m")
Expand Down Expand Up @@ -790,6 +798,16 @@ def convert_sdxl_lightning_base():
)


def convert_ic_light():
run_conversion_script(
"convert_ic_light.py",
"tests/weights/iclight_sd15_fc.safetensors",
"tests/weights/iclight_sd15_fc-refiners.safetensors",
half=False,
expected_hash="be315c1f",
)


def download_all():
print(f"\nAll weights will be downloaded to {test_weights_dir}\n")
download_sd15("runwayml/stable-diffusion-v1-5")
Expand All @@ -811,6 +829,7 @@ def download_all():
download_lcm_lora()
download_sdxl_lightning_base()
download_sdxl_lightning_lora()
download_ic_light()


def convert_all():
Expand All @@ -830,6 +849,7 @@ def convert_all():
convert_control_lora_fooocus()
convert_lcm_base()
convert_sdxl_lightning_base()
convert_ic_light()


def main():
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import torch
from PIL import Image
from torch.nn.init import zeros_ as zero_init

from refiners.fluxion import layers as fl
from refiners.fluxion.utils import image_to_tensor, no_grad
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder, StableDiffusion_1
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import DownBlocks, SD1UNet


class ICLight(StableDiffusion_1):
"""
IC-Light is a Stable Diffusion model that can be used to relight a reference image.

At initialization, the UNet will be patched to accept four additional input channels. Only the text-conditioned relighting model is supported for now.

```example
import torch
from huggingface_hub import hf_hub_download
from PIL import Image

from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
from refiners.foundationals.clip import CLIPTextEncoderL
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1Autoencoder, SD1UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
no_grad().__enter__()
manual_seed(42)

sd = ICLight(
patch_weights=load_from_safetensors(
path=hf_hub_download(
repo_id="refiners/ic_light.sd1_5.fc",
filename="model.safetensors",
),
device=device,
),
unet=SD1UNet(in_channels=4, device=device, dtype=dtype).load_from_safetensors(
tensors_path=hf_hub_download(
repo_id="refiners/realistic_vision.v5_1.sd1_5.unet",
filename="model.safetensors",
)
),
clip_text_encoder=CLIPTextEncoderL(device=device, dtype=dtype).load_from_safetensors(
tensors_path=hf_hub_download(
repo_id="refiners/realistic_vision.v5_1.sd1_5.text_encoder",
filename="model.safetensors",
)
),
lda=SD1Autoencoder(device=device, dtype=dtype).load_from_safetensors(
tensors_path=hf_hub_download(
repo_id="refiners/realistic_vision.v5_1.sd1_5.autoencoder",
filename="model.safetensors",
)
),
device=device,
dtype=dtype,
)

prompt = "soft lighting, high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)

image = Image.open("reference-image.png").resize((512, 512))
sd.set_ic_light_condition(image)

x = torch.randn(
size=(1, 4, 64, 64),
device=device,
dtype=dtype,
)

for step in sd.steps:
x = sd(
x=x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=1.5,
)
predicted_image = sd.lda.latents_to_image(x)

predicted_image.save("ic-light-output.png")
"""

def __init__(
self,
patch_weights: dict[str, torch.Tensor],
unet: SD1UNet,
lda: SD1Autoencoder | None = None,
clip_text_encoder: CLIPTextEncoderL | None = None,
solver: Solver | None = None,
device: torch.device | str = "cpu",
dtype: torch.dtype = torch.float32,
) -> None:
super().__init__(
unet=unet,
lda=lda,
clip_text_encoder=clip_text_encoder,
solver=solver,
device=device,
dtype=dtype,
)
self._extend_conv_in()
self._apply_patch(weights=patch_weights)

@no_grad()
def _extend_conv_in(self) -> None:
"""
Extend to 8 the input channels of the first convolutional layer of the UNet.
"""
down_blocks = self.unet.ensure_find(DownBlocks)
first_block = down_blocks.layer(0, fl.Chain)
conv_in = first_block.ensure_find(fl.Conv2d)
new_conv_in = fl.Conv2d(
in_channels=conv_in.in_channels + 4,
out_channels=conv_in.out_channels,
kernel_size=(conv_in.kernel_size[0], conv_in.kernel_size[1]),
padding=(int(conv_in.padding[0]), int(conv_in.padding[1])),
device=conv_in.device,
dtype=conv_in.dtype,
)
zero_init(new_conv_in.weight)
new_conv_in.bias = conv_in.bias
new_conv_in.weight[:, :4, :, :] = conv_in.weight
first_block.replace(old_module=conv_in, new_module=new_conv_in)

def _apply_patch(self, weights: dict[str, torch.Tensor]) -> None:
"""
Apply the patch weights to the UNet, modifying inplace the state dict.
"""
current_state_dict = self.unet.state_dict()
new_state_dict = {
key: tensor + weights[key].to(tensor.device, tensor.dtype) for key, tensor in current_state_dict.items()
}
self.unet.load_state_dict(new_state_dict)

@staticmethod
def compute_gray_composite(image: Image.Image, mask: Image.Image) -> Image.Image:
"""
Compute a grayscale composite of an image and a mask.
"""
assert mask.mode == "L", "Mask must be a grayscale image"
assert image.size == mask.size, "Image and mask must have the same size"
background = Image.new("RGB", image.size, (127, 127, 127))
return Image.composite(image, background, mask)

def set_ic_light_condition(
self, image: Image.Image, mask: Image.Image | None = None, use_rescaled_image: bool = False
) -> None:
"""
Set the IC light condition.

If a mask is provided, it will be used to compute a grayscale composite of the image and the mask ; otherwise,
the image will be used as is, but note that IC-Light requires a 127-valued gray background to work.

`use_rescaled_image` is used to rescale the image to [-1, 1] range. This is the expected range when using the
Stable Diffusion autoencoder. But in the original code this part is skipped, giving different results.
see https://github.com/lllyasviel/IC-Light/blob/788687452a2bad59633a401281c8aee91bdd3750/gradio_demo.py#L262-L265
"""
if mask is not None:
image = self.compute_gray_composite(image=image, mask=mask)
image_tensor = image_to_tensor(image, device=self.device, dtype=self.dtype)
if use_rescaled_image:
image_tensor = 2 * image_tensor - 1
latents = self.lda.encode(image_tensor)
self._ic_light_condition = latents

def __call__(
self, x: torch.Tensor, step: int, *, clip_text_embedding: torch.Tensor, condition_scale: float = 2.0
) -> torch.Tensor:
assert self._ic_light_condition is not None, "Reference image not set, use `set_ic_light_condition` first"
x = torch.cat((x, self._ic_light_condition), dim=1)
return super().__call__(
x,
step,
clip_text_embedding=clip_text_embedding,
condition_scale=condition_scale,
)
Loading