Skip to content

Commit 5375aa2

Browse files
make karras sigmas and sde_variance work together
1 parent cf247a1 commit 5375aa2

File tree

12 files changed

+445
-88
lines changed

12 files changed

+445
-88
lines changed

diffusers_ddim.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import torch
2+
from diffusers import StableDiffusionPipeline
3+
4+
# from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
5+
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
6+
7+
from refiners.fluxion.utils import manual_seed
8+
9+
# diffusers_solver = DDIMScheduler.from_config( # type: ignore
10+
# {
11+
# "_class_name": "PNDMScheduler",
12+
# "_diffusers_version": "0.6.0",
13+
# "beta_end": 0.012,
14+
# "beta_schedule": "scaled_linear",
15+
# "beta_start": 0.00085,
16+
# "num_train_timesteps": 1000,
17+
# "set_alpha_to_one": False,
18+
# "skip_prk_steps": True,
19+
# "steps_offset": 1,
20+
# "trained_betas": None,
21+
# "clip_sample": False,
22+
# # "use_karras_sigmas": True,
23+
# # "algorithm": "dpmsolver++",
24+
# }
25+
# )
26+
27+
28+
model_id = "botp/stable-diffusion-v1-5"
29+
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, scheduler=diffusers_solver)
30+
pipe = pipe.to("cuda:1")
31+
32+
prompt = "a cute cat, detailed high-quality professional image"
33+
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
34+
manual_seed(2)
35+
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
36+
37+
image.save("diffusers-ddim.png")

diffusers_karrs.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch
2+
from diffusers import StableDiffusionPipeline
3+
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
4+
5+
from refiners.fluxion.utils import manual_seed
6+
7+
model_id = "botp/stable-diffusion-v1-5"
8+
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
9+
pipe = pipe.to("cuda:1")
10+
11+
12+
config = {**pipe.scheduler.config}
13+
config["use_karras_sigmas"] = True
14+
config["algorithm_type"] = "sde-dpmsolver++"
15+
pipe.scheduler = DPMSolverMultistepScheduler.from_config(config)
16+
17+
prompt = "a cute cat, detailed high-quality professional image"
18+
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
19+
manual_seed(2)
20+
image = pipe(prompt, negative_prompt=negative_prompt, num_inference_steps=18, guidance_scale=7.5).images[0]
21+
22+
image.save("diffusers-sde-dpm-karras.png")

old-script.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import torch
2+
from PIL import Image
3+
4+
from diffusers import DPMSolverMultistepScheduler
5+
from diffusers import StableDiffusionPipeline
6+
7+
weights_path = "weights"
8+
device = torch.device("cuda:1")
9+
n_steps = 30
10+
11+
pipe = StableDiffusionPipeline.from_pretrained(
12+
"botp/stable-diffusion-v1-5",
13+
torch_dtype=torch.float32,
14+
safety_checker=None,
15+
).to(device)
16+
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
17+
18+
prompt = "a cute cat, detailed high-quality professional image"
19+
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
20+
21+
22+
torch.manual_seed(2)
23+
output = pipe(
24+
prompt=prompt,
25+
negative_prompt=negative_prompt,
26+
num_inference_steps=n_steps,
27+
guidance_scale=7.5,
28+
)
29+
30+
output.images[0].save("output-diffusers-diffusion.png")

