|
| 1 | +# Copyright 2022 Adobe Research. All rights reserved. |
| 2 | +# To view a copy of the license, visit LICENSE.md. |
| 3 | + |
| 4 | + |
| 5 | +import sys |
| 6 | +import os |
| 7 | +import argparse |
| 8 | +import torch |
| 9 | +from scipy.linalg import lu_factor, lu_solve |
| 10 | + |
| 11 | +sys.path.append('./') |
| 12 | +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline |
| 13 | +from src import diffusers_sample |
| 14 | + |
| 15 | + |
| 16 | +def gdupdateWexact(K, V, Ktarget1, Vtarget1, W, device='cuda'): |
| 17 | + input_ = K |
| 18 | + output = V |
| 19 | + C = input_.T @ input_ |
| 20 | + d = [] |
| 21 | + lu, piv = lu_factor(C.cpu().numpy()) |
| 22 | + for i in range(Ktarget1.size(0)): |
| 23 | + sol = lu_solve((lu, piv), Ktarget1[i].reshape(-1, 1).cpu().numpy()) |
| 24 | + d.append(torch.from_numpy(sol).to(K.device)) |
| 25 | + |
| 26 | + d = torch.cat(d, 1).T |
| 27 | + |
| 28 | + e2 = d @ Ktarget1.T |
| 29 | + e1 = (Vtarget1.T - W @ Ktarget1.T) |
| 30 | + delta = e1 @ torch.linalg.inv(e2) |
| 31 | + |
| 32 | + Wnew = W + delta @ d |
| 33 | + lambda_split1 = Vtarget1.size(0) |
| 34 | + |
| 35 | + input_ = torch.cat([Ktarget1.T, K.T], dim=1) |
| 36 | + output = torch.cat([Vtarget1, V], dim=0) |
| 37 | + |
| 38 | + loss = torch.norm((Wnew @ input_).T - output, 2, dim=1) |
| 39 | + print(loss[:lambda_split1].mean().item(), loss[lambda_split1:].mean().item()) |
| 40 | + |
| 41 | + return Wnew |
| 42 | + |
| 43 | + |
| 44 | +def compose(paths, category, outpath, pretrained_model_path, regularization_prompt, prompts, save_path, device='cuda'): |
| 45 | + model_id = pretrained_model_path |
| 46 | + pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") |
| 47 | + |
| 48 | + layers_modified = [] |
| 49 | + for name, param in pipe.unet.named_parameters(): |
| 50 | + if 'attn2.to_k' in name or 'attn2.to_v' in name: |
| 51 | + layers_modified.append(name) |
| 52 | + |
| 53 | + tokenizer = pipe.tokenizer |
| 54 | + tokenizer_2 = pipe.tokenizer_2 |
| 55 | + |
| 56 | + def get_text_embedding(prompts): |
| 57 | + with torch.no_grad(): |
| 58 | + uc = [] |
| 59 | + for text in prompts: |
| 60 | + tokens = tokenizer(text, |
| 61 | + truncation=True, |
| 62 | + max_length=tokenizer.model_max_length, |
| 63 | + return_length=True, |
| 64 | + return_overflowing_tokens=False, |
| 65 | + padding="do_not_pad", |
| 66 | + ).input_ids |
| 67 | + tokens_2 = tokenizer_2(text, |
| 68 | + truncation=True, |
| 69 | + max_length=tokenizer.model_max_length, |
| 70 | + return_length=True, |
| 71 | + return_overflowing_tokens=False, |
| 72 | + padding="do_not_pad", |
| 73 | + ).input_ids |
| 74 | + # In sdxl, the text embedding should be 2048 (the concat of 768 and 1280). |
| 75 | + if 'photo of a' in text[:15]: |
| 76 | + print(text) |
| 77 | + text_embedding_1 = pipe.text_encoder(torch.cuda.LongTensor(tokens).reshape(1, -1))[0][:, |
| 78 | + 4:].reshape(-1, 768) |
| 79 | + text_embedding_2 = pipe.text_encoder_2(torch.cuda.LongTensor(tokens).reshape(1, -1))[1][:, |
| 80 | + 4:].reshape(-1, 1280) |
| 81 | + uc.append(torch.concat([text_embedding_1, text_embedding_2], dim=-1)) |
| 82 | + else: |
| 83 | + text_embedding_1 = pipe.text_encoder(torch.cuda.LongTensor(tokens).reshape(1, -1))[0][:, |
| 84 | + 1:].reshape(-1, 768) |
| 85 | + text_embedding_2 = pipe.text_encoder_2(torch.cuda.LongTensor(tokens).reshape(1, -1))[1][:, |
| 86 | + 1:].reshape(-1, 1280) |
| 87 | + uc.append(torch.concat([text_embedding_1, text_embedding_2], dim=-1)) |
| 88 | + return torch.cat(uc, 0).float() |
| 89 | + |
| 90 | + embeds = {} |
| 91 | + count = 1 |
| 92 | + model2_sts = [] |
| 93 | + modifier_tokens = [] |
| 94 | + modifier_token_ids = [] |
| 95 | + categories = [] |
| 96 | + for path1, cat1 in zip(paths.split('+'), category.split('+')): |
| 97 | + model2_st = torch.load(path1) |
| 98 | + if 'modifier_token' in model2_st: |
| 99 | + # composition of models with individual concept only |
| 100 | + key = list(model2_st['modifier_token'].keys())[0] |
| 101 | + _ = tokenizer.add_tokens(f'<new{count}>') |
| 102 | + _ = tokenizer_2.add_tokens(f'<new{count}>') |
| 103 | + modifier_token_ids.append(tokenizer.convert_tokens_to_ids(f'<new{count}>')) |
| 104 | + modifier_tokens.append(True) |
| 105 | + embeds[f'<new{count}>'] = model2_st['modifier_token'][key] |
| 106 | + else: |
| 107 | + modifier_tokens.append(False) |
| 108 | + |
| 109 | + model2_sts.append(model2_st['unet']) |
| 110 | + categories.append(cat1) |
| 111 | + count += 1 |
| 112 | + |
| 113 | + pipe.text_encoder.resize_token_embeddings(len(tokenizer)) |
| 114 | + pipe.text_encoder_2.resize_token_embeddings(len(tokenizer_2)) |
| 115 | + |
| 116 | + token_embeds = pipe.text_encoder.get_input_embeddings().weight.data |
| 117 | + for (x, y) in zip(modifier_token_ids, list(embeds.keys())): |
| 118 | + token_embeds[x] = embeds[y][0] |
| 119 | + print(x, y, "added embeddings") |
| 120 | + |
| 121 | + pipe.text_encoder_2.resize_token_embeddings(len(tokenizer)) |
| 122 | + token_embeds_2 = pipe.text_encoder_2.get_input_embeddings().weight.data |
| 123 | + for (x, y) in zip(modifier_token_ids, list(embeds.keys())): |
| 124 | + token_embeds_2[x] = embeds[y][1] |
| 125 | + print(x, y, "added embeddings") |
| 126 | + |
| 127 | + f = open(regularization_prompt, 'r') |
| 128 | + prompt = [x.strip() for x in f.readlines()][:200] |
| 129 | + uc = get_text_embedding(prompt) |
| 130 | + |
| 131 | + uc_targets = [] |
| 132 | + from collections import defaultdict |
| 133 | + uc_values = defaultdict(list) |
| 134 | + for composing_model_count in range(len(model2_sts)): |
| 135 | + category = categories[composing_model_count] |
| 136 | + if modifier_tokens[composing_model_count]: |
| 137 | + string1 = f'<new{composing_model_count + 1}> {category}' |
| 138 | + else: |
| 139 | + string1 = f'{category}' |
| 140 | + if 'art' in string1: |
| 141 | + prompt = [string1] + [f"painting in the style of {string1}"] |
| 142 | + else: |
| 143 | + prompt = [string1] + [f"photo of a {string1}"] |
| 144 | + uc_targets.append(get_text_embedding(prompt)) |
| 145 | + for each in layers_modified: |
| 146 | + uc_values[each].append((model2_sts[composing_model_count][each].to(device) @ uc_targets[-1].T).T) |
| 147 | + |
| 148 | + uc_targets = torch.cat(uc_targets, 0) |
| 149 | + |
| 150 | + removal_indices = [] |
| 151 | + for i in range(uc_targets.size(0)): |
| 152 | + for j in range(i + 1, uc_targets.size(0)): |
| 153 | + if (uc_targets[i] - uc_targets[j]).abs().mean() == 0: |
| 154 | + removal_indices.append(j) |
| 155 | + |
| 156 | + removal_indices = list(set(removal_indices)) |
| 157 | + uc_targets = torch.stack([uc_targets[i] for i in range(uc_targets.size(0)) if i not in removal_indices], 0) |
| 158 | + for each in layers_modified: |
| 159 | + uc_values[each] = torch.cat(uc_values[each], 0) |
| 160 | + uc_values[each] = torch.stack( |
| 161 | + [uc_values[each][i] for i in range(uc_values[each].size(0)) if i not in removal_indices], 0) |
| 162 | + print(uc_values[each].size(), each) |
| 163 | + |
| 164 | + print("target size:", uc_targets.size()) |
| 165 | + |
| 166 | + new_weights = {'unet': {}} |
| 167 | + for each in layers_modified: |
| 168 | + W = pipe.unet.state_dict()[each].float() |
| 169 | + values = (W @ uc.T).T # W(C_reg) |
| 170 | + input_target = uc_targets |
| 171 | + output_target = uc_values[each] |
| 172 | + |
| 173 | + Wnew = gdupdateWexact(uc[:values.shape[0]], |
| 174 | + values, |
| 175 | + input_target, |
| 176 | + output_target, |
| 177 | + W.clone(), |
| 178 | + ) |
| 179 | + |
| 180 | + new_weights['unet'][each] = Wnew |
| 181 | + print(Wnew.size()) |
| 182 | + |
| 183 | + new_weights['modifier_token'] = embeds |
| 184 | + os.makedirs(f'{save_path}/{outpath}', exist_ok=True) |
| 185 | + torch.save(new_weights, f'{save_path}/{outpath}/delta.bin') |
| 186 | + |
| 187 | + if prompts is not None: |
| 188 | + if os.path.exists(prompts): |
| 189 | + diffusers_sample.sample(model_id, f'{save_path}/{outpath}/delta.bin', prompts, prompt=None, compress=False, |
| 190 | + freeze_model='crossattn_kv', batch_size=1) |
| 191 | + else: |
| 192 | + diffusers_sample.sample(model_id, f'{save_path}/{outpath}/delta.bin', from_file=None, prompt=prompts, |
| 193 | + compress=False, freeze_model='crossattn_kv', batch_size=1) |
| 194 | + |
| 195 | + |
| 196 | +def parse_args(): |
| 197 | + parser = argparse.ArgumentParser('', add_help=False) |
| 198 | + parser.add_argument('--paths', help='+ separated list of checkpoints', required=True, |
| 199 | + type=str) |
| 200 | + parser.add_argument('--save_path', help='folder name to save optimized weights', default='optimized_logs', |
| 201 | + type=str) |
| 202 | + parser.add_argument('--categories', help='+ separated list of categories of the models', required=True, |
| 203 | + type=str) |
| 204 | + parser.add_argument('--prompts', help='prompts for composition model (can be a file or string)', default=None, |
| 205 | + type=str) |
| 206 | + parser.add_argument('--ckpt', required=True, |
| 207 | + type=str) |
| 208 | + parser.add_argument('--regularization_prompt', default='./data/regularization_captions.txt', |
| 209 | + type=str) |
| 210 | + return parser.parse_args() |
| 211 | + |
| 212 | + |
| 213 | +if __name__ == "__main__": |
| 214 | + args = parse_args() |
| 215 | + paths = args.paths |
| 216 | + categories = args.categories |
| 217 | + if ' ' in categories: |
| 218 | + temp = categories.replace(' ', '_') |
| 219 | + else: |
| 220 | + temp = categories |
| 221 | + outpath = '_'.join(['optimized', temp]) |
| 222 | + compose(paths, categories, outpath, args.ckpt, args.regularization_prompt, args.prompts, args.save_path) |
0 commit comments