Skip to content

Commit 1c00f0f

Browse files
authored
Added flux demo (#3418)
1 parent b63e06c commit 1c00f0f

File tree

17 files changed

+719
-183
lines changed

17 files changed

+719
-183
lines changed

MODULE.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ new_local_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:local.
3737
new_local_repository(
3838
name = "cuda",
3939
build_file = "@//third_party/cuda:BUILD",
40-
path = "/usr/local/cuda-12.8/",
40+
path = "/usr/local/cuda-12.9/",
4141
)
4242

4343
# for Jetson

examples/apps/flux_demo.py

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
import argparse
2+
import os
3+
import re
4+
import sys
5+
import time
6+
7+
import gradio as gr
8+
import modelopt.torch.quantization as mtq
9+
import torch
10+
import torch_tensorrt
11+
from accelerate.hooks import remove_hook_from_module
12+
from diffusers import FluxPipeline
13+
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
14+
15+
# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
16+
sys.path.append(os.path.join(os.path.dirname(__file__), "../dynamo"))
17+
from register_sdpa import *
18+
19+
DEVICE = "cuda:0"
20+
21+
22+
def compile_model(
23+
args,
24+
) -> tuple[
25+
FluxPipeline, FluxTransformer2DModel, torch_tensorrt.MutableTorchTensorRTModule
26+
]:
27+
28+
if args.dtype == "fp8":
29+
enabled_precisions = {torch.float8_e4m3fn, torch.float16}
30+
ptq_config = mtq.FP8_DEFAULT_CFG
31+
32+
elif args.dtype == "int8":
33+
enabled_precisions = {torch.int8, torch.float16}
34+
ptq_config = mtq.INT8_DEFAULT_CFG
35+
ptq_config["quant_cfg"]["*weight_quantizer"]["axis"] = None
36+
37+
elif args.dtype == "fp16":
38+
enabled_precisions = {torch.float16}
39+
40+
print(f"\nUsing {args.dtype}")
41+
42+
pipe = FluxPipeline.from_pretrained(
43+
"black-forest-labs/FLUX.1-dev",
44+
torch_dtype=torch.float16,
45+
).to(torch.float16)
46+
47+
if args.low_vram_mode:
48+
pipe.enable_model_cpu_offload()
49+
else:
50+
pipe.to(DEVICE)
51+
52+
backbone = pipe.transformer
53+
backbone.eval()
54+
55+
def filter_func(name):
56+
pattern = re.compile(
57+
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*"
58+
)
59+
return pattern.match(name) is not None
60+
61+
def do_calibrate(
62+
pipe,
63+
prompt: str,
64+
) -> None:
65+
"""
66+
Run calibration steps on the pipeline using the given prompts.
67+
"""
68+
image = pipe(
69+
prompt,
70+
output_type="pil",
71+
num_inference_steps=20,
72+
generator=torch.Generator("cuda").manual_seed(0),
73+
).images[0]
74+
75+
def forward_loop(mod):
76+
# Switch the pipeline's backbone, run calibration
77+
pipe.transformer = mod
78+
do_calibrate(
79+
pipe=pipe,
80+
prompt="a dog running in a park",
81+
)
82+
83+
if args.dtype != "fp16":
84+
backbone = mtq.quantize(backbone, ptq_config, forward_loop)
85+
mtq.disable_quantizer(backbone, filter_func)
86+
87+
batch_size = 2 if args.dynamic_shapes else 1
88+
if args.dynamic_shapes:
89+
BATCH = torch.export.Dim("batch", min=1, max=8)
90+
dynamic_shapes = {
91+
"hidden_states": {0: BATCH},
92+
"encoder_hidden_states": {0: BATCH},
93+
"pooled_projections": {0: BATCH},
94+
"timestep": {0: BATCH},
95+
"txt_ids": {},
96+
"img_ids": {},
97+
"guidance": {0: BATCH},
98+
"joint_attention_kwargs": {},
99+
"return_dict": None,
100+
}
101+
else:
102+
dynamic_shapes = None
103+
104+
settings = {
105+
"strict": False,
106+
"allow_complex_guards_as_runtime_asserts": True,
107+
"enabled_precisions": enabled_precisions,
108+
"truncate_double": True,
109+
"min_block_size": 1,
110+
"debug": False,
111+
"use_python_runtime": True,
112+
"immutable_weights": False,
113+
"offload_module_to_cpu": True,
114+
}
115+
if args.low_vram_mode:
116+
pipe.remove_all_hooks()
117+
pipe.enable_sequential_cpu_offload()
118+
remove_hook_from_module(pipe.transformer, recurse=True)
119+
pipe.transformer.to(DEVICE)
120+
trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings)
121+
if dynamic_shapes:
122+
trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes)
123+
pipe.transformer = trt_gm
124+
125+
image = pipe(
126+
"Test",
127+
output_type="pil",
128+
num_inference_steps=2,
129+
num_images_per_prompt=batch_size,
130+
).images
131+
132+
torch.cuda.empty_cache()
133+
134+
if args.low_vram_mode:
135+
pipe.remove_all_hooks()
136+
pipe.to(DEVICE)
137+
138+
return pipe, backbone, trt_gm
139+
140+
141+
def launch_gradio(pipeline, backbone, trt_gm):
142+
143+
def generate_image(prompt, inference_step, batch_size=2):
144+
start_time = time.time()
145+
image = pipeline(
146+
prompt,
147+
output_type="pil",
148+
num_inference_steps=inference_step,
149+
num_images_per_prompt=batch_size,
150+
).images
151+
end_time = time.time()
152+
return image, end_time - start_time
153+
154+
def model_change(model):
155+
if model == "Torch Model":
156+
pipeline.transformer = backbone
157+
backbone.to(DEVICE)
158+
else:
159+
backbone.to("cpu")
160+
pipeline.transformer = trt_gm
161+
torch.cuda.empty_cache()
162+
163+
def load_lora(path):
164+
pipeline.load_lora_weights(
165+
path,
166+
adapter_name="lora1",
167+
)
168+
pipeline.set_adapters(["lora1"], adapter_weights=[1])
169+
pipeline.fuse_lora()
170+
pipeline.unload_lora_weights()
171+
print("LoRA loaded! Begin refitting")
172+
generate_image(pipeline, ["Test"], 2)
173+
print("Refitting Finished!")
174+
175+
# Create Gradio interface
176+
with gr.Blocks(title="Flux Demo with Torch-TensorRT") as demo:
177+
gr.Markdown("# Flux Image Generation Demo Accelerated by Torch-TensorRT")
178+
179+
with gr.Row():
180+
with gr.Column():
181+
# Input components
182+
prompt_input = gr.Textbox(
183+
label="Prompt", placeholder="Enter your prompt here...", lines=3
184+
)
185+
model_dropdown = gr.Dropdown(
186+
choices=["Torch Model", "Torch-TensorRT Accelerated Model"],
187+
value="Torch-TensorRT Accelerated Model",
188+
label="Model Variant",
189+
)
190+
191+
lora_upload_path = gr.Textbox(
192+
label="LoRA Path",
193+
placeholder="Enter the LoRA checkpoint path here. It could be a local path or a Hugging Face URL.",
194+
value="gokaygokay/Flux-Engrave-LoRA",
195+
lines=2,
196+
)
197+
num_steps = gr.Slider(
198+
minimum=20, maximum=100, value=20, step=1, label="Inference Steps"
199+
)
200+
batch_size = gr.Slider(
201+
minimum=1, maximum=8, value=1, step=1, label="Batch Size"
202+
)
203+
204+
generate_btn = gr.Button("Generate Image")
205+
load_lora_btn = gr.Button("Load LoRA")
206+
207+
with gr.Column():
208+
# Output component
209+
output_image = gr.Gallery(label="Generated Image")
210+
time_taken = gr.Textbox(
211+
label="Generation Time (seconds)", interactive=False
212+
)
213+
214+
# Connect the button to the generation function
215+
model_dropdown.change(model_change, inputs=[model_dropdown])
216+
load_lora_btn.click(
217+
fn=load_lora,
218+
inputs=[
219+
lora_upload_path,
220+
],
221+
)
222+
223+
# Update generate button click to include time output
224+
generate_btn.click(
225+
fn=generate_image,
226+
inputs=[
227+
prompt_input,
228+
num_steps,
229+
batch_size,
230+
],
231+
outputs=[output_image, time_taken],
232+
)
233+
demo.launch()
234+
235+
236+
def main(args):
237+
pipe, backbone, trt_gm = compile_model(args)
238+
launch_gradio(pipe, backbone, trt_gm)
239+
240+
241+
# Launch the interface
242+
if __name__ == "__main__":
243+
parser = argparse.ArgumentParser(
244+
description="Run Flux quantization with different dtypes"
245+
)
246+
247+
parser.add_argument(
248+
"--dtype",
249+
choices=["fp8", "int8", "fp16"],
250+
default="fp16",
251+
help="Select the data type to use (fp8 or int8 or fp16)",
252+
)
253+
parser.add_argument(
254+
"--low_vram_mode",
255+
action="store_true",
256+
help="Use low VRAM mode when you have a small GPU (<=32GB)",
257+
)
258+
parser.add_argument(
259+
"--dynamic_shapes",
260+
"-d",
261+
action="store_true",
262+
help="Use dynamic shapes",
263+
)
264+
args = parser.parse_args()
265+
main(args)

