Skip to content

Commit 5d680b3

Browse files
committed
update readme
1 parent bc24274 commit 5d680b3

File tree

1 file changed

+155
-2
lines changed

1 file changed

+155
-2
lines changed

README.md

Lines changed: 155 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,158 @@
1-
# SmoothCache
2-
Implementation of SmoothCache, a project aimed at speeding-up Diffusion Transformer (DiT) based GenAI models with error-guided caching.
1+
<!-- <div align="center">
2+
<img src="https://github.com/Roblox/SmoothCache/blob/main/assets/TeaserFigureFlat.png" width="100%" ></img>
3+
<br>
4+
<em>
5+
(Accelerating Diffusion Transformer inference across multiple modalities with 50 DDIM Steps on DiT-XL-256x256, 100 DPM-Solver++(3M) SDE steps for a 10s audio sample (spectrogram shown) on Stable Audio Open, 30 Rectified Flow steps on Open-Sora 480p 2s videos)
6+
</em>
7+
</div>
8+
<br> -->
9+
10+
![Figure 1. Accelerating Diffusion Transformer inference across multiple modalities with 50 DDIM Steps on DiT-XL-256x256, 100 DPM-Solver++(3M) SDE steps for a 10s audio sample (spectrogram shown) on Stable Audio Open, 30 Rectified Flow steps on Open-Sora 480p 2s videos](assets/TeaserFigureFlat.png)
11+
12+
**Figure 1. Accelerating Diffusion Transformer inference across multiple modalities with 50 DDIM Steps on DiT-XL-256x256, 100 DPM-Solver++(3M) SDE steps for a 10s audio sample (spectrogram shown) on Stable Audio Open, 30 Rectified Flow steps on Open-Sora 480p 2s videos**
13+
14+
15+
# Introduction
16+
We introduce **SmoothCache**, a straightforward acceleration technique for DiT architecture models, that's both **training-free, flexible and performant**. By leveraging layer-wise representation error, our method identifies redundancies in the diffusion process, generates a static caching scheme to reuse output featuremaps and therefore reduces the need for computationally expensive operations. This solution works across different models and modalities, can be easily dropped into existing Diffusion Transformer pipelines, can be stacked on different solvers, and requires no additional training or datasets. **SmoothCache** consistently outperforms various solvers designed to accelerate the diffusion process, while matching or surpassing the performance of existing modality-specific caching techniques.
17+
18+
19+
## Quick Start
20+
21+
### Install
22+
```bash
23+
pip install SmoothCache
24+
```
25+
26+
### Usage
27+
28+
We have implemented drop-in SmoothCache helper classes that easily applies to [Huggingface Diffuser DiTPipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/dit), and [original DiT implementations](https://github.com/facebookresearch/DiT).
29+
30+
Generally, only 3 additional lines needs to be added to the original sampler scripts:
31+
```python
32+
from SmoothCache import <DESIREDCacheHelper>
33+
cache_helper = DiffuserCacheHelper(<MODEL_HANDLER>, schedule=schedule)
34+
cache_helper.enable()
35+
# Original sampler code.
36+
cache_helper.eisable()
37+
```
38+
39+
Usage example with Huggingface Diffuser DiTPipeline:
40+
```python
41+
import json
42+
import torch
43+
from diffusers import DiTPipeline, DPMSolverMultistepScheduler
44+
45+
# Import SmoothCacheHelper
46+
from SmoothCache import DiffuserCacheHelper
47+
48+
# Load the DiT pipeline and scheduler
49+
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=torch.float16)
50+
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
51+
pipe = pipe.to("cuda")
52+
53+
# Initialize the DiffuserCacheHelper with the model
54+
with open("smoothcache_schedules/50-N-3-threshold-0.35.json", "r") as f:
55+
schedule = json.load(f)
56+
cache_helper = DiffuserCacheHelper(pipe.transformer, schedule=schedule)
57+
58+
# Enable the caching helper
59+
cache_helper.enable()
60+
# Prepare the input
61+
words = ["Labrador retriever"]
62+
class_ids = pipe.get_label_ids(words)
63+
64+
# Generate images with the pipeline
65+
generator = torch.manual_seed(33)
66+
image = pipe(class_labels=class_ids, num_inference_steps=50, generator=generator).images[0]
67+
68+
# Restore the original forward method and disable the helper
69+
# disable() should be paired up with enable()
70+
cache_helper.disable()
71+
```
72+
73+
Usage example with original DiT implementation
74+
```python
75+
import torch
76+
77+
torch.backends.cuda.matmul.allow_tf32 = True
78+
torch.backends.cudnn.allow_tf32 = True
79+
from torchvision.utils import save_image
80+
from diffusion import create_diffusion
81+
from diffusers.models import AutoencoderKL
82+
from download import find_model
83+
from models import DiT_models
84+
import argparse
85+
from SmoothCache import DiTCacheHelper # Import DiTCacheHelper
86+
import json
87+
88+
# Setup PyTorch:
89+
torch.manual_seed(args.seed)
90+
torch.set_grad_enabled(False)
91+
device = "cuda" if torch.cuda.is_available() else "cpu"
92+
93+
if args.ckpt is None:
94+
assert (
95+
args.model == "DiT-XL/2"
96+
), "Only DiT-XL/2 models are available for auto-download."
97+
assert args.image_size in [256, 512]
98+
assert args.num_classes == 1000
99+
100+
# Load model:
101+
latent_size = args.image_size // 8
102+
model = DiT_models[args.model](
103+
input_size=latent_size, num_classes=args.num_classes
104+
).to(device)
105+
ckpt_path = args.ckpt or f"DiT-XL-2-{args.image_size}x{args.image_size}.pt"
106+
state_dict = find_model(ckpt_path)
107+
model.load_state_dict(state_dict)
108+
model.eval() # important!
109+
with open("smoothcache_schedules/50-N-3-threshold-0.35.json", "r") as f:
110+
schedule = json.load(f)
111+
cache_helper = DiTCacheHelper(model, schedule=schedule)
112+
113+
# number of timesteps should be consistent with provided schedules
114+
diffusion = create_diffusion(str(len(schedule[cache_helper.components_to_wrap[0]])))
115+
116+
# Enable the caching helper
117+
cache_helper.enable()
118+
119+
# Sample images:
120+
samples = diffusion.p_sample_loop(
121+
model.forward_with_cfg,
122+
z.shape,
123+
z,
124+
clip_denoised=False,
125+
model_kwargs=model_kwargs,
126+
progress=True,
127+
device=device,
128+
)
129+
samples, _ = samples.chunk(2, dim=0) # Remove null class samples
130+
samples = vae.decode(samples / 0.18215).sample
131+
132+
# Disable the caching helper after sampling
133+
cache_helper.disable()
134+
# Save and display images:
135+
save_image(samples, "sample.png", nrow=4, normalize=True, value_range=(-1, 1))
136+
```
137+
138+
## Visualization
139+
140+
(WIP)
141+
142+
143+
144+
## Quantitative Results
145+
146+
### Image Generation with DiT-XL/2-256x256
147+
![Table 1. Results For DiT-XL-256x256 on using DDIM Sampling.
148+
Note that L2C is not training free](assets/table1.png)
149+
150+
### Video Generation with OpenSora
151+
![Table 2. Results For OpenSora on Rectified Flow](assets/table2.png)
152+
153+
### Audio Generation with Stable Audio Open
154+
![Table 3. Results For Stable Audio Open on DPMSolver++(3M) SDE on 3 datasets](assets/table3.png)
155+
3156

4157
# License
5158
SmoothCache is licensed under the [Apache-2.0](LICENSE) license.

0 commit comments

Comments
 (0)