|
212 | 212 | # limitations under the License.
|
213 | 213 | from typing import Callable, Optional
|
214 | 214 | import torch
|
215 |
| -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection |
| 215 | +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection, \ |
| 216 | + CLIPVisionModelWithProjection, CLIPImageProcessor |
216 | 217 | from accelerate.logging import get_logger
|
217 | 218 |
|
218 | 219 | from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
@@ -551,22 +552,36 @@ def __init__(
|
551 | 552 | tokenizer_2: CLIPTokenizer,
|
552 | 553 | unet: UNet2DConditionModel,
|
553 | 554 | scheduler: KarrasDiffusionSchedulers,
|
| 555 | + image_encoder: CLIPVisionModelWithProjection = None, |
| 556 | + feature_extractor: CLIPImageProcessor = None, |
554 | 557 | force_zeros_for_empty_prompt: bool = True,
|
555 | 558 | add_watermarker: Optional[bool] = None,
|
556 | 559 | modifier_token: list = [],
|
557 | 560 | modifier_token_id: list = [],
|
558 | 561 | modifier_token_id_2: list = []
|
559 | 562 | ):
|
560 |
| - super().__init__(vae=vae, |
561 |
| - text_encoder=text_encoder, |
562 |
| - text_encoder_2=text_encoder_2, |
563 |
| - tokenizer=tokenizer, |
564 |
| - tokenizer_2=tokenizer_2, |
565 |
| - unet=unet, |
566 |
| - scheduler=scheduler, |
| 563 | + super().__init__(vae, |
| 564 | + text_encoder, |
| 565 | + text_encoder_2, |
| 566 | + tokenizer, |
| 567 | + tokenizer_2, |
| 568 | + unet, |
| 569 | + scheduler, |
| 570 | + image_encoder=image_encoder, |
| 571 | + feature_extractor=feature_extractor, |
567 | 572 | force_zeros_for_empty_prompt=force_zeros_for_empty_prompt,
|
568 | 573 | add_watermarker=add_watermarker,
|
569 | 574 | )
|
| 575 | + # super().__init__(vae, |
| 576 | + # text_encoder, |
| 577 | + # text_encoder_2, |
| 578 | + # tokenizer, |
| 579 | + # tokenizer_2, |
| 580 | + # unet, |
| 581 | + # scheduler, |
| 582 | + # force_zeros_for_empty_prompt, |
| 583 | + # add_watermarker, |
| 584 | + # ) |
570 | 585 |
|
571 | 586 | # change attn class
|
572 | 587 | self.modifier_token = modifier_token
|
|
0 commit comments