Skip to content

Commit 886b514

Browse files
Nupur KumariNupur Kumari
authored andcommitted
update
1 parent 4345c28 commit 886b514

File tree

5 files changed

+1602
-15
lines changed

5 files changed

+1602
-15
lines changed

README.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
**[NEW!]** CustomConcept101 dataset. We release a new dataset of 101 concepts along with their evaluation prompts. For more details please refer [here](customconcept101/README.md).
99

10+
**[NEW!]** Custom Diffusion with SDXL. Diffusers code now with updated diffusers==0.21.4.
1011

1112
<br>
1213
<div class="gif">
@@ -182,11 +183,11 @@ python sample.py --prompt "the <new2> cat sculpture in the style of a <new1> woo
182183

183184
```
184185
## install requirements
185-
pip install accelerate
186+
pip install accelerate>=0.24.1
186187
pip install modelcards
187-
pip install transformers>=4.25.1
188+
pip install transformers>=4.31.0
188189
pip install deepspeed
189-
pip install diffusers==0.14.0
190+
pip install diffusers==0.21.4
190191
accelerate config
191192
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
192193
```
@@ -217,7 +218,7 @@ accelerate launch src/diffusers_training.py \
217218
python src/diffusers_sample.py --delta_ckpt logs/cat/delta.bin --ckpt "CompVis/stable-diffusion-v1-4" --prompt "<new1> cat playing with a ball"
218219
```
219220

220-
You can also use `--enable_xformers_memory_efficient_attention` and enable `fp16` during `accelerate config` for faster training with lower VRAM requirement.
221+
You can also use `--enable_xformers_memory_efficient_attention` and enable `fp16` during `accelerate config` for faster training with lower VRAM requirement. To train with SDXL use `diffusers_training_sdxl.py` with `MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"`.
221222

222223
**Multi-Concept fine-tuning**
223224

src/diffusers_model_pipeline.py

Lines changed: 149 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -212,14 +212,16 @@
212212
# limitations under the License.
213213
from typing import Callable, Optional
214214
import torch
215-
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
215+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
216216
from accelerate.logging import get_logger
217217

218218
from diffusers.models import AutoencoderKL, UNet2DConditionModel
219219
from diffusers.schedulers.scheduling_utils import SchedulerMixin
220+
from diffusers.schedulers import KarrasDiffusionSchedulers
220221
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
222+
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
221223
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
222-
from diffusers.models.cross_attention import CrossAttention
224+
from diffusers.models.attention import Attention
223225
from diffusers.utils.import_utils import is_xformers_available
224226

