Skip to content

Commit 7e7ea92

Browse files
authored
deduplicate torch ao debugger tests between pytorch/ao and ExecuTorch (#2390)
Summary: X-link: pytorch/executorch#11735 This diff deduplicates numeric debugging tests on XnnPack quantizer between torchao and ExecuTorch. Reviewed By: jerryzh168 Differential Revision: D76634915
1 parent 63a91d7 commit 7e7ea92

File tree

2 files changed

+74
-150
lines changed

2 files changed

+74
-150
lines changed

test/quantization/pt2e/test_numeric_debugger.py

Lines changed: 3 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,13 @@
1212

1313
import torch
1414
from torch.testing._internal.common_quantization import TestHelperModules
15-
from torch.testing._internal.common_utils import IS_WINDOWS, TestCase, run_tests
15+
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests
1616

1717
from torchao.quantization.pt2e import (
18-
CUSTOM_KEY,
19-
NUMERIC_DEBUG_HANDLE_KEY,
20-
compare_results,
21-
extract_results_from_loggers,
2218
generate_numeric_debug_handle,
2319
prepare_for_propagation_comparison,
2420
)
25-
from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process
26-
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
27-
from torchao.testing.pt2e._xnnpack_quantizer import (
28-
XNNPACKQuantizer,
29-
get_symmetric_quantization_config,
30-
)
21+
from torchao.testing.pt2e.utils import PT2ENumericDebuggerTestCase
3122
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
3223

3324
if TORCH_VERSION_AT_LEAST_2_7:
@@ -36,59 +27,7 @@
3627

3728
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+")
3829
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
39-
class TestNumericDebugger(TestCase):
40-
def _assert_each_node_has_debug_handle(self, model) -> None:
41-
def _assert_node_has_debug_handle(node):
42-
self.assertTrue(
43-
CUSTOM_KEY in node.meta
44-
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY],
45-
f"Node {node} doesn't have debug handle",
46-
)
47-
48-
bfs_trace_with_node_process(model, _assert_node_has_debug_handle)
49-
50-
def _extract_debug_handles(self, model) -> dict[str, int]:
51-
debug_handle_map: dict[str, int] = {}
52-
53-
def _extract_debug_handles_from_node(node):
54-
nonlocal debug_handle_map
55-
if (
56-
CUSTOM_KEY in node.meta
57-
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
58-
):
59-
debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][
60-
NUMERIC_DEBUG_HANDLE_KEY
61-
]
62-
63-
bfs_trace_with_node_process(model, _extract_debug_handles_from_node)
64-
65-
return debug_handle_map
66-
67-
def _extract_debug_handles_with_prev_decomp_op(self, model) -> dict[str, int]:
68-
prev_decomp_op_to_debug_handle_map: dict[str, int] = {}
69-
70-
def _extract_debug_handles_with_prev_decomp_op_from_node(node):
71-
nonlocal prev_decomp_op_to_debug_handle_map
72-
if (
73-
CUSTOM_KEY in node.meta
74-
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
75-
):
76-
prev_decomp_op = str(node.meta.get("nn_module_stack"))
77-
debug_handle = node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY]
78-
if prev_decomp_op not in prev_decomp_op_to_debug_handle_map:
79-
prev_decomp_op_to_debug_handle_map[prev_decomp_op] = debug_handle
80-
else:
81-
assert (
82-
prev_decomp_op_to_debug_handle_map[prev_decomp_op]
83-
== debug_handle
84-
), f"Node {node} has different debug handle {debug_handle}"
85-
"than previous node sharing the same decomp op {prev_decomp_op}"
86-
87-
bfs_trace_with_node_process(
88-
model, _extract_debug_handles_with_prev_decomp_op_from_node
89-
)
90-
return prev_decomp_op_to_debug_handle_map
91-
30+
class TestNumericDebuggerInfra(PT2ENumericDebuggerTestCase):
9231
@unittest.skip(
9332
"torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recom..."
9433
)
@@ -113,36 +52,6 @@ def test_control_flow(self):
11352

11453
self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map))
11554

