3
3
from torch .nn .init import zeros_ as zero_init
4
4
5
5
from refiners .fluxion import layers as fl
6
- from refiners .fluxion .utils import image_to_tensor , no_grad
6
+ from refiners .fluxion .utils import no_grad
7
7
from refiners .foundationals .clip .text_encoder import CLIPTextEncoderL
8
8
from refiners .foundationals .latent_diffusion .solvers .solver import Solver
9
9
from refiners .foundationals .latent_diffusion .stable_diffusion_1 .model import SD1Autoencoder , StableDiffusion_1
10
10
from refiners .foundationals .latent_diffusion .stable_diffusion_1 .unet import DownBlocks , SD1UNet
11
11
12
12
13
13
class ICLight (StableDiffusion_1 ):
14
- """
15
- IC-Light is a Stable Diffusion model that can be used to relight a reference image.
16
-
17
- At initialization, the UNet will be patched to accept four additional input channels. Only the text-conditioned relighting model is supported for now.
18
-
19
- ```example
20
- import torch
21
- from huggingface_hub import hf_hub_download
22
- from PIL import Image
23
-
24
- from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
25
- from refiners.foundationals.clip import CLIPTextEncoderL
26
- from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1Autoencoder, SD1UNet
27
- from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight
28
-
29
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
- dtype = torch.float32
31
- no_grad().__enter__()
32
- manual_seed(42)
33
-
34
- sd = ICLight(
35
- patch_weights=load_from_safetensors(
36
- path=hf_hub_download(
37
- repo_id="refiners/ic_light.sd1_5.fc",
38
- filename="model.safetensors",
14
+ """IC-Light is a Stable Diffusion model that can be used to relight a reference image.
15
+
16
+ At initialization, the UNet will be patched to accept four additional input channels.
17
+ Only the text-conditioned relighting model is supported for now.
18
+
19
+
20
+ Example:
21
+ ```py
22
+ import torch
23
+ from huggingface_hub import hf_hub_download
24
+ from PIL import Image
25
+
26
+ from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
27
+ from refiners.foundationals.clip import CLIPTextEncoderL
28
+ from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1Autoencoder, SD1UNet
29
+ from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight
30
+
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ dtype = torch.float32
33
+ no_grad().__enter__()
34
+ manual_seed(42)
35
+
36
+ sd = ICLight(
37
+ patch_weights=load_from_safetensors(
38
+ path=hf_hub_download(
39
+ repo_id="refiners/ic_light.sd1_5.fc",
40
+ filename="model.safetensors",
41
+ ),
42
+ device=device,
43
+ ),
44
+ unet=SD1UNet(in_channels=4, device=device, dtype=dtype).load_from_safetensors(
45
+ tensors_path=hf_hub_download(
46
+ repo_id="refiners/realistic_vision.v5_1.sd1_5.unet",
47
+ filename="model.safetensors",
48
+ )
49
+ ),
50
+ clip_text_encoder=CLIPTextEncoderL(device=device, dtype=dtype).load_from_safetensors(
51
+ tensors_path=hf_hub_download(
52
+ repo_id="refiners/realistic_vision.v5_1.sd1_5.text_encoder",
53
+ filename="model.safetensors",
54
+ )
55
+ ),
56
+ lda=SD1Autoencoder(device=device, dtype=dtype).load_from_safetensors(
57
+ tensors_path=hf_hub_download(
58
+ repo_id="refiners/realistic_vision.v5_1.sd1_5.autoencoder",
59
+ filename="model.safetensors",
60
+ )
39
61
),
40
62
device=device,
41
- ),
42
- unet=SD1UNet(in_channels=4, device=device, dtype=dtype).load_from_safetensors(
43
- tensors_path=hf_hub_download(
44
- repo_id="refiners/realistic_vision.v5_1.sd1_5.unet",
45
- filename="model.safetensors",
46
- )
47
- ),
48
- clip_text_encoder=CLIPTextEncoderL(device=device, dtype=dtype).load_from_safetensors(
49
- tensors_path=hf_hub_download(
50
- repo_id="refiners/realistic_vision.v5_1.sd1_5.text_encoder",
51
- filename="model.safetensors",
52
- )
53
- ),
54
- lda=SD1Autoencoder(device=device, dtype=dtype).load_from_safetensors(
55
- tensors_path=hf_hub_download(
56
- repo_id="refiners/realistic_vision.v5_1.sd1_5.autoencoder",
57
- filename="model.safetensors",
58
- )
59
- ),
60
- device=device,
61
- dtype=dtype,
62
- )
63
-
64
- prompt = "soft lighting, high-quality professional image"
65
- negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
66
- clip_text_embedding = sd.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
67
-
68
- image = Image.open("reference-image.png").resize((512, 512))
69
- sd.set_ic_light_condition(image)
70
-
71
- x = torch.randn(
72
- size=(1, 4, 64, 64),
73
- device=device,
74
- dtype=dtype,
75
- )
76
-
77
- for step in sd.steps:
78
- x = sd(
79
- x=x,
80
- step=step,
81
- clip_text_embedding=clip_text_embedding,
82
- condition_scale=1.5,
63
+ dtype=dtype,
64
+ )
65
+
66
+ prompt = "soft lighting, high-quality professional image"
67
+ negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
68
+ clip_text_embedding = sd.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
69
+
70
+ image = Image.open("reference-image.png").resize((512, 512))
71
+ sd.set_ic_light_condition(image)
72
+
73
+ x = torch.randn(
74
+ size=(1, 4, 64, 64),
75
+ device=device,
76
+ dtype=dtype,
83
77
)
84
- predicted_image = sd.lda.latents_to_image(x)
85
78
86
- predicted_image.save("ic-light-output.png")
79
+ for step in sd.steps:
80
+ x = sd(
81
+ x=x,
82
+ step=step,
83
+ clip_text_embedding=clip_text_embedding,
84
+ condition_scale=1.5,
85
+ )
86
+ predicted_image = sd.lda.latents_to_image(x)
87
+
88
+ predicted_image.save("ic-light-output.png")
89
+ ```
87
90
"""
88
91
89
92
def __init__ (
@@ -109,9 +112,7 @@ def __init__(
109
112
110
113
@no_grad ()
111
114
def _extend_conv_in (self ) -> None :
112
- """
113
- Extend to 8 the input channels of the first convolutional layer of the UNet.
114
- """
115
+ """Extend to 8 the input channels of the first convolutional layer of the UNet."""
115
116
down_blocks = self .unet .ensure_find (DownBlocks )
116
117
first_block = down_blocks .layer (0 , fl .Chain )
117
118
conv_in = first_block .ensure_find (fl .Conv2d )
@@ -129,48 +130,57 @@ def _extend_conv_in(self) -> None:
129
130
first_block .replace (old_module = conv_in , new_module = new_conv_in )
130
131
131
132
def _apply_patch (self , weights : dict [str , torch .Tensor ]) -> None :
132
- """
133
- Apply the patch weights to the UNet, modifying inplace the state dict.
134
- """
133
+ """Apply the weights patch to the UNet, modifying inplace the state dict."""
135
134
current_state_dict = self .unet .state_dict ()
136
135
new_state_dict = {
137
136
key : tensor + weights [key ].to (tensor .device , tensor .dtype ) for key , tensor in current_state_dict .items ()
138
137
}
139
138
self .unet .load_state_dict (new_state_dict )
140
139
141
140
@staticmethod
142
- def compute_gray_composite (image : Image .Image , mask : Image .Image ) -> Image .Image :
143
- """
144
- Compute a grayscale composite of an image and a mask.
141
+ def compute_gray_composite (
142
+ image : Image .Image ,
143
+ mask : Image .Image ,
144
+ ) -> Image .Image :
145
+ """Compute a grayscale composite of an image and a mask.
146
+
147
+ IC-Light will recreate the image
148
+
149
+ Args:
150
+ image: The image to composite.
151
+ mask: The mask to use for the composite.
145
152
"""
146
153
assert mask .mode == "L" , "Mask must be a grayscale image"
147
154
assert image .size == mask .size , "Image and mask must have the same size"
148
155
background = Image .new ("RGB" , image .size , (127 , 127 , 127 ))
149
156
return Image .composite (image , background , mask )
150
157
151
158
def set_ic_light_condition (
152
- self , image : Image .Image , mask : Image .Image | None = None , use_rescaled_image : bool = False
159
+ self ,
160
+ image : Image .Image ,
161
+ mask : Image .Image | None = None ,
153
162
) -> None :
154
- """
155
- Set the IC light condition.
163
+ """Set the IC light condition.
164
+
165
+ Args:
166
+ image: The reference image.
167
+ mask: The mask to use for the reference image.
156
168
157
169
If a mask is provided, it will be used to compute a grayscale composite of the image and the mask ; otherwise,
158
170
the image will be used as is, but note that IC-Light requires a 127-valued gray background to work.
159
-
160
- `use_rescaled_image` is used to rescale the image to [-1, 1] range. This is the expected range when using the
161
- Stable Diffusion autoencoder. But in the original code this part is skipped, giving different results.
162
- see https://github.com/lllyasviel/IC-Light/blob/788687452a2bad59633a401281c8aee91bdd3750/gradio_demo.py#L262-L265
163
171
"""
164
172
if mask is not None :
165
173
image = self .compute_gray_composite (image = image , mask = mask )
166
- image_tensor = image_to_tensor (image , device = self .device , dtype = self .dtype )
167
- if use_rescaled_image :
168
- image_tensor = 2 * image_tensor - 1
169
- latents = self .lda .encode (image_tensor )
174
+ latents = self .lda .image_to_latents (image )
170
175
self ._ic_light_condition = latents
171
176
172
177
def __call__ (
173
- self , x : torch .Tensor , step : int , * , clip_text_embedding : torch .Tensor , condition_scale : float = 2.0
178
+ self ,
179
+ x : torch .Tensor ,
180
+ step : int ,
181
+ * ,
182
+ clip_text_embedding : torch .Tensor ,
183
+ condition_scale : float = 2.0 ,
174
184
) -> torch .Tensor :
175
185
assert self ._ic_light_condition is not None , "Reference image not set, use `set_ic_light_condition` first"
176
186
x = torch .cat ((x , self ._ic_light_condition ), dim = 1 )
0 commit comments