Skip to content

Commit b23bab2

Browse files
committed
feat: Saving modules using the AOTI format
1 parent f09be72 commit b23bab2

File tree

8 files changed

+246
-524
lines changed

8 files changed

+246
-524
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,5 @@ MODULE.bazel.lock
7878
*.whl
7979
.coverage
8080
coverage.xml
81-
*.log
81+
*.log
82+
*.pt2

docsrc/user_guide/saving_models.rst

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@ Saving models compiled with Torch-TensorRT can be done using `torch_tensorrt.sav
1414
Dynamo IR
1515
-------------
1616

17-
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.fx.GraphModule` object by default.
18-
We can save this object in either `TorchScript` (`torch.jit.ScriptModule`) or `ExportedProgram` (`torch.export.ExportedProgram`) formats by
17+
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.fx.GraphModule` object by default.
18+
We can save this object in either `TorchScript` (`torch.jit.ScriptModule`), `ExportedProgram` (`torch.export.ExportedProgram`) or `PT2` formats by
1919
specifying the `output_format` flag. Here are the options `output_format` will accept
2020

2121
* `exported_program` : This is the default. We perform transformations on the graphmodule first and use `torch.export.save` to save the module.
2222
* `torchscript` : We trace the graphmodule via `torch.jit.trace` and save it via `torch.jit.save`.
23+
* `PT2 Format` : This is a next generation runtime for PyTorch models, allowing them to run in Python and in C++
2324

2425
a) ExportedProgram
2526
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -52,8 +53,8 @@ b) Torchscript
5253
model = MyModel().eval().cuda()
5354
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
5455
# trt_gm is a torch.fx.GraphModule object
55-
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
56-
torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", inputs=inputs)
56+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", arg_inputs=inputs)
57+
torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", arg_inputs=inputs)
5758
5859
# Later, you can load it and run inference
5960
model = torch.jit.load("trt.ts").cuda()
@@ -73,7 +74,7 @@ For `ir=ts`, this behavior stays the same in 2.X versions as well.
7374
7475
model = MyModel().eval().cuda()
7576
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
76-
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs=inputs) # Output is a ScriptModule object
77+
trt_ts = torch_tensorrt.compile(model, ir="ts", arg_inputs=inputs) # Output is a ScriptModule object
7778
torch.jit.save(trt_ts, "trt_model.ts")
7879
7980
# Later, you can load it and run inference
@@ -98,3 +99,26 @@ Here's an example usage
9899
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
99100
model = torch_tensorrt.load(<file_path>).module()
100101
model(*inputs)
102+
103+
b) PT2 Format
104+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
105+
106+
PT2 is a new format that allows models to be run outside of Python in the future. It utilizes `AOTInductor <https://docs.pytorch.org/docs/main/torch.compiler_aot_inductor.html>`_
107+
to generate kernels for components that will not be run in TensorRT.
108+
109+
Here's an example on how to save and load Torch-TensorRT Module using AOTInductor in Python
110+
111+
.. code-block:: python
112+
113+
import torch
114+
import torch_tensorrt
115+
116+
model = MyModel().eval().cuda()
117+
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
118+
# trt_ep is a torch.fx.GraphModule object
119+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
120+
torch_tensorrt.save(trt_gm, "trt.pt2", arg_inputs=inputs, output_format="aot_inductor", retrace=True)
121+
122+
# Later, you can load it and run inference
123+
model = torch._inductor.aoti_load_package("trt.pt2")
124+
model(*inputs)
Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
CXX=g++
22
DEP_DIR=$(PWD)/deps
3-
INCLUDE_DIRS=-I$(DEP_DIR)/libtorch/include -I$(DEP_DIR)/torch_tensorrt/include
4-
LIB_DIRS=-L$(DEP_DIR)/torch_tensorrt/lib -L$(DEP_DIR)/libtorch/lib # -Wl,-rpath $(DEP_DIR)/tensorrt/lib
5-
LIBS=-Wl,--no-as-needed -ltorchtrt_runtime -Wl,--as-needed -ltorch -ltorch_cuda -ltorch_cpu -ltorch_global_deps -lbackend_with_compiler -lc10 -lc10_cuda
3+
CUDA_HOME=/usr/local/cuda
4+
INCLUDE_DIRS=-I$(DEP_DIR)/libtorch/include -I$(DEP_DIR)/torch_tensorrt/include -I$(CUDA_HOME)/include -I$(DEP_DIR)/libtorch/include/torch/csrc/api/include
5+
LIB_DIRS=-L$(DEP_DIR)/torch_tensorrt/lib -L$(DEP_DIR)/libtorch/lib -Wl,-rpath $(DEP_DIR)/tensorrt/lib
6+
LIBS=-Wl,--no-as-needed -ltorchtrt_runtime -ltorchtrt_plugins -Wl,--as-needed -ltorch -ltorch_cuda -ltorch_cpu -ltorch_global_deps -lbackend_with_compiler -lc10 -lc10_cuda
67
SRCS=main.cpp
78

