Skip to content

Commit 277b0fd

Browse files
ily-Rdeltheil
authored andcommitted
ella adapter implementation. tested with sd1.5 model
1 parent a8efe5e commit 277b0fd

File tree

11 files changed

+603
-1
lines changed

11 files changed

+603
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ test = [
5555
# HQ-SAM missing dependency:
5656
# https://github.com/SysCV/sam-hq/pull/59
5757
"timm>=0.5.0",
58+
"sentencepiece>=0.2.0",
5859
]
5960
conversion = [
6061
"diffusers>=0.26.1",
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import argparse
2+
from pathlib import Path
3+
4+
import torch
5+
from huggingface_hub import hf_hub_download # type: ignore
6+
7+
from refiners.fluxion.utils import load_from_safetensors, save_to_safetensors
8+
9+
10+
class Args(argparse.Namespace):
11+
source_path: str
12+
output_path: str | None
13+
use_half: bool
14+
15+
16+
def convert(args: Args) -> dict[str, torch.Tensor]:
17+
if Path(args.source_path).suffix != ".safetensors":
18+
args.source_path = hf_hub_download(
19+
repo_id=args.source_path, filename="ella-sd1.5-tsc-t5xl.safetensors", local_dir="tests/weights/ELLA-Adapter"
20+
)
21+
weights = load_from_safetensors(args.source_path)
22+
23+
for key in list(weights.keys()):
24+
if "latents" in key:
25+
new_key = "PerceiverResampler.Latents.ParameterInitialized.weight"
26+
weights[new_key] = weights.pop(key)
27+
elif "time_embedding" in key:
28+
new_key = key.replace("time_embedding", "TimestepEncoder.RangeEncoder").replace("linear", "Linear")
29+
weights[new_key] = weights.pop(key)
30+
elif "proj_in" in key:
31+
new_key = f"PerceiverResampler.Linear.{key.split('.')[-1]}"
32+
weights[new_key] = weights.pop(key)
33+
elif "time_aware" in key:
34+
new_key = f"PerceiverResampler.Residual.Linear.{key.split('.')[-1]}"
35+
weights[new_key] = weights.pop(key)
36+
elif "attn.in_proj" in key:
37+
layer_num = int(key.split(".")[2])
38+
query_param, key_param, value_param = weights.pop(key).chunk(3, dim=0)
39+
param_type = "weight" if "weight" in key else "bias"
40+
for i, param in enumerate([query_param, key_param, value_param]):
41+
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_1.PerceiverAttention.Attention.Distribute.Linear_{i+1}.{param_type}"
42+
weights[new_key] = param
43+
elif "attn.out_proj" in key:
44+
layer_num = int(key.split(".")[2])
45+
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_1.PerceiverAttention.Attention.Linear.{key.split('.')[-1]}"
46+
weights[new_key] = weights.pop(key)
47+
elif "ln_ff" in key:
48+
layer_num = int(key.split(".")[2])
49+
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_2.AdaLayerNorm.Parallel.Chain.Linear.{key.split('.')[-1]}"
50+
weights[new_key] = weights.pop(key)
51+
elif "ln_1" in key or "ln_2" in key:
52+
layer_num = int(key.split(".")[2])
53+
n = 1 if int(key.split(".")[3].split("_")[-1]) == 2 else 2
54+
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_1.PerceiverAttention.Distribute.AdaLayerNorm_{n}.Parallel.Chain.Linear.{key.split('.')[-1]}"
55+
weights[new_key] = weights.pop(key)
56+
elif "mlp" in key:
57+
layer_num = int(key.split(".")[2])
58+
n = 1 if "c_fc" in key else 2
59+
new_key = f"PerceiverResampler.Transformer.TransformerLayer_{layer_num+1}.Residual_2.FeedForward.Linear_{n}.{key.split('.')[-1]}"
60+
weights[new_key] = weights.pop(key)
61+
62+
if args.use_half:
63+
weights = {key: value.half() for key, value in weights.items()}
64+
65+
return weights
66+
67+
68+
if __name__ == "__main__":
69+
parser = argparse.ArgumentParser(description="Convert a pretrained Ella Adapter to refiners implementation")
70+
parser.add_argument(
71+
"--from",
72+
type=str,
73+
dest="source_path",
74+
default="QQGYLab/ELLA",
75+
help=(
76+
"A path to a local .safetensors weights. If not provided, a repo from Hugging Face Hub will be used"
77+
"Default to QQGYLab/ELLA"
78+
),
79+
)
80+
81+
parser.add_argument(
82+
"--to",
83+
type=str,
84+
dest="output_path",
85+
default=None,
86+
help=(
87+
"Path to save the converted model (extension will be .safetensors). If not specified, the output path will"
88+
" be the source path with the prefix set to refiners"
89+
),
90+
)
91+
parser.add_argument(
92+
"--half",
93+
action="store_true",
94+
dest="use_half",
95+
default=True,
96+
help="Use this flag to save the output file as half precision (default: full precision).",
97+
)
98+
args = parser.parse_args(namespace=Args())
99+
weights = convert(args)
100+
if args.output_path is None:
101+
args.output_path = f"{Path(args.source_path).stem}-refiners.safetensors"
102+
save_to_safetensors(path=args.output_path, tensors=weights)

scripts/prepare_test_weights.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,29 @@ def download_ip_adapter():
353353
download_files(urls, sdxl_models_folder)
354354

355355

356+
def download_t5xl_fp16():
357+
base_folder = os.path.join(test_weights_dir, "QQGYLab", "T5XLFP16")
358+
urls = [
359+
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/config.json",
360+
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/model.safetensors",
361+
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/special_tokens_map.json",
362+
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/spiece.model",
363+
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/tokenizer.json",
364+
"https://huggingface.co/QQGYLab/ELLA/resolve/main/models--google--flan-t5-xl--text_encoder/tokenizer_config.json",
365+
]
366+
download_files(urls, base_folder)
367+
368+
369+
def download_ella_adapter():
370+
download_t5xl_fp16()
371+
base_folder = os.path.join(test_weights_dir, "QQGYLab", "ELLA")
372+
download_file(
373+
"https://huggingface.co/QQGYLab/ELLA/resolve/main/ella-sd1.5-tsc-t5xl.safetensors",
374+
base_folder,
375+
expected_hash="5af7b200",
376+
)
377+
378+
356379
def download_t2i_adapter():
357380
base_folder = os.path.join(test_weights_dir, "TencentARC", "t2iadapter_depth_sd15v2")
358381
urls = [
@@ -689,6 +712,17 @@ def convert_ip_adapter():
689712
)
690713

691714

715+
def convert_ella_adapter():
716+
os.makedirs("tests/weights/ELLA-Adapter", exist_ok=True)
717+
run_conversion_script(
718+
"convert_ella_adapter.py",
719+
"tests/weights/QQGYLab/ELLA/ella-sd1.5-tsc-t5xl.safetensors",
720+
"tests/weights/ELLA-Adapter/ella-sd1.5-tsc-t5xl.safetensors",
721+
half=True,
722+
expected_hash="b8244cb6",
723+
)
724+
725+
692726
def convert_t2i_adapter():
693727
os.makedirs("tests/weights/T2I-Adapter", exist_ok=True)
694728
run_conversion_script(
@@ -860,6 +894,7 @@ def download_all():
860894
download_unclip()
861895
download_ip_adapter()
862896
download_t2i_adapter()
897+
download_ella_adapter()
863898
download_sam()
864899
download_hq_sam()
865900
download_dinov2()
@@ -884,6 +919,7 @@ def convert_all():
884919
convert_unclip()
885920
convert_ip_adapter()
886921
convert_t2i_adapter()
922+
convert_ella_adapter()
887923
convert_sam()
888924
convert_hq_sam()
889925
convert_dinov2()

src/refiners/foundationals/latent_diffusion/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from refiners.foundationals.latent_diffusion.solvers import DPMSolver, LCMSolver, Solver
99
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
1010
SD1ControlnetAdapter,
11+
SD1ELLAAdapter,
1112
SD1IPAdapter,
1213
SD1T2IAdapter,
1314
SD1UNet,
@@ -32,6 +33,7 @@
3233
"SD1ControlnetAdapter",
3334
"SD1IPAdapter",
3435
"SD1T2IAdapter",
36+
"SD1ELLAAdapter",
3537
"SDXLUNet",
3638
"DoubleTextEncoder",
3739
"SDXLIPAdapter",

0 commit comments

Comments
 (0)