Skip to content

Commit f1a6c9f

Browse files
committed
update readme
1 parent bc24274 commit f1a6c9f

15 files changed

+189
-12
lines changed

README.md

Lines changed: 156 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,159 @@
1-
# SmoothCache
2-
Implementation of SmoothCache, a project aimed at speeding-up Diffusion Transformer (DiT) based GenAI models with error-guided caching.
1+
<!-- <div align="center">
2+
<img src="https://github.com/Roblox/SmoothCache/blob/main/assets/TeaserFigureFlat.png" width="100%" ></img>
3+
<br>
4+
<em>
5+
(Accelerating Diffusion Transformer inference across multiple modalities with 50 DDIM Steps on DiT-XL-256x256, 100 DPM-Solver++(3M) SDE steps for a 10s audio sample (spectrogram shown) on Stable Audio Open, 30 Rectified Flow steps on Open-Sora 480p 2s videos)
6+
</em>
7+
</div>
8+
<br> -->
9+
10+
![Accelerating Diffusion Transformer inference across multiple modalities with 50 DDIM Steps on DiT-XL-256x256, 100 DPM-Solver++(3M) SDE steps for a 10s audio sample (spectrogram shown) on Stable Audio Open, 30 Rectified Flow steps on Open-Sora 480p 2s videos](assets/TeaserFigureFlat.png)
11+
12+
**Figure 1. Accelerating Diffusion Transformer inference across multiple modalities with 50 DDIM Steps on DiT-XL-256x256, 100 DPM-Solver++(3M) SDE steps for a 10s audio sample (spectrogram shown) on Stable Audio Open, 30 Rectified Flow steps on Open-Sora 480p 2s videos**
13+
14+
15+
# Introduction
16+
We introduce **SmoothCache**, a straightforward acceleration technique for DiT architecture models, that's both **training-free, flexible and performant**. By leveraging layer-wise representation error, our method identifies redundancies in the diffusion process, generates a static caching scheme to reuse output featuremaps and therefore reduces the need for computationally expensive operations. This solution works across different models and modalities, can be easily dropped into existing Diffusion Transformer pipelines, can be stacked on different solvers, and requires no additional training or datasets. **SmoothCache** consistently outperforms various solvers designed to accelerate the diffusion process, while matching or surpassing the performance of existing modality-specific caching techniques.
17+
18+
![Illustration of SmoothCache. When the layer representation loss obtained from the calibration pass is below some threshold α, the corresponding layer is cached and used in place of the same computation on a future timestep. The figure on the left shows how the layer representation error impacts whether certain layers are eligible for caching. The error of the attention (attn) layer is higher in earlier timesteps, so our schedule caches the later timesteps accordingly. The figure on the right shows the application of the caching schedule to the DiT-XL architecture. The output of the attn layer at time t − 1 is cached and re-used in place of computing FFN t − 2, since the corresponding error is below α. This cached output is introduced in the model using the properties of the residual connection.](assets/SmoothCache2.png)
19+
20+
## Quick Start
21+
22+
### Install
23+
```bash
24+
pip install SmoothCache
25+
```
26+
27+
### Usage
28+
29+
We have implemented drop-in SmoothCache helper classes that easily applies to [Huggingface Diffuser DiTPipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/dit), and [original DiT implementations](https://github.com/facebookresearch/DiT).
30+
31+
Generally, only 3 additional lines needs to be added to the original sampler scripts:
32+
```python
33+
from SmoothCache import <DESIREDCacheHelper>
34+
cache_helper = DiffuserCacheHelper(<MODEL_HANDLER>, schedule=schedule)
35+
cache_helper.enable()
36+
# Original sampler code.
37+
cache_helper.eisable()
38+
```
39+
40+
#### Usage example with Huggingface Diffuser DiTPipeline:
41+
```python
42+
import json
43+
import torch
44+
from diffusers import DiTPipeline, DPMSolverMultistepScheduler
45+
46+
# Import SmoothCacheHelper
47+
from SmoothCache import DiffuserCacheHelper
48+
49+
# Load the DiT pipeline and scheduler
50+
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=torch.float16)
51+
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
52+
pipe = pipe.to("cuda")
53+
54+
# Initialize the DiffuserCacheHelper with the model
55+
with open("smoothcache_schedules/50-N-3-threshold-0.35.json", "r") as f:
56+
schedule = json.load(f)
57+
cache_helper = DiffuserCacheHelper(pipe.transformer, schedule=schedule)
58+
59+
# Enable the caching helper
60+
cache_helper.enable()
61+
# Prepare the input
62+
words = ["Labrador retriever"]
63+
class_ids = pipe.get_label_ids(words)
64+
65+
# Generate images with the pipeline
66+
generator = torch.manual_seed(33)
67+
image = pipe(class_labels=class_ids, num_inference_steps=50, generator=generator).images[0]
68+
69+
# Restore the original forward method and disable the helper
70+
# disable() should be paired up with enable()
71+
cache_helper.disable()
72+
```
73+
74+
#### Usage example with original DiT implementation
75+
```python
76+
import torch
77+
78+
torch.backends.cuda.matmul.allow_tf32 = True
79+
torch.backends.cudnn.allow_tf32 = True
80+
from torchvision.utils import save_image
81+
from diffusion import create_diffusion
82+
from diffusers.models import AutoencoderKL
83+
from download import find_model
84+
from models import DiT_models
85+
import argparse
86+
from SmoothCache import DiTCacheHelper # Import DiTCacheHelper
87+
import json
88+
89+
# Setup PyTorch:
90+
torch.manual_seed(args.seed)
91+
torch.set_grad_enabled(False)
92+
device = "cuda" if torch.cuda.is_available() else "cpu"
93+
94+
if args.ckpt is None:
95+
assert (
96+
args.model == "DiT-XL/2"
97+
), "Only DiT-XL/2 models are available for auto-download."
98+
assert args.image_size in [256, 512]
99+
assert args.num_classes == 1000
100+
101+
# Load model:
102+
latent_size = args.image_size // 8
103+
model = DiT_models[args.model](
104+
input_size=latent_size, num_classes=args.num_classes
105+
).to(device)
106+
ckpt_path = args.ckpt or f"DiT-XL-2-{args.image_size}x{args.image_size}.pt"
107+
state_dict = find_model(ckpt_path)
108+
model.load_state_dict(state_dict)
109+
model.eval() # important!
110+
with open("smoothcache_schedules/50-N-3-threshold-0.35.json", "r") as f:
111+
schedule = json.load(f)
112+
cache_helper = DiTCacheHelper(model, schedule=schedule)
113+
114+
# number of timesteps should be consistent with provided schedules
115+
diffusion = create_diffusion(str(len(schedule[cache_helper.components_to_wrap[0]])))
116+
117+
# Enable the caching helper
118+
cache_helper.enable()
119+
120+
# Sample images:
121+
samples = diffusion.p_sample_loop(
122+
model.forward_with_cfg,
123+
z.shape,
124+
z,
125+
clip_denoised=False,
126+
model_kwargs=model_kwargs,
127+
progress=True,
128+
device=device,
129+
)
130+
samples, _ = samples.chunk(2, dim=0) # Remove null class samples
131+
samples = vae.decode(samples / 0.18215).sample
132+
133+
# Disable the caching helper after sampling
134+
cache_helper.disable()
135+
# Save and display images:
136+
save_image(samples, "sample.png", nrow=4, normalize=True, value_range=(-1, 1))
137+
```
138+
139+
## Visualization
140+
141+
(WIP)
142+
143+
144+
145+
## Evaluation
146+
147+
### Image Generation with DiT-XL/2-256x256
148+
![Table 1. Results For DiT-XL-256x256 on using DDIM Sampling.
149+
Note that L2C is not training free](assets/table1.png)
150+
151+
### Video Generation with OpenSora
152+
![Table 2. Results For OpenSora on Rectified Flow](assets/table2.png)
153+
154+
### Audio Generation with Stable Audio Open
155+
![Table 3. Results For Stable Audio Open on DPMSolver++(3M) SDE on 3 datasets](assets/table3.png)
156+
3157

4158
# License
5159
SmoothCache is licensed under the [Apache-2.0](LICENSE) license.

SmoothCache/smooth_cache_helper.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def reset_state(self):
6363
def is_skip_step(self, full_name):
6464
# Extract component name and block index from full_name
6565
names = full_name.split('.')
66-
component_name = names[-1] # e.g., 'attn' or 'mlp'
66+
component_name = names[-1] # e.g., 'attn' or 'mlp', etc.
6767
block_index = names[-2] # e.g., '0', '1', '2', etc.
6868
schedule_key_with_index = f"{component_name}-{block_index}"
6969
schedule_key_without_index = component_name
@@ -76,34 +76,27 @@ def is_skip_step(self, full_name):
7676
# Use the general schedule for the component
7777
schedule_key = schedule_key_without_index
7878
else:
79-
# If neither key is in the schedule, do not skip
8079
return False
8180

82-
# Get the current timestep for this module
83-
current_step = self.current_steps.get(full_name, 0) - 1 # Adjust index to start from 0
84-
85-
# Retrieve the schedule list for the selected key
81+
# Get the current timestep for this module by # Adjust index to start from 0
82+
current_step = self.current_steps.get(full_name, 0) - 1
8683
schedule_list = self.schedule[schedule_key]
8784

8885
if current_step < 0 or current_step >= len(schedule_list):
89-
# If current_step is out of bounds, do not skip
9086
return False
9187

9288
# 1 means run normally, 0 means use cached result (skip computation)
9389
skip = schedule_list[current_step] == 0
9490

9591
return skip
9692

97-
98-
9993
def wrap_components(self):
10094
# Wrap specified components within each block class
10195
for block_name, block in self.model.named_modules():
10296
if any(isinstance(block, cls) for cls in self.block_classes):
10397
self.wrap_block_components(block, block_name)
10498

10599
def wrap_block_components(self, block, block_name):
106-
#TODO: verify block exists
107100
if len(self.components_to_wrap) > 0:
108101
for comp_name in self.components_to_wrap:
109102
if hasattr(block, comp_name):

assets/SmoothCache2.png

Lines changed: 3 additions & 0 deletions
Loading

assets/TeaserFigureFlat.png

Lines changed: 3 additions & 0 deletions
Loading

assets/table1.png

Lines changed: 3 additions & 0 deletions
Loading

assets/table2.png

Lines changed: 3 additions & 0 deletions
Loading

assets/table3.png

Lines changed: 3 additions & 0 deletions
Loading
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
{"attn": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1],
2+
"mlp": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1]}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
{"attn": [1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1],
2+
"mlp": [1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1]}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
{"mlp": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1],
2+
"attn": [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1]}

0 commit comments

Comments
 (0)