Skip to content

Commit c656d25

Browse files
authored
feat: Qwen-image/FLUX 4bits w/ nunchaku + cache (#285)
* feat: add quantize examples * feat: add quantize examples * feat: add quantize examples * feat: add quantize examples * feat: add quantize examples * feat: add quantize examples * feat: add quantize examples * feat: add quantize examples * feat: add quantize examples * feat: add quantize examples * feat: add quantize examples * feat: add quantize examples * feat: add quantize examples * feat: add quantize examples * feat: add quantize examples * feat: add quantize examples * feat: add quantize examples * feat: add quantize examples * feat: add quantize examples
1 parent adf46ae commit c656d25

File tree

11 files changed

+442
-36
lines changed

11 files changed

+442
-36
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,14 +147,15 @@ You can install the stable release of cache-dit from PyPI, or the latest develop
147147
- **[🎉Easy New Model Integration](./docs/User_Guide.md#automatic-block-adapter)**: Features like **Unified Cache APIs**, **Forward Pattern Matching**, **Automatic Block Adapter**, **Hybrid Forward Pattern**, and **Patch Functor** make it highly functional and flexible. For example, we achieved 🎉 Day 1 support for [HunyuanImage-2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1) with 1.7x speedup w/o precision loss—even before it was available in the Diffusers library.
148148
- **[🎉State-of-the-Art Performance](./bench/)**: Compared with algorithms including Δ-DiT, Chipmunk, FORA, DuCa, TaylorSeer and FoCa, cache-dit achieved the **SOTA** performance w/ **7.4x↑🎉** speedup on ClipScore!
149149
- **[🎉Support for 4/8-Steps Distilled Models](./bench/)**: Surprisingly, cache-dit's **DBCache** works for extremely few-step distilled models—something many other methods fail to do.
150-
- **[🎉Compatibility with Other Optimizations](./docs/User_Guide.md#️torch-compile)**: Designed to work seamlessly with torch.compile, model CPU offload, sequential CPU offload, group offloading, etc.
150+
- **[🎉Compatibility with Other Optimizations](./docs/User_Guide.md#️torch-compile)**: Designed to work seamlessly with torch.compile, model CPU offload, sequential CPU offload, group offloading, Quantization(**[torchao](./examples/quantize/)**, **[🔥nunchaku](./examples/quantize/)**), etc.
151151
- **[🎉Hybrid Cache Acceleration](./docs/User_Guide.md#taylorseer-calibrator)**: Now supports hybrid **Block-wise Cache + Calibrator** schemes (e.g., DBCache or DBPrune + TaylorSeerCalibrator). DBCache or DBPrune acts as the **Indicator** to decide *when* to cache, while the Calibrator decides *how* to cache. More mainstream cache acceleration algorithms (e.g., FoCa) will be supported in the future, along with additional benchmarks—stay tuned for updates!
152152
- **[🤗Diffusers Ecosystem Integration](https://huggingface.co/docs/diffusers/main/en/optimization/cache_dit)**: 🔥**cache-dit** has joined the Diffusers community ecosystem as the **first** DiT-specific cache acceleration framework! Check out the documentation here: <a href="https://huggingface.co/docs/diffusers/main/en/optimization/cache_dit"><img src=https://img.shields.io/badge/🤗Diffusers-ecosystem-yellow.svg ></a>
153153

154154
![](https://github.com/vipshop/cache-dit/raw/main/assets/clip-score-bench.png)
155155

156156
## 🔥Important News
157157

158+
- 2025.10.15: 🎉cache-dit now supported [**🔥nunchaku**](https://github.com/nunchaku-tech/nunchaku): Qwen-Image/FLUX.1 [4-bits examples](./examples/quantize/)
158159
- 2025.10.13: 🎉cache-dit achieved the **SOTA** performance w/ **7.4x↑🎉** speedup on ClipScore!
159160
- 2025.10.10: 🔥[**Qwen-Image-ControlNet-Inpainting**](https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting) **2.3x↑🎉** speedup! Check the [example](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_qwen_image_controlnet_inpaint.py).
160161
- 2025.09.26: 🔥[**Qwen-Image-Edit-Plus(2509)**](https://github.com/QwenLM/Qwen-Image) **2.1x↑🎉** speedup! Please check the [example](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_qwen_image_edit_plus.py).

examples/pipeline/run_hunyuan_image_2.1.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
import sys
3-
import gc
43

54
sys.path.append("..")
65
sys.path.append(os.environ.get("HYIMAGE_PKG_DIR", "."))
@@ -67,10 +66,6 @@
6766
pipe.text_encoder,
6867
quant_type=args.quantize_type,
6968
)
70-
time.sleep(0.5)
71-
torch.cuda.empty_cache()
72-
gc.collect()
73-
7469

7570
pipe.to("cuda")
7671

examples/quantize/run_flux_ao.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import os
2+
import sys
3+
4+
sys.path.append("..")
5+
6+
import time
7+
import torch
8+
from diffusers import FluxPipeline, FluxTransformer2DModel
9+
from utils import get_args, strify, cachify
10+
import cache_dit
11+
12+
13+
args = get_args()
14+
print(args)
15+
16+
17+
pipe: FluxPipeline = FluxPipeline.from_pretrained(
18+
os.environ.get(
19+
"FLUX_DIR",
20+
"black-forest-labs/FLUX.1-dev",
21+
),
22+
torch_dtype=torch.bfloat16,
23+
).to("cuda")
24+
25+
26+
if args.cache:
27+
cachify(args, pipe)
28+
29+
30+
if args.quantize:
31+
assert isinstance(pipe.transformer, FluxTransformer2DModel)
32+
pipe.transformer = cache_dit.quantize(
33+
pipe.transformer,
34+
quant_type=args.quantize_type,
35+
)
36+
37+
38+
def run_pipe(pipe: FluxPipeline):
39+
image = pipe(
40+
"A cat holding a sign that says hello world",
41+
num_inference_steps=28,
42+
generator=torch.Generator("cpu").manual_seed(0),
43+
).images[0]
44+
return image
45+
46+
47+
if args.compile:
48+
assert isinstance(pipe.transformer, FluxTransformer2DModel)
49+
pipe.transformer.compile_repeated_blocks(fullgraph=True)
50+
51+
# warmup
52+
_ = run_pipe(pipe)
53+
54+
55+
start = time.time()
56+
image = run_pipe(pipe)
57+
end = time.time()
58+
59+
cache_dit.summary(pipe)
60+
61+
time_cost = end - start
62+
save_path = f"flux.ao.{strify(args, pipe)}.png"
63+
print(f"Time cost: {time_cost:.2f}s")
64+
print(f"Saving image to {save_path}")
65+
image.save(save_path)
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import os
2+
import sys
3+
4+
sys.path.append("..")
5+
import time
6+
7+
import torch
8+
from diffusers import FluxPipeline, FluxTransformer2DModel
9+
10+
from nunchaku.models.transformers.transformer_flux_v2 import (
11+
NunchakuFluxTransformer2DModelV2,
12+
)
13+
from utils import get_args, strify
14+
import cache_dit
15+
16+
args = get_args()
17+
print(args)
18+
19+
nunchaku_flux_dir = os.environ.get(
20+
"NUNCHAKA_FLUX_DIR",
21+
"nunchaku-tech/nunchaku-flux.1-dev",
22+
)
23+
transformer = NunchakuFluxTransformer2DModelV2.from_pretrained(
24+
f"{nunchaku_flux_dir}/svdq-int4_r32-flux.1-dev.safetensors",
25+
)
26+
pipe: FluxPipeline = FluxPipeline.from_pretrained(
27+
os.environ.get("FLUX_DIR", "black-forest-labs/FLUX.1-dev"),
28+
transformer=transformer,
29+
torch_dtype=torch.bfloat16,
30+
).to("cuda")
31+
32+
33+
if args.cache:
34+
from cache_dit import (
35+
ParamsModifier,
36+
DBCacheConfig,
37+
TaylorSeerCalibratorConfig,
38+
)
39+
40+
cache_dit.enable_cache(
41+
pipe,
42+
cache_config=DBCacheConfig(
43+
Fn_compute_blocks=args.Fn,
44+
Bn_compute_blocks=args.Bn,
45+
max_warmup_steps=args.max_warmup_steps,
46+
max_cached_steps=args.max_cached_steps,
47+
max_continuous_cached_steps=args.max_continuous_cached_steps,
48+
residual_diff_threshold=args.rdt,
49+
),
50+
calibrator_config=(
51+
TaylorSeerCalibratorConfig(
52+
taylorseer_order=args.taylorseer_order,
53+
)
54+
if args.taylorseer
55+
else None
56+
),
57+
params_modifiers=[
58+
ParamsModifier(
59+
# transformer_blocks
60+
cache_config=DBCacheConfig().reset(
61+
residual_diff_threshold=args.rdt
62+
),
63+
),
64+
ParamsModifier(
65+
# single_transformer_blocks
66+
cache_config=DBCacheConfig().reset(
67+
residual_diff_threshold=args.rdt * 3
68+
),
69+
),
70+
],
71+
)
72+
73+
74+
def run_pipe(pipe: FluxPipeline):
75+
image = pipe(
76+
"A cat holding a sign that says hello world",
77+
num_inference_steps=28,
78+
generator=torch.Generator("cpu").manual_seed(0),
79+
).images[0]
80+
return image
81+
82+
83+
if args.compile:
84+
assert isinstance(pipe.transformer, FluxTransformer2DModel)
85+
cache_dit.set_compile_configs()
86+
pipe.transformer = torch.compile(pipe.transformer)
87+
88+
# warmup
89+
_ = run_pipe(pipe)
90+
91+
92+
start = time.time()
93+
image = run_pipe(pipe)
94+
end = time.time()
95+
96+
cache_dit.summary(pipe)
97+
98+
time_cost = end - start
99+
save_path = f"flux.nunchaku.int4.{strify(args, pipe)}.png"
100+
print(f"Time cost: {time_cost:.2f}s")
101+
print(f"Saving image to {save_path}")
102+
image.save(save_path)
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import os
2+
import sys
3+
4+
sys.path.append("..")
5+
6+
import time
7+
import torch
8+
from diffusers.quantizers import PipelineQuantizationConfig
9+
from diffusers import QwenImagePipeline, QwenImageTransformer2DModel
10+
from nunchaku.models.transformers.transformer_qwenimage import (
11+
NunchakuQwenImageTransformer2DModel,
12+
)
13+
14+
from utils import get_args, strify
15+
import cache_dit
16+
17+
18+
args = get_args()
19+
print(args)
20+
21+
nunchaku_qwen_image_dir = os.environ.get(
22+
"NUNCHAKA_QWEN_IMAGE_DIR",
23+
"nunchaku-tech/nunchaku-qwen-image.1-dev",
24+
)
25+
transformer = NunchakuQwenImageTransformer2DModel.from_pretrained(
26+
f"{nunchaku_qwen_image_dir}/svdq-int4_r32-qwen-image.safetensors"
27+
)
28+
29+
# Minimize VRAM required: 20GiB
30+
pipe = QwenImagePipeline.from_pretrained(
31+
os.environ.get(
32+
"QWEN_IMAGE_DIR",
33+
"Qwen/Qwen-Image",
34+
),
35+
transformer=transformer,
36+
torch_dtype=torch.bfloat16,
37+
quantization_config=PipelineQuantizationConfig(
38+
quant_backend="bitsandbytes_4bit",
39+
quant_kwargs={
40+
"load_in_4bit": True,
41+
"bnb_4bit_quant_type": "nf4",
42+
"bnb_4bit_compute_dtype": torch.bfloat16,
43+
},
44+
components_to_quantize=["text_encoder"],
45+
),
46+
).to("cuda")
47+
48+
49+
if args.cache:
50+
from cache_dit import (
51+
DBCacheConfig,
52+
TaylorSeerCalibratorConfig,
53+
)
54+
55+
cache_dit.enable_cache(
56+
pipe,
57+
cache_config=DBCacheConfig(
58+
Fn_compute_blocks=args.Fn,
59+
Bn_compute_blocks=args.Bn,
60+
max_warmup_steps=args.max_warmup_steps,
61+
max_cached_steps=args.max_cached_steps,
62+
max_continuous_cached_steps=args.max_continuous_cached_steps,
63+
residual_diff_threshold=args.rdt,
64+
),
65+
calibrator_config=(
66+
TaylorSeerCalibratorConfig(
67+
taylorseer_order=args.taylorseer_order,
68+
)
69+
if args.taylorseer
70+
else None
71+
),
72+
)
73+
74+
75+
positive_magic = {
76+
"en": ", Ultra HD, 4K, cinematic composition.", # for english prompt
77+
"zh": ", 超清,4K,电影级构图.", # for chinese prompt
78+
}
79+
80+
# Generate image
81+
prompt = """A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition"""
82+
83+
# using an empty string if you do not have specific concept to remove
84+
negative_prompt = " "
85+
86+
87+
# Generate with different aspect ratios
88+
aspect_ratios = {
89+
"1:1": (1328, 1328),
90+
"16:9": (1664, 928),
91+
"9:16": (928, 1664),
92+
"4:3": (1472, 1140),
93+
"3:4": (1140, 1472),
94+
"3:2": (1584, 1056),
95+
"2:3": (1056, 1584),
96+
}
97+
98+
width, height = aspect_ratios["16:9"]
99+
100+
assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
101+
102+
103+
def run_pipe():
104+
# do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
105+
image = pipe(
106+
prompt=prompt + positive_magic["en"],
107+
negative_prompt=negative_prompt,
108+
width=width,
109+
height=height,
110+
num_inference_steps=50,
111+
true_cfg_scale=4.0,
112+
generator=torch.Generator(device="cpu").manual_seed(42),
113+
).images[0]
114+
return image
115+
116+
117+
if args.compile:
118+
cache_dit.set_compile_configs()
119+
pipe.transformer = torch.compile(pipe.transformer)
120+
121+
# warmup
122+
run_pipe()
123+
124+
125+
start = time.time()
126+
image = run_pipe()
127+
end = time.time()
128+
129+
stats = cache_dit.summary(pipe)
130+
131+
time_cost = end - start
132+
save_path = f"qwen-image.nunchaku.{strify(args, stats)}.png"
133+
print(f"Time cost: {time_cost:.2f}s")
134+
print(f"Saving image to {save_path}")
135+
image.save(save_path)

src/cache_dit/cache_factory/block_adapters/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ def flux_adapter(pipe, **kwargs) -> BlockAdapter:
1212
from cache_dit.utils import is_diffusers_at_least_0_3_5
1313

1414
assert isinstance(pipe.transformer, FluxTransformer2DModel)
15-
if is_diffusers_at_least_0_3_5():
15+
transformer_cls_name: str = pipe.transformer.__class__.__name__
16+
if is_diffusers_at_least_0_3_5() and not transformer_cls_name.startswith(
17+
"Nunchaku"
18+
):
1619
return BlockAdapter(
1720
pipe=pipe,
1821
transformer=pipe.transformer,

0 commit comments

Comments
 (0)