Skip to content

Commit 1f2b35e

Browse files
committed
Added torchtrt.dynamo.debugger. Cleaning settings.debug
1 parent 8de3947 commit 1f2b35e

File tree

13 files changed

+228
-54
lines changed

13 files changed

+228
-54
lines changed

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
load_cross_compiled_exported_program,
1515
save_cross_compiled_exported_program,
1616
)
17+
from ._debugger import Debugger
1718
from ._exporter import export
1819
from ._refit import refit_module_weights
1920
from ._settings import CompilationSettings

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,13 @@ def compile(
520520
"""
521521

522522
if debug:
523-
set_log_level(logger.parent, logging.DEBUG)
523+
warnings.warn(
524+
"The 'debug' argument is deprecated and will be removed in a future release. "
525+
"Please use the torch_tensorrt.dynamo.Debugger context manager for debugging and graph capture.",
526+
DeprecationWarning,
527+
stacklevel=2,
528+
)
529+
524530
if "truncate_long_and_double" in kwargs.keys():
525531
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
526532
raise ValueError(
@@ -642,7 +648,6 @@ def compile(
642648
"enabled_precisions": (
643649
enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS
644650
),
645-
"debug": debug,
646651
"device": device,
647652
"assume_dynamic_shape_support": assume_dynamic_shape_support,
648653
"workspace_size": workspace_size,
@@ -745,7 +750,7 @@ def compile_module(
745750

746751
# Check the number of supported operations in the graph
747752
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
748-
gm, settings.debug, settings.torch_executed_ops
753+
gm, settings.torch_executed_ops
749754
)
750755

751756
dryrun_tracker.total_ops_in_graph = total_ops
@@ -797,7 +802,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
797802
logger.info("Partitioning the graph via the fast partitioner")
798803
partitioned_module, supported_ops = partitioning.fast_partition(
799804
gm,
800-
verbose=settings.debug,
801805
min_block_size=settings.min_block_size,
802806
torch_executed_ops=settings.torch_executed_ops,
803807
require_full_compilation=settings.require_full_compilation,
@@ -818,7 +822,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
818822
logger.info("Partitioning the graph via the global partitioner")
819823
partitioned_module, supported_ops = partitioning.global_partition(
820824
gm,
821-
verbose=settings.debug,
822825
min_block_size=settings.min_block_size,
823826
torch_executed_ops=settings.torch_executed_ops,
824827
require_full_compilation=settings.require_full_compilation,
@@ -925,17 +928,21 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
925928
)
926929

927930
trt_modules[name] = trt_module
931+
from torch_tensorrt.dynamo._debugger import (
932+
DEBUG_FILE_DIR,
933+
SAVE_ENGINE_PROFILE,
934+
)
928935

929-
if settings.debug and settings.engine_vis_dir:
936+
if SAVE_ENGINE_PROFILE:
930937
if settings.use_python_runtime:
931938
logger.warning(
932939
"Profiling can only be enabled when using the C++ runtime"
933940
)
934941
else:
935-
if not os.path.exists(settings.engine_vis_dir):
936-
os.makedirs(settings.engine_vis_dir)
942+
path = os.path.join(DEBUG_FILE_DIR, "engine_visualization")
943+
os.makedirs(path, exist_ok=True)
937944
trt_module.enable_profiling(
938-
profiling_results_dir=settings.engine_vis_dir,
945+
profiling_results_dir=path,
939946
profile_format="trex",
940947
)
941948

py/torch_tensorrt/dynamo/_debugger.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import logging
2+
import os
3+
import tempfile
4+
from logging.config import dictConfig
5+
from typing import Any, List, Optional
6+
7+
import torch
8+
from torch_tensorrt.dynamo.lowering import (
9+
ATEN_POST_LOWERING_PASSES,
10+
ATEN_PRE_LOWERING_PASSES,
11+
)
12+
13+
_LOGGER = logging.getLogger("torch_tensorrt [TensorRT Conversion Context]")
14+
GRAPH_LEVEL = 5
15+
logging.addLevelName(GRAPH_LEVEL, "GRAPHS")
16+
17+
# Debugger States
18+
DEBUG_FILE_DIR = tempfile.TemporaryDirectory().name
19+
SAVE_ENGINE_PROFILE = False
20+
21+
22+
class Debugger:
23+
def __init__(
24+
self,
25+
level: str,
26+
capture_fx_graph_before: Optional[List[str]] = None,
27+
capture_fx_graph_after: Optional[List[str]] = None,
28+
save_engine_profile: bool = False,
29+
logging_dir: Optional[str] = None,
30+
):
31+
32+
if level != "graphs" and (capture_fx_graph_after or save_engine_profile):
33+
_LOGGER.warning(
34+
"Capture FX Graph or Draw Engine Graph is only supported when level is 'graphs'"
35+
)
36+
37+
if level == "debug":
38+
self.level = logging.DEBUG
39+
elif level == "info":
40+
self.level = logging.INFO
41+
elif level == "warning":
42+
self.level = logging.WARNING
43+
elif level == "error":
44+
self.level = logging.ERROR
45+
elif level == "internal_errors":
46+
self.level = logging.CRITICAL
47+
elif level == "graphs":
48+
self.level = GRAPH_LEVEL
49+
50+
else:
51+
raise ValueError(
52+
f"Invalid level: {level}, allowed levels are: debug, info, warning, error, internal_errors, graphs"
53+
)
54+
55+
self.capture_fx_graph_before = capture_fx_graph_before
56+
self.capture_fx_graph_after = capture_fx_graph_after
57+
global SAVE_ENGINE_PROFILE
58+
SAVE_ENGINE_PROFILE = save_engine_profile
59+
60+
if logging_dir is not None:
61+
global DEBUG_FILE_DIR
62+
DEBUG_FILE_DIR = logging_dir
63+
os.makedirs(DEBUG_FILE_DIR, exist_ok=True)
64+
65+
def __enter__(self) -> None:
66+
self.original_lvl = _LOGGER.getEffectiveLevel()
67+
self.rt_level = torch.ops.tensorrt.get_logging_level()
68+
dictConfig(self.get_config())
69+
70+
if self.level == GRAPH_LEVEL:
71+
self.old_pre_passes, self.old_post_passes = (
72+
ATEN_PRE_LOWERING_PASSES.passes,
73+
ATEN_POST_LOWERING_PASSES.passes,
74+
)
75+
pre_pass_names = [p.__name__ for p in self.old_pre_passes]
76+
post_pass_names = [p.__name__ for p in self.old_post_passes]
77+
path = os.path.join(DEBUG_FILE_DIR, "lowering_passes_visualization")
78+
if self.capture_fx_graph_before is not None:
79+
pre_vis_passes = [
80+
p for p in self.capture_fx_graph_before if p in pre_pass_names
81+
]
82+
post_vis_passes = [
83+
p for p in self.capture_fx_graph_before if p in post_pass_names
84+
]
85+
ATEN_PRE_LOWERING_PASSES.insert_debug_pass_before(pre_vis_passes, path)
86+
ATEN_POST_LOWERING_PASSES.insert_debug_pass_before(
87+
post_vis_passes, path
88+
)
89+
if self.capture_fx_graph_after is not None:
90+
pre_vis_passes = [
91+
p for p in self.capture_fx_graph_after if p in pre_pass_names
92+
]
93+
post_vis_passes = [
94+
p for p in self.capture_fx_graph_after if p in post_pass_names
95+
]
96+
ATEN_PRE_LOWERING_PASSES.insert_debug_pass_after(pre_vis_passes, path)
97+
ATEN_POST_LOWERING_PASSES.insert_debug_pass_after(post_vis_passes, path)
98+
99+
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
100+
101+
dictConfig(self.get_default_config())
102+
torch.ops.tensorrt.set_logging_level(self.rt_level)
103+
if self.level == GRAPH_LEVEL and self.capture_fx_graph_after:
104+
ATEN_PRE_LOWERING_PASSES.passes, ATEN_POST_LOWERING_PASSES.passes = (
105+
self.old_pre_passes,
106+
self.old_post_passes,
107+
)
108+
109+
def get_config(self) -> dict[str, Any]:
110+
config = {
111+
"version": 1,
112+
"disable_existing_loggers": False,
113+
"formatters": {
114+
"brief": {
115+
"format": "%(asctime)s - %(levelname)s - %(message)s",
116+
"datefmt": "%H:%M:%S",
117+
},
118+
"standard": {
119+
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
120+
"datefmt": "%Y-%m-%d %H:%M:%S",
121+
},
122+
},
123+
"handlers": {
124+
"file": {
125+
"level": self.level,
126+
"class": "logging.FileHandler",
127+
"filename": f"{DEBUG_FILE_DIR}/torch_tensorrt_logging.log",
128+
"formatter": "standard",
129+
},
130+
"console": {
131+
"level": self.level,
132+
"class": "logging.StreamHandler",
133+
"formatter": "brief",
134+
},
135+
},
136+
"loggers": {
137+
"": { # root logger
138+
"handlers": ["file", "console"],
139+
"level": self.level,
140+
"propagate": True,
141+
},
142+
},
143+
"force": True,
144+
}
145+
return config
146+
147+
def get_default_config(self) -> dict[str, Any]:
148+
config = {
149+
"version": 1,
150+
"disable_existing_loggers": False,
151+
"formatters": {
152+
"brief": {
153+
"format": "%(asctime)s - %(levelname)s - %(message)s",
154+
"datefmt": "%H:%M:%S",
155+
},
156+
"standard": {
157+
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
158+
"datefmt": "%Y-%m-%d %H:%M:%S",
159+
},
160+
},
161+
"handlers": {
162+
"console": {
163+
"level": self.original_lvl,
164+
"class": "logging.StreamHandler",
165+
"formatter": "brief",
166+
},
167+
},
168+
"loggers": {
169+
"": { # root logger
170+
"handlers": ["console"],
171+
"level": self.original_lvl,
172+
"propagate": True,
173+
},
174+
},
175+
"force": True,
176+
}
177+
return config

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
DLA_SRAM_SIZE = 1048576
1616
ENGINE_CAPABILITY = EngineCapability.STANDARD
1717
WORKSPACE_SIZE = 0
18-
ENGINE_VIS_DIR = None
1918
MIN_BLOCK_SIZE = 5
2019
PASS_THROUGH_BUILD_FAILURES = False
2120
MAX_AUX_STREAMS = None

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
check_module_output,
4040
get_model_device,
4141
get_torch_inputs,
42-
set_log_level,
4342
to_torch_device,
4443
to_torch_tensorrt_device,
4544
)
@@ -72,7 +71,6 @@ def construct_refit_mapping(
7271
interpreter = TRTInterpreter(
7372
module,
7473
inputs,
75-
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
7674
output_dtypes=output_dtypes,
7775
compilation_settings=settings,
7876
)
@@ -266,9 +264,6 @@ def refit_module_weights(
266264
not settings.immutable_weights
267265
), "Refitting is not enabled. Please recompile the engine with immutable_weights=False."
268266

269-
if settings.debug:
270-
set_log_level(logger.parent, logging.DEBUG)
271-
272267
device = to_torch_tensorrt_device(settings.device)
273268
if arg_inputs:
274269
if not isinstance(arg_inputs, collections.abc.Sequence):
@@ -304,7 +299,6 @@ def refit_module_weights(
304299
try:
305300
new_partitioned_module, supported_ops = partitioning.fast_partition(
306301
new_gm,
307-
verbose=settings.debug,
308302
min_block_size=settings.min_block_size,
309303
torch_executed_ops=settings.torch_executed_ops,
310304
)
@@ -320,7 +314,6 @@ def refit_module_weights(
320314
if not settings.use_fast_partitioner:
321315
new_partitioned_module, supported_ops = partitioning.global_partition(
322316
new_gm,
323-
verbose=settings.debug,
324317
min_block_size=settings.min_block_size,
325318
torch_executed_ops=settings.torch_executed_ops,
326319
)

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from dataclasses import dataclass, field
23
from typing import Collection, Optional, Set, Tuple, Union
34

@@ -7,7 +8,6 @@
78
from torch_tensorrt.dynamo._defaults import (
89
ASSUME_DYNAMIC_SHAPE_SUPPORT,
910
CACHE_BUILT_ENGINES,
10-
DEBUG,
1111
DISABLE_TF32,
1212
DLA_GLOBAL_DRAM_SIZE,
1313
DLA_LOCAL_DRAM_SIZE,
@@ -18,7 +18,6 @@
1818
ENABLE_WEIGHT_STREAMING,
1919
ENABLED_PRECISIONS,
2020
ENGINE_CAPABILITY,
21-
ENGINE_VIS_DIR,
2221
HARDWARE_COMPATIBLE,
2322
IMMUTABLE_WEIGHTS,
2423
L2_LIMIT_FOR_TILING,
@@ -102,7 +101,7 @@ class CompilationSettings:
102101
"""
103102

104103
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
105-
debug: bool = DEBUG
104+
debug: bool = logging.root.manager.root.level <= logging.DEBUG
106105
workspace_size: int = WORKSPACE_SIZE
107106
min_block_size: int = MIN_BLOCK_SIZE
108107
torch_executed_ops: Collection[Target] = field(default_factory=set)

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def __init__(
7575
self,
7676
module: torch.fx.GraphModule,
7777
input_specs: Sequence[Input],
78-
logger_level: trt.ILogger.Severity = trt.ILogger.Severity.WARNING,
7978
output_dtypes: Optional[Sequence[dtype]] = None,
8079
compilation_settings: CompilationSettings = CompilationSettings(),
8180
engine_cache: Optional[BaseEngineCache] = None,

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import logging
44
from typing import Any, List, Optional, Sequence
55

6-
import tensorrt as trt
76
import torch
87
from torch_tensorrt._enums import dtype
98
from torch_tensorrt._features import ENABLED_FEATURES
@@ -60,7 +59,6 @@ def interpret_module_to_result(
6059
interpreter = TRTInterpreter(
6160
module,
6261
inputs,
63-
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
6462
output_dtypes=output_dtypes,
6563
compilation_settings=settings,
6664
engine_cache=engine_cache,

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
)
1414
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS, NodeSet
1515
from torch_tensorrt.dynamo._defaults import (
16-
DEBUG,
1716
MIN_BLOCK_SIZE,
1817
REQUIRE_FULL_COMPILATION,
1918
)
2019
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
2120
DYNAMO_CONVERTERS as CONVERTERS,
2221
)
23-
from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterRegistry
22+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
23+
ConverterRegistry,
24+
)
2425

2526
logger = logging.getLogger(__name__)
2627

@@ -250,7 +251,6 @@ def starter_nodes(self) -> Tuple[NodeSet, NodeSet]:
250251

251252
def partition(
252253
gm: torch.fx.GraphModule,
253-
verbose: bool = DEBUG,
254254
min_block_size: int = MIN_BLOCK_SIZE,
255255
torch_executed_ops: Collection[Target] = set(),
256256
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
@@ -286,7 +286,6 @@ def partition(
286286

287287
partitioned_graph = partitioner.partition_graph()
288288

289-
if verbose:
290-
supported_ops.print_support_overview(partitioner.num_trt_accelerated_subgraphs)
289+
supported_ops.print_support_overview(partitioner.num_trt_accelerated_subgraphs)
291290

292291
return partitioned_graph, supported_ops

0 commit comments

Comments
 (0)