diff --git a/README.md b/README.md index ddbb642..477c72d 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,12 @@ -# Custom Diffusion +# Custom Diffusion for SDXL + +This repository is forked from Custom Diffusion. + +This repository fix some bugs in the original Custom Diffusion to combine with Stable Diffusion XL(SDXL). + + + +# Custom Diffusion (Original) ### [website](https://www.cs.cmu.edu/~custom-diffusion/) | [paper](http://arxiv.org/abs/2212.04488) diff --git a/inference.sh b/inference.sh new file mode 100644 index 0000000..10c6a25 --- /dev/null +++ b/inference.sh @@ -0,0 +1,2 @@ + +CUDA_VISIBLE_DEVICES=0 python src/diffusers_sample.py --delta_ckpt logs/wooden_pot/delta.bin --sdxl --ckpt "/data/home/chensh/data/huggingface_model/stable-diffusion-xl-base-1.0" --prompt " cat playing with a ball" \ No newline at end of file diff --git a/src/diffusers_composenW_sdxl.py b/src/diffusers_composenW_sdxl.py new file mode 100644 index 0000000..5a2db5a --- /dev/null +++ b/src/diffusers_composenW_sdxl.py @@ -0,0 +1,222 @@ +# Copyright 2022 Adobe Research. All rights reserved. +# To view a copy of the license, visit LICENSE.md. + + +import sys +import os +import argparse +import torch +from scipy.linalg import lu_factor, lu_solve + +sys.path.append('./') +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline +from src import diffusers_sample + + +def gdupdateWexact(K, V, Ktarget1, Vtarget1, W, device='cuda'): + input_ = K + output = V + C = input_.T @ input_ + d = [] + lu, piv = lu_factor(C.cpu().numpy()) + for i in range(Ktarget1.size(0)): + sol = lu_solve((lu, piv), Ktarget1[i].reshape(-1, 1).cpu().numpy()) + d.append(torch.from_numpy(sol).to(K.device)) + + d = torch.cat(d, 1).T + + e2 = d @ Ktarget1.T + e1 = (Vtarget1.T - W @ Ktarget1.T) + delta = e1 @ torch.linalg.inv(e2) + + Wnew = W + delta @ d + lambda_split1 = Vtarget1.size(0) + + input_ = torch.cat([Ktarget1.T, K.T], dim=1) + output = torch.cat([Vtarget1, V], dim=0) + + loss = torch.norm((Wnew @ input_).T - output, 2, dim=1) + print(loss[:lambda_split1].mean().item(), loss[lambda_split1:].mean().item()) + + return Wnew + + +def compose(paths, category, outpath, pretrained_model_path, regularization_prompt, prompts, save_path, device='cuda'): + model_id = pretrained_model_path + pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") + + layers_modified = [] + for name, param in pipe.unet.named_parameters(): + if 'attn2.to_k' in name or 'attn2.to_v' in name: + layers_modified.append(name) + + tokenizer = pipe.tokenizer + tokenizer_2 = pipe.tokenizer_2 + + def get_text_embedding(prompts): + with torch.no_grad(): + uc = [] + for text in prompts: + tokens = tokenizer(text, + truncation=True, + max_length=tokenizer.model_max_length, + return_length=True, + return_overflowing_tokens=False, + padding="do_not_pad", + ).input_ids + tokens_2 = tokenizer_2(text, + truncation=True, + max_length=tokenizer.model_max_length, + return_length=True, + return_overflowing_tokens=False, + padding="do_not_pad", + ).input_ids + # In sdxl, the text embedding should be 2048 (the concat of 768 and 1280). + if 'photo of a' in text[:15]: + print(text) + text_embedding_1 = pipe.text_encoder(torch.cuda.LongTensor(tokens).reshape(1, -1))[0][:, + 4:].reshape(-1, 768) + text_embedding_2 = pipe.text_encoder_2(torch.cuda.LongTensor(tokens).reshape(1, -1))[1][:, + 4:].reshape(-1, 1280) + uc.append(torch.concat([text_embedding_1, text_embedding_2], dim=-1)) + else: + text_embedding_1 = pipe.text_encoder(torch.cuda.LongTensor(tokens).reshape(1, -1))[0][:, + 1:].reshape(-1, 768) + text_embedding_2 = pipe.text_encoder_2(torch.cuda.LongTensor(tokens).reshape(1, -1))[1][:, + 1:].reshape(-1, 1280) + uc.append(torch.concat([text_embedding_1, text_embedding_2], dim=-1)) + return torch.cat(uc, 0).float() + + embeds = {} + count = 1 + model2_sts = [] + modifier_tokens = [] + modifier_token_ids = [] + categories = [] + for path1, cat1 in zip(paths.split('+'), category.split('+')): + model2_st = torch.load(path1) + if 'modifier_token' in model2_st: + # composition of models with individual concept only + key = list(model2_st['modifier_token'].keys())[0] + _ = tokenizer.add_tokens(f'') + _ = tokenizer_2.add_tokens(f'') + modifier_token_ids.append(tokenizer.convert_tokens_to_ids(f'')) + modifier_tokens.append(True) + embeds[f''] = model2_st['modifier_token'][key] + else: + modifier_tokens.append(False) + + model2_sts.append(model2_st['unet']) + categories.append(cat1) + count += 1 + + pipe.text_encoder.resize_token_embeddings(len(tokenizer)) + pipe.text_encoder_2.resize_token_embeddings(len(tokenizer_2)) + + token_embeds = pipe.text_encoder.get_input_embeddings().weight.data + for (x, y) in zip(modifier_token_ids, list(embeds.keys())): + token_embeds[x] = embeds[y][0] + print(x, y, "added embeddings") + + pipe.text_encoder_2.resize_token_embeddings(len(tokenizer)) + token_embeds_2 = pipe.text_encoder_2.get_input_embeddings().weight.data + for (x, y) in zip(modifier_token_ids, list(embeds.keys())): + token_embeds_2[x] = embeds[y][1] + print(x, y, "added embeddings") + + f = open(regularization_prompt, 'r') + prompt = [x.strip() for x in f.readlines()][:200] + uc = get_text_embedding(prompt) + + uc_targets = [] + from collections import defaultdict + uc_values = defaultdict(list) + for composing_model_count in range(len(model2_sts)): + category = categories[composing_model_count] + if modifier_tokens[composing_model_count]: + string1 = f' {category}' + else: + string1 = f'{category}' + if 'art' in string1: + prompt = [string1] + [f"painting in the style of {string1}"] + else: + prompt = [string1] + [f"photo of a {string1}"] + uc_targets.append(get_text_embedding(prompt)) + for each in layers_modified: + uc_values[each].append((model2_sts[composing_model_count][each].to(device) @ uc_targets[-1].T).T) + + uc_targets = torch.cat(uc_targets, 0) + + removal_indices = [] + for i in range(uc_targets.size(0)): + for j in range(i + 1, uc_targets.size(0)): + if (uc_targets[i] - uc_targets[j]).abs().mean() == 0: + removal_indices.append(j) + + removal_indices = list(set(removal_indices)) + uc_targets = torch.stack([uc_targets[i] for i in range(uc_targets.size(0)) if i not in removal_indices], 0) + for each in layers_modified: + uc_values[each] = torch.cat(uc_values[each], 0) + uc_values[each] = torch.stack( + [uc_values[each][i] for i in range(uc_values[each].size(0)) if i not in removal_indices], 0) + print(uc_values[each].size(), each) + + print("target size:", uc_targets.size()) + + new_weights = {'unet': {}} + for each in layers_modified: + W = pipe.unet.state_dict()[each].float() + values = (W @ uc.T).T # W(C_reg) + input_target = uc_targets + output_target = uc_values[each] + + Wnew = gdupdateWexact(uc[:values.shape[0]], + values, + input_target, + output_target, + W.clone(), + ) + + new_weights['unet'][each] = Wnew + print(Wnew.size()) + + new_weights['modifier_token'] = embeds + os.makedirs(f'{save_path}/{outpath}', exist_ok=True) + torch.save(new_weights, f'{save_path}/{outpath}/delta.bin') + + if prompts is not None: + if os.path.exists(prompts): + diffusers_sample.sample(model_id, f'{save_path}/{outpath}/delta.bin', prompts, prompt=None, compress=False, + freeze_model='crossattn_kv', batch_size=1) + else: + diffusers_sample.sample(model_id, f'{save_path}/{outpath}/delta.bin', from_file=None, prompt=prompts, + compress=False, freeze_model='crossattn_kv', batch_size=1) + + +def parse_args(): + parser = argparse.ArgumentParser('', add_help=False) + parser.add_argument('--paths', help='+ separated list of checkpoints', required=True, + type=str) + parser.add_argument('--save_path', help='folder name to save optimized weights', default='optimized_logs', + type=str) + parser.add_argument('--categories', help='+ separated list of categories of the models', required=True, + type=str) + parser.add_argument('--prompts', help='prompts for composition model (can be a file or string)', default=None, + type=str) + parser.add_argument('--ckpt', required=True, + type=str) + parser.add_argument('--regularization_prompt', default='./data/regularization_captions.txt', + type=str) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + paths = args.paths + categories = args.categories + if ' ' in categories: + temp = categories.replace(' ', '_') + else: + temp = categories + outpath = '_'.join(['optimized', temp]) + compose(paths, categories, outpath, args.ckpt, args.regularization_prompt, args.prompts, args.save_path) diff --git a/src/diffusers_model_pipeline.py b/src/diffusers_model_pipeline.py index 4a43e07..61baa6c 100644 --- a/src/diffusers_model_pipeline.py +++ b/src/diffusers_model_pipeline.py @@ -212,7 +212,8 @@ # limitations under the License. from typing import Callable, Optional import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection, \ + CLIPVisionModelWithProjection, CLIPImageProcessor from accelerate.logging import get_logger from diffusers.models import AutoencoderKL, UNet2DConditionModel @@ -551,6 +552,8 @@ def __init__( tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, modifier_token: list = [], @@ -564,9 +567,21 @@ def __init__( tokenizer_2, unet, scheduler, - force_zeros_for_empty_prompt, - add_watermarker, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + force_zeros_for_empty_prompt=force_zeros_for_empty_prompt, + add_watermarker=add_watermarker, ) + # super().__init__(vae, + # text_encoder, + # text_encoder_2, + # tokenizer, + # tokenizer_2, + # unet, + # scheduler, + # force_zeros_for_empty_prompt, + # add_watermarker, + # ) # change attn class self.modifier_token = modifier_token diff --git a/src/diffusers_sample.py b/src/diffusers_sample.py index f3c2931..26b5d18 100644 --- a/src/diffusers_sample.py +++ b/src/diffusers_sample.py @@ -10,6 +10,7 @@ import numpy as np import torch from PIL import Image +from diffusers import StableDiffusionXLPipeline sys.path.append('./') from src.diffusers_model_pipeline import CustomDiffusionPipeline, CustomDiffusionXLPipeline @@ -18,11 +19,12 @@ def sample(ckpt, delta_ckpt, from_file, prompt, compress, batch_size, freeze_model, sdxl=False): model_id = ckpt if sdxl: - pipe = CustomDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") + pipe = CustomDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16) + # pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16) + pipe = pipe.to("cuda") else: pipe = CustomDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") pipe.load_model(delta_ckpt, compress) - outdir = os.path.dirname(delta_ckpt) generator = torch.Generator(device='cuda').manual_seed(42) diff --git a/src/diffusers_training.py b/src/diffusers_training.py index 6b207b3..565d50b 100644 --- a/src/diffusers_training.py +++ b/src/diffusers_training.py @@ -246,7 +246,7 @@ from src.diffusers_data_pipeline import CustomDiffusionDataset, PromptDataset, collate_fn from src import retrieve -check_min_version("0.21.4") +# check_min_version("0.21.4") logger = get_logger(__name__) diff --git a/src/diffusers_training_sdxl.py b/src/diffusers_training_sdxl.py index ca601fa..b7654f6 100644 --- a/src/diffusers_training_sdxl.py +++ b/src/diffusers_training_sdxl.py @@ -12,7 +12,6 @@ # Apache License # Version 2.0, January 2004 # http://www.apache.org/licenses/ - # TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION # 1. Definitions. @@ -220,6 +219,7 @@ import math import os import shutil +import json import warnings from pathlib import Path @@ -259,7 +259,7 @@ from src import retrieve # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.21.4") +# check_min_version("0.21.4") logger = get_logger(__name__) @@ -269,13 +269,13 @@ def create_custom_diffusion(unet, freeze_model): if freeze_model == 'crossattn': if 'attn2' in name: params.requires_grad = True - print(name) + # print(name) else: params.requires_grad = False elif freeze_model == "crossattn_kv": if 'attn2.to_k' in name or 'attn2.to_v' in name: params.requires_grad = True - print(name) + # print(name) else: params.requires_grad = False else: @@ -830,7 +830,6 @@ def main(args): sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): images = pipeline(example["prompt"]).images - for i, image in enumerate(images): hash_image = hashlib.sha1(image.tobytes()).hexdigest() image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" diff --git a/train.sh b/train.sh new file mode 100644 index 0000000..af9e5ff --- /dev/null +++ b/train.sh @@ -0,0 +1,37 @@ +## launch training script (2 GPUs recommended, increase --max_train_steps to 500 if 1 GPU) +export CUDA_VISIBLE_DEVICES=0 +export MODEL_NAME="/data/home/chensh/data/huggingface_model/stable-diffusion-xl-base-1.0" + +export INSTANCE_DIR="./data/cat" +export INSTANCE_PROMPT="photo of a cat" +export CLASS_DIR="./sample_reg/samples_cat/" +export CLASS_PROMPT="cat" +export OUTPUT_DIR="./logs/cat" +export modifier_token="" + +#export INSTANCE_DIR="./data/wooden_pot" +#export INSTANCE_PROMPT="photo of a wooden pot" +#export CLASS_DIR="./data/prior_woodenpot/" +#export CLASS_PROMPT="wooden pot" +#export OUTPUT_DIR="./logs/wooden_pot" +#export modifier_token="" + +accelerate launch src/diffusers_training_sdxl.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="${INSTANCE_PROMPT}" \ + --class_prompt="${CLASS_PROMPT}" \ + --resolution=1024 \ + --train_batch_size=1 \ + --learning_rate=1e-5 \ + --lr_warmup_steps=0 \ + --max_train_steps=1000 \ + --num_class_images=200 \ + --scale_lr --hflip \ + --modifier_token="${modifier_token}" + +### sample +#python src/diffusers_sample.py --delta_ckpt logs/cat/delta.bin --ckpt "CompVis/stable-diffusion-v1-4" --prompt " cat playing with a ball" \ No newline at end of file