14
14
GRAPH_LEVEL = 5
15
15
logging .addLevelName (GRAPH_LEVEL , "GRAPHS" )
16
16
17
- # Debugger States
18
- DEBUG_FILE_DIR = tempfile .TemporaryDirectory ().name
19
- SAVE_ENGINE_PROFILE = False
20
-
21
17
22
18
class Debugger :
23
19
def __init__ (
24
20
self ,
25
- level : str ,
21
+ log_level : str ,
26
22
capture_fx_graph_before : Optional [List [str ]] = None ,
27
23
capture_fx_graph_after : Optional [List [str ]] = None ,
28
24
save_engine_profile : bool = False ,
29
25
logging_dir : Optional [str ] = None ,
30
26
):
31
-
32
- if level != "graphs" and (capture_fx_graph_after or save_engine_profile ):
27
+ self . debug_file_dir = tempfile . TemporaryDirectory (). name
28
+ if log_level != "graphs" and (capture_fx_graph_after or save_engine_profile ):
33
29
_LOGGER .warning (
34
30
"Capture FX Graph or Draw Engine Graph is only supported when level is 'graphs'"
35
31
)
36
32
37
- if level == "debug" :
38
- self .level = logging .DEBUG
39
- elif level == "info" :
40
- self .level = logging .INFO
41
- elif level == "warning" :
42
- self .level = logging .WARNING
43
- elif level == "error" :
44
- self .level = logging .ERROR
45
- elif level == "internal_errors" :
46
- self .level = logging .CRITICAL
47
- elif level == "graphs" :
48
- self .level = GRAPH_LEVEL
33
+ if log_level == "debug" :
34
+ self .log_level = logging .DEBUG
35
+ elif log_level == "info" :
36
+ self .log_level = logging .INFO
37
+ elif log_level == "warning" :
38
+ self .log_level = logging .WARNING
39
+ elif log_level == "error" :
40
+ self .log_level = logging .ERROR
41
+ elif log_level == "internal_errors" :
42
+ self .log_level = logging .CRITICAL
43
+ elif log_level == "graphs" :
44
+ self .log_level = GRAPH_LEVEL
49
45
50
46
else :
51
47
raise ValueError (
52
- f"Invalid level: { level } , allowed levels are: debug, info, warning, error, internal_errors, graphs"
48
+ f"Invalid level: { log_level } , allowed levels are: debug, info, warning, error, internal_errors, graphs"
53
49
)
54
50
55
51
self .capture_fx_graph_before = capture_fx_graph_before
56
52
self .capture_fx_graph_after = capture_fx_graph_after
57
- global SAVE_ENGINE_PROFILE
58
- SAVE_ENGINE_PROFILE = save_engine_profile
59
53
60
54
if logging_dir is not None :
61
- global DEBUG_FILE_DIR
62
- DEBUG_FILE_DIR = logging_dir
63
- os .makedirs (DEBUG_FILE_DIR , exist_ok = True )
55
+ self .debug_file_dir = logging_dir
56
+ os .makedirs (self .debug_file_dir , exist_ok = True )
64
57
65
58
def __enter__ (self ) -> None :
66
59
self .original_lvl = _LOGGER .getEffectiveLevel ()
67
60
self .rt_level = torch .ops .tensorrt .get_logging_level ()
68
61
dictConfig (self .get_config ())
69
62
70
- if self .level == GRAPH_LEVEL :
63
+ if self .log_level == GRAPH_LEVEL :
71
64
self .old_pre_passes , self .old_post_passes = (
72
65
ATEN_PRE_LOWERING_PASSES .passes ,
73
66
ATEN_POST_LOWERING_PASSES .passes ,
74
67
)
75
68
pre_pass_names = [p .__name__ for p in self .old_pre_passes ]
76
69
post_pass_names = [p .__name__ for p in self .old_post_passes ]
77
- path = os .path .join (DEBUG_FILE_DIR , "lowering_passes_visualization" )
70
+ path = os .path .join (self . debug_file_dir , "lowering_passes_visualization" )
78
71
if self .capture_fx_graph_before is not None :
79
72
pre_vis_passes = [
80
73
p for p in self .capture_fx_graph_before if p in pre_pass_names
@@ -100,11 +93,12 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
100
93
101
94
dictConfig (self .get_default_config ())
102
95
torch .ops .tensorrt .set_logging_level (self .rt_level )
103
- if self .level == GRAPH_LEVEL and self .capture_fx_graph_after :
96
+ if self .log_level == GRAPH_LEVEL and self .capture_fx_graph_after :
104
97
ATEN_PRE_LOWERING_PASSES .passes , ATEN_POST_LOWERING_PASSES .passes = (
105
98
self .old_pre_passes ,
106
99
self .old_post_passes ,
107
100
)
101
+ self .debug_file_dir = tempfile .TemporaryDirectory ().name
108
102
109
103
def get_config (self ) -> dict [str , Any ]:
110
104
config = {
@@ -122,21 +116,21 @@ def get_config(self) -> dict[str, Any]:
122
116
},
123
117
"handlers" : {
124
118
"file" : {
125
- "level" : self .level ,
119
+ "level" : self .log_level ,
126
120
"class" : "logging.FileHandler" ,
127
- "filename" : f"{ DEBUG_FILE_DIR } /torch_tensorrt_logging.log" ,
121
+ "filename" : f"{ self . debug_file_dir } /torch_tensorrt_logging.log" ,
128
122
"formatter" : "standard" ,
129
123
},
130
124
"console" : {
131
- "level" : self .level ,
125
+ "level" : self .log_level ,
132
126
"class" : "logging.StreamHandler" ,
133
127
"formatter" : "brief" ,
134
128
},
135
129
},
136
130
"loggers" : {
137
131
"" : { # root logger
138
132
"handlers" : ["file" , "console" ],
139
- "level" : self .level ,
133
+ "level" : self .log_level ,
140
134
"propagate" : True ,
141
135
},
142
136
},
0 commit comments