212
212
# limitations under the License.
213
213
from typing import Callable , Optional
214
214
import torch
215
- from transformers import CLIPFeatureExtractor , CLIPTextModel , CLIPTokenizer
215
+ from transformers import CLIPFeatureExtractor , CLIPTextModel , CLIPTokenizer , CLIPTextModelWithProjection
216
216
from accelerate .logging import get_logger
217
217
218
218
from diffusers .models import AutoencoderKL , UNet2DConditionModel
219
219
from diffusers .schedulers .scheduling_utils import SchedulerMixin
220
+ from diffusers .schedulers import KarrasDiffusionSchedulers
220
221
from diffusers .pipelines .stable_diffusion import StableDiffusionPipeline
222
+ from diffusers .pipelines .stable_diffusion_xl import StableDiffusionXLPipeline
221
223
from diffusers .pipelines .stable_diffusion .safety_checker import StableDiffusionSafetyChecker
222
- from diffusers .models .cross_attention import CrossAttention
224
+ from diffusers .models .attention import Attention
223
225
from diffusers .utils .import_utils import is_xformers_available
224
226
225
227
if is_xformers_available ():
@@ -277,7 +279,7 @@ def set_use_memory_efficient_attention_xformers(
277
279
class CustomDiffusionAttnProcessor :
278
280
def __call__ (
279
281
self ,
280
- attn : CrossAttention ,
282
+ attn : Attention ,
281
283
hidden_states ,
282
284
encoder_hidden_states = None ,
283
285
attention_mask = None ,
@@ -291,8 +293,8 @@ def __call__(
291
293
encoder_hidden_states = hidden_states
292
294
else :
293
295
crossattn = True
294
- if attn .cross_attention_norm :
295
- encoder_hidden_states = attn .norm_cross (encoder_hidden_states )
296
+ if attn .norm_cross :
297
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
296
298
297
299
key = attn .to_k (encoder_hidden_states )
298
300
value = attn .to_v (encoder_hidden_states )
@@ -322,7 +324,7 @@ class CustomDiffusionXFormersAttnProcessor:
322
324
def __init__ (self , attention_op : Optional [Callable ] = None ):
323
325
self .attention_op = attention_op
324
326
325
- def __call__ (self , attn : CrossAttention , hidden_states , encoder_hidden_states = None , attention_mask = None ):
327
+ def __call__ (self , attn : Attention , hidden_states , encoder_hidden_states = None , attention_mask = None ):
326
328
batch_size , sequence_length , _ = hidden_states .shape
327
329
328
330
attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
@@ -496,3 +498,144 @@ def load_model(self, save_path, compress=False):
496
498
params .data += st ['unet' ][name ]['u' ]@st ['unet' ][name ]['v' ]
497
499
elif name in st ['unet' ]:
498
500
params .data .copy_ (st ['unet' ][f'{ name } ' ])
501
+
502
+
503
+ class CustomDiffusionXLPipeline (StableDiffusionXLPipeline ):
504
+ r"""
505
+ Pipeline for custom diffusion model.
506
+
507
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
508
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.).
509
+
510
+ Args:
511
+ vae ([`AutoencoderKL`]):
512
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
513
+ text_encoder ([`CLIPTextModel`]):
514
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
515
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
516
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
517
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
518
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
519
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
520
+ specifically the
521
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
522
+ variant.
523
+ tokenizer (`CLIPTokenizer`):
524
+ Tokenizer of class
525
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
526
+ tokenizer_2 (`CLIPTokenizer`):
527
+ Second Tokenizer of class
528
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
529
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
530
+ scheduler ([`SchedulerMixin`]):
531
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
532
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
533
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
534
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
535
+ `stabilityai/stable-diffusion-xl-base-1-0`.
536
+ add_watermarker (`bool`, *optional*):
537
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
538
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
539
+ watermarker will be used.
540
+ modifier_token: list of new modifier tokens added or to be added to text_encoder
541
+ modifier_token_id: list of id of new modifier tokens added or to be added to text_encoder
542
+ modifier_token_id_2: list of id of new modifier tokens added or to be added to text_encoder_2
543
+ """
544
+
545
+ def __init__ (
546
+ self ,
547
+ vae : AutoencoderKL ,
548
+ text_encoder : CLIPTextModel ,
549
+ text_encoder_2 : CLIPTextModelWithProjection ,
550
+ tokenizer : CLIPTokenizer ,
551
+ tokenizer_2 : CLIPTokenizer ,
552
+ unet : UNet2DConditionModel ,
553
+ scheduler : KarrasDiffusionSchedulers ,
554
+ force_zeros_for_empty_prompt : bool = True ,
555
+ add_watermarker : Optional [bool ] = None ,
556
+ modifier_token : list = [],
557
+ modifier_token_id : list = [],
558
+ modifier_token_id_2 : list = []
559
+ ):
560
+ super ().__init__ (vae ,
561
+ text_encoder ,
562
+ text_encoder_2 ,
563
+ tokenizer ,
564
+ tokenizer_2 ,
565
+ unet ,
566
+ scheduler ,
567
+ force_zeros_for_empty_prompt ,
568
+ add_watermarker ,
569
+ )
570
+
571
+ # change attn class
572
+ self .modifier_token = modifier_token
573
+ self .modifier_token_id = modifier_token_id
574
+ self .modifier_token_id_2 = modifier_token_id_2
575
+
576
+ def save_pretrained (self , save_path , freeze_model = "crossattn_kv" , save_text_encoder = False , all = False ):
577
+ if all :
578
+ super ().save_pretrained (save_path )
579
+ else :
580
+ delta_dict = {'unet' : {}, 'modifier_token' : {}}
581
+ if self .modifier_token is not None :
582
+ print (self .modifier_token_id , self .modifier_token )
583
+ for i in range (len (self .modifier_token_id )):
584
+ delta_dict ['modifier_token' ][self .modifier_token [i ]] = []
585
+ learned_embeds = self .text_encoder .get_input_embeddings ().weight [self .modifier_token_id [i ]]
586
+ learned_embeds_2 = self .text_encoder_2 .get_input_embeddings ().weight [self .modifier_token_id_2 [i ]]
587
+ delta_dict ['modifier_token' ][self .modifier_token [i ]].append (learned_embeds .detach ().cpu ())
588
+ delta_dict ['modifier_token' ][self .modifier_token [i ]].append (learned_embeds_2 .detach ().cpu ())
589
+ if save_text_encoder :
590
+ delta_dict ['text_encoder' ] = self .text_encoder .state_dict ()
591
+ delta_dict ['text_encoder_2' ] = self .text_encoder_2 .state_dict ()
592
+ for name , params in self .unet .named_parameters ():
593
+ if freeze_model == "crossattn" :
594
+ if 'attn2' in name :
595
+ delta_dict ['unet' ][name ] = params .cpu ().clone ()
596
+ elif freeze_model == "crossattn_kv" :
597
+ if 'attn2.to_k' in name or 'attn2.to_v' in name :
598
+ delta_dict ['unet' ][name ] = params .cpu ().clone ()
599
+ else :
600
+ raise ValueError (
601
+ "freeze_model argument only supports crossattn_kv or crossattn"
602
+ )
603
+ torch .save (delta_dict , save_path )
604
+
605
+ def load_model (self , save_path , compress = False ):
606
+ st = torch .load (save_path )
607
+ if 'text_encoder' in st :
608
+ self .text_encoder .load_state_dict (st ['text_encoder' ])
609
+ self .text_encoder_2 .load_state_dict (st ['text_encoder_2' ])
610
+ if 'modifier_token' in st :
611
+ modifier_tokens = list (st ['modifier_token' ].keys ())
612
+ modifier_token_id = []
613
+ modifier_token_id_2 = []
614
+ for modifier_token in modifier_tokens :
615
+ num_added_tokens = self .tokenizer .add_tokens (modifier_token )
616
+ num_added_tokens_2 = self .tokenizer_2 .add_tokens (modifier_token )
617
+ if num_added_tokens == 0 or num_added_tokens_2 == 0 :
618
+ raise ValueError (
619
+ f"The tokenizer already contains the token { modifier_token } . Please pass a different"
620
+ " `modifier_token` that is not already in the tokenizer."
621
+ )
622
+
623
+ modifier_token_id .append (self .tokenizer .convert_tokens_to_ids (modifier_token ))
624
+ modifier_token_id_2 .append (self .tokenizer_2 .convert_tokens_to_ids (modifier_token ))
625
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
626
+ self .text_encoder .resize_token_embeddings (len (self .tokenizer ))
627
+ self .text_encoder_2 .resize_token_embeddings (len (self .tokenizer_2 ))
628
+ token_embeds = self .text_encoder .get_input_embeddings ().weight .data
629
+ for i , id_ in enumerate (modifier_token_id ):
630
+ token_embeds [id_ ] = st ['modifier_token' ][modifier_tokens [i ]][0 ]
631
+ token_embeds = self .text_encoder_2 .get_input_embeddings ().weight .data
632
+ for i , id_ in enumerate (modifier_token_id_2 ):
633
+ token_embeds [id_ ] = st ['modifier_token' ][modifier_tokens [i ]][1 ]
634
+
635
+ for name , params in self .unet .named_parameters ():
636
+ if 'attn2' in name :
637
+ if compress and ('to_k' in name or 'to_v' in name ):
638
+ params .data += st ['unet' ][name ]['u' ]@st ['unet' ][name ]['v' ]
639
+ elif name in st ['unet' ]:
640
+ params .data .copy_ (st ['unet' ][f'{ name } ' ])
641
+
0 commit comments