Skip to content

Commit aa37194

Browse files
authored
chore: Access user settings within the lowering system (#3245)
1 parent 4be64a8 commit aa37194

24 files changed

+194
-110
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -242,16 +242,6 @@ def compile(
242242
raise AssertionError(
243243
f"Input graph should be an ExportedProgram but got type {type(exported_program)}"
244244
)
245-
exported_program = pre_export_lowering(exported_program)
246-
exported_program = exported_program.run_decompositions(
247-
get_decompositions(enable_experimental_decompositions)
248-
)
249-
gm = exported_program.module()
250-
logger.debug("Input graph: " + str(gm.graph))
251-
252-
# Apply lowering on the graph module
253-
gm = post_lowering(gm, use_fp32_acc=use_fp32_acc)
254-
logger.debug("Lowered Input graph: " + str(gm.graph))
255245

256246
engine_cache = None
257247
if cache_built_engines or reuse_cached_engines:
@@ -305,6 +295,19 @@ def compile(
305295

306296
settings = CompilationSettings(**compilation_options)
307297
logger.info("Compilation Settings: %s\n", settings)
298+
299+
exported_program = pre_export_lowering(exported_program, settings)
300+
exported_program = exported_program.run_decompositions(
301+
get_decompositions(enable_experimental_decompositions)
302+
)
303+
304+
gm = exported_program.module()
305+
logger.debug("Input graph: " + str(gm.graph))
306+
307+
# Apply lowering on the graph module
308+
gm = post_lowering(gm, settings)
309+
logger.debug("Lowered Input graph: " + str(gm.graph))
310+
308311
trt_gm = compile_module(
309312
gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
310313
)
@@ -683,7 +686,10 @@ def convert_exported_program_to_serialized_trt_engine(
683686
"use_fp32_acc": use_fp32_acc,
684687
}
685688

686-
exported_program = pre_export_lowering(exported_program)
689+
settings = CompilationSettings(**compilation_options)
690+
logger.info("Compilation Settings: %s\n", settings)
691+
692+
exported_program = pre_export_lowering(exported_program, settings)
687693
# Decompose the exported program
688694
exported_program = exported_program.run_decompositions(
689695
get_decompositions(enable_experimental_decompositions)
@@ -692,12 +698,9 @@ def convert_exported_program_to_serialized_trt_engine(
692698
logger.debug("Input graph: " + str(gm.graph))
693699

694700
# Apply lowering on the graph module
695-
gm = post_lowering(gm)
701+
gm = post_lowering(gm, settings)
696702
logger.debug("Lowered Input graph: " + str(gm.graph))
697703

698-
settings = CompilationSettings(**compilation_options)
699-
logger.info("Compilation Settings: %s\n", settings)
700-
701704
# Configure user compilation settings to converters.
702705
CONVERTERS.set_compilation_settings(settings)
703706

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,15 +292,15 @@ def refit_module_weights(
292292
raise AssertionError(
293293
f"Input graph should be an ExportedProgram but got type {type(new_weight_module)}"
294294
)
295-
new_weight_module = pre_export_lowering(new_weight_module)
295+
new_weight_module = pre_export_lowering(new_weight_module, settings)
296296
new_weight_module = new_weight_module.run_decompositions(
297297
get_decompositions(settings.enable_experimental_decompositions)
298298
)
299299
new_gm = new_weight_module.module()
300300
logger.debug("Input graph: " + str(new_gm.graph))
301301
# Apply lowering on the graph module
302302

303-
new_gm = post_lowering(new_gm)
303+
new_gm = post_lowering(new_gm, settings)
304304

305305
logger.info("Compilation Settings: %s\n", settings)
306306

@@ -397,7 +397,7 @@ def refit_module_weights(
397397
if isinstance(compiled_submodule, PythonTorchTensorRTModule):
398398
engine = compiled_submodule.engine
399399
elif isinstance(compiled_submodule, TorchTensorRTModule):
400-
engine_info = compiled_submodule.engine.__getstate__()[0] # type: ignore[index]
400+
engine_info = compiled_submodule.engine.__getstate__()[0]
401401
engine = get_engine_from_encoded_engine(
402402
engine_info[ENGINE_IDX], runtime
403403
)

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,16 @@ def _pretraced_backend(
7777
with unittest.mock.patch.object(
7878
fake_mode, "allow_non_fake_inputs", True
7979
), fake_mode:
80-
repair_input_aliasing(gm)
80+
repair_input_aliasing(gm, settings)
8181

8282
# Remove sym_int placeholders and inputs
83-
remove_sym_nodes(gm)
83+
remove_sym_nodes(gm, settings)
8484
torch_inputs = [
8585
input for input in sample_inputs if isinstance(input, torch.Tensor)
8686
]
8787

8888
# Remove detach nodes
89-
remove_detach(gm)
89+
remove_detach(gm, settings)
9090

9191
# Invoke AOTAutograd to translate operators to aten
9292
gm = aot_export_joint_simple(
@@ -100,7 +100,7 @@ def _pretraced_backend(
100100

101101
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
102102

103-
gm = post_lowering(gm, use_fp32_acc=settings.use_fp32_acc)
103+
gm = post_lowering(gm, settings)
104104

105105
logger.debug("Lowered Input graph:\n " + str(gm.graph))
106106

py/torch_tensorrt/dynamo/lowering/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,4 @@
33
torch_enabled_decompositions,
44
)
55
from ._decompositions import get_decompositions # noqa: F401
6-
from ._remove_sym_nodes import remove_sym_nodes
7-
from ._repair_input_aliasing import repair_input_aliasing
8-
from .passes import post_lowering, pre_export_lowering
9-
from .passes.remove_detach import remove_detach
6+
from .passes import *
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
from ._aten_lowering_pass import *
2+
from .remove_sym_nodes import remove_sym_nodes
3+
from .repair_input_aliasing import repair_input_aliasing

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import logging
2-
from typing import Any, Callable, Optional, Sequence, Union
2+
from typing import Callable, Optional, Sequence, Union
33

44
import torch
5+
from torch_tensorrt.dynamo._settings import CompilationSettings
56

67
from .accumulate_fp32_matmul import accumulate_fp32_matmul
78
from .constant_folding import constant_fold
@@ -29,6 +30,7 @@
2930
replace_full_like_with_full,
3031
view_to_reshape,
3132
remove_assert_scalar,
33+
accumulate_fp32_matmul,
3234
]
3335
)
3436

@@ -91,25 +93,28 @@ def _remove_lowering_pass(*, index: int) -> None:
9193
return
9294

9395

94-
def post_lowering(gm: torch.fx.GraphModule, **kwargs: Any) -> torch.fx.GraphModule:
96+
def post_lowering(
97+
gm: torch.fx.GraphModule, settings: CompilationSettings = CompilationSettings()
98+
) -> torch.fx.GraphModule:
9599
"""Applies the lowering passes to a graph module after torch.export/ torch.compile and their decompositions, returns the modified GraphModule"""
96100
logging.debug(
97101
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_POST_LOWERING_PASSES}"
98102
)
99-
gm = ATEN_POST_LOWERING_PASSES(gm)
100-
if kwargs.get("use_fp32_acc", False):
101-
gm = accumulate_fp32_matmul(gm)
103+
gm = ATEN_POST_LOWERING_PASSES(gm, settings)
102104

103105
return gm
104106

105107

106-
def pre_export_lowering(ep: torch.export.ExportedProgram) -> torch.fx.GraphModule:
108+
def pre_export_lowering(
109+
ep: torch.export.ExportedProgram,
110+
settings: CompilationSettings = CompilationSettings(),
111+
) -> torch.fx.GraphModule:
107112
"""Applies the lowering passes to a graph module after torch.export/ torch.compile and their decompositions, returns the modified GraphModule"""
108113
logging.debug(
109114
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_PRE_LOWERING_PASSES}"
110115
)
111116
gm = ep.graph_module
112-
gm = ATEN_PRE_LOWERING_PASSES(gm)
117+
gm = ATEN_PRE_LOWERING_PASSES(gm, settings)
113118
return ep
114119

115120

Lines changed: 50 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,65 @@
11
import logging
22

33
import torch
4+
from torch_tensorrt.dynamo._settings import CompilationSettings
45
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
56
clean_up_graph_after_modifications,
67
)
78

89
logger = logging.getLogger(__name__)
910

1011

11-
def accumulate_fp32_matmul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
12+
def accumulate_fp32_matmul(
13+
gm: torch.fx.GraphModule, settings: CompilationSettings
14+
) -> torch.fx.GraphModule:
1215
"""Replace a matmul layer with fp32 accumulation nodes"""
13-
matmul_targets = [
14-
torch.ops.aten.mm.default,
15-
torch.ops.aten.bmm.default,
16-
torch.ops.aten.addmm.default,
17-
]
18-
matmul_nodes = [node for node in gm.graph.nodes if node.target in matmul_targets]
19-
for matmul_node in matmul_nodes:
20-
# Prior to the matmul node, insert a cast to the 32-bit float32 node
21-
node_inputs = matmul_node.all_input_nodes
22-
23-
for node_input in node_inputs:
24-
with gm.graph.inserting_before(matmul_node):
25-
node_32bit = gm.graph.call_function(
16+
if settings.use_fp32_acc:
17+
matmul_targets = [
18+
torch.ops.aten.mm.default,
19+
torch.ops.aten.bmm.default,
20+
torch.ops.aten.addmm.default,
21+
]
22+
23+
matmul_nodes = [
24+
node for node in gm.graph.nodes if node.target in matmul_targets
25+
]
26+
for matmul_node in matmul_nodes:
27+
# Prior to the matmul node, insert a cast to the 32-bit float32 node
28+
node_inputs = matmul_node.all_input_nodes
29+
30+
for node_input in node_inputs:
31+
with gm.graph.inserting_before(matmul_node):
32+
node_32bit = gm.graph.call_function(
33+
torch.ops.aten._to_copy.default,
34+
args=(node_input,),
35+
kwargs={"dtype": torch.float32},
36+
)
37+
38+
# Replace the input to matmul node with new 32-bit cast node
39+
matmul_node.replace_input_with(node_input, node_32bit)
40+
41+
# Add a cast back to original precision
42+
with gm.graph.inserting_after(matmul_node):
43+
node_orig_precision = gm.graph.call_function(
2644
torch.ops.aten._to_copy.default,
27-
args=(node_input,),
28-
kwargs={"dtype": torch.float32},
45+
args=(matmul_node,),
46+
kwargs={"dtype": torch.float16},
2947
)
48+
matmul_node.replace_all_uses_with(
49+
node_orig_precision, propagate_meta=False
50+
)
51+
# This is a hack. replace_all_uses_with isn't working here. It complains node_orig_precision is already being used before created.
52+
node_orig_precision.replace_input_with(
53+
node_orig_precision.all_input_nodes[0], matmul_node
54+
)
55+
56+
gm = clean_up_graph_after_modifications(gm)
57+
logger.debug(
58+
f"Graph after enabling matmul layers to use FP32 accumulation:\n{gm.graph}"
59+
)
60+
else:
61+
logger.debug(
62+
"Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings"
63+
)
3064

31-
# Replace the input to matmul node with new 32-bit cast node
32-
matmul_node.replace_input_with(node_input, node_32bit)
33-
34-
# Add a cast back to original precision
35-
with gm.graph.inserting_after(matmul_node):
36-
node_orig_precision = gm.graph.call_function(
37-
torch.ops.aten._to_copy.default,
38-
args=(matmul_node,),
39-
kwargs={"dtype": torch.float16},
40-
)
41-
matmul_node.replace_all_uses_with(node_orig_precision, propagate_meta=False)
42-
# This is a hack. replace_all_uses_with isn't working here. It complains node_orig_precision is already being used before created.
43-
node_orig_precision.replace_input_with(
44-
node_orig_precision.all_input_nodes[0], matmul_node
45-
)
46-
47-
gm = clean_up_graph_after_modifications(gm)
48-
logger.debug(f"Graph after changing matmuls to use FP32 accumulation:\n{gm.graph}")
4965
return gm

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
from torch_tensorrt._utils import sanitized_torch_version
6+
from torch_tensorrt.dynamo._settings import CompilationSettings
67
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
78
clean_up_graph_after_modifications,
89
)
@@ -19,7 +20,9 @@
1920

2021

2122
@torch.utils._python_dispatch._disable_current_modes() # type: ignore
22-
def constant_fold(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
23+
def constant_fold(
24+
gm: torch.fx.GraphModule, settings: CompilationSettings
25+
) -> torch.fx.GraphModule:
2326
"""Adapted from:
2427
https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
2528

py/torch_tensorrt/dynamo/lowering/passes/fuse_prims_broadcast.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22

33
import torch
4+
from torch_tensorrt.dynamo._settings import CompilationSettings
45
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
56
clean_up_graph_after_modifications,
67
)
@@ -9,7 +10,9 @@
910

1011

1112
# TODO: Add relevant prims to this fusion
12-
def fuse_prims_broadcast(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
13+
def fuse_prims_broadcast(
14+
gm: torch.fx.GraphModule, settings: CompilationSettings
15+
) -> torch.fx.GraphModule:
1316
"""Fuses prim nodes which are effectively the ATen equivalents with keep_dim=True"""
1417
modified_graph = False
1518

py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22
from typing import Callable, Tuple
33

44
import torch
5+
from torch_tensorrt.dynamo._settings import CompilationSettings
56
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
67
clean_up_graph_after_modifications,
78
)
89

910
logger = logging.getLogger(__name__)
1011

1112

12-
def lower_linear(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
13+
def lower_linear(
14+
gm: torch.fx.GraphModule, settings: CompilationSettings
15+
) -> torch.fx.GraphModule:
1316
"""Replace aten.linear with an equivalent implementation which can be easily converted to TRT"""
1417
orig, replacement = linear_replacement()
1518

0 commit comments

Comments
 (0)