examples/dynamo/mutable_torchtrt_module_example.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch
2323
import torch_tensorrt as torch_trt
2424
import torchvision.models as models
25+
from diffusers import DiffusionPipeline
2526

2627
np.random.seed(5)
2728
torch.manual_seed(5)
@@ -31,7 +32,7 @@
3132
# Initialize the Mutable Torch TensorRT Module with settings.
3233
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
3334
settings = {
34-
"use_python": False,
35+
"use_python_runtime": False,
3536
"enabled_precisions": {torch.float32},
3637
"immutable_weights": False,
3738
}
@@ -40,7 +41,6 @@
4041
mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings)
4142
# You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module.
4243
mutable_module(*inputs)
43-
4444
# %%
4545
# Make modifications to the mutable module.
4646
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -73,13 +73,12 @@
7373
# Stable Diffusion with Huggingface
7474
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7575

76-
from diffusers import DiffusionPipeline
7776

7877
with torch.no_grad():
7978
settings = {
8079
"use_python_runtime": True,
8180
"enabled_precisions": {torch.float16},
82-
"debug": True,
81+
"debug": False,
8382
"immutable_weights": False,
8483
}
8584

@@ -106,7 +105,7 @@
106105
"text_embeds": {0: BATCH},
107106
"time_ids": {0: BATCH},
108107
},
109-
"return_dict": False,
108+
"return_dict": None,
110109
}
111110
pipe.unet.set_expected_dynamic_shape_range(
112111
args_dynamic_shapes, kwargs_dynamic_shapes

examples/dynamo/refit_engine_example.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
)
102102

