Skip to content

Commit 5683644

Browse files
committed
Changed the debug setting (#3551)
1 parent d6aa8a4 commit 5683644

File tree

11 files changed

+183
-36
lines changed

11 files changed

+183
-36
lines changed

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
load_cross_compiled_exported_program,
1515
save_cross_compiled_exported_program,
1616
)
17-
from ._Debugger import Debugger
1817
from ._exporter import export
1918
from ._refit import refit_module_weights
2019
from ._settings import CompilationSettings
2120
from ._SourceIR import SourceIR
2221
from ._tracer import trace
22+
from .debug._Debugger import Debugger

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import collections.abc
44
import logging
5+
import os
56
import platform
67
import warnings
78
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
@@ -31,6 +32,8 @@
3132
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
3233
DYNAMO_CONVERTERS as CONVERTERS,
3334
)
35+
from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig
36+
from torch_tensorrt.dynamo.debug._supports_debugger import fn_supports_debugger
3437
from torch_tensorrt.dynamo.lowering import (
3538
get_decompositions,
3639
post_lowering,
@@ -42,7 +45,6 @@
4245
get_output_metadata,
4346
parse_graph_io,
4447
prepare_inputs,
45-
set_log_level,
4648
to_torch_device,
4749
to_torch_tensorrt_device,
4850
)
@@ -64,7 +66,7 @@ def cross_compile_for_windows(
6466
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
6567
] = _defaults.ENABLED_PRECISIONS,
6668
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
67-
debug: bool = _defaults.DEBUG,
69+
debug: bool = False,
6870
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
6971
workspace_size: int = _defaults.WORKSPACE_SIZE,
7072
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
@@ -186,7 +188,11 @@ def cross_compile_for_windows(
186188
)
187189

188190
if debug:
189-
set_log_level(logger.parent, logging.DEBUG)
191+
warnings.warn(
192+
"`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` to configure debugging options.",
193+
DeprecationWarning,
194+
stacklevel=2,
195+
)
190196