89
TARGET=torchtrt_runtime_example
910

1011
$(TARGET):
1112
$(CXX) $(SRCS) $(INCLUDE_DIRS) $(LIB_DIRS) $(LIBS) -o $(TARGET)
13+
echo "Add to LD_LIBRARY_PATH: $(DEP_DIR)/torch_tensorrt/lib:$(DEP_DIR)/libtorch/lib:$(DEP_DIR)/tensorrt/lib:$(CUDA_HOME)/lib64"
14+
15+
generate_pt2:
16+
uv run network.py
1217

1318
clean:
1419
$(RM) $(TARGET)

examples/torchtrt_runtime_example/main.cpp

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,63 @@
33
#include <memory>
44
#include <sstream>
55
#include <vector>
6-
#include "torch/script.h"
6+
#include "torch/torch.h"
7+
#include "torch/csrc/inductor/aoti_package/model_package_loader.h"
8+
#include "torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h"
9+
10+
/*
11+
* This example demonstrates how to load and run a pre-built Torch-TensorRT
12+
* AOTInductor (AOTI) model package using the PyTorch C++ API.
13+
*
14+
* Usage:
15+
* torchtrt_runtime_example <path-to-pre-built-trt-aoti module>
16+
*
17+
* Steps:
18+
* 1. Parse the path to the AOTI model package from the command line.
19+
* 2. Load the model package using AOTIModelPackageLoader.
20+
* 3. Prepare a random CUDA tensor as input.
21+
* 4. Run inference using the loaded model.
22+
* 5. Print the output tensor(s) or an error message if inference fails.
23+
*/
724

825
int main(int argc, const char* argv[]) {
26+
// Check for correct number of command-line arguments
927
if (argc < 2) {
10-
std::cerr << "usage: samplertapp <path-to-pre-built-trt-ts module>\n";
28+
std::cerr << "usage: torchtrt_runtime_example <path-to-pre-built-trt-aoti module>\n";
1129
return -1;
1230
}
1331

14-
std::string trt_ts_module_path = argv[1];
32+
// Get the path to the TRT AOTI model package from the command line
33+
std::string trt_aoti_module_path = argv[1];
1534

16-
torch::jit::Module trt_ts_mod;
35+
// Enable inference mode for thread-local optimizations
36+
c10::InferenceMode mode;
1737
try {
18-
// Deserialize the ScriptModule from a file using torch::jit::load().
19-
trt_ts_mod = torch::jit::load(trt_ts_module_path);
38+
// Load the AOTI model package
39+
torch::inductor::AOTIModelPackageLoader runner(trt_aoti_module_path);
40+
41+
// Create a random input tensor on CUDA with shape [1, 3, 5, 5] and type float32
42+
std::vector<at::Tensor> inputs = {at::randn({1, 3, 5, 5}, {at::kCUDA}).to(torch::kFloat32)};
43+
44+
// Run inference using the loaded model
45+
std::vector<at::Tensor> outputs = runner.run(inputs);
46+
47+
// Process and print the output tensor(s)
48+
if (!outputs.empty()) {
49+
std::cout << "Model output: " << outputs[0] << std::endl;
50+
} else {
51+
std::cerr << "No output tensors received!" << std::endl;
52+
}
53+
2054
} catch (const c10::Error& e) {
21-
std::cerr << "error loading the model from : " << trt_ts_module_path << std::endl;
22-
return -1;
55+
// Handle errors from the PyTorch C++ API
56+
std::cerr << "Error running model: " << e.what() << std::endl;
57+
return 1;
58+
} catch (const std::exception& e) {
59+
// Handle other standard exceptions
60+
std::cerr << "An unexpected error occurred: " << e.what() << std::endl;
61+
return 1;
2362
}
2463

25-
std::cout << "Running TRT engine" << std::endl;
26-
std::vector<torch::jit::IValue> trt_inputs_ivalues;
27-
trt_inputs_ivalues.push_back(at::randint(-5, 5, {1, 3, 5, 5}, {at::kCUDA}).to(torch::kFloat32));
28-
torch::jit::IValue trt_results_ivalues = trt_ts_mod.forward(trt_inputs_ivalues);
29-
std::cout << "==================TRT outputs================" << std::endl;
30-
std::cout << trt_results_ivalues << std::endl;
31-
std::cout << "=============================================" << std::endl;
32-
std::cout << "TRT engine execution completed. " << std::endl;
64+
return 0;
3365
}

