Skip to content

Commit be2cd9e

Browse files
author
Nupur Kumari
committed
diffusion update
1 parent 98d18e4 commit be2cd9e

File tree

3 files changed

+9
-18
lines changed

3 files changed

+9
-18
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ pip install clip-retrieval tqdm
107107

108108
Our code was developed on the following commit `#21f890f9da3cfbeaba8e2ac3c425ee9e998d5229` of [stable-diffusion](https://github.com/CompVis/stable-diffusion).
109109

110-
For downloading the stable-diffusion model checkpoint, please refer [here](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original).
110+
Download the stable-diffusion model checkpoint
111+
`wget https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt`
112+
For more details, please refer [here](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original).
111113

112114
**Dataset:** we release some of the datasets used in paper [here](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip).
113115
Images taken from UnSplash are under [UnSplash LICENSE](https://unsplash.com/license). Moongate dataset can be downloaded from [here](https://github.com/odegeasslbc/FastGAN-pytorch).
@@ -132,7 +134,7 @@ python src/get_deltas.py --path logs/<folder-name> --newtoken 1
132134
python sample.py --prompt "<new1> cat playing with a ball" --delta_ckpt logs/<folder-name>/checkpoints/delta_epoch\=000004.ckpt --ckpt <pretrained-model-path>
133135
```
134136

135-
Our results in the paper are not based on the [clip-retrieval](https://github.com/rom1504/clip-retrieval) for retrieving real images as the regularization samples. But this also leads to similar results.
137+
The `<pretrained-model-path>` is the path to the pretrained `sd-v1-4.ckpt` model. Our results in the paper are not based on the [clip-retrieval](https://github.com/rom1504/clip-retrieval) for retrieving real images as the regularization samples. But this also leads to similar results.
136138

137139
**Generated images as regularization**
138140
```

src/diffusers_model_pipeline.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -415,16 +415,6 @@ def __init__(
415415
requires_safety_checker)
416416

417417
# change attn class
418-
def change_attn(unet):
419-
for layer in unet.children():
420-
if type(layer) == CrossAttention:
421-
bound_method = set_use_memory_efficient_attention_xformers.__get__(layer, layer.__class__)
422-
setattr(layer, 'set_use_memory_efficient_attention_xformers', bound_method)
423-
else:
424-
change_attn(layer)
425-
426-
change_attn(self.unet)
427-
self.unet.set_attn_processor(CustomDiffusionAttnProcessor())
428418
self.modifier_token = modifier_token
429419
self.modifier_token_id = modifier_token_id
430420

src/diffusers_training.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -645,8 +645,7 @@ def main(args):
645645
class_images_dir.mkdir(parents=True, exist_ok=True)
646646
if args.real_prior:
647647
if accelerator.is_main_process:
648-
name = '_'.join(concept['class_prompt'].split())
649-
if not Path(os.path.join(class_images_dir, name)).exists() or len(list(Path(os.path.join(class_images_dir, name)).iterdir())) < args.num_class_images:
648+
if not Path(os.path.join(class_images_dir, 'images')).exists() or len(list(Path(os.path.join(class_images_dir, 'images')).iterdir())) < args.num_class_images:
650649
retrieve.retrieve(concept['class_prompt'], class_images_dir, args.num_class_images)
651650
concept['class_prompt'] = os.path.join(class_images_dir, 'caption.txt')
652651
concept['class_data_dir'] = os.path.join(class_images_dir, 'images.txt')
@@ -674,7 +673,7 @@ def main(args):
674673
num_new_images = args.num_class_images - cur_class_images
675674
logger.info(f"Number of class images to sample: {num_new_images}.")
676675

677-
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
676+
sample_dataset = PromptDataset(concept['class_prompt'], num_new_images)
678677
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
679678

680679
sample_dataloader = accelerator.prepare(sample_dataloader)
@@ -741,7 +740,6 @@ def main(args):
741740
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
742741
)
743742

744-
# We only train the additional adapter LoRA layers
745743
vae.requires_grad_(False)
746744
if not args.train_text_encoder and args.modifier_token is None:
747745
text_encoder.requires_grad_(False)
@@ -1032,12 +1030,13 @@ def main(args):
10321030
args.pretrained_model_name_or_path,
10331031
unet=accelerator.unwrap_model(unet),
10341032
text_encoder=accelerator.unwrap_model(text_encoder),
1033+
tokenizer=tokenizer,
10351034
revision=args.revision,
10361035
modifier_token=args.modifier_token,
10371036
modifier_token_id=modifier_token_id,
10381037
)
10391038
save_path = os.path.join(args.output_dir, f"delta-{global_step}.bin")
1040-
pipeline.save_pretrained(save_path)
1039+
pipeline.save_pretrained(save_path, freeze_model=args.freeze_model)
10411040

10421041
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
10431042
progress_bar.set_postfix(**logs)
@@ -1061,7 +1060,7 @@ def main(args):
10611060
modifier_token_id=modifier_token_id,
10621061
)
10631062
save_path = os.path.join(args.output_dir, f"delta.bin")
1064-
pipeline.save_pretrained(save_path)
1063+
pipeline.save_pretrained(save_path, freeze_model=args.freeze_model)
10651064
if args.validation_prompt is not None:
10661065
logger.info(
10671066
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"

0 commit comments

Comments
 (0)