2
2
3
3
import collections .abc
4
4
import logging
5
+ import os
5
6
import platform
6
7
import warnings
7
8
from typing import Any , Collection , List , Optional , Sequence , Set , Tuple , Union
31
32
from torch_tensorrt .dynamo .conversion ._ConverterRegistry import (
32
33
DYNAMO_CONVERTERS as CONVERTERS ,
33
34
)
35
+ from torch_tensorrt .dynamo .debug ._DebuggerConfig import DebuggerConfig
36
+ from torch_tensorrt .dynamo .debug ._supports_debugger import fn_supports_debugger
34
37
from torch_tensorrt .dynamo .lowering import (
35
38
get_decompositions ,
36
39
post_lowering ,
42
45
get_output_metadata ,
43
46
parse_graph_io ,
44
47
prepare_inputs ,
45
- set_log_level ,
46
48
to_torch_device ,
47
49
to_torch_tensorrt_device ,
48
50
)
@@ -64,7 +66,7 @@ def cross_compile_for_windows(
64
66
Set [Union [torch .dtype , dtype ]], Tuple [Union [torch .dtype , dtype ]]
65
67
] = _defaults .ENABLED_PRECISIONS ,
66
68
engine_capability : EngineCapability = _defaults .ENGINE_CAPABILITY ,
67
- debug : bool = _defaults . DEBUG ,
69
+ debug : bool = False ,
68
70
num_avg_timing_iters : int = _defaults .NUM_AVG_TIMING_ITERS ,
69
71
workspace_size : int = _defaults .WORKSPACE_SIZE ,
70
72
dla_sram_size : int = _defaults .DLA_SRAM_SIZE ,
@@ -186,7 +188,11 @@ def cross_compile_for_windows(
186
188
)
187
189
188
190
if debug :
189
- set_log_level (logger .parent , logging .DEBUG )
191
+ warnings .warn (
192
+ "`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` to configure debugging options." ,
193
+ DeprecationWarning ,
194
+ stacklevel = 2 ,
195
+ )
190
196
191
197
if "truncate_long_and_double" in kwargs .keys ():
192
198
if truncate_double is not _defaults .TRUNCATE_DOUBLE :
@@ -297,7 +303,6 @@ def cross_compile_for_windows(
297
303
"enabled_precisions" : (
298
304
enabled_precisions if enabled_precisions else _defaults .ENABLED_PRECISIONS
299
305
),
300
- "debug" : debug ,
301
306
"device" : device ,
302
307
"assume_dynamic_shape_support" : assume_dynamic_shape_support ,
303
308
"workspace_size" : workspace_size ,
@@ -399,7 +404,7 @@ def compile(
399
404
Set [Union [torch .dtype , dtype ]], Tuple [Union [torch .dtype , dtype ]]
400
405
] = _defaults .ENABLED_PRECISIONS ,
401
406
engine_capability : EngineCapability = _defaults .ENGINE_CAPABILITY ,
402
- debug : bool = _defaults . DEBUG ,
407
+ debug : bool = False ,
403
408
num_avg_timing_iters : int = _defaults .NUM_AVG_TIMING_ITERS ,
404
409
workspace_size : int = _defaults .WORKSPACE_SIZE ,
405
410
dla_sram_size : int = _defaults .DLA_SRAM_SIZE ,
@@ -518,6 +523,13 @@ def compile(
518
523
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
519
524
"""
520
525
526
+ if debug :
527
+ warnings .warn (
528
+ "`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` for debugging functionality" ,
529
+ DeprecationWarning ,
530
+ stacklevel = 2 ,
531
+ )
532
+
521
533
if "truncate_long_and_double" in kwargs .keys ():
522
534
if truncate_double is not _defaults .TRUNCATE_DOUBLE :
523
535
raise ValueError (
@@ -639,7 +651,6 @@ def compile(
639
651
"enabled_precisions" : (
640
652
enabled_precisions if enabled_precisions else _defaults .ENABLED_PRECISIONS
641
653
),
642
- "debug" : debug ,
643
654
"device" : device ,
644
655
"assume_dynamic_shape_support" : assume_dynamic_shape_support ,
645
656
"workspace_size" : workspace_size ,
@@ -713,12 +724,15 @@ def compile(
713
724
return trt_gm
714
725
715
726
727
+ @fn_supports_debugger
716
728
def compile_module (
717
729
gm : torch .fx .GraphModule ,
718
730
sample_arg_inputs : Sequence [Input ],
719
731
sample_kwarg_inputs : Optional [dict [Any , Any ]] = None ,
720
732
settings : CompilationSettings = CompilationSettings (),
721
733
engine_cache : Optional [BaseEngineCache ] = None ,
734
+ * ,
735
+ _debugger_settings : Optional [DebuggerConfig ] = None ,
722
736
) -> torch .fx .GraphModule :
723
737
"""Compile a traced FX module
724
738
@@ -921,6 +935,34 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
921
935
922
936
trt_modules [name ] = trt_module
923
937
938
+ if _debugger_settings :
939
+
940
+ if _debugger_settings .save_engine_profile :
941
+ if settings .use_python_runtime :
942
+ if _debugger_settings .profile_format == "trex" :
943
+ logger .warning (
944
+ "Profiling with TREX can only be enabled when using the C++ runtime. Python runtime profiling only support cudagraph visualization."
945
+ )
946
+ trt_module .enable_profiling ()
947
+ else :
948
+ path = os .path .join (
949
+ _debugger_settings .logging_dir , "engine_visualization"
950
+ )
951
+ os .makedirs (path , exist_ok = True )
952
+ trt_module .enable_profiling (
953
+ profiling_results_dir = path ,
954
+ profile_format = _debugger_settings .profile_format ,
955
+ )
956
+
957
+ if _debugger_settings .save_layer_info :
958
+ with open (
959
+ os .path .join (
960
+ _debugger_settings .logging_dir , "engine_layer_info.json"
961
+ ),
962
+ "w" ,
963
+ ) as f :
964
+ f .write (trt_module .get_layer_info ())
965
+
924
966
# Parse the graph I/O and store it in dryrun tracker
925
967
parse_graph_io (gm , dryrun_tracker )
926
968
@@ -948,7 +990,7 @@ def convert_exported_program_to_serialized_trt_engine(
948
990
enabled_precisions : (
949
991
Set [torch .dtype | dtype ] | Tuple [torch .dtype | dtype ]
950
992
) = _defaults .ENABLED_PRECISIONS ,
951
- debug : bool = _defaults . DEBUG ,
993
+ debug : bool = False ,
952
994
assume_dynamic_shape_support : bool = _defaults .ASSUME_DYNAMIC_SHAPE_SUPPORT ,
953
995
workspace_size : int = _defaults .WORKSPACE_SIZE ,
954
996
min_block_size : int = _defaults .MIN_BLOCK_SIZE ,
@@ -1051,7 +1093,11 @@ def convert_exported_program_to_serialized_trt_engine(
1051
1093
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
1052
1094
"""
1053
1095
if debug :
1054
- set_log_level (logger .parent , logging .DEBUG )
1096
+ warnings .warn (
1097
+ "`debug` is deprecated. Please use `torch_tensorrt.dynamo.Debugger` to configure debugging options." ,
1098
+ DeprecationWarning ,
1099
+ stacklevel = 2 ,
1100
+ )
1055
1101
1056
1102
if "truncate_long_and_double" in kwargs .keys ():
1057
1103
if truncate_double is not _defaults .TRUNCATE_DOUBLE :
0 commit comments