examples/torchtrt_runtime_example/network.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import torch.nn as nn
3-
import torch_tensorrt as torchtrt
3+
import torch_tensorrt
44

55

66
# create a simple norm layer.
@@ -29,21 +29,41 @@ def forward(self, x):
2929

3030
def main():
3131
model = ConvGelu().eval().cuda()
32-
scripted_model = torch.jit.script(model)
33-
32+
torch_ex_input = torch.randn([1, 3, 5, 5], device="cuda")
3433
compile_settings = {
35-
"inputs": [torchtrt.Input([1, 3, 5, 5])],
34+
"arg_inputs": [torch_ex_input],
35+
"ir": "dynamo",
3636
"enabled_precisions": {torch.float32},
37+
"min_block_size": 1,
3738
}
3839

39-
trt_ts_module = torchtrt.compile(scripted_model, **compile_settings)
40-
torch.jit.save(trt_ts_module, "conv_gelu.jit")
40+
cg_trt_module = torch_tensorrt.compile(model, **compile_settings)
41+
torch_tensorrt.save(
42+
cg_trt_module,
43+
file_path="torchtrt_aoti_conv_gelu.pt2",
44+
output_format="aot_inductor",
45+
retrace=True,
46+
arg_inputs=[torch_ex_input],
47+
)
4148

4249
norm_model = Norm().eval().cuda()
43-
norm_ts_module = torch.jit.script(norm_model)
44-
norm_trt_ts = torchtrt.compile(norm_ts_module, **compile_settings)
45-
torch.jit.save(norm_trt_ts, "norm.jit")
46-
print("Generated Torchscript-TRT models.")
50+
norm_trt_module = torch_tensorrt.compile(norm_model, **compile_settings)
51+
torch_tensorrt.save(
52+
norm_trt_module,
53+
file_path="torchtrt_aoti_norm.pt2",
54+
output_format="aot_inductor",
55+
retrace=True,
56+
arg_inputs=[torch_ex_input],
57+
)
58+
print("Generated TorchTRT-AOTI models.")
59+
60+
loaded_cg_trt_module = torch._inductor.aoti_load_package(
61+
"torchtrt_aoti_conv_gelu.pt2"
62+
)
63+
loaded_norm_trt_module = torch._inductor.aoti_load_package("torchtrt_aoti_norm.pt2")
64+
with torch.inference_mode():
65+
print(loaded_cg_trt_module(torch_ex_input))
66+
print(loaded_norm_trt_module(torch_ex_input))
4767

4868

4969
if __name__ == "__main__":

