Skip to content

Commit ec390e5

Browse files
committed
Revert to debug flag
1 parent 1f2b35e commit ec390e5

File tree

3 files changed

+29
-60
lines changed

3 files changed

+29
-60
lines changed

py/torch_tensorrt/dynamo/_debugger.py renamed to py/torch_tensorrt/dynamo/Debugger.py

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,67 +14,60 @@
1414
GRAPH_LEVEL = 5
1515
logging.addLevelName(GRAPH_LEVEL, "GRAPHS")
1616

17-
# Debugger States
18-
DEBUG_FILE_DIR = tempfile.TemporaryDirectory().name
19-
SAVE_ENGINE_PROFILE = False
20-
2117

2218
class Debugger:
2319
def __init__(
2420
self,
25-
level: str,
21+
log_level: str,
2622
capture_fx_graph_before: Optional[List[str]] = None,
2723
capture_fx_graph_after: Optional[List[str]] = None,
2824
save_engine_profile: bool = False,
2925
logging_dir: Optional[str] = None,
3026
):
31-
32-
if level != "graphs" and (capture_fx_graph_after or save_engine_profile):
27+
self.debug_file_dir = tempfile.TemporaryDirectory().name
28+
if log_level != "graphs" and (capture_fx_graph_after or save_engine_profile):
3329
_LOGGER.warning(
3430
"Capture FX Graph or Draw Engine Graph is only supported when level is 'graphs'"
3531
)
3632

37-
if level == "debug":
38-
self.level = logging.DEBUG
39-
elif level == "info":
40-
self.level = logging.INFO
41-
elif level == "warning":
42-
self.level = logging.WARNING
43-
elif level == "error":
44-
self.level = logging.ERROR
45-
elif level == "internal_errors":
46-
self.level = logging.CRITICAL
47-
elif level == "graphs":
48-
self.level = GRAPH_LEVEL
33+
if log_level == "debug":
34+
self.log_level = logging.DEBUG
35+
elif log_level == "info":
36+
self.log_level = logging.INFO
37+
elif log_level == "warning":
38+
self.log_level = logging.WARNING
39+
elif log_level == "error":
40+
self.log_level = logging.ERROR
41+
elif log_level == "internal_errors":
42+
self.log_level = logging.CRITICAL
43+
elif log_level == "graphs":
44+
self.log_level = GRAPH_LEVEL
4945

5046
else:
5147
raise ValueError(
52-
f"Invalid level: {level}, allowed levels are: debug, info, warning, error, internal_errors, graphs"
48+
f"Invalid level: {log_level}, allowed levels are: debug, info, warning, error, internal_errors, graphs"
5349
)
5450

5551
self.capture_fx_graph_before = capture_fx_graph_before
5652
self.capture_fx_graph_after = capture_fx_graph_after
57-
global SAVE_ENGINE_PROFILE
58-
SAVE_ENGINE_PROFILE = save_engine_profile
5953

6054
if logging_dir is not None:
61-
global DEBUG_FILE_DIR
62-
DEBUG_FILE_DIR = logging_dir
63-
os.makedirs(DEBUG_FILE_DIR, exist_ok=True)
55+
self.debug_file_dir = logging_dir
56+
os.makedirs(self.debug_file_dir, exist_ok=True)
6457

6558
def __enter__(self) -> None:
6659
self.original_lvl = _LOGGER.getEffectiveLevel()
6760
self.rt_level = torch.ops.tensorrt.get_logging_level()
6861
dictConfig(self.get_config())
6962

