Skip to content

Update diffusers_model_pipeline.py #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
# Custom Diffusion
# Custom Diffusion for SDXL

This repository is forked from Custom Diffusion.

This repository fix some bugs in the original Custom Diffusion to combine with Stable Diffusion XL(SDXL).



# Custom Diffusion (Original)

### [website](https://www.cs.cmu.edu/~custom-diffusion/) | [paper](http://arxiv.org/abs/2212.04488)

Expand Down
2 changes: 2 additions & 0 deletions inference.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

CUDA_VISIBLE_DEVICES=0 python src/diffusers_sample.py --delta_ckpt logs/wooden_pot/delta.bin --sdxl --ckpt "/data/home/chensh/data/huggingface_model/stable-diffusion-xl-base-1.0" --prompt "<new1> cat playing with a ball"
222 changes: 222 additions & 0 deletions src/diffusers_composenW_sdxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# Copyright 2022 Adobe Research. All rights reserved.
# To view a copy of the license, visit LICENSE.md.


import sys
import os
import argparse
import torch
from scipy.linalg import lu_factor, lu_solve

sys.path.append('./')
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from src import diffusers_sample


def gdupdateWexact(K, V, Ktarget1, Vtarget1, W, device='cuda'):
input_ = K
output = V
C = input_.T @ input_
d = []
lu, piv = lu_factor(C.cpu().numpy())
for i in range(Ktarget1.size(0)):
sol = lu_solve((lu, piv), Ktarget1[i].reshape(-1, 1).cpu().numpy())
d.append(torch.from_numpy(sol).to(K.device))

d = torch.cat(d, 1).T

e2 = d @ Ktarget1.T
e1 = (Vtarget1.T - W @ Ktarget1.T)
delta = e1 @ torch.linalg.inv(e2)

Wnew = W + delta @ d
lambda_split1 = Vtarget1.size(0)

input_ = torch.cat([Ktarget1.T, K.T], dim=1)
output = torch.cat([Vtarget1, V], dim=0)

loss = torch.norm((Wnew @ input_).T - output, 2, dim=1)
print(loss[:lambda_split1].mean().item(), loss[lambda_split1:].mean().item())

return Wnew


def compose(paths, category, outpath, pretrained_model_path, regularization_prompt, prompts, save_path, device='cuda'):
model_id = pretrained_model_path
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")

layers_modified = []
for name, param in pipe.unet.named_parameters():
if 'attn2.to_k' in name or 'attn2.to_v' in name:
layers_modified.append(name)

tokenizer = pipe.tokenizer
tokenizer_2 = pipe.tokenizer_2

def get_text_embedding(prompts):
with torch.no_grad():
uc = []
for text in prompts:
tokens = tokenizer(text,
truncation=True,
max_length=tokenizer.model_max_length,
return_length=True,
return_overflowing_tokens=False,
padding="do_not_pad",
).input_ids
tokens_2 = tokenizer_2(text,
truncation=True,
max_length=tokenizer.model_max_length,
return_length=True,
return_overflowing_tokens=False,
padding="do_not_pad",
).input_ids
# In sdxl, the text embedding should be 2048 (the concat of 768 and 1280).
if 'photo of a' in text[:15]:
print(text)
text_embedding_1 = pipe.text_encoder(torch.cuda.LongTensor(tokens).reshape(1, -1))[0][:,
4:].reshape(-1, 768)
text_embedding_2 = pipe.text_encoder_2(torch.cuda.LongTensor(tokens).reshape(1, -1))[1][:,
4:].reshape(-1, 1280)
uc.append(torch.concat([text_embedding_1, text_embedding_2], dim=-1))
else:
text_embedding_1 = pipe.text_encoder(torch.cuda.LongTensor(tokens).reshape(1, -1))[0][:,
1:].reshape(-1, 768)
text_embedding_2 = pipe.text_encoder_2(torch.cuda.LongTensor(tokens).reshape(1, -1))[1][:,
1:].reshape(-1, 1280)
uc.append(torch.concat([text_embedding_1, text_embedding_2], dim=-1))
return torch.cat(uc, 0).float()

