Skip to content

Commit 3c10f2b

Browse files
committed
add optional arg and copyright header
1 parent a2d4d9e commit 3c10f2b

File tree

5 files changed

+65
-10
lines changed

5 files changed

+65
-10
lines changed

README.md

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
**Figure 1. Accelerating Diffusion Transformer inference across multiple modalities with 50 DDIM Steps on DiT-XL-256x256, 100 DPM-Solver++(3M) SDE steps for a 10s audio sample (spectrogram shown) on Stable Audio Open, 30 Rectified Flow steps on Open-Sora 480p 2s videos**
1313

1414

15+
# Updates
16+
SmoothCache now supports generating cache schedues using a zero-intrusion external helper. See [run_calibration.py](./examples/run_calibration.py) to find out how it generates a schedule compatible with [HuggingFace Diffusers DiTPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/dit/pipeline_dit.py), without requiring any changes to Diffusers implementation!
17+
18+
1519
# Introduction
1620
We introduce **SmoothCache**, a straightforward acceleration technique for DiT architecture models, that's both **training-free, flexible and performant**. By leveraging layer-wise representation error, our method identifies redundancies in the diffusion process, generates a static caching scheme to reuse output featuremaps and therefore reduces the need for computationally expensive operations. This solution works across different models and modalities, can be easily dropped into existing Diffusion Transformer pipelines, can be stacked on different solvers, and requires no additional training or datasets. **SmoothCache** consistently outperforms various solvers designed to accelerate the diffusion process, while matching or surpassing the performance of existing modality-specific caching techniques.
1721

@@ -26,7 +30,7 @@ We introduce **SmoothCache**, a straightforward acceleration technique for DiT a
2630
pip install SmoothCache
2731
```
2832

29-
### Usage
33+
### Usage - Inference
3034

3135
Inspired by [DeepCache](https://raw.githubusercontent.com/horseee/DeepCache), we have implemented drop-in SmoothCache helper classes that easily applies to [Huggingface Diffuser DiTPipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/dit), and [original DiT implementations](https://github.com/facebookresearch/DiT).
3236

@@ -138,6 +142,15 @@ cache_helper.disable()
138142
save_image(samples, "sample.png", nrow=4, normalize=True, value_range=(-1, 1))
139143
```
140144

145+
### Usage - Cache Schedule Generation
146+
See [run_calibration.py](./examples/run_calibration.py), which generates schedule for the self-attention module ([attn1](https://github.com/huggingface/diffusers/blob/37a5f1b3b69ed284086fb31fb1b49668cba6c365/src/diffusers/models/attention.py#L380))
147+
from Diffusers [BasicTransformerBlock](https://github.com/huggingface/diffusers/blob/37a5f1b3b69ed284086fb31fb1b49668cba6c365/src/diffusers/models/attention.py#L261C7-L261C28) block.
148+
149+
Note that only self-attention, and not cross-attention, is enabled in the stock config of Diffusers [DiT module](https://github.com/huggingface/diffusers/blob/37a5f1b3b69ed284086fb31fb1b49668cba6c365/src/diffusers/models/transformers/dit_transformer_2d.py#L72-L73). We leave this behavior
150+
as-is for the purpose of minimal intrusion.
151+
152+
We welcome all contributions aimed at expending SmoothCache's model coverage and module coverage.
153+
141154
## Visualization
142155

143156
(WIP)

SmoothCache/calibration/calibration_helper.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,17 @@
1-
# calibration_helper.py
1+
# Copyright 2022 Roblox Corporation
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
215
import json
316
import re
417
import statistics

SmoothCache/calibration/diffuser_calibration_helper.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
1-
# SmoothCache/calibration/diffuser_calibration_helper.py
1+
# Copyright 2022 Roblox Corporation
22

3-
from typing import List, Union, Type
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List, Optional
416
import torch.nn as nn
517
from .calibration_helper import CalibrationHelper
618

@@ -16,7 +28,8 @@ def __init__(
1628
calibration_lookahead: int = 3,
1729
calibration_threshold: float = 0.0,
1830
schedule_length: int = 50,
19-
log_file: str = "calibration_schedule.json"
31+
log_file: str = "calibration_schedule.json",
32+
components_to_wrap: Optional[List[str]] = None
2033
):
2134
"""
2235
Diffuser-specific CalibrationHelper derived from CalibrationHelper.
@@ -25,18 +38,20 @@ def __init__(
2538
model (nn.Module): The model to wrap (e.g., pipe.transformer).
2639
calibration_lookahead (int): Steps to look back for error calculation.
2740
calibration_threshold (float): Cutoff L1 error value to enable caching.
28-
schedule_length (int): Length of the generated schedule, 1:1 mapped to pipeline timesteps
41+
schedule_length (int): Length of the generated schedule, 1:1 mapped to pipeline timesteps.
2942
log_file (str): Path to save the generated schedule JSON.
30-
43+
components_to_wrap (List[str], optional): List of component names to wrap.
44+
Defaults to ['attn1'].
45+
3146
Raises:
3247
ImportError: If diffusers' BasicTransformerBlock is unavailable.
3348
"""
3449
if BasicTransformerBlock is None:
3550
raise ImportError("Diffusers library not installed or BasicTransformerBlock not found.")
3651

3752
block_classes = [BasicTransformerBlock]
38-
components_to_wrap = ['attn1'] # Wrap 'attn1' component
39-
53+
if components_to_wrap is None:
54+
components_to_wrap = ['attn1']
4055
super().__init__(
4156
model=model,
4257
block_classes=block_classes,

SmoothCache/diffuser_cache_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Helper Class for Diffusion Transformer Implemented at
1616
https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/dit"""
1717

18+
from typing import List, Optional
1819
from .smooth_cache_helper import SmoothCacheHelper
1920

2021
try:

examples/run_calibration.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,17 @@
1-
# example_calibration_run.py
1+
# Copyright 2022 Roblox Corporation
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
215
import torch
316
from diffusers import DiTPipeline, DPMSolverMultistepScheduler
417
from SmoothCache import DiffuserCalibrationHelper

0 commit comments

Comments
 (0)