103103
# Check the output
104+
model2.to("cuda")
104105
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(*inputs)
105106
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
106107
assert torch.allclose(

examples/dynamo/torch_export_flux_dev.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,21 +114,22 @@
114114
min_block_size=1,
115115
use_fp32_acc=True,
116116
use_explicit_typing=True,
117+
immutable_weights=False,
118+
offload_module_to_cpu=True,
117119
)
118120

119121
# %%
120122
# Post Processing
121123
# ---------------------------
122124
# Release the GPU memory occupied by the exported program and the pipe.transformer
123125
# Set the transformer in the Flux pipeline to the Torch-TRT compiled model
124-
125-
del ep
126-
backbone.to("cpu")
126+
pipe.transformer = None
127127
pipe.to(DEVICE)
128-
torch.cuda.empty_cache()
129128
pipe.transformer = trt_gm
129+
del ep
130+
torch.cuda.empty_cache()
130131
pipe.transformer.config = config
131-
132+
trt_gm.device = torch.device("cuda")
132133
# %%
133134
# Image generation using prompt
134135
# ---------------------------

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,7 @@ def compile(
693693
)
694694

695695
gm = exported_program.module()
696+
# Move the weights in the state_dict to CPU
696697
logger.debug("Input graph: " + str(gm.graph))
697698

698699
# Apply lowering on the graph module
@@ -914,7 +915,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
914915
parse_graph_io(submodule, subgraph_data)
915916
dryrun_tracker.tensorrt_graph_count += 1
916917
dryrun_tracker.per_subgraph_data.append(subgraph_data)
917-
918+
torch.cuda.empty_cache()
918919
# Create TRT engines from submodule
919920
if not settings.dryrun:
920921
trt_module = convert_module(

0 commit comments

Comments
 (0)