Skip to content

Commit a5d3c29

Browse files
LaurentLaurent2916
authored andcommitted
update ic_light adapter, bugfix + improve docstrings
1 parent 2cb0f06 commit a5d3c29

File tree

2 files changed

+103
-91
lines changed

2 files changed

+103
-91
lines changed

src/refiners/foundationals/latent_diffusion/stable_diffusion_1/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter
2+
from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight
23
from refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter
34
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
45
SD1Autoencoder,
@@ -16,4 +17,5 @@
1617
"SD1ControlnetAdapter",
1718
"SD1IPAdapter",
1819
"SD1T2IAdapter",
20+
"ICLight",
1921
]

src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ic_light.py

Lines changed: 101 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -3,87 +3,90 @@
33
from torch.nn.init import zeros_ as zero_init
44

55
from refiners.fluxion import layers as fl
6-
from refiners.fluxion.utils import image_to_tensor, no_grad
6+
from refiners.fluxion.utils import no_grad
77
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
88
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
99
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder, StableDiffusion_1
1010
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import DownBlocks, SD1UNet
1111

1212

1313
class ICLight(StableDiffusion_1):
14-
"""
15-
IC-Light is a Stable Diffusion model that can be used to relight a reference image.
16-
17-
At initialization, the UNet will be patched to accept four additional input channels. Only the text-conditioned relighting model is supported for now.
18-
19-
```example
20-
import torch
21-
from huggingface_hub import hf_hub_download
22-
from PIL import Image
23-
24-
from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
25-
from refiners.foundationals.clip import CLIPTextEncoderL
26-
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1Autoencoder, SD1UNet
27-
from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight
28-
29-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30-
dtype = torch.float32
31-
no_grad().__enter__()
32-
manual_seed(42)
33-
34-
sd = ICLight(
35-
patch_weights=load_from_safetensors(
36-
path=hf_hub_download(
37-
repo_id="refiners/ic_light.sd1_5.fc",
38-
filename="model.safetensors",
14+
"""IC-Light is a Stable Diffusion model that can be used to relight a reference image.
15+
16+
At initialization, the UNet will be patched to accept four additional input channels.
17+
Only the text-conditioned relighting model is supported for now.
18+
19+
20+
Example:
21+
```py
22+
import torch
23+
from huggingface_hub import hf_hub_download
24+
from PIL import Image
25+
26+
from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
27+
from refiners.foundationals.clip import CLIPTextEncoderL
28+
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1Autoencoder, SD1UNet
29+
from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight
30+
31+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32+
dtype = torch.float32
33+
no_grad().__enter__()
34+
manual_seed(42)
35+
36+
sd = ICLight(
37+
patch_weights=load_from_safetensors(
38+
path=hf_hub_download(
39+
repo_id="refiners/ic_light.sd1_5.fc",
40+
filename="model.safetensors",
41+
),
42+
device=device,
43+
),
44+
unet=SD1UNet(in_channels=4, device=device, dtype=dtype).load_from_safetensors(
45+
tensors_path=hf_hub_download(
46+
repo_id="refiners/realistic_vision.v5_1.sd1_5.unet",
47+
filename="model.safetensors",
48+
)
49+
),
50+
clip_text_encoder=CLIPTextEncoderL(device=device, dtype=dtype).load_from_safetensors(
51+
tensors_path=hf_hub_download(
52+
repo_id="refiners/realistic_vision.v5_1.sd1_5.text_encoder",
53+
filename="model.safetensors",
54+
)
55+
),
56+
lda=SD1Autoencoder(device=device, dtype=dtype).load_from_safetensors(
57+
tensors_path=hf_hub_download(
58+
repo_id="refiners/realistic_vision.v5_1.sd1_5.autoencoder",
59+
filename="model.safetensors",
60+
)
3961
),
4062
device=device,
41-
),
42-
unet=SD1UNet(in_channels=4, device=device, dtype=dtype).load_from_safetensors(
43-
tensors_path=hf_hub_download(
44-
repo_id="refiners/realistic_vision.v5_1.sd1_5.unet",
45-
filename="model.safetensors",
46-
)
47-
),
48-
clip_text_encoder=CLIPTextEncoderL(device=device, dtype=dtype).load_from_safetensors(
49-
tensors_path=hf_hub_download(
50-
repo_id="refiners/realistic_vision.v5_1.sd1_5.text_encoder",
51-
filename="model.safetensors",
52-
)
53-
),
54-
lda=SD1Autoencoder(device=device, dtype=dtype).load_from_safetensors(
55-
tensors_path=hf_hub_download(
56-
repo_id="refiners/realistic_vision.v5_1.sd1_5.autoencoder",
57-
filename="model.safetensors",
58-
)
59-
),
60-
device=device,
61-
dtype=dtype,
62-
)
63-
64-
prompt = "soft lighting, high-quality professional image"
65-
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
66-
clip_text_embedding = sd.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
67-
68-
image = Image.open("reference-image.png").resize((512, 512))
69-
sd.set_ic_light_condition(image)
70-
71-
x = torch.randn(
72-
size=(1, 4, 64, 64),
73-
device=device,
74-
dtype=dtype,
75-
)
76-
77-
for step in sd.steps:
78-
x = sd(
79-
x=x,
80-
step=step,
81-
clip_text_embedding=clip_text_embedding,
82-
condition_scale=1.5,
63+
dtype=dtype,
64+
)
65+
66+
prompt = "soft lighting, high-quality professional image"
67+
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
68+
clip_text_embedding = sd.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
69+
70+
image = Image.open("reference-image.png").resize((512, 512))
71+
sd.set_ic_light_condition(image)
72+
73+
x = torch.randn(
74+
size=(1, 4, 64, 64),
75+
device=device,
76+
dtype=dtype,
8377
)
84-
predicted_image = sd.lda.latents_to_image(x)
8578
86-
predicted_image.save("ic-light-output.png")
79+
for step in sd.steps:
80+
x = sd(
81+
x=x,
82+
step=step,
83+
clip_text_embedding=clip_text_embedding,
84+
condition_scale=1.5,
85+
)
86+
predicted_image = sd.lda.latents_to_image(x)
87+
88+
predicted_image.save("ic-light-output.png")
89+
```
8790
"""
8891

8992
def __init__(
@@ -109,9 +112,7 @@ def __init__(
109112

110113
@no_grad()
111114
def _extend_conv_in(self) -> None:
112-
"""
113-
Extend to 8 the input channels of the first convolutional layer of the UNet.
114-
"""
115+
"""Extend to 8 the input channels of the first convolutional layer of the UNet."""
115116
down_blocks = self.unet.ensure_find(DownBlocks)
116117
first_block = down_blocks.layer(0, fl.Chain)
117118
conv_in = first_block.ensure_find(fl.Conv2d)
@@ -129,48 +130,57 @@ def _extend_conv_in(self) -> None:
129130
first_block.replace(old_module=conv_in, new_module=new_conv_in)
130131

131132
def _apply_patch(self, weights: dict[str, torch.Tensor]) -> None:
132-
"""
133-
Apply the patch weights to the UNet, modifying inplace the state dict.
134-
"""
133+
"""Apply the weights patch to the UNet, modifying inplace the state dict."""
135134
current_state_dict = self.unet.state_dict()
136135
new_state_dict = {
137136
key: tensor + weights[key].to(tensor.device, tensor.dtype) for key, tensor in current_state_dict.items()
138137
}
139138
self.unet.load_state_dict(new_state_dict)
140139

141140
@staticmethod
142-
def compute_gray_composite(image: Image.Image, mask: Image.Image) -> Image.Image:
143-
"""
144-
Compute a grayscale composite of an image and a mask.
141+
def compute_gray_composite(
142+
image: Image.Image,
143+
mask: Image.Image,
144+
) -> Image.Image:
145+
"""Compute a grayscale composite of an image and a mask.
146+
147+
IC-Light will recreate the image
148+
149+
Args:
150+
image: The image to composite.
151+
mask: The mask to use for the composite.
145152
"""
146153
assert mask.mode == "L", "Mask must be a grayscale image"
147154
assert image.size == mask.size, "Image and mask must have the same size"
148155
background = Image.new("RGB", image.size, (127, 127, 127))
149156
return Image.composite(image, background, mask)
150157