run_ic_light.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from pathlib import Path
2+
3+
import torch
4+
from PIL import Image
5+
6+
from refiners.fluxion.utils import load_from_safetensors, no_grad
7+
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
8+
from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLightBackground
9+
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder
10+
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
11+
12+
no_grad().__enter__()
13+
14+
device = "cuda" if torch.cuda.is_available() else "cpu"
15+
sd = ICLightBackground(
16+
patch_weights=load_from_safetensors("iclight_sd15_fbc-refiners.safetensors", device=device),
17+
unet=SD1UNet(in_channels=4, device=device).load_from_safetensors("realistic-vision-v51-unet.safetensors"),
18+
clip_text_encoder=CLIPTextEncoderL(device=device).load_from_safetensors(
19+
"realistic-vision-v51-text_encoder.safetensors"
20+
),
21+
lda=SD1Autoencoder(device=device).load_from_safetensors("realistic-vision-v51-autoencoder.safetensors"),
22+
device=device,
23+
)
24+
25+
prompt = "porcelaine mug, 4k high quality, soft lighting, high-quality professional image"
26+
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
27+
clip_text_embedding = sd.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
28+
29+
data_test = Path.home() / "data_test"
30+
uid = "ic_out"
31+
foreground_image = Image.open(f"{uid}.png").convert("RGB")
32+
mask = Image.open(f"{uid}.mask.png").resize(foreground_image.size).convert("L").point(lambda x: 255 if x > 0 else 0)
33+
34+
35+
sd.set_foreground_condition(foreground_image, mask=mask, use_rescaled_image=True)
36+
sd.set_background_condition(foreground_image)
37+
38+
x = torch.randn(1, 4, foreground_image.height // 8, foreground_image.width // 8, device=device)
39+
40+
for step in sd.steps:
41+
x = sd(
42+
x,
43+
step=step,
44+
clip_text_embedding=clip_text_embedding,
45+
condition_scale=3,
46+
)
47+
predicted_image = sd.lda.latents_to_image(x)
48+
49+
predicted_image.save("ic-light-output.png")

src/refiners/foundationals/latent_diffusion/flux/__init__.py

Whitespace-only changes.
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from typing import Annotated, NamedTuple, cast
2+
3+
import torch
4+
5+
from refiners.fluxion import layers as fl
6+
7+
8+
class FluxParams(NamedTuple):
9+
in_channels: int
10+
vec_in_dim: int
11+
context_in_dim: int
12+
hidden_size: int
13+
mlp_ratio: float
14+
num_heads: int
15+
depth: int
16+
depth_single_blocks: int
17+
axes_dim: list[int]
18+
theta: int
19+
qkv_bias: bool
20+
guidance_embed: bool
21+
22+
23+
class DoubleStreamBlocks(fl.Chain):
24+
pass
25+
26+
27+
class SingleStreamBlocks(fl.Chain):
28+
pass
29+
30+
31+
class PositionalEncoding(fl.Passthrough):
32+
pass
33+
34+
35+
class TimestepEmbedding(fl.Passthrough):
36+
pass
37+
38+
39+
class TextEmbedding(fl.Passthrough):
40+
pass
41+
42+
43+
class DoubleStreamAttention(fl.Chain):
44+
pass
45+
46+
47+
class SingleStreamAttention(fl.Chain):
48+
pass
49+
50+
51+
class LastLayer(fl.Chain):
52+
def __init__(self, params: FluxParams) -> None:
53+
super().__init__(fl.LayerNorm(params.hidden_size), fl.Linear(params.hidden_size, params.in_channels))
54+
55+
56+
class Flux(fl.Sum):
57+
def __init__(
58+
self,
59+
params: FluxParams,
60+
) -> None:
61+
super().__init__(
62+
TimestepEmbedding(),
63+
PositionalEncoding(),
64+
TextEmbedding(),
65+
DoubleStreamBlocks(DoubleStreamAttention() for _ in range(params.depth)),
66+
SingleStreamBlocks(SingleStreamAttention() for _ in range(params.depth_single_blocks)),
67+
)
68+
69+
70+
if __name__ == "__main__":
71+
flux = Flux(
72+
params=FluxParams(
73+
in_channels=3,
74+
vec_in_dim=32,
75+
context_in_dim=64,
76+
hidden_size=256,
77+
mlp_ratio=4.0,
78+
num_heads=8,
79+
depth=2,
80+
depth_single_blocks=1,
81+
axes_dim=[32, 32],
82+
theta=1024,
83+
qkv_bias=True,
84+
guidance_embed=True,
85+
),
86+
)
87+
from typing import get_type_hints
88+
89+
print(repr(flux))
90+
image_in = flux.layer(0)
91+
print(image_in)
92+
print(get_type_hints(image_in))

0 commit comments

Comments
 (0)