embeds = {}
count = 1
model2_sts = []
modifier_tokens = []
modifier_token_ids = []
categories = []
for path1, cat1 in zip(paths.split('+'), category.split('+')):
model2_st = torch.load(path1)
if 'modifier_token' in model2_st:
# composition of models with individual concept only
key = list(model2_st['modifier_token'].keys())[0]
_ = tokenizer.add_tokens(f'<new{count}>')
_ = tokenizer_2.add_tokens(f'<new{count}>')
modifier_token_ids.append(tokenizer.convert_tokens_to_ids(f'<new{count}>'))
modifier_tokens.append(True)
embeds[f'<new{count}>'] = model2_st['modifier_token'][key]
else:
modifier_tokens.append(False)

model2_sts.append(model2_st['unet'])
categories.append(cat1)
count += 1

pipe.text_encoder.resize_token_embeddings(len(tokenizer))
pipe.text_encoder_2.resize_token_embeddings(len(tokenizer_2))

token_embeds = pipe.text_encoder.get_input_embeddings().weight.data
for (x, y) in zip(modifier_token_ids, list(embeds.keys())):
token_embeds[x] = embeds[y][0]
print(x, y, "added embeddings")

pipe.text_encoder_2.resize_token_embeddings(len(tokenizer))
token_embeds_2 = pipe.text_encoder_2.get_input_embeddings().weight.data
for (x, y) in zip(modifier_token_ids, list(embeds.keys())):
token_embeds_2[x] = embeds[y][1]
print(x, y, "added embeddings")

f = open(regularization_prompt, 'r')
prompt = [x.strip() for x in f.readlines()][:200]
uc = get_text_embedding(prompt)

uc_targets = []
from collections import defaultdict
uc_values = defaultdict(list)
for composing_model_count in range(len(model2_sts)):
category = categories[composing_model_count]
if modifier_tokens[composing_model_count]:
string1 = f'<new{composing_model_count + 1}> {category}'
else:
string1 = f'{category}'
if 'art' in string1:
prompt = [string1] + [f"painting in the style of {string1}"]
else:
prompt = [string1] + [f"photo of a {string1}"]
uc_targets.append(get_text_embedding(prompt))
for each in layers_modified:
uc_values[each].append((model2_sts[composing_model_count][each].to(device) @ uc_targets[-1].T).T)

uc_targets = torch.cat(uc_targets, 0)

removal_indices = []
for i in range(uc_targets.size(0)):
for j in range(i + 1, uc_targets.size(0)):
if (uc_targets[i] - uc_targets[j]).abs().mean() == 0:
removal_indices.append(j)

removal_indices = list(set(removal_indices))
uc_targets = torch.stack([uc_targets[i] for i in range(uc_targets.size(0)) if i not in removal_indices], 0)
for each in layers_modified:
uc_values[each] = torch.cat(uc_values[each], 0)
uc_values[each] = torch.stack(
[uc_values[each][i] for i in range(uc_values[each].size(0)) if i not in removal_indices], 0)
print(uc_values[each].size(), each)

print("target size:", uc_targets.size())

new_weights = {'unet': {}}
for each in layers_modified:
W = pipe.unet.state_dict()[each].float()
values = (W @ uc.T).T # W(C_reg)
input_target = uc_targets
output_target = uc_values[each]

Wnew = gdupdateWexact(uc[:values.shape[0]],
values,
input_target,
output_target,
W.clone(),
)

new_weights['unet'][each] = Wnew
print(Wnew.size())

new_weights['modifier_token'] = embeds
os.makedirs(f'{save_path}/{outpath}', exist_ok=True)
torch.save(new_weights, f'{save_path}/{outpath}/delta.bin')

if prompts is not None:
if os.path.exists(prompts):
diffusers_sample.sample(model_id, f'{save_path}/{outpath}/delta.bin', prompts, prompt=None, compress=False,
freeze_model='crossattn_kv', batch_size=1)
else:
diffusers_sample.sample(model_id, f'{save_path}/{outpath}/delta.bin', from_file=None, prompt=prompts,
compress=False, freeze_model='crossattn_kv', batch_size=1)