py/torch_tensorrt/_compile.py

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,7 @@ def save(
585585
kwarg_inputs: Optional[dict[str, Any]] = None,
586586
retrace: bool = False,
587587
pickle_protocol: int = 2,
588+
**kwargs: Any,
588589
) -> None:
589590
"""
590591
Save the model to disk in the specified output format.
@@ -594,15 +595,15 @@ def save(
594595
inputs (torch.Tensor): Torch input tensors
595596
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
596597
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
597-
output_format (str): Format to save the model. Options include exported_program | torchscript.
598+
output_format (str): Format to save the model. Options include exported_program | torchscript | aot_inductor.
598599
retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
599600
This flag is experimental for now.
600601
pickle_protocol (int): The pickle protocol to use to save the model. Default is 2. Increase this to 4 or higher for large models
601602
"""
602603
if isinstance(module, CudaGraphsTorchTensorRTModule):
603604
module = module.compiled_module
604605
module_type = _parse_module_type(module)
605-
accepted_formats = {"exported_program", "torchscript"}
606+
accepted_formats = {"exported_program", "torchscript", "aot_inductor"}
606607
if arg_inputs is not None and not all(
607608
isinstance(input, torch.Tensor) for input in arg_inputs
608609
):
@@ -633,9 +634,9 @@ def save(
633634
"Input model is of type nn.Module. Saving nn.Module directly is not supported. Supported model types torch.jit.ScriptModule | torch.fx.GraphModule | torch.export.ExportedProgram."
634635
)
635636
elif module_type == _ModuleType.ts:
636-
if output_format == "exported_program":
637+
if not all([output_format == f for f in ["exported_program", "aot_inductor"]]):
637638
raise ValueError(
638-
"Provided model is a torch.jit.ScriptModule but the output_format specified is exported_program. Please verify the output_format"
639+
"Provided model is a torch.jit.ScriptModule but the output_format specified is not torchscript. Other output formats are not supported"
639640
)
640641
else:
641642
if arg_inputs is not None:
@@ -653,7 +654,22 @@ def save(
653654
logger.warning(
654655
"Provided model is a torch.export.ExportedProgram, inputs or arg_inputs is not necessary during save, it uses the inputs or arg_inputs provided during export and compile"
655656
)
656-
torch.export.save(module, file_path)
657+
if output_format == "exported_program":
658+
torch.export.save(module, file_path, pickle_protocol=pickle_protocol)
659+
elif output_format == "aot_inductor":
660+
inductor_configs = {}
661+
if "inductor_configs" in kwargs:
662+
inductor_configs = kwargs["inductor_configs"]
663+
664+
torch._inductor.aoti_compile_and_package(
665+
exp_program,
666+
inductor_configs=inductor_configs,
667+
package_path=file_path,
668+
)
669+
else:
670+
raise RuntimeError(
671+
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
672+
)
657673
elif module_type == _ModuleType.fx:
658674
# The module type is torch.fx.GraphModule
659675
if output_format == "torchscript":
@@ -670,9 +686,24 @@ def save(
670686
"Provided model is a torch.fx.GraphModule and retrace is False, inputs or arg_inputs is not necessary during save."
671687
)
672688
exp_program = export(module)
673-
torch.export.save(
674-
exp_program, file_path, pickle_protocol=pickle_protocol
675-
)
689+
if output_format == "exported_program":
690+
torch.export.save(
691+
exp_program, file_path, pickle_protocol=pickle_protocol
692+
)
693+
elif output_format == "aot_inductor":
694+
inductor_configs = {}
695+
if "inductor_configs" in kwargs:
696+
inductor_configs = kwargs["inductor_configs"]
697+
698+
torch._inductor.aoti_compile_and_package(
699+
exp_program,
700+
inductor_configs=inductor_configs,
701+
package_path=file_path,
702+
)
703+
else:
704+
raise RuntimeError(
705+
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
706+
)
676707
else:
677708
if arg_inputs is None:
678709
raise ValueError(
@@ -684,6 +715,22 @@ def save(
684715
kwargs=kwarg_inputs,
685716
strict=False,
686717
)
687-
torch.export.save(
688-
exp_program, file_path, pickle_protocol=pickle_protocol
689-
)
718+
719+
if output_format == "exported_program":
720+
torch.export.save(
721+
exp_program, file_path, pickle_protocol=pickle_protocol
722+
)
723+
elif output_format == "aot_inductor":
724+
inductor_configs = {}
725+
if "inductor_configs" in kwargs:
726+
inductor_configs = kwargs["inductor_configs"]
727+
728+
torch._inductor.aoti_compile_and_package(
729+
exp_program,
730+
inductor_configs=inductor_configs,
731+
package_path=file_path,
732+
)
733+
else:
734+
raise RuntimeError(
735+
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
736+
)

pyproject.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ requires = [
99
"typing-extensions>=4.7.0",
1010
"future>=0.18.3",
1111
"tensorrt-cu12>=10.11.0,<10.12.0; 'tegra' not in platform_release",
12-
"tensorrt>=10.3.0,<10.4.0; 'tegra' in platform_release",
12+
"tensorrt-cu12>=10.3.0,<10.4.0; 'tegra' not in platform_release",
1313
"torch>=2.8.0.dev,<2.9.0; 'tegra' not in platform_release",
1414
"torch>=2.7.0,<2.8.0; 'tegra' in platform_release",
1515
"pybind11==2.6.2",
@@ -63,7 +63,6 @@ dependencies = [
6363

6464
"tensorrt>=10.11.0,<10.12.0; 'tegra' not in platform_release",
6565
"tensorrt>=10.3.0,<10.4.0; 'tegra' in platform_release",
66-
6766
"tensorrt-cu12>=10.11.0,<10.12.0; 'tegra' not in platform_release",
6867
"tensorrt-cu12-bindings>=10.11.0,<10.12.0; 'tegra' not in platform_release",
6968
"tensorrt-cu12-libs>=10.11.0,<10.12.0; 'tegra' not in platform_release",
@@ -99,8 +98,7 @@ torchvision = [
9998
"torchvision",
10099
] #Leaving torchvisions dependency unconstrained so uv can just install something that should work for the torch we have. TV's on PyT makes it hard to put version constrains in
101100
quantization = ["nvidia-modelopt[all]>=0.27.1"]
102-
monitoring-tools = ["rich>=13.7.1"]
103-
jupyter = ["rich[jupyter]>=13.7.1"]
101+
104102

105103
[project.urls]
106104
Homepage = "https://pytorch.org/tensorrt"

0 commit comments

Comments
 (0)