Code Appendix of Rethinking Diffusion Posterior Sampling: From Conditional Score Estimator to Maximizing a Posterior
- Python==3.10 pytorch==2.1.0, diffusers==0.30.0, transformers==4.37.2
- We provide 5 example images in ./example_imgs
- Network access is needed, the scripts will automatically download models from huggingface
- Two repo for blurring operator
git clone https://github.com/VinAIResearch/blur-kernel-space-exploring bkse git clone https://github.com/LeviBorodenko/motionblur motionblur
- Run DPS for SRx8
python -u main_sd2.py --data ./example_imgs --out ./example_output --mode dps --step 500 --operator srx8 --scale 4.8
- Run DPS for Gaussian deblur
python -u main_sd2.py --data ./example_imgs --out ./example_output --mode dps --step 500 --operator gdb --scale 0.6
- Run DPS for Nonlinear deblur
python -u main_sd2.py --data ./example_imgs --out ./example_output --mode dps --step 500 --operator nlb --scale 0.6
- Run PSLD for SRx8
- go pipe.py def run_stsl and tune the value for eta, for now eta=0.1
norm = dist + 0.1 * inpaint_error
python -u main_sd2.py --data ./example_imgs --out ./example_output --mode psld --step 500 --operator srx8 --scale 4.8
- Run DSG for SRx8
python -u main_sd2.py --data ./example_imgs --out ./example_output --mode dsg --step 500 --operator srx8 --scale 0.08
- Run DMAP for SRx8
- go pipe.py def run_stsl and tune the value for K, for now K=2
norm = dist + 0.1 * inpaint_error
python -u main_sd2.py --data ./example_imgs --out ./example_output --mode dmap --step 250 --operator srx8 --scale 9.6
- Train a controlnet for SRx8
accelerate launch train_controlnet.py \ --pretrained_model_name_or_path "stabilityai/stable-diffusion-2-base" \ --operator srx8 \ --output_dir ./model_out \ --train_data ./train_data \ --val_data ./val_data \ --conditioning_image_column=conditioning_image \ --image_column=image \ --caption_column=text \ --resolution=512 \ --learning_rate=1e-5 \ --train_batch_size=32 \ --gradient_accumulation_steps=2 \ --num_train_epochs=10000 \ --tracker_project_name="controlnet" \ --enable_xformers_memory_efficient_attention \ --checkpointing_steps=500 \ --validation_steps=500
- Run DPS for SRx8, with a pre-trained controlnet
python -u main_sd2cn.py --data ./example_imgs --out ./example_output --cnmodel /controlnet --step 500 --scale 4.8 --mode dps