1212
1313import torch
1414from torch .testing ._internal .common_quantization import TestHelperModules
15- from torch .testing ._internal .common_utils import IS_WINDOWS , TestCase , run_tests
15+ from torch .testing ._internal .common_utils import IS_WINDOWS , run_tests
1616
1717from torchao .quantization .pt2e import (
18- CUSTOM_KEY ,
19- NUMERIC_DEBUG_HANDLE_KEY ,
20- compare_results ,
21- extract_results_from_loggers ,
2218 generate_numeric_debug_handle ,
2319 prepare_for_propagation_comparison ,
2420)
25- from torchao .quantization .pt2e .graph_utils import bfs_trace_with_node_process
26- from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
27- from torchao .testing .pt2e ._xnnpack_quantizer import (
28- XNNPACKQuantizer ,
29- get_symmetric_quantization_config ,
30- )
21+ from torchao .testing .pt2e .utils import PT2ENumericDebuggerTestCase
3122from torchao .utils import TORCH_VERSION_AT_LEAST_2_7
3223
3324if TORCH_VERSION_AT_LEAST_2_7 :
3627
3728@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_7 , "Requires torch 2.7+" )
3829@unittest .skipIf (IS_WINDOWS , "Windows not yet supported for torch.compile" )
39- class TestNumericDebugger (TestCase ):
40- def _assert_each_node_has_debug_handle (self , model ) -> None :
41- def _assert_node_has_debug_handle (node ):
42- self .assertTrue (
43- CUSTOM_KEY in node .meta
44- and NUMERIC_DEBUG_HANDLE_KEY in node .meta [CUSTOM_KEY ],
45- f"Node { node } doesn't have debug handle" ,
46- )
47-
48- bfs_trace_with_node_process (model , _assert_node_has_debug_handle )
49-
50- def _extract_debug_handles (self , model ) -> dict [str , int ]:
51- debug_handle_map : dict [str , int ] = {}
52-
53- def _extract_debug_handles_from_node (node ):
54- nonlocal debug_handle_map
55- if (
56- CUSTOM_KEY in node .meta
57- and NUMERIC_DEBUG_HANDLE_KEY in node .meta [CUSTOM_KEY ]
58- ):
59- debug_handle_map [str (node )] = node .meta [CUSTOM_KEY ][
60- NUMERIC_DEBUG_HANDLE_KEY
61- ]
62-
63- bfs_trace_with_node_process (model , _extract_debug_handles_from_node )
64-
65- return debug_handle_map
66-
67- def _extract_debug_handles_with_prev_decomp_op (self , model ) -> dict [str , int ]:
68- prev_decomp_op_to_debug_handle_map : dict [str , int ] = {}
69-
70- def _extract_debug_handles_with_prev_decomp_op_from_node (node ):
71- nonlocal prev_decomp_op_to_debug_handle_map
72- if (
73- CUSTOM_KEY in node .meta
74- and NUMERIC_DEBUG_HANDLE_KEY in node .meta [CUSTOM_KEY ]
75- ):
76- prev_decomp_op = str (node .meta .get ("nn_module_stack" ))
77- debug_handle = node .meta [CUSTOM_KEY ][NUMERIC_DEBUG_HANDLE_KEY ]
78- if prev_decomp_op not in prev_decomp_op_to_debug_handle_map :
79- prev_decomp_op_to_debug_handle_map [prev_decomp_op ] = debug_handle
80- else :
81- assert (
82- prev_decomp_op_to_debug_handle_map [prev_decomp_op ]
83- == debug_handle
84- ), f"Node { node } has different debug handle { debug_handle } "
85- "than previous node sharing the same decomp op {prev_decomp_op}"
86-
87- bfs_trace_with_node_process (
88- model , _extract_debug_handles_with_prev_decomp_op_from_node
89- )
90- return prev_decomp_op_to_debug_handle_map
91-
30+ class TestNumericDebuggerInfra (PT2ENumericDebuggerTestCase ):
9231 @unittest .skip (
9332 "torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recom..."
9433 )
@@ -113,36 +52,6 @@ def test_control_flow(self):
11352
11453 self .assertEqual (len (set (debug_handle_map .values ())), len (debug_handle_map ))
11554
116- def test_quantize_pt2e_preserve_handle (self ):
117- m = TestHelperModules .Conv2dThenConv1d ()
118- example_inputs = m .example_inputs ()
119- ep = export_for_training (m , example_inputs , strict = True )
120- generate_numeric_debug_handle (ep )
121- m = ep .module ()
122-
123- quantizer = XNNPACKQuantizer ().set_global (
124- get_symmetric_quantization_config (is_per_channel = False )
125- )
126- m = prepare_pt2e (m , quantizer )
127- debug_handle_map = self ._extract_debug_handles (m )
128- res_counter = Counter (debug_handle_map .values ())
129- repeated_debug_handle_ids = [1 , 2 , 3 ]
130- # 3 ids were repeated because we copy over the id from node to its output observer
131- # torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default
132- for dh_id in repeated_debug_handle_ids :
133- self .assertEqual (res_counter [dh_id ], 2 )
134-
135- m (* example_inputs )
136- m = convert_pt2e (m )
137- self ._assert_each_node_has_debug_handle (ep )
138- debug_handle_map = self ._extract_debug_handles (m )
139- res_counter = Counter (debug_handle_map .values ())
140- # same set of ids where repeated, because we copy over the id from observer/fake_quant to
141- # dequantize node
142- repeated_debug_handle_ids = [1 , 2 , 3 ]
143- for dh_id in repeated_debug_handle_ids :
144- self .assertEqual (res_counter [dh_id ], 2 )
145-
14655 def test_copy_preserve_handle (self ):
14756 m = TestHelperModules .Conv2dThenConv1d ()
14857 example_inputs = m .example_inputs ()
@@ -262,61 +171,6 @@ def test_prepare_for_propagation_comparison(self):
262171 self .assertTrue ("conv2d" in [logger .node_name for logger in loggers ])
263172 self .assertEqual (res , ref )
264173
265- def test_extract_results_from_loggers (self ):
266- m = TestHelperModules .Conv2dThenConv1d ()
267- example_inputs = m .example_inputs ()
268- ep = export_for_training (m , example_inputs , strict = True )
269- generate_numeric_debug_handle (ep )
270- m = ep .module ()
271- m_ref_logger = prepare_for_propagation_comparison (m )
272-
273- quantizer = XNNPACKQuantizer ().set_global (
274- get_symmetric_quantization_config (is_per_channel = False )
275- )
276- m = prepare_pt2e (m , quantizer )
277- m (* example_inputs )
278- m = convert_pt2e (m )
279- m_quant_logger = prepare_for_propagation_comparison (m )
280-
281- m_ref_logger (* example_inputs )
282- m_quant_logger (* example_inputs )
283- ref_results = extract_results_from_loggers (m_ref_logger )
284- quant_results = extract_results_from_loggers (m_quant_logger )
285- comparison_results = compare_results (ref_results , quant_results )
286- for node_summary in comparison_results .values ():
287- if len (node_summary .results ) > 0 :
288- self .assertGreaterEqual (node_summary .results [0 ].sqnr , 35 )
289-
290- def test_extract_results_from_loggers_list_output (self ):
291- m = TestHelperModules .Conv2dWithSplit ()
292- example_inputs = m .example_inputs ()
293- ep = export_for_training (m , example_inputs , strict = True )
294- generate_numeric_debug_handle (ep )
295- m = ep .module ()
296- m_ref_logger = prepare_for_propagation_comparison (m )
297-
298- quantizer = XNNPACKQuantizer ().set_global (
299- get_symmetric_quantization_config (is_per_channel = False )
300- )
301- m = prepare_pt2e (m , quantizer )
302- m (* example_inputs )
303- m = convert_pt2e (m )
304- m_quant_logger = prepare_for_propagation_comparison (m )
305-
306- m_ref_logger (* example_inputs )
307- m_quant_logger (* example_inputs )
308- ref_results = extract_results_from_loggers (m_ref_logger )
309- quant_results = extract_results_from_loggers (m_quant_logger )
310- comparison_results = compare_results (ref_results , quant_results )
311- for node_summary in comparison_results .values ():
312- if len (node_summary .results ) > 0 :
313- sqnr = node_summary .results [0 ].sqnr
314- if isinstance (sqnr , list ):
315- for sqnr_i in sqnr :
316- self .assertGreaterEqual (sqnr_i , 35 )
317- else :
318- self .assertGreaterEqual (sqnr , 35 )
319-
320174 def test_added_node_gets_unique_id (self ) -> None :
321175 m = TestHelperModules .Conv2dThenConv1d ()
322176 example_inputs = m .example_inputs ()
0 commit comments