70-
if self.level == GRAPH_LEVEL:
63+
if self.log_level == GRAPH_LEVEL:
7164
self.old_pre_passes, self.old_post_passes = (
7265
ATEN_PRE_LOWERING_PASSES.passes,
7366
ATEN_POST_LOWERING_PASSES.passes,
7467
)
7568
pre_pass_names = [p.__name__ for p in self.old_pre_passes]
7669
post_pass_names = [p.__name__ for p in self.old_post_passes]
77-
path = os.path.join(DEBUG_FILE_DIR, "lowering_passes_visualization")
70+
path = os.path.join(self.debug_file_dir, "lowering_passes_visualization")
7871
if self.capture_fx_graph_before is not None:
7972
pre_vis_passes = [
8073
p for p in self.capture_fx_graph_before if p in pre_pass_names
@@ -100,11 +93,12 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
10093

10194
dictConfig(self.get_default_config())
10295
torch.ops.tensorrt.set_logging_level(self.rt_level)
103-
if self.level == GRAPH_LEVEL and self.capture_fx_graph_after:
96+
if self.log_level == GRAPH_LEVEL and self.capture_fx_graph_after:
10497
ATEN_PRE_LOWERING_PASSES.passes, ATEN_POST_LOWERING_PASSES.passes = (
10598
self.old_pre_passes,
10699
self.old_post_passes,
107100
)
101+
self.debug_file_dir = tempfile.TemporaryDirectory().name
108102

109103
def get_config(self) -> dict[str, Any]:
110104
config = {
@@ -122,21 +116,21 @@ def get_config(self) -> dict[str, Any]:
122116
},
123117
"handlers": {
124118
"file": {
125-
"level": self.level,
119+
"level": self.log_level,
126120
"class": "logging.FileHandler",
127-
"filename": f"{DEBUG_FILE_DIR}/torch_tensorrt_logging.log",
121+
"filename": f"{self.debug_file_dir}/torch_tensorrt_logging.log",
128122
"formatter": "standard",
129123
},
130124
"console": {
131-
"level": self.level,
125+
"level": self.log_level,
132126
"class": "logging.StreamHandler",
133127
"formatter": "brief",
134128
},
135129
},
136130
"loggers": {
137131
"": { # root logger
138132
"handlers": ["file", "console"],
139-
"level": self.level,
133+
"level": self.log_level,
140134
"propagate": True,
141135
},
142136
},

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import collections.abc
44
import logging
5-
import os
65
import platform
76
import warnings
87
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
@@ -519,14 +518,6 @@ def compile(
519518
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
520519
"""
521520

522-
if debug:
523-
warnings.warn(
524-
"The 'debug' argument is deprecated and will be removed in a future release. "
525-
"Please use the torch_tensorrt.dynamo.Debugger context manager for debugging and graph capture.",
526-
DeprecationWarning,
527-
stacklevel=2,
528-
)
529-
530521
if "truncate_long_and_double" in kwargs.keys():
531522
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
532523
raise ValueError(
@@ -648,6 +639,7 @@ def compile(
648639
"enabled_precisions": (
649640
enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS
650641
),
642+
"debug": debug,
651643
"device": device,
652644
"assume_dynamic_shape_support": assume_dynamic_shape_support,
653645
"workspace_size": workspace_size,
@@ -928,23 +920,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
928920
)
929921

930922
trt_modules[name] = trt_module
931-
from torch_tensorrt.dynamo._debugger import (
932-
DEBUG_FILE_DIR,
933-
SAVE_ENGINE_PROFILE,
934-
)
935-
936-
if SAVE_ENGINE_PROFILE:
937-
if settings.use_python_runtime:
938-
logger.warning(
939-
"Profiling can only be enabled when using the C++ runtime"
940-
)
941-
else:
942-
path = os.path.join(DEBUG_FILE_DIR, "engine_visualization")
943-
os.makedirs(path, exist_ok=True)
944-
trt_module.enable_profiling(
945-
profiling_results_dir=path,
946-
profile_format="trex",
947-
)
948923

949924
# Parse the graph I/O and store it in dryrun tracker
950925
parse_graph_io(gm, dryrun_tracker)

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import logging
21
from dataclasses import dataclass, field
32
from typing import Collection, Optional, Set, Tuple, Union
43

@@ -8,6 +7,7 @@
87
from torch_tensorrt.dynamo._defaults import (
98
ASSUME_DYNAMIC_SHAPE_SUPPORT,
109
CACHE_BUILT_ENGINES,
10+
DEBUG,
1111
DISABLE_TF32,
1212
DLA_GLOBAL_DRAM_SIZE,
1313
DLA_LOCAL_DRAM_SIZE,
@@ -101,7 +101,7 @@ class CompilationSettings:
101101
"""
102102

103103
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
104-
debug: bool = logging.root.manager.root.level <= logging.DEBUG
104+
debug: bool = DEBUG
105105
workspace_size: int = WORKSPACE_SIZE
106106
min_block_size: int = MIN_BLOCK_SIZE
107107
torch_executed_ops: Collection[Target] = field(default_factory=set)

0 commit comments

Comments
 (0)