191197
if "truncate_long_and_double" in kwargs.keys():
192198
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
@@ -297,7 +303,6 @@ def cross_compile_for_windows(
297303
"enabled_precisions": (
298304
enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS
299305
),
300-
"debug": debug,
301306
"device": device,
302307
"assume_dynamic_shape_support": assume_dynamic_shape_support,
303308
"workspace_size": workspace_size,
@@ -399,7 +404,7 @@ def compile(
399404
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
400405
] = _defaults.ENABLED_PRECISIONS,
401406
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
402-
debug: bool = _defaults.DEBUG,
407+
debug: bool = False,
403408
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
404409
workspace_size: int = _defaults.WORKSPACE_SIZE,
405410
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
@@ -518,6 +523,13 @@ def compile(
518523
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
519524
"""
520525

526+
if debug:
527+
warnings.warn(
528+
"`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` for debugging functionality",
529+
DeprecationWarning,
530+
stacklevel=2,
531+
)
532+
521533
if "truncate_long_and_double" in kwargs.keys():
522534
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
523535
raise ValueError(
@@ -639,7 +651,6 @@ def compile(
639651
"enabled_precisions": (
640652
enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS
641653
),
642-
"debug": debug,
643654
"device": device,
644655
"assume_dynamic_shape_support": assume_dynamic_shape_support,
645656
"workspace_size": workspace_size,
@@ -713,12 +724,15 @@ def compile(
713724
return trt_gm
714725

715726

727+
@fn_supports_debugger
716728
def compile_module(
717729
gm: torch.fx.GraphModule,
718730
sample_arg_inputs: Sequence[Input],
719731
sample_kwarg_inputs: Optional[dict[Any, Any]] = None,
720732
settings: CompilationSettings = CompilationSettings(),
721733
engine_cache: Optional[BaseEngineCache] = None,
734+
*,
735+
_debugger_settings: Optional[DebuggerConfig] = None,
722736
) -> torch.fx.GraphModule:
723737
"""Compile a traced FX module
724738
@@ -921,6 +935,34 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
921935

922936
trt_modules[name] = trt_module
923937

938+
if _debugger_settings:
939+
940+
if _debugger_settings.save_engine_profile:
941+
if settings.use_python_runtime:
942+
if _debugger_settings.profile_format == "trex":
943+
logger.warning(
944+
"Profiling with TREX can only be enabled when using the C++ runtime. Python runtime profiling only support cudagraph visualization."
945+
)
946+
trt_module.enable_profiling()
947+
else:
948+
path = os.path.join(
949+
_debugger_settings.logging_dir, "engine_visualization"
950+
)
951+
os.makedirs(path, exist_ok=True)
952+
trt_module.enable_profiling(
953+
profiling_results_dir=path,
954+
profile_format=_debugger_settings.profile_format,
955+
)
956+
957+
if _debugger_settings.save_layer_info:
958+
with open(
959+
os.path.join(
960+
_debugger_settings.logging_dir, "engine_layer_info.json"
961+
),
962+
"w",
963+
) as f:
964+
f.write(trt_module.get_layer_info())
965+
924966
# Parse the graph I/O and store it in dryrun tracker
925967
parse_graph_io(gm, dryrun_tracker)
926968

@@ -948,7 +990,7 @@ def convert_exported_program_to_serialized_trt_engine(
948990
enabled_precisions: (
949991
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
950992
) = _defaults.ENABLED_PRECISIONS,
951-
debug: bool = _defaults.DEBUG,
993+
debug: bool = False,
952994
assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT,
953995
workspace_size: int = _defaults.WORKSPACE_SIZE,
954996
min_block_size: int = _defaults.MIN_BLOCK_SIZE,
@@ -1051,7 +1093,11 @@ def convert_exported_program_to_serialized_trt_engine(
10511093
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
10521094
"""
10531095
if debug:
1054-
set_log_level(logger.parent, logging.DEBUG)
1096+
warnings.warn(
1097+
"`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` to configure debugging options.",
1098+
DeprecationWarning,
1099+
stacklevel=2,
1100+
)
10551101

10561102
if "truncate_long_and_double" in kwargs.keys():
10571103
if truncate_double is not _defaults.TRUNCATE_DOUBLE:

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from torch_tensorrt._enums import EngineCapability, dtype
77

88
ENABLED_PRECISIONS = {dtype.f32}
9-
DEBUG = False
109
DEVICE = None
1110
DISABLE_TF32 = False
1211
ASSUME_DYNAMIC_SHAPE_SUPPORT = False

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from torch_tensorrt.dynamo._defaults import (
88
ASSUME_DYNAMIC_SHAPE_SUPPORT,
99
CACHE_BUILT_ENGINES,
10-
DEBUG,
1110
DISABLE_TF32,
1211
DLA_GLOBAL_DRAM_SIZE,
1312
DLA_LOCAL_DRAM_SIZE,
@@ -101,7 +100,6 @@ class CompilationSettings:
101100
"""
102101

103102
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
104-
debug: bool = DEBUG
105103
workspace_size: int = WORKSPACE_SIZE
106104
min_block_size: int = MIN_BLOCK_SIZE
107105
torch_executed_ops: Collection[Target] = field(default_factory=set)

py/torch_tensorrt/dynamo/_tracer.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import torch
88
from torch.export import Dim, export
99
from torch_tensorrt._Input import Input
10-
from torch_tensorrt.dynamo._defaults import DEBUG, default_device
11-
from torch_tensorrt.dynamo.utils import get_torch_inputs, set_log_level, to_torch_device
10+
from torch_tensorrt.dynamo._defaults import default_device
11+
from torch_tensorrt.dynamo.utils import get_torch_inputs, to_torch_device
1212

1313
logger = logging.getLogger(__name__)
1414

@@ -70,10 +70,6 @@ def trace(
7070
if kwarg_inputs is None:
7171
kwarg_inputs = {}
7272

73-
debug = kwargs.get("debug", DEBUG)
74-
if debug:
75-
set_log_level(logger.parent, logging.DEBUG)
76-
7773
device = to_torch_device(kwargs.get("device", default_device()))
7874
torch_arg_inputs = get_torch_inputs(arg_inputs, device)
7975
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
to_torch,
4747
)
4848
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device
49+
from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig
50+
from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger
4951
from torch_tensorrt.fx.observer import Observer
5052
from torch_tensorrt.logging import TRT_LOGGER
5153

@@ -70,6 +72,7 @@ class TRTInterpreterResult(NamedTuple):
7072
requires_output_allocator: bool
7173

7274

75+
@cls_supports_debugger
7376
class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc]
7477
def __init__(
7578
self,
@@ -78,12 +81,14 @@ def __init__(
7881
output_dtypes: Optional[Sequence[dtype]] = None,
7982
compilation_settings: CompilationSettings = CompilationSettings(),
8083
engine_cache: Optional[BaseEngineCache] = None,
84+
*,
85+
_debugger_settings: Optional[DebuggerConfig] = None,
8186
):
8287
super().__init__(module)
8388

8489
self.logger = TRT_LOGGER
8590
self.builder = trt.Builder(self.logger)
86-
91+
self._debugger_settings = _debugger_settings
8792
flag = 0
8893
if compilation_settings.use_explicit_typing:
8994
STRONGLY_TYPED = 1 << (int)(
@@ -204,7 +209,7 @@ def _populate_trt_builder_config(
204209
) -> trt.IBuilderConfig:
205210
builder_config = self.builder.create_builder_config()
206211

207-
if self.compilation_settings.debug:
212+
if self._debugger_settings and self._debugger_settings.engine_builder_monitor:
208213
builder_config.progress_monitor = TRTBulderMonitor()
209214

210215
if self.compilation_settings.workspace_size != 0:
@@ -215,7 +220,8 @@ def _populate_trt_builder_config(
215220
if version.parse(trt.__version__) >= version.parse("8.2"):
216221
builder_config.profiling_verbosity = (
217222
trt.ProfilingVerbosity.DETAILED
218-
if self.compilation_settings.debug
223+
if self._debugger_settings
224+
and self._debugger_settings.save_engine_profile
219225
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
220226
)
221227

0 commit comments

Comments
 (0)