def parse_args():
parser = argparse.ArgumentParser('', add_help=False)
parser.add_argument('--paths', help='+ separated list of checkpoints', required=True,
type=str)
parser.add_argument('--save_path', help='folder name to save optimized weights', default='optimized_logs',
type=str)
parser.add_argument('--categories', help='+ separated list of categories of the models', required=True,
type=str)
parser.add_argument('--prompts', help='prompts for composition model (can be a file or string)', default=None,
type=str)
parser.add_argument('--ckpt', required=True,
type=str)
parser.add_argument('--regularization_prompt', default='./data/regularization_captions.txt',
type=str)
return parser.parse_args()


if __name__ == "__main__":
args = parse_args()
paths = args.paths
categories = args.categories
if ' ' in categories:
temp = categories.replace(' ', '_')
else:
temp = categories
outpath = '_'.join(['optimized', temp])
compose(paths, categories, outpath, args.ckpt, args.regularization_prompt, args.prompts, args.save_path)
21 changes: 18 additions & 3 deletions src/diffusers_model_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@
# limitations under the License.
from typing import Callable, Optional
import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection, \
CLIPVisionModelWithProjection, CLIPImageProcessor
from accelerate.logging import get_logger

from diffusers.models import AutoencoderKL, UNet2DConditionModel
Expand Down Expand Up @@ -551,6 +552,8 @@ def __init__(
tokenizer_2: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
image_encoder: CLIPVisionModelWithProjection = None,
feature_extractor: CLIPImageProcessor = None,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
modifier_token: list = [],
Expand All @@ -564,9 +567,21 @@ def __init__(
tokenizer_2,
unet,
scheduler,
force_zeros_for_empty_prompt,
add_watermarker,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
force_zeros_for_empty_prompt=force_zeros_for_empty_prompt,
add_watermarker=add_watermarker,
)
# super().__init__(vae,
# text_encoder,
# text_encoder_2,
# tokenizer,
# tokenizer_2,
# unet,
# scheduler,
# force_zeros_for_empty_prompt,
# add_watermarker,
# )

# change attn class
self.modifier_token = modifier_token
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import torch
from PIL import Image
from diffusers import StableDiffusionXLPipeline

sys.path.append('./')
from src.diffusers_model_pipeline import CustomDiffusionPipeline, CustomDiffusionXLPipeline
Expand All @@ -18,11 +19,12 @@
def sample(ckpt, delta_ckpt, from_file, prompt, compress, batch_size, freeze_model, sdxl=False):
model_id = ckpt
if sdxl:
pipe = CustomDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
pipe = CustomDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
# pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
else:
pipe = CustomDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
pipe.load_model(delta_ckpt, compress)

outdir = os.path.dirname(delta_ckpt)
generator = torch.Generator(device='cuda').manual_seed(42)

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@
from src.diffusers_data_pipeline import CustomDiffusionDataset, PromptDataset, collate_fn
from src import retrieve

check_min_version("0.21.4")
# check_min_version("0.21.4")

logger = get_logger(__name__)

Expand Down
9 changes: 4 additions & 5 deletions src/diffusers_training_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# Apache License
# Version 2.0, January 2004
# http://www.apache.org/licenses/

# TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

# 1. Definitions.
Expand Down Expand Up @@ -220,6 +219,7 @@
import math
import os
import shutil
import json
import warnings
from pathlib import Path

Expand Down Expand Up @@ -259,7 +259,7 @@
from src import retrieve

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.4")
# check_min_version("0.21.4")

logger = get_logger(__name__)

Expand All @@ -269,13 +269,13 @@ def create_custom_diffusion(unet, freeze_model):
if freeze_model == 'crossattn':
if 'attn2' in name:
params.requires_grad = True
print(name)
# print(name)
else:
params.requires_grad = False
elif freeze_model == "crossattn_kv":
if 'attn2.to_k' in name or 'attn2.to_v' in name:
params.requires_grad = True
print(name)
# print(name)
else:
params.requires_grad = False
else:
Expand Down Expand Up @@ -830,7 +830,6 @@ def main(args):
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
):
images = pipeline(example["prompt"]).images

for i, image in enumerate(images):
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
Expand Down
Loading