151158
def set_ic_light_condition(
152-
self, image: Image.Image, mask: Image.Image | None = None, use_rescaled_image: bool = False
159+
self,
160+
image: Image.Image,
161+
mask: Image.Image | None = None,
153162
) -> None:
154-
"""
155-
Set the IC light condition.
163+
"""Set the IC light condition.
164+
165+
Args:
166+
image: The reference image.
167+
mask: The mask to use for the reference image.
156168
157169
If a mask is provided, it will be used to compute a grayscale composite of the image and the mask ; otherwise,
158170
the image will be used as is, but note that IC-Light requires a 127-valued gray background to work.
159-
160-
`use_rescaled_image` is used to rescale the image to [-1, 1] range. This is the expected range when using the
161-
Stable Diffusion autoencoder. But in the original code this part is skipped, giving different results.
162-
see https://github.com/lllyasviel/IC-Light/blob/788687452a2bad59633a401281c8aee91bdd3750/gradio_demo.py#L262-L265
163171
"""
164172
if mask is not None:
165173
image = self.compute_gray_composite(image=image, mask=mask)
166-
image_tensor = image_to_tensor(image, device=self.device, dtype=self.dtype)
167-
if use_rescaled_image:
168-
image_tensor = 2 * image_tensor - 1
169-
latents = self.lda.encode(image_tensor)
174+
latents = self.lda.image_to_latents(image)
170175
self._ic_light_condition = latents
171176

172177
def __call__(
173-
self, x: torch.Tensor, step: int, *, clip_text_embedding: torch.Tensor, condition_scale: float = 2.0
178+
self,
179+
x: torch.Tensor,
180+
step: int,
181+
*,
182+
clip_text_embedding: torch.Tensor,
183+
condition_scale: float = 2.0,
174184
) -> torch.Tensor:
175185
assert self._ic_light_condition is not None, "Reference image not set, use `set_ic_light_condition` first"
176186
x = torch.cat((x, self._ic_light_condition), dim=1)

0 commit comments

Comments
 (0)