225227
if is_xformers_available():
@@ -277,7 +279,7 @@ def set_use_memory_efficient_attention_xformers(
277279
class CustomDiffusionAttnProcessor:
278280
def __call__(
279281
self,
280-
attn: CrossAttention,
282+
attn: Attention,
281283
hidden_states,
282284
encoder_hidden_states=None,
283285
attention_mask=None,
@@ -291,8 +293,8 @@ def __call__(
291293
encoder_hidden_states = hidden_states
292294
else:
293295
crossattn = True
294-
if attn.cross_attention_norm:
295-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
296+
if attn.norm_cross:
297+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
296298

297299
key = attn.to_k(encoder_hidden_states)
298300
value = attn.to_v(encoder_hidden_states)
@@ -322,7 +324,7 @@ class CustomDiffusionXFormersAttnProcessor:
322324
def __init__(self, attention_op: Optional[Callable] = None):
323325
self.attention_op = attention_op
324326

325-
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
327+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
326328
batch_size, sequence_length, _ = hidden_states.shape
327329

328330
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
@@ -496,3 +498,144 @@ def load_model(self, save_path, compress=False):
496498
params.data += st['unet'][name]['u']@st['unet'][name]['v']
497499
elif name in st['unet']:
498500
params.data.copy_(st['unet'][f'{name}'])
501+
502+
503+
class CustomDiffusionXLPipeline(StableDiffusionXLPipeline):
504+
r"""
505+
Pipeline for custom diffusion model.
506+
507+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
508+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.).
509+
510+
Args:
511+
vae ([`AutoencoderKL`]):
512+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
513+
text_encoder ([`CLIPTextModel`]):
514+
Frozen text-encoder. Stable Diffusion XL uses the text portion of
515+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
516+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
517+
text_encoder_2 ([` CLIPTextModelWithProjection`]):
518+
Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
519+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
520+
specifically the
521+
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
522+
variant.
523+
tokenizer (`CLIPTokenizer`):
524+
Tokenizer of class
525+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
526+
tokenizer_2 (`CLIPTokenizer`):
527+
Second Tokenizer of class
528+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
529+
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
530+
scheduler ([`SchedulerMixin`]):
531+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
532+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
533+
force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
534+
Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
535+
`stabilityai/stable-diffusion-xl-base-1-0`.
536+
add_watermarker (`bool`, *optional*):
537+
Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
538+
watermark output images. If not defined, it will default to True if the package is installed, otherwise no
539+
watermarker will be used.
540+
modifier_token: list of new modifier tokens added or to be added to text_encoder
541+
modifier_token_id: list of id of new modifier tokens added or to be added to text_encoder
542+
modifier_token_id_2: list of id of new modifier tokens added or to be added to text_encoder_2
543+
"""
544+
545+
def __init__(
546+
self,
547+
vae: AutoencoderKL,
548+
text_encoder: CLIPTextModel,
549+
text_encoder_2: CLIPTextModelWithProjection,
550+
tokenizer: CLIPTokenizer,
551+
tokenizer_2: CLIPTokenizer,
552+
unet: UNet2DConditionModel,
553+
scheduler: KarrasDiffusionSchedulers,
554+
force_zeros_for_empty_prompt: bool = True,
555+
add_watermarker: Optional[bool] = None,
556+
modifier_token: list = [],
557+
modifier_token_id: list = [],
558+
modifier_token_id_2: list = []
559+
):
560+
super().__init__(vae,
561+
text_encoder,
562+
text_encoder_2,
563+
tokenizer,
564+
tokenizer_2,
565+
unet,
566+
scheduler,
567+
force_zeros_for_empty_prompt,
568+
add_watermarker,
569+
)
570+
571+
# change attn class
572+
self.modifier_token = modifier_token
573+
self.modifier_token_id = modifier_token_id
574+
self.modifier_token_id_2 = modifier_token_id_2
575+
576+
def save_pretrained(self, save_path, freeze_model="crossattn_kv", save_text_encoder=False, all=False):
577+
if all:
578+
super().save_pretrained(save_path)
579+
else:
580+
delta_dict = {'unet': {}, 'modifier_token': {}}
581+
if self.modifier_token is not None:
582+
print(self.modifier_token_id, self.modifier_token)
583+
for i in range(len(self.modifier_token_id)):
584+
delta_dict['modifier_token'][self.modifier_token[i]] = []
585+
learned_embeds = self.text_encoder.get_input_embeddings().weight[self.modifier_token_id[i]]
586+
learned_embeds_2 = self.text_encoder_2.get_input_embeddings().weight[self.modifier_token_id_2[i]]
587+
delta_dict['modifier_token'][self.modifier_token[i]].append(learned_embeds.detach().cpu())
588+
delta_dict['modifier_token'][self.modifier_token[i]].append(learned_embeds_2.detach().cpu())
589+
if save_text_encoder:
590+
delta_dict['text_encoder'] = self.text_encoder.state_dict()
591+
delta_dict['text_encoder_2'] = self.text_encoder_2.state_dict()
592+
for name, params in self.unet.named_parameters():
593+
if freeze_model == "crossattn":
594+
if 'attn2' in name:
595+
delta_dict['unet'][name] = params.cpu().clone()
596+
elif freeze_model == "crossattn_kv":
597+
if 'attn2.to_k' in name or 'attn2.to_v' in name:
598+
delta_dict['unet'][name] = params.cpu().clone()
599+
else:
600+
raise ValueError(
601+
"freeze_model argument only supports crossattn_kv or crossattn"
602+
)
603+
torch.save(delta_dict, save_path)
604+
605+
def load_model(self, save_path, compress=False):
606+
st = torch.load(save_path)
607+
if 'text_encoder' in st:
608+
self.text_encoder.load_state_dict(st['text_encoder'])
609+
self.text_encoder_2.load_state_dict(st['text_encoder_2'])
610+
if 'modifier_token' in st:
611+
modifier_tokens = list(st['modifier_token'].keys())
612+
modifier_token_id = []
613+
modifier_token_id_2 = []
614+
for modifier_token in modifier_tokens:
615+
num_added_tokens = self.tokenizer.add_tokens(modifier_token)
616+
num_added_tokens_2 = self.tokenizer_2.add_tokens(modifier_token)
617+
if num_added_tokens == 0 or num_added_tokens_2 == 0:
618+
raise ValueError(
619+
f"The tokenizer already contains the token {modifier_token}. Please pass a different"
620+
" `modifier_token` that is not already in the tokenizer."
621+
)
622+
623+
modifier_token_id.append(self.tokenizer.convert_tokens_to_ids(modifier_token))
624+
modifier_token_id_2.append(self.tokenizer_2.convert_tokens_to_ids(modifier_token))
625+
# Resize the token embeddings as we are adding new special tokens to the tokenizer
626+
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
627+
self.text_encoder_2.resize_token_embeddings(len(self.tokenizer_2))
628+
token_embeds = self.text_encoder.get_input_embeddings().weight.data
629+
for i, id_ in enumerate(modifier_token_id):
630+
token_embeds[id_] = st['modifier_token'][modifier_tokens[i]][0]
631+
token_embeds = self.text_encoder_2.get_input_embeddings().weight.data
632+
for i, id_ in enumerate(modifier_token_id_2):
633+
token_embeds[id_] = st['modifier_token'][modifier_tokens[i]][1]
634+
635+
for name, params in self.unet.named_parameters():
636+
if 'attn2' in name:
637+
if compress and ('to_k' in name or 'to_v' in name):
638+
params.data += st['unet'][name]['u']@st['unet'][name]['v']
639+
elif name in st['unet']:
640+
params.data.copy_(st['unet'][f'{name}'])
641+

src/diffusers_sample.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,15 @@
1212
from PIL import Image
1313

1414
sys.path.append('./')
15-
from src.diffusers_model_pipeline import CustomDiffusionPipeline
15+
from src.diffusers_model_pipeline import CustomDiffusionPipeline, CustomDiffusionXLPipeline
1616

1717

18-
def sample(ckpt, delta_ckpt, from_file, prompt, compress, batch_size, freeze_model):
18+
def sample(ckpt, delta_ckpt, from_file, prompt, compress, batch_size, freeze_model, sdxl=False):
1919
model_id = ckpt
20-
pipe = CustomDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
20+
if sdxl:
21+
pipe = CustomDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
22+
else:
23+
pipe = CustomDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
2124
pipe.load_model(delta_ckpt, compress)
2225

2326
outdir = os.path.dirname(delta_ckpt)
@@ -63,6 +66,7 @@ def parse_args():
6366
parser.add_argument('--prompt', help='prompt to generate', default=None,
6467
type=str)
6568
parser.add_argument("--compress", action='store_true')
69+
parser.add_argument("--sdxl", action='store_true')
6670
parser.add_argument("--batch_size", default=5, type=int)
6771
parser.add_argument('--freeze_model', help='crossattn or crossattn_kv', default='crossattn_kv',
6872
type=str)
@@ -71,4 +75,4 @@ def parse_args():
7175

7276
if __name__ == "__main__":
7377
args = parse_args()
74-
sample(args.ckpt, args.delta_ckpt, args.from_file, args.prompt, args.compress, args.batch_size, args.freeze_model)
78+
sample(args.ckpt, args.delta_ckpt, args.from_file, args.prompt, args.compress, args.batch_size, args.freeze_model, args.sdxl)

src/diffusers_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@
246246
from src.diffusers_data_pipeline import CustomDiffusionDataset, PromptDataset, collate_fn
247247
from src import retrieve
248248

249-
check_min_version("0.14.0")
249+
check_min_version("0.21.4")
250250

251251
logger = get_logger(__name__)
252252

0 commit comments

Comments
 (0)