116-
def test_quantize_pt2e_preserve_handle(self):
117-
m = TestHelperModules.Conv2dThenConv1d()
118-
example_inputs = m.example_inputs()
119-
ep = export_for_training(m, example_inputs, strict=True)
120-
generate_numeric_debug_handle(ep)
121-
m = ep.module()
122-
123-
quantizer = XNNPACKQuantizer().set_global(
124-
get_symmetric_quantization_config(is_per_channel=False)
125-
)
126-
m = prepare_pt2e(m, quantizer)
127-
debug_handle_map = self._extract_debug_handles(m)
128-
res_counter = Counter(debug_handle_map.values())
129-
repeated_debug_handle_ids = [1, 2, 3]
130-
# 3 ids were repeated because we copy over the id from node to its output observer
131-
# torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default
132-
for dh_id in repeated_debug_handle_ids:
133-
self.assertEqual(res_counter[dh_id], 2)
134-
135-
m(*example_inputs)
136-
m = convert_pt2e(m)
137-
self._assert_each_node_has_debug_handle(ep)
138-
debug_handle_map = self._extract_debug_handles(m)
139-
res_counter = Counter(debug_handle_map.values())
140-
# same set of ids where repeated, because we copy over the id from observer/fake_quant to
141-
# dequantize node
142-
repeated_debug_handle_ids = [1, 2, 3]
143-
for dh_id in repeated_debug_handle_ids:
144-
self.assertEqual(res_counter[dh_id], 2)
145-
14655
def test_copy_preserve_handle(self):
14756
m = TestHelperModules.Conv2dThenConv1d()
14857
example_inputs = m.example_inputs()
@@ -262,61 +171,6 @@ def test_prepare_for_propagation_comparison(self):
262171
self.assertTrue("conv2d" in [logger.node_name for logger in loggers])
263172
self.assertEqual(res, ref)
264173

265-
def test_extract_results_from_loggers(self):
266-
m = TestHelperModules.Conv2dThenConv1d()
267-
example_inputs = m.example_inputs()
268-
ep = export_for_training(m, example_inputs, strict=True)
269-
generate_numeric_debug_handle(ep)
270-
m = ep.module()
271-
m_ref_logger = prepare_for_propagation_comparison(m)
272-
273-
quantizer = XNNPACKQuantizer().set_global(
274-
get_symmetric_quantization_config(is_per_channel=False)
275-
)
276-
m = prepare_pt2e(m, quantizer)
277-
m(*example_inputs)
278-
m = convert_pt2e(m)
279-
m_quant_logger = prepare_for_propagation_comparison(m)
280-
281-
m_ref_logger(*example_inputs)
282-
m_quant_logger(*example_inputs)
283-
ref_results = extract_results_from_loggers(m_ref_logger)
284-
quant_results = extract_results_from_loggers(m_quant_logger)
285-
comparison_results = compare_results(ref_results, quant_results)
286-
for node_summary in comparison_results.values():
287-
if len(node_summary.results) > 0:
288-
self.assertGreaterEqual(node_summary.results[0].sqnr, 35)
289-
290-
def test_extract_results_from_loggers_list_output(self):
291-
m = TestHelperModules.Conv2dWithSplit()
292-
example_inputs = m.example_inputs()
293-
ep = export_for_training(m, example_inputs, strict=True)
294-
generate_numeric_debug_handle(ep)
295-
m = ep.module()
296-
m_ref_logger = prepare_for_propagation_comparison(m)
297-
298-
quantizer = XNNPACKQuantizer().set_global(
299-
get_symmetric_quantization_config(is_per_channel=False)
300-
)
301-
m = prepare_pt2e(m, quantizer)
302-
m(*example_inputs)
303-
m = convert_pt2e(m)
304-
m_quant_logger = prepare_for_propagation_comparison(m)
305-
306-
m_ref_logger(*example_inputs)
307-
m_quant_logger(*example_inputs)
308-
ref_results = extract_results_from_loggers(m_ref_logger)
309-
quant_results = extract_results_from_loggers(m_quant_logger)
310-
comparison_results = compare_results(ref_results, quant_results)
311-
for node_summary in comparison_results.values():
312-
if len(node_summary.results) > 0:
313-
sqnr = node_summary.results[0].sqnr
314-
if isinstance(sqnr, list):
315-
for sqnr_i in sqnr:
316-
self.assertGreaterEqual(sqnr_i, 35)
317-
else:
318-
self.assertGreaterEqual(sqnr, 35)
319-
320174
def test_added_node_gets_unique_id(self) -> None:
321175
m = TestHelperModules.Conv2dThenConv1d()
322176
example_inputs = m.example_inputs()

