Skip to content

Commit 779d52b

Browse files
SHSH
authored andcommitted
Add the support of SDXL
1 parent 96f872c commit 779d52b

File tree

2 files changed

+222
-2
lines changed

2 files changed

+222
-2
lines changed

src/diffusers_composenW_sdxl.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
# Copyright 2022 Adobe Research. All rights reserved.
2+
# To view a copy of the license, visit LICENSE.md.
3+
4+
5+
import sys
6+
import os
7+
import argparse
8+
import torch
9+
from scipy.linalg import lu_factor, lu_solve
10+
11+
sys.path.append('./')
12+
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
13+
from src import diffusers_sample
14+
15+
16+
def gdupdateWexact(K, V, Ktarget1, Vtarget1, W, device='cuda'):
17+
input_ = K
18+
output = V
19+
C = input_.T @ input_
20+
d = []
21+
lu, piv = lu_factor(C.cpu().numpy())
22+
for i in range(Ktarget1.size(0)):
23+
sol = lu_solve((lu, piv), Ktarget1[i].reshape(-1, 1).cpu().numpy())
24+
d.append(torch.from_numpy(sol).to(K.device))
25+
26+
d = torch.cat(d, 1).T
27+
28+
e2 = d @ Ktarget1.T
29+
e1 = (Vtarget1.T - W @ Ktarget1.T)
30+
delta = e1 @ torch.linalg.inv(e2)
31+
32+
Wnew = W + delta @ d
33+
lambda_split1 = Vtarget1.size(0)
34+
35+
input_ = torch.cat([Ktarget1.T, K.T], dim=1)
36+
output = torch.cat([Vtarget1, V], dim=0)
37+
38+
loss = torch.norm((Wnew @ input_).T - output, 2, dim=1)
39+
print(loss[:lambda_split1].mean().item(), loss[lambda_split1:].mean().item())
40+
41+
return Wnew
42+
43+
44+
def compose(paths, category, outpath, pretrained_model_path, regularization_prompt, prompts, save_path, device='cuda'):
45+
model_id = pretrained_model_path
46+
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
47+
48+
layers_modified = []
49+
for name, param in pipe.unet.named_parameters():
50+
if 'attn2.to_k' in name or 'attn2.to_v' in name:
51+
layers_modified.append(name)
52+
53+
tokenizer = pipe.tokenizer
54+
tokenizer_2 = pipe.tokenizer_2
55+
56+
def get_text_embedding(prompts):
57+
with torch.no_grad():
58+
uc = []
59+
for text in prompts:
60+
tokens = tokenizer(text,
61+
truncation=True,
62+
max_length=tokenizer.model_max_length,
63+
return_length=True,
64+
return_overflowing_tokens=False,
65+
padding="do_not_pad",
66+
).input_ids
67+
tokens_2 = tokenizer_2(text,
68+
truncation=True,
69+
max_length=tokenizer.model_max_length,
70+
return_length=True,
71+
return_overflowing_tokens=False,
72+
padding="do_not_pad",
73+
).input_ids
74+
# In sdxl, the text embedding should be 2048 (the concat of 768 and 1280).
75+
if 'photo of a' in text[:15]:
76+
print(text)
77+
text_embedding_1 = pipe.text_encoder(torch.cuda.LongTensor(tokens).reshape(1, -1))[0][:,
78+
4:].reshape(-1, 768)
79+
text_embedding_2 = pipe.text_encoder_2(torch.cuda.LongTensor(tokens).reshape(1, -1))[1][:,
80+
4:].reshape(-1, 1280)
81+
uc.append(torch.concat([text_embedding_1, text_embedding_2], dim=-1))
82+
else:
83+
text_embedding_1 = pipe.text_encoder(torch.cuda.LongTensor(tokens).reshape(1, -1))[0][:,
84+
1:].reshape(-1, 768)
85+
text_embedding_2 = pipe.text_encoder_2(torch.cuda.LongTensor(tokens).reshape(1, -1))[1][:,
86+
1:].reshape(-1, 1280)
87+
uc.append(torch.concat([text_embedding_1, text_embedding_2], dim=-1))
88+
return torch.cat(uc, 0).float()
89+
90+
embeds = {}
91+
count = 1
92+
model2_sts = []
93+
modifier_tokens = []
94+
modifier_token_ids = []
95+
categories = []
96+
for path1, cat1 in zip(paths.split('+'), category.split('+')):
97+
model2_st = torch.load(path1)
98+
if 'modifier_token' in model2_st:
99+
# composition of models with individual concept only
100+
key = list(model2_st['modifier_token'].keys())[0]
101+
_ = tokenizer.add_tokens(f'<new{count}>')
102+
_ = tokenizer_2.add_tokens(f'<new{count}>')
103+
modifier_token_ids.append(tokenizer.convert_tokens_to_ids(f'<new{count}>'))
104+
modifier_tokens.append(True)
105+
embeds[f'<new{count}>'] = model2_st['modifier_token'][key]
106+
else:
107+
modifier_tokens.append(False)
108+
109+
model2_sts.append(model2_st['unet'])
110+
categories.append(cat1)
111+
count += 1
112+
113+
pipe.text_encoder.resize_token_embeddings(len(tokenizer))
114+
pipe.text_encoder_2.resize_token_embeddings(len(tokenizer_2))
115+
116+
token_embeds = pipe.text_encoder.get_input_embeddings().weight.data
117+
for (x, y) in zip(modifier_token_ids, list(embeds.keys())):
118+
token_embeds[x] = embeds[y][0]
119+
print(x, y, "added embeddings")
120+
121+
pipe.text_encoder_2.resize_token_embeddings(len(tokenizer))
122+
token_embeds_2 = pipe.text_encoder_2.get_input_embeddings().weight.data
123+
for (x, y) in zip(modifier_token_ids, list(embeds.keys())):
124+
token_embeds_2[x] = embeds[y][1]
125+
print(x, y, "added embeddings")
126+
127+
f = open(regularization_prompt, 'r')
128+
prompt = [x.strip() for x in f.readlines()][:200]
129+
uc = get_text_embedding(prompt)
130+
131+
uc_targets = []
132+
from collections import defaultdict
133+
uc_values = defaultdict(list)
134+
for composing_model_count in range(len(model2_sts)):
135+
category = categories[composing_model_count]
136+
if modifier_tokens[composing_model_count]:
137+
string1 = f'<new{composing_model_count + 1}> {category}'
138+
else:
139+
string1 = f'{category}'
140+
if 'art' in string1:
141+
prompt = [string1] + [f"painting in the style of {string1}"]
142+
else:
143+
prompt = [string1] + [f"photo of a {string1}"]
144+
uc_targets.append(get_text_embedding(prompt))
145+
for each in layers_modified:
146+
uc_values[each].append((model2_sts[composing_model_count][each].to(device) @ uc_targets[-1].T).T)
147+
148+
uc_targets = torch.cat(uc_targets, 0)
149+
150+
removal_indices = []
151+
for i in range(uc_targets.size(0)):
152+
for j in range(i + 1, uc_targets.size(0)):
153+
if (uc_targets[i] - uc_targets[j]).abs().mean() == 0:
154+
removal_indices.append(j)
155+
156+
removal_indices = list(set(removal_indices))
157+
uc_targets = torch.stack([uc_targets[i] for i in range(uc_targets.size(0)) if i not in removal_indices], 0)
158+
for each in layers_modified:
159+
uc_values[each] = torch.cat(uc_values[each], 0)
160+
uc_values[each] = torch.stack(
161+
[uc_values[each][i] for i in range(uc_values[each].size(0)) if i not in removal_indices], 0)
162+
print(uc_values[each].size(), each)
163+
164+
print("target size:", uc_targets.size())
165+
166+
new_weights = {'unet': {}}
167+
for each in layers_modified:
168+
W = pipe.unet.state_dict()[each].float()
169+
values = (W @ uc.T).T # W(C_reg)
170+
input_target = uc_targets
171+
output_target = uc_values[each]
172+
173+
Wnew = gdupdateWexact(uc[:values.shape[0]],
174+
values,
175+
input_target,
176+
output_target,
177+
W.clone(),
178+
)
179+
180+
new_weights['unet'][each] = Wnew
181+
print(Wnew.size())
182+
183+
new_weights['modifier_token'] = embeds
184+
os.makedirs(f'{save_path}/{outpath}', exist_ok=True)
185+
torch.save(new_weights, f'{save_path}/{outpath}/delta.bin')
186+
187+
if prompts is not None:
188+
if os.path.exists(prompts):
189+
diffusers_sample.sample(model_id, f'{save_path}/{outpath}/delta.bin', prompts, prompt=None, compress=False,
190+
freeze_model='crossattn_kv', batch_size=1)
191+
else:
192+
diffusers_sample.sample(model_id, f'{save_path}/{outpath}/delta.bin', from_file=None, prompt=prompts,
193+
compress=False, freeze_model='crossattn_kv', batch_size=1)
194+
195+
196+
def parse_args():
197+
parser = argparse.ArgumentParser('', add_help=False)
198+
parser.add_argument('--paths', help='+ separated list of checkpoints', required=True,
199+
type=str)
200+
parser.add_argument('--save_path', help='folder name to save optimized weights', default='optimized_logs',
201+
type=str)
202+
parser.add_argument('--categories', help='+ separated list of categories of the models', required=True,
203+
type=str)
204+
parser.add_argument('--prompts', help='prompts for composition model (can be a file or string)', default=None,
205+
type=str)
206+
parser.add_argument('--ckpt', required=True,
207+
type=str)
208+
parser.add_argument('--regularization_prompt', default='./data/regularization_captions.txt',
209+
type=str)
210+
return parser.parse_args()
211+
212+
213+
if __name__ == "__main__":
214+
args = parse_args()
215+
paths = args.paths
216+
categories = args.categories
217+
if ' ' in categories:
218+
temp = categories.replace(' ', '_')
219+
else:
220+
temp = categories
221+
outpath = '_'.join(['optimized', temp])
222+
compose(paths, categories, outpath, args.ckpt, args.regularization_prompt, args.prompts, args.save_path)

src/diffusers_sample.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,11 @@ def sample(ckpt, delta_ckpt, from_file, prompt, compress, batch_size, freeze_mod
2020
model_id = ckpt
2121
if sdxl:
2222
pipe = CustomDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
23-
print(pipe.components)
2423
# pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
2524
pipe = pipe.to("cuda")
2625
else:
2726
pipe = CustomDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
2827
pipe.load_model(delta_ckpt, compress)
29-
3028
outdir = os.path.dirname(delta_ckpt)
3129
generator = torch.Generator(device='cuda').manual_seed(42)
3230

0 commit comments

Comments
 (0)