diff --git a/test/quantization/pt2e/test_numeric_debugger.py b/test/quantization/pt2e/test_numeric_debugger.py index 027d57d1b2..5f565767aa 100644 --- a/test/quantization/pt2e/test_numeric_debugger.py +++ b/test/quantization/pt2e/test_numeric_debugger.py @@ -12,22 +12,13 @@ import torch from torch.testing._internal.common_quantization import TestHelperModules -from torch.testing._internal.common_utils import IS_WINDOWS, TestCase, run_tests +from torch.testing._internal.common_utils import IS_WINDOWS, run_tests from torchao.quantization.pt2e import ( - CUSTOM_KEY, - NUMERIC_DEBUG_HANDLE_KEY, - compare_results, - extract_results_from_loggers, generate_numeric_debug_handle, prepare_for_propagation_comparison, ) -from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process -from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e -from torchao.testing.pt2e._xnnpack_quantizer import ( - XNNPACKQuantizer, - get_symmetric_quantization_config, -) +from torchao.testing.pt2e.utils import PT2ENumericDebuggerTestCase from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 if TORCH_VERSION_AT_LEAST_2_7: @@ -36,59 +27,7 @@ @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") -class TestNumericDebugger(TestCase): - def _assert_each_node_has_debug_handle(self, model) -> None: - def _assert_node_has_debug_handle(node): - self.assertTrue( - CUSTOM_KEY in node.meta - and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY], - f"Node {node} doesn't have debug handle", - ) - - bfs_trace_with_node_process(model, _assert_node_has_debug_handle) - - def _extract_debug_handles(self, model) -> dict[str, int]: - debug_handle_map: dict[str, int] = {} - - def _extract_debug_handles_from_node(node): - nonlocal debug_handle_map - if ( - CUSTOM_KEY in node.meta - and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] - ): - debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][ - NUMERIC_DEBUG_HANDLE_KEY - ] - - bfs_trace_with_node_process(model, _extract_debug_handles_from_node) - - return debug_handle_map - - def _extract_debug_handles_with_prev_decomp_op(self, model) -> dict[str, int]: - prev_decomp_op_to_debug_handle_map: dict[str, int] = {} - - def _extract_debug_handles_with_prev_decomp_op_from_node(node): - nonlocal prev_decomp_op_to_debug_handle_map - if ( - CUSTOM_KEY in node.meta - and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] - ): - prev_decomp_op = str(node.meta.get("nn_module_stack")) - debug_handle = node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] - if prev_decomp_op not in prev_decomp_op_to_debug_handle_map: - prev_decomp_op_to_debug_handle_map[prev_decomp_op] = debug_handle - else: - assert ( - prev_decomp_op_to_debug_handle_map[prev_decomp_op] - == debug_handle - ), f"Node {node} has different debug handle {debug_handle}" - "than previous node sharing the same decomp op {prev_decomp_op}" - - bfs_trace_with_node_process( - model, _extract_debug_handles_with_prev_decomp_op_from_node - ) - return prev_decomp_op_to_debug_handle_map - +class TestNumericDebuggerInfra(PT2ENumericDebuggerTestCase): @unittest.skip( "torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recom..." ) @@ -113,36 +52,6 @@ def test_control_flow(self): self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map)) - def test_quantize_pt2e_preserve_handle(self): - m = TestHelperModules.Conv2dThenConv1d() - example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs, strict=True) - generate_numeric_debug_handle(ep) - m = ep.module() - - quantizer = XNNPACKQuantizer().set_global( - get_symmetric_quantization_config(is_per_channel=False) - ) - m = prepare_pt2e(m, quantizer) - debug_handle_map = self._extract_debug_handles(m) - res_counter = Counter(debug_handle_map.values()) - repeated_debug_handle_ids = [1, 2, 3] - # 3 ids were repeated because we copy over the id from node to its output observer - # torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default - for dh_id in repeated_debug_handle_ids: - self.assertEqual(res_counter[dh_id], 2) - - m(*example_inputs) - m = convert_pt2e(m) - self._assert_each_node_has_debug_handle(ep) - debug_handle_map = self._extract_debug_handles(m) - res_counter = Counter(debug_handle_map.values()) - # same set of ids where repeated, because we copy over the id from observer/fake_quant to - # dequantize node - repeated_debug_handle_ids = [1, 2, 3] - for dh_id in repeated_debug_handle_ids: - self.assertEqual(res_counter[dh_id], 2) - def test_copy_preserve_handle(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() @@ -262,61 +171,6 @@ def test_prepare_for_propagation_comparison(self): self.assertTrue("conv2d" in [logger.node_name for logger in loggers]) self.assertEqual(res, ref) - def test_extract_results_from_loggers(self): - m = TestHelperModules.Conv2dThenConv1d() - example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs, strict=True) - generate_numeric_debug_handle(ep) - m = ep.module() - m_ref_logger = prepare_for_propagation_comparison(m) - - quantizer = XNNPACKQuantizer().set_global( - get_symmetric_quantization_config(is_per_channel=False) - ) - m = prepare_pt2e(m, quantizer) - m(*example_inputs) - m = convert_pt2e(m) - m_quant_logger = prepare_for_propagation_comparison(m) - - m_ref_logger(*example_inputs) - m_quant_logger(*example_inputs) - ref_results = extract_results_from_loggers(m_ref_logger) - quant_results = extract_results_from_loggers(m_quant_logger) - comparison_results = compare_results(ref_results, quant_results) - for node_summary in comparison_results.values(): - if len(node_summary.results) > 0: - self.assertGreaterEqual(node_summary.results[0].sqnr, 35) - - def test_extract_results_from_loggers_list_output(self): - m = TestHelperModules.Conv2dWithSplit() - example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs, strict=True) - generate_numeric_debug_handle(ep) - m = ep.module() - m_ref_logger = prepare_for_propagation_comparison(m) - - quantizer = XNNPACKQuantizer().set_global( - get_symmetric_quantization_config(is_per_channel=False) - ) - m = prepare_pt2e(m, quantizer) - m(*example_inputs) - m = convert_pt2e(m) - m_quant_logger = prepare_for_propagation_comparison(m) - - m_ref_logger(*example_inputs) - m_quant_logger(*example_inputs) - ref_results = extract_results_from_loggers(m_ref_logger) - quant_results = extract_results_from_loggers(m_quant_logger) - comparison_results = compare_results(ref_results, quant_results) - for node_summary in comparison_results.values(): - if len(node_summary.results) > 0: - sqnr = node_summary.results[0].sqnr - if isinstance(sqnr, list): - for sqnr_i in sqnr: - self.assertGreaterEqual(sqnr_i, 35) - else: - self.assertGreaterEqual(sqnr, 35) - def test_added_node_gets_unique_id(self) -> None: m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() diff --git a/torchao/testing/pt2e/utils.py b/torchao/testing/pt2e/utils.py index ad49fec014..4342d81dc1 100644 --- a/torchao/testing/pt2e/utils.py +++ b/torchao/testing/pt2e/utils.py @@ -6,6 +6,7 @@ import copy import unittest +from typing import Dict import torch from torch.ao.quantization.backend_config import ( @@ -19,13 +20,19 @@ NodeSpec, QuantizationTestCase, ) +from torch.testing._internal.common_utils import TestCase +from torchao.quantization.pt2e import ( + CUSTOM_KEY, + NUMERIC_DEBUG_HANDLE_KEY, +) +from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process from torchao.quantization.pt2e.quantize_pt2e import ( convert_pt2e, prepare_pt2e, prepare_qat_pt2e, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 if TORCH_VERSION_AT_LEAST_2_5: from torch.export import export_for_training @@ -133,3 +140,66 @@ def _test_quantizer( fx_quant_output = m_fx(*example_inputs) self.assertEqual(fx_quant_output, pt2_quant_output) return m + + +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") +class PT2ENumericDebuggerTestCase(TestCase): + """ + Base test case class for PT2E numeric debugger tests containing common utility functions + for numeric debugging functionality. + """ + + def _assert_each_node_has_debug_handle(self, model) -> None: + """Assert that each node in the model has a debug handle.""" + + def _assert_node_has_debug_handle(node): + self.assertTrue( + CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY], + f"Node {node} doesn't have debug handle", + ) + + bfs_trace_with_node_process(model, _assert_node_has_debug_handle) + + def _extract_debug_handles(self, model) -> Dict[str, int]: + """Extract debug handles from all nodes in the model.""" + debug_handle_map: Dict[str, int] = {} + + def _extract_debug_handles_from_node(node): + nonlocal debug_handle_map + if ( + CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] + ): + debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][ + NUMERIC_DEBUG_HANDLE_KEY + ] + + bfs_trace_with_node_process(model, _extract_debug_handles_from_node) + return debug_handle_map + + def _extract_debug_handles_with_prev_decomp_op(self, model) -> Dict[str, int]: + """Extract debug handles with previous decomposition operation mapping.""" + prev_decomp_op_to_debug_handle_map: Dict[str, int] = {} + + def _extract_debug_handles_with_prev_decomp_op_from_node(node): + nonlocal prev_decomp_op_to_debug_handle_map + if ( + CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] + ): + prev_decomp_op = str(node.meta.get("nn_module_stack")) + debug_handle = node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] + if prev_decomp_op not in prev_decomp_op_to_debug_handle_map: + prev_decomp_op_to_debug_handle_map[prev_decomp_op] = debug_handle + else: + assert ( + prev_decomp_op_to_debug_handle_map[prev_decomp_op] + == debug_handle + ), f"Node {node} has different debug handle {debug_handle}" + "than previous node sharing the same decomp op {prev_decomp_op}" + + bfs_trace_with_node_process( + model, _extract_debug_handles_with_prev_decomp_op_from_node + ) + return prev_decomp_op_to_debug_handle_map