torchao/testing/pt2e/utils.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import copy
88
import unittest
9+
from typing import Dict
910

1011
import torch
1112
from torch.ao.quantization.backend_config import (
@@ -19,13 +20,19 @@
1920
NodeSpec,
2021
QuantizationTestCase,
2122
)
23+
from torch.testing._internal.common_utils import TestCase
2224

25+
from torchao.quantization.pt2e import (
26+
CUSTOM_KEY,
27+
NUMERIC_DEBUG_HANDLE_KEY,
28+
)
29+
from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process
2330
from torchao.quantization.pt2e.quantize_pt2e import (
2431
convert_pt2e,
2532
prepare_pt2e,
2633
prepare_qat_pt2e,
2734
)
28-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
35+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7
2936

3037
if TORCH_VERSION_AT_LEAST_2_5:
3138
from torch.export import export_for_training
@@ -133,3 +140,66 @@ def _test_quantizer(
133140
fx_quant_output = m_fx(*example_inputs)
134141
self.assertEqual(fx_quant_output, pt2_quant_output)
135142
return m
143+
144+
145+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+")
146+
class PT2ENumericDebuggerTestCase(TestCase):
147+
"""
148+
Base test case class for PT2E numeric debugger tests containing common utility functions
149+
for numeric debugging functionality.
150+
"""
151+
152+
def _assert_each_node_has_debug_handle(self, model) -> None:
153+
"""Assert that each node in the model has a debug handle."""
154+
155+
def _assert_node_has_debug_handle(node):
156+
self.assertTrue(
157+
CUSTOM_KEY in node.meta
158+
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY],
159+
f"Node {node} doesn't have debug handle",
160+
)
161+
162+
bfs_trace_with_node_process(model, _assert_node_has_debug_handle)
163+
164+
def _extract_debug_handles(self, model) -> Dict[str, int]:
165+
"""Extract debug handles from all nodes in the model."""
166+
debug_handle_map: Dict[str, int] = {}
167+
168+
def _extract_debug_handles_from_node(node):
169+
nonlocal debug_handle_map
170+
if (
171+
CUSTOM_KEY in node.meta
172+
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
173+
):
174+
debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][
175+
NUMERIC_DEBUG_HANDLE_KEY
176+
]
177+
178+
bfs_trace_with_node_process(model, _extract_debug_handles_from_node)
179+
return debug_handle_map
180+
181+
def _extract_debug_handles_with_prev_decomp_op(self, model) -> Dict[str, int]:
182+
"""Extract debug handles with previous decomposition operation mapping."""
183+
prev_decomp_op_to_debug_handle_map: Dict[str, int] = {}
184+
185+
def _extract_debug_handles_with_prev_decomp_op_from_node(node):
186+
nonlocal prev_decomp_op_to_debug_handle_map
187+
if (
188+
CUSTOM_KEY in node.meta
189+
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
190+
):
191+
prev_decomp_op = str(node.meta.get("nn_module_stack"))
192+
debug_handle = node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY]
193+
if prev_decomp_op not in prev_decomp_op_to_debug_handle_map:
194+
prev_decomp_op_to_debug_handle_map[prev_decomp_op] = debug_handle
195+
else:
196+
assert (
197+
prev_decomp_op_to_debug_handle_map[prev_decomp_op]
198+
== debug_handle
199+
), f"Node {node} has different debug handle {debug_handle}"
200+
"than previous node sharing the same decomp op {prev_decomp_op}"
201+
202+
bfs_trace_with_node_process(
203+
model, _extract_debug_handles_with_prev_decomp_op_from_node
204+
)
205+
return prev_decomp_op_to_debug_handle_map

0 commit comments

Comments
 (0)