Skip to content

Commit f8b7287

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
deduplicate torch ao debugger tests between pytorch/ao and ExecuTorch (#11735)
Summary: Pull Request resolved: #11735 This diff deduplicates numeric debugging tests on XnnPack quantizer between torchao and ExecuTorch. Differential Revision: D76634915
1 parent 057558f commit f8b7287

File tree

1 file changed

+3
-33
lines changed

1 file changed

+3
-33
lines changed

backends/xnnpack/test/quantizer/test_pt2e_quantization.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-unsafe
88

99
from collections import Counter
10-
from typing import Dict, Tuple
10+
from typing import Tuple
1111

1212
import torch
1313
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
@@ -31,27 +31,23 @@
3131
from torch.testing._internal.common_utils import (
3232
instantiate_parametrized_tests,
3333
TemporaryFileName,
34-
TestCase,
3534
)
3635
from torchao.quantization.pt2e import (
3736
allow_exported_model_train_eval,
3837
compare_results,
39-
CUSTOM_KEY,
4038
extract_results_from_loggers,
4139
generate_numeric_debug_handle,
42-
NUMERIC_DEBUG_HANDLE_KEY,
4340
prepare_for_propagation_comparison,
4441
)
4542

46-
from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process
4743
from torchao.quantization.pt2e.quantize_pt2e import (
4844
convert_pt2e,
4945
prepare_pt2e,
5046
prepare_qat_pt2e,
5147
)
5248
from torchao.quantization.pt2e.quantizer import ComposableQuantizer, Quantizer
5349
from torchao.quantization.pt2e.quantizer.embedding_quantizer import EmbeddingQuantizer
54-
from torchao.testing.pt2e.utils import PT2EQuantizationTestCase
50+
from torchao.testing.pt2e.utils import PT2EQuantizationTestCase, PT2ENumericDebuggerTestCase
5551

5652

5753
class TestQuantizePT2E(PT2EQuantizationTestCase):
@@ -723,33 +719,7 @@ def test_save_load(self) -> None:
723719
instantiate_parametrized_tests(TestQuantizePT2E)
724720

725721

726-
class TestNumericDebugger(TestCase):
727-
def _extract_debug_handles(self, model) -> Dict[str, int]:
728-
debug_handle_map: Dict[str, int] = {}
729-
730-
def _extract_debug_handles_from_node(node: torch.fx.Node) -> None:
731-
nonlocal debug_handle_map
732-
if (
733-
CUSTOM_KEY in node.meta
734-
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
735-
):
736-
debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][
737-
NUMERIC_DEBUG_HANDLE_KEY
738-
]
739-
740-
bfs_trace_with_node_process(model, _extract_debug_handles_from_node)
741-
return debug_handle_map
742-
743-
def _assert_each_node_has_debug_handle(self, model) -> None:
744-
def _assert_node_has_debug_handle(node: torch.fx.Node) -> None:
745-
self.assertTrue(
746-
CUSTOM_KEY in node.meta
747-
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY],
748-
f"Node {node} doesn't have debug handle",
749-
)
750-
751-
bfs_trace_with_node_process(model, _assert_node_has_debug_handle)
752-
722+
class TestXNNPACKQuantizerNumericDebugger(PT2ENumericDebuggerTestCase):
753723
def test_quantize_pt2e_preserve_handle(self) -> None:
754724
m = TestHelperModules.Conv2dThenConv1d()
755725
example_inputs = m.example_inputs()

0 commit comments

Comments
 (0)