Skip to content

Commit 3d81fbd

Browse files
SHSH
authored andcommitted
Update the code(now can be used with single node single gpu, no mix precision)
1 parent bada230 commit 3d81fbd

File tree

3 files changed

+42
-6
lines changed

3 files changed

+42
-6
lines changed

src/diffusers_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@
246246
from src.diffusers_data_pipeline import CustomDiffusionDataset, PromptDataset, collate_fn
247247
from src import retrieve
248248

249-
check_min_version("0.21.4")
249+
# check_min_version("0.21.4")
250250

251251
logger = get_logger(__name__)
252252

src/diffusers_training_sdxl.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# Apache License
1313
# Version 2.0, January 2004
1414
# http://www.apache.org/licenses/
15-
1615
# TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1716

1817
# 1. Definitions.
@@ -220,6 +219,7 @@
220219
import math
221220
import os
222221
import shutil
222+
import json
223223
import warnings
224224
from pathlib import Path
225225

@@ -259,7 +259,7 @@
259259
from src import retrieve
260260

261261
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
262-
check_min_version("0.21.4")
262+
# check_min_version("0.21.4")
263263

264264
logger = get_logger(__name__)
265265

@@ -269,13 +269,13 @@ def create_custom_diffusion(unet, freeze_model):
269269
if freeze_model == 'crossattn':
270270
if 'attn2' in name:
271271
params.requires_grad = True
272-
print(name)
272+
# print(name)
273273
else:
274274
params.requires_grad = False
275275
elif freeze_model == "crossattn_kv":
276276
if 'attn2.to_k' in name or 'attn2.to_v' in name:
277277
params.requires_grad = True
278-
print(name)
278+
# print(name)
279279
else:
280280
params.requires_grad = False
281281
else:
@@ -830,7 +830,6 @@ def main(args):
830830
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
831831
):
832832
images = pipeline(example["prompt"]).images
833-
834833
for i, image in enumerate(images):
835834
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
836835
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"

train.sh

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
## launch training script (2 GPUs recommended, increase --max_train_steps to 500 if 1 GPU)
2+
export CUDA_VISIBLE_DEVICES=0
3+
export MODEL_NAME="/data/home/chensh/data/huggingface_model/stable-diffusion-xl-base-1.0"
4+
5+
export INSTANCE_DIR="./data/cat"
6+
export INSTANCE_PROMPT="photo of a <new1> cat"
7+
export CLASS_DIR="./sample_reg/samples_cat/"
8+
export CLASS_PROMPT="cat"
9+
export OUTPUT_DIR="./logs/cat"
10+
export modifier_token="<new1>"
11+
12+
#export INSTANCE_DIR="./data/wooden_pot"
13+
#export INSTANCE_PROMPT="photo of a <new2> wooden pot"
14+
#export CLASS_DIR="./data/prior_woodenpot/"
15+
#export CLASS_PROMPT="wooden pot"
16+
#export OUTPUT_DIR="./logs/wooden_pot"
17+
#export modifier_token="<new2>"
18+
19+
accelerate launch src/diffusers_training_sdxl.py \
20+
--pretrained_model_name_or_path=$MODEL_NAME \
21+
--instance_data_dir=$INSTANCE_DIR \
22+
--class_data_dir=$CLASS_DIR \
23+
--output_dir=$OUTPUT_DIR \
24+
--with_prior_preservation --prior_loss_weight=1.0 \
25+
--instance_prompt="${INSTANCE_PROMPT}" \
26+
--class_prompt="${CLASS_PROMPT}" \
27+
--resolution=1024 \
28+
--train_batch_size=1 \
29+
--learning_rate=1e-5 \
30+
--lr_warmup_steps=0 \
31+
--max_train_steps=1000 \
32+
--num_class_images=200 \
33+
--scale_lr --hflip \
34+
--modifier_token="${modifier_token}"
35+
36+
### sample
37+
#python src/diffusers_sample.py --delta_ckpt logs/cat/delta.bin --ckpt "CompVis/stable-diffusion-v1-4" --prompt "<new1> cat playing with a ball"

0 commit comments

Comments
 (0)