|
7 | 7 | # pyre-unsafe
|
8 | 8 |
|
9 | 9 | from collections import Counter
|
10 |
| -from typing import Dict, Tuple |
| 10 | +from typing import Tuple |
11 | 11 |
|
12 | 12 | import torch
|
13 | 13 | from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
|
|
31 | 31 | from torch.testing._internal.common_utils import (
|
32 | 32 | instantiate_parametrized_tests,
|
33 | 33 | TemporaryFileName,
|
34 |
| - TestCase, |
35 | 34 | )
|
36 | 35 | from torchao.quantization.pt2e import (
|
37 | 36 | allow_exported_model_train_eval,
|
38 | 37 | compare_results,
|
39 |
| - CUSTOM_KEY, |
40 | 38 | extract_results_from_loggers,
|
41 | 39 | generate_numeric_debug_handle,
|
42 |
| - NUMERIC_DEBUG_HANDLE_KEY, |
43 | 40 | prepare_for_propagation_comparison,
|
44 | 41 | )
|
45 | 42 |
|
46 |
| -from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process |
47 | 43 | from torchao.quantization.pt2e.quantize_pt2e import (
|
48 | 44 | convert_pt2e,
|
49 | 45 | prepare_pt2e,
|
50 | 46 | prepare_qat_pt2e,
|
51 | 47 | )
|
52 | 48 | from torchao.quantization.pt2e.quantizer import ComposableQuantizer, Quantizer
|
53 | 49 | from torchao.quantization.pt2e.quantizer.embedding_quantizer import EmbeddingQuantizer
|
54 |
| -from torchao.testing.pt2e.utils import PT2EQuantizationTestCase |
| 50 | +from torchao.testing.pt2e.utils import PT2EQuantizationTestCase, PT2ENumericDebuggerTestCase |
55 | 51 |
|
56 | 52 |
|
57 | 53 | class TestQuantizePT2E(PT2EQuantizationTestCase):
|
@@ -723,33 +719,7 @@ def test_save_load(self) -> None:
|
723 | 719 | instantiate_parametrized_tests(TestQuantizePT2E)
|
724 | 720 |
|
725 | 721 |
|
726 |
| -class TestNumericDebugger(TestCase): |
727 |
| - def _extract_debug_handles(self, model) -> Dict[str, int]: |
728 |
| - debug_handle_map: Dict[str, int] = {} |
729 |
| - |
730 |
| - def _extract_debug_handles_from_node(node: torch.fx.Node) -> None: |
731 |
| - nonlocal debug_handle_map |
732 |
| - if ( |
733 |
| - CUSTOM_KEY in node.meta |
734 |
| - and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] |
735 |
| - ): |
736 |
| - debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][ |
737 |
| - NUMERIC_DEBUG_HANDLE_KEY |
738 |
| - ] |
739 |
| - |
740 |
| - bfs_trace_with_node_process(model, _extract_debug_handles_from_node) |
741 |
| - return debug_handle_map |
742 |
| - |
743 |
| - def _assert_each_node_has_debug_handle(self, model) -> None: |
744 |
| - def _assert_node_has_debug_handle(node: torch.fx.Node) -> None: |
745 |
| - self.assertTrue( |
746 |
| - CUSTOM_KEY in node.meta |
747 |
| - and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY], |
748 |
| - f"Node {node} doesn't have debug handle", |
749 |
| - ) |
750 |
| - |
751 |
| - bfs_trace_with_node_process(model, _assert_node_has_debug_handle) |
752 |
| - |
| 722 | +class TestXNNPACKQuantizerNumericDebugger(PT2ENumericDebuggerTestCase): |
753 | 723 | def test_quantize_pt2e_preserve_handle(self) -> None:
|
754 | 724 | m = TestHelperModules.Conv2dThenConv1d()
|
755 | 725 | example_inputs = m.example_inputs()
|
|
0 commit comments