Skip to content

Commit 2fff7ad

Browse files
committed
Revert to debug flag
1 parent 74bb32d commit 2fff7ad

File tree

3 files changed

+29
-60
lines changed

3 files changed

+29
-60
lines changed

py/torch_tensorrt/dynamo/_debugger.py renamed to py/torch_tensorrt/dynamo/Debugger.py

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,67 +14,60 @@
1414
GRAPH_LEVEL = 5
1515
logging.addLevelName(GRAPH_LEVEL, "GRAPHS")
1616

17-
# Debugger States
18-
DEBUG_FILE_DIR = tempfile.TemporaryDirectory().name
19-
SAVE_ENGINE_PROFILE = False
20-
2117

2218
class Debugger:
2319
def __init__(
2420
self,
25-
level: str,
21+
log_level: str,
2622
capture_fx_graph_before: Optional[List[str]] = None,
2723
capture_fx_graph_after: Optional[List[str]] = None,
2824
save_engine_profile: bool = False,
2925
logging_dir: Optional[str] = None,
3026
):
31-
32-
if level != "graphs" and (capture_fx_graph_after or save_engine_profile):
27+
self.debug_file_dir = tempfile.TemporaryDirectory().name
28+
if log_level != "graphs" and (capture_fx_graph_after or save_engine_profile):
3329
_LOGGER.warning(
3430
"Capture FX Graph or Draw Engine Graph is only supported when level is 'graphs'"
3531
)
3632

37-
if level == "debug":
38-
self.level = logging.DEBUG
39-
elif level == "info":
40-
self.level = logging.INFO
41-
elif level == "warning":
42-
self.level = logging.WARNING
43-
elif level == "error":
44-
self.level = logging.ERROR
45-
elif level == "internal_errors":
46-
self.level = logging.CRITICAL
47-
elif level == "graphs":
48-
self.level = GRAPH_LEVEL
33+
if log_level == "debug":
34+
self.log_level = logging.DEBUG
35+
elif log_level == "info":
36+
self.log_level = logging.INFO
37+
elif log_level == "warning":
38+
self.log_level = logging.WARNING
39+
elif log_level == "error":
40+
self.log_level = logging.ERROR
41+
elif log_level == "internal_errors":
42+
self.log_level = logging.CRITICAL
43+
elif log_level == "graphs":
44+
self.log_level = GRAPH_LEVEL
4945

5046
else:
5147
raise ValueError(
52-
f"Invalid level: {level}, allowed levels are: debug, info, warning, error, internal_errors, graphs"
48+
f"Invalid level: {log_level}, allowed levels are: debug, info, warning, error, internal_errors, graphs"
5349
)
5450

5551
self.capture_fx_graph_before = capture_fx_graph_before
5652
self.capture_fx_graph_after = capture_fx_graph_after
57-
global SAVE_ENGINE_PROFILE
58-
SAVE_ENGINE_PROFILE = save_engine_profile
5953

6054
if logging_dir is not None:
61-
global DEBUG_FILE_DIR
62-
DEBUG_FILE_DIR = logging_dir
63-
os.makedirs(DEBUG_FILE_DIR, exist_ok=True)
55+
self.debug_file_dir = logging_dir
56+
os.makedirs(self.debug_file_dir, exist_ok=True)
6457

6558
def __enter__(self) -> None:
6659
self.original_lvl = _LOGGER.getEffectiveLevel()
6760
self.rt_level = torch.ops.tensorrt.get_logging_level()
6861
dictConfig(self.get_config())
6962

70-
if self.level == GRAPH_LEVEL:
63+
if self.log_level == GRAPH_LEVEL:
7164
self.old_pre_passes, self.old_post_passes = (
7265
ATEN_PRE_LOWERING_PASSES.passes,
7366
ATEN_POST_LOWERING_PASSES.passes,
7467
)
7568
pre_pass_names = [p.__name__ for p in self.old_pre_passes]
7669
post_pass_names = [p.__name__ for p in self.old_post_passes]
77-
path = os.path.join(DEBUG_FILE_DIR, "lowering_passes_visualization")
70+
path = os.path.join(self.debug_file_dir, "lowering_passes_visualization")
7871
if self.capture_fx_graph_before is not None:
7972
pre_vis_passes = [
8073
p for p in self.capture_fx_graph_before if p in pre_pass_names
@@ -100,11 +93,12 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
10093

10194
dictConfig(self.get_default_config())
10295
torch.ops.tensorrt.set_logging_level(self.rt_level)
103-
if self.level == GRAPH_LEVEL and self.capture_fx_graph_after:
96+
if self.log_level == GRAPH_LEVEL and self.capture_fx_graph_after:
10497
ATEN_PRE_LOWERING_PASSES.passes, ATEN_POST_LOWERING_PASSES.passes = (
10598
self.old_pre_passes,
10699
self.old_post_passes,
107100
)
101+
self.debug_file_dir = tempfile.TemporaryDirectory().name
108102

109103
def get_config(self) -> dict[str, Any]:
110104
config = {
@@ -122,21 +116,21 @@ def get_config(self) -> dict[str, Any]:
122116
},
123117
"handlers": {
124118
"file": {
125-
"level": self.level,
119+
"level": self.log_level,
126120
"class": "logging.FileHandler",
127-
"filename": f"{DEBUG_FILE_DIR}/torch_tensorrt_logging.log",
121+
"filename": f"{self.debug_file_dir}/torch_tensorrt_logging.log",
128122
"formatter": "standard",
129123
},
130124
"console": {
131-
"level": self.level,
125+
"level": self.log_level,
132126
"class": "logging.StreamHandler",
133127
"formatter": "brief",
134128
},
135129
},
136130
"loggers": {
137131
"": { # root logger
138132
"handlers": ["file", "console"],
139-
"level": self.level,
133+
"level": self.log_level,
140134
"propagate": True,
141135
},
142136
},

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import collections.abc
44
import logging
5-
import os
65
import platform
76
import warnings
87
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
@@ -504,14 +503,6 @@ def compile(
504503
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
505504
"""
506505

507-
if debug:
508-
warnings.warn(
509-
"The 'debug' argument is deprecated and will be removed in a future release. "
510-
"Please use the torch_tensorrt.dynamo.Debugger context manager for debugging and graph capture.",
511-
DeprecationWarning,
512-
stacklevel=2,
513-
)
514-
515506
if "truncate_long_and_double" in kwargs.keys():
516507
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
517508
raise ValueError(
@@ -642,6 +633,7 @@ def compile(
642633
"enabled_precisions": (
643634
enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS
644635
),
636+
"debug": debug,
645637
"device": device,
646638
"assume_dynamic_shape_support": assume_dynamic_shape_support,
647639
"workspace_size": workspace_size,
@@ -907,23 +899,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
907899
)
908900

909901
trt_modules[name] = trt_module
910-
from torch_tensorrt.dynamo._debugger import (
911-
DEBUG_FILE_DIR,
912-
SAVE_ENGINE_PROFILE,
913-
)
914-
915-
if SAVE_ENGINE_PROFILE:
916-
if settings.use_python_runtime:
917-
logger.warning(
918-
"Profiling can only be enabled when using the C++ runtime"
919-
)
920-
else:
921-
path = os.path.join(DEBUG_FILE_DIR, "engine_visualization")
922-
os.makedirs(path, exist_ok=True)
923-
trt_module.enable_profiling(
924-
profiling_results_dir=path,
925-
profile_format="trex",
926-
)
927902

928903
# Parse the graph I/O and store it in dryrun tracker
929904
parse_graph_io(gm, dryrun_tracker)

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import logging
21
from dataclasses import dataclass, field
32
from typing import Collection, Optional, Set, Tuple, Union
43

@@ -8,6 +7,7 @@
87
from torch_tensorrt.dynamo._defaults import (
98
ASSUME_DYNAMIC_SHAPE_SUPPORT,
109
CACHE_BUILT_ENGINES,
10+
DEBUG,
1111
DISABLE_TF32,
1212
DLA_GLOBAL_DRAM_SIZE,
1313
DLA_LOCAL_DRAM_SIZE,
@@ -100,7 +100,7 @@ class CompilationSettings:
100100
"""
101101

102102
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
103-
debug: bool = logging.root.manager.root.level <= logging.DEBUG
103+
debug: bool = DEBUG
104104
workspace_size: int = WORKSPACE_SIZE
105105
min_block_size: int = MIN_BLOCK_SIZE
106106
torch_executed_ops: Collection[Target] = field(default_factory=set)

0 commit comments

Comments
 (0)