12
12
13
13
import torch
14
14
from torch .testing ._internal .common_quantization import TestHelperModules
15
- from torch .testing ._internal .common_utils import IS_WINDOWS , TestCase , run_tests
15
+ from torch .testing ._internal .common_utils import IS_WINDOWS , run_tests
16
16
17
17
from torchao .quantization .pt2e import (
18
- CUSTOM_KEY ,
19
- NUMERIC_DEBUG_HANDLE_KEY ,
20
- compare_results ,
21
- extract_results_from_loggers ,
22
18
generate_numeric_debug_handle ,
23
19
prepare_for_propagation_comparison ,
24
20
)
25
- from torchao .quantization .pt2e .graph_utils import bfs_trace_with_node_process
26
- from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
27
- from torchao .testing .pt2e ._xnnpack_quantizer import (
28
- XNNPACKQuantizer ,
29
- get_symmetric_quantization_config ,
30
- )
21
+ from torchao .testing .pt2e .utils import PT2ENumericDebuggerTestCase
31
22
from torchao .utils import TORCH_VERSION_AT_LEAST_2_7
32
23
33
24
if TORCH_VERSION_AT_LEAST_2_7 :
36
27
37
28
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_7 , "Requires torch 2.7+" )
38
29
@unittest .skipIf (IS_WINDOWS , "Windows not yet supported for torch.compile" )
39
- class TestNumericDebugger (TestCase ):
40
- def _assert_each_node_has_debug_handle (self , model ) -> None :
41
- def _assert_node_has_debug_handle (node ):
42
- self .assertTrue (
43
- CUSTOM_KEY in node .meta
44
- and NUMERIC_DEBUG_HANDLE_KEY in node .meta [CUSTOM_KEY ],
45
- f"Node { node } doesn't have debug handle" ,
46
- )
47
-
48
- bfs_trace_with_node_process (model , _assert_node_has_debug_handle )
49
-
50
- def _extract_debug_handles (self , model ) -> dict [str , int ]:
51
- debug_handle_map : dict [str , int ] = {}
52
-
53
- def _extract_debug_handles_from_node (node ):
54
- nonlocal debug_handle_map
55
- if (
56
- CUSTOM_KEY in node .meta
57
- and NUMERIC_DEBUG_HANDLE_KEY in node .meta [CUSTOM_KEY ]
58
- ):
59
- debug_handle_map [str (node )] = node .meta [CUSTOM_KEY ][
60
- NUMERIC_DEBUG_HANDLE_KEY
61
- ]
62
-
63
- bfs_trace_with_node_process (model , _extract_debug_handles_from_node )
64
-
65
- return debug_handle_map
66
-
67
- def _extract_debug_handles_with_prev_decomp_op (self , model ) -> dict [str , int ]:
68
- prev_decomp_op_to_debug_handle_map : dict [str , int ] = {}
69
-
70
- def _extract_debug_handles_with_prev_decomp_op_from_node (node ):
71
- nonlocal prev_decomp_op_to_debug_handle_map
72
- if (
73
- CUSTOM_KEY in node .meta
74
- and NUMERIC_DEBUG_HANDLE_KEY in node .meta [CUSTOM_KEY ]
75
- ):
76
- prev_decomp_op = str (node .meta .get ("nn_module_stack" ))
77
- debug_handle = node .meta [CUSTOM_KEY ][NUMERIC_DEBUG_HANDLE_KEY ]
78
- if prev_decomp_op not in prev_decomp_op_to_debug_handle_map :
79
- prev_decomp_op_to_debug_handle_map [prev_decomp_op ] = debug_handle
80
- else :
81
- assert (
82
- prev_decomp_op_to_debug_handle_map [prev_decomp_op ]
83
- == debug_handle
84
- ), f"Node { node } has different debug handle { debug_handle } "
85
- "than previous node sharing the same decomp op {prev_decomp_op}"
86
-
87
- bfs_trace_with_node_process (
88
- model , _extract_debug_handles_with_prev_decomp_op_from_node
89
- )
90
- return prev_decomp_op_to_debug_handle_map
91
-
30
+ class TestNumericDebuggerInfra (PT2ENumericDebuggerTestCase ):
92
31
@unittest .skip (
93
32
"torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recom..."
94
33
)
@@ -113,36 +52,6 @@ def test_control_flow(self):
113
52
114
53
self .assertEqual (len (set (debug_handle_map .values ())), len (debug_handle_map ))
115
54
116
- def test_quantize_pt2e_preserve_handle (self ):
117
- m = TestHelperModules .Conv2dThenConv1d ()
118
- example_inputs = m .example_inputs ()
119
- ep = export_for_training (m , example_inputs , strict = True )
120
- generate_numeric_debug_handle (ep )
121
- m = ep .module ()
122
-
123
- quantizer = XNNPACKQuantizer ().set_global (
124
- get_symmetric_quantization_config (is_per_channel = False )
125
- )
126
- m = prepare_pt2e (m , quantizer )
127
- debug_handle_map = self ._extract_debug_handles (m )
128
- res_counter = Counter (debug_handle_map .values ())
129
- repeated_debug_handle_ids = [1 , 2 , 3 ]
130
- # 3 ids were repeated because we copy over the id from node to its output observer
131
- # torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default
132
- for dh_id in repeated_debug_handle_ids :
133
- self .assertEqual (res_counter [dh_id ], 2 )
134
-
135
- m (* example_inputs )
136
- m = convert_pt2e (m )
137
- self ._assert_each_node_has_debug_handle (ep )
138
- debug_handle_map = self ._extract_debug_handles (m )
139
- res_counter = Counter (debug_handle_map .values ())
140
- # same set of ids where repeated, because we copy over the id from observer/fake_quant to
141
- # dequantize node
142
- repeated_debug_handle_ids = [1 , 2 , 3 ]
143
- for dh_id in repeated_debug_handle_ids :
144
- self .assertEqual (res_counter [dh_id ], 2 )
145
-
146
55
def test_copy_preserve_handle (self ):
147
56
m = TestHelperModules .Conv2dThenConv1d ()
148
57
example_inputs = m .example_inputs ()
@@ -262,61 +171,6 @@ def test_prepare_for_propagation_comparison(self):
262
171
self .assertTrue ("conv2d" in [logger .node_name for logger in loggers ])
263
172
self .assertEqual (res , ref )
264
173
265
- def test_extract_results_from_loggers (self ):
266
- m = TestHelperModules .Conv2dThenConv1d ()
267
- example_inputs = m .example_inputs ()
268
- ep = export_for_training (m , example_inputs , strict = True )
269
- generate_numeric_debug_handle (ep )
270
- m = ep .module ()
271
- m_ref_logger = prepare_for_propagation_comparison (m )
272
-
273
- quantizer = XNNPACKQuantizer ().set_global (
274
- get_symmetric_quantization_config (is_per_channel = False )
275
- )
276
- m = prepare_pt2e (m , quantizer )
277
- m (* example_inputs )
278
- m = convert_pt2e (m )
279
- m_quant_logger = prepare_for_propagation_comparison (m )
280
-
281
- m_ref_logger (* example_inputs )
282
- m_quant_logger (* example_inputs )
283
- ref_results = extract_results_from_loggers (m_ref_logger )
284
- quant_results = extract_results_from_loggers (m_quant_logger )
285
- comparison_results = compare_results (ref_results , quant_results )
286
- for node_summary in comparison_results .values ():
287
- if len (node_summary .results ) > 0 :
288
- self .assertGreaterEqual (node_summary .results [0 ].sqnr , 35 )
289
-
290
- def test_extract_results_from_loggers_list_output (self ):
291
- m = TestHelperModules .Conv2dWithSplit ()
292
- example_inputs = m .example_inputs ()
293
- ep = export_for_training (m , example_inputs , strict = True )
294
- generate_numeric_debug_handle (ep )
295
- m = ep .module ()
296
- m_ref_logger = prepare_for_propagation_comparison (m )
297
-
298
- quantizer = XNNPACKQuantizer ().set_global (
299
- get_symmetric_quantization_config (is_per_channel = False )
300
- )
301
- m = prepare_pt2e (m , quantizer )
302
- m (* example_inputs )
303
- m = convert_pt2e (m )
304
- m_quant_logger = prepare_for_propagation_comparison (m )
305
-
306
- m_ref_logger (* example_inputs )
307
- m_quant_logger (* example_inputs )
308
- ref_results = extract_results_from_loggers (m_ref_logger )
309
- quant_results = extract_results_from_loggers (m_quant_logger )
310
- comparison_results = compare_results (ref_results , quant_results )
311
- for node_summary in comparison_results .values ():
312
- if len (node_summary .results ) > 0 :
313
- sqnr = node_summary .results [0 ].sqnr
314
- if isinstance (sqnr , list ):
315
- for sqnr_i in sqnr :
316
- self .assertGreaterEqual (sqnr_i , 35 )
317
- else :
318
- self .assertGreaterEqual (sqnr , 35 )
319
-
320
174
def test_added_node_gets_unique_id (self ) -> None :
321
175
m = TestHelperModules .Conv2dThenConv1d ()
322
176
example_inputs = m .example_inputs ()
0 commit comments