Skip to content

Commit 212da9f

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Set pyre-strict for passes unit tests. (pytorch#11740)
Summary: Pull Request resolved: pytorch#11740 This diff fixes this test failure https://www.internalfb.com/intern/test/844425134145806 ...and also ensures same problem won't pop up for other unit tests. Differential Revision: D76767018
1 parent 3a6c664 commit 212da9f

File tree

7 files changed

+361
-271
lines changed

7 files changed

+361
-271
lines changed

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 112 additions & 100 deletions
Large diffs are not rendered by default.

backends/cadence/aot/tests/test_memory_passes.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# pyre-unsafe
7+
# pyre-strict
88

99
import math
1010
import unittest
11-
from typing import cast, Optional
11+
from typing import cast, List, Optional
1212

1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
@@ -224,11 +224,11 @@ def verify_nop_memory_alloc(self, graph_module: torch.fx.GraphModule) -> None:
224224
# GenerateSliceAndSelectNopConstraints, and GenerateCatNopConstraints passes.
225225
def run_memory_planning(
226226
self,
227-
original,
228-
opt_level=2,
229-
mem_algo=1, # greedy_by_size_for_offset_calculation_with_hierarchy
230-
alloc_graph_input=True,
231-
alloc_graph_output=True,
227+
original: GraphModule,
228+
opt_level: int = 2,
229+
mem_algo: int = 1, # greedy_by_size_for_offset_calculation_with_hierarchy
230+
alloc_graph_input: bool = True,
231+
alloc_graph_output: bool = True,
232232
memory_config: Optional[MemoryConfig] = None,
233233
) -> GraphModule:
234234
if memory_config is None:
@@ -242,6 +242,7 @@ def run_memory_planning(
242242
alloc_graph_output=alloc_graph_output,
243243
)(graph_module).graph_module
244244

245+
# pyre-ignore[56]
245246
@parameterized.expand(
246247
[
247248
[
@@ -259,7 +260,11 @@ def run_memory_planning(
259260
]
260261
)
261262
def test_optimize_cat_on_placeholders(
262-
self, x_shape, y_shape, concat_dim, alloc_graph_input
263+
self,
264+
x_shape: List[int],
265+
y_shape: List[int],
266+
concat_dim: int,
267+
alloc_graph_input: bool,
263268
) -> None:
264269
concat_shape = [x_shape[concat_dim] + y_shape[concat_dim], x_shape[1]]
265270
builder = GraphBuilder()
@@ -294,7 +299,12 @@ def test_optimize_cat_on_placeholders(
294299
# "add_add_cat_model" : cat(x + 123, y + 456)
295300
# "add_add_cat_add_model": cat(x + 123, y + 456) + 789
296301
def get_graph_module(
297-
self, model_name, x_shape, y_shape, concated_shape, concat_dim
302+
self,
303+
model_name: str,
304+
x_shape: List[int],
305+
y_shape: List[int],
306+
concated_shape: List[int],
307+
concat_dim: int,
298308
) -> GraphModule:
299309
builder = GraphBuilder()
300310
x = builder.placeholder("x", torch.ones(*x_shape, dtype=torch.float32))
@@ -346,6 +356,7 @@ def get_graph_module(
346356

347357
raise ValueError(f"Unknown model name {model_name}")
348358

359+
# pyre-ignore[56]
349360
@parameterized.expand(
350361
[
351362
(
@@ -366,7 +377,12 @@ def get_graph_module(
366377
name_func=lambda f, _, param: f"{f.__name__}_{param.args[0]}",
367378
)
368379
def test_cat_optimized(
369-
self, _, x_shape, y_shape, concated_shape, concat_dim
380+
self,
381+
_,
382+
x_shape: List[int],
383+
y_shape: List[int],
384+
concated_shape: List[int],
385+
concat_dim: int,
370386
) -> None:
371387
original = self.get_graph_module(
372388
"add_add_cat_model", x_shape, y_shape, concated_shape, concat_dim
@@ -379,6 +395,7 @@ def test_cat_optimized(
379395
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
380396
self.verify_nop_memory_alloc(graph_module)
381397

398+
# pyre-ignore[56]
382399
@parameterized.expand(
383400
[
384401
(
@@ -392,7 +409,12 @@ def test_cat_optimized(
392409
name_func=lambda f, _, param: f"{f.__name__}_{param.args[0]}",
393410
)
394411
def test_cat_not_optimized(
395-
self, _, x_shape, y_shape, concated_shape, concat_dim
412+
self,
413+
_,
414+
x_shape: List[int],
415+
y_shape: List[int],
416+
concated_shape: List[int],
417+
concat_dim: int,
396418
) -> None:
397419
original = self.get_graph_module(
398420
"add_add_cat_model", x_shape, y_shape, concated_shape, concat_dim
@@ -404,6 +426,7 @@ def test_cat_not_optimized(
404426
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
405427
self.verify_nop_memory_alloc(graph_module)
406428

429+
# pyre-ignore[56]
407430
@parameterized.expand(
408431
[
409432
(
@@ -426,7 +449,13 @@ def test_cat_not_optimized(
426449
name_func=lambda f, _, param: f"{f.__name__}_{param.args[0]}",
427450
)
428451
def test_cat_not_graph_output(
429-
self, _, x_shape, y_shape, concated_shape, concat_dim, expected_cat_nodes
452+
self,
453+
_,
454+
x_shape: List[int],
455+
y_shape: List[int],
456+
concated_shape: List[int],
457+
concat_dim: int,
458+
expected_cat_nodes: int,
430459
) -> None:
431460
original = self.get_graph_module(
432461
"add_add_cat_add_model", x_shape, y_shape, concated_shape, concat_dim
@@ -493,13 +522,14 @@ def test_optimize_cat_with_slice(self) -> None:
493522
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.slice.Tensor), 1)
494523
self.verify_nop_memory_alloc(graph_module)
495524

525+
# pyre-ignore[56]
496526
@parameterized.expand(
497527
[
498528
(True,), # alloc_graph_input
499529
(False,), # alloc_graph_input
500530
],
501531
)
502-
def test_optimize_cat_with_slice_infeasible(self, alloc_graph_input) -> None:
532+
def test_optimize_cat_with_slice_infeasible(self, alloc_graph_input: bool) -> None:
503533
x_shape = [5, 6]
504534
y_shape = [3, 6]
505535
concated_shape = [8, 6]

backends/cadence/aot/tests/test_pass_filter.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# pyre-unsafe
7+
# pyre-strict
88

99

1010
import unittest
11-
1211
from copy import deepcopy
1312

13+
from typing import Callable, Dict
14+
1415
from executorch.backends.cadence.aot import pass_utils
1516
from executorch.backends.cadence.aot.pass_utils import (
1617
ALL_CADENCE_PASSES,
@@ -23,24 +24,26 @@
2324

2425

2526
class TestBase(unittest.TestCase):
26-
def setUp(self):
27+
def setUp(self) -> None:
2728
# Before running each test, create a copy of _all_passes to later restore it after test.
2829
# This avoids messing up the original _all_passes when running tests.
2930
self._all_passes_original = deepcopy(ALL_CADENCE_PASSES)
3031
# Clear _all_passes to do a clean test. It'll be restored after each test in tearDown().
3132
pass_utils.ALL_CADENCE_PASSES.clear()
3233

33-
def tearDown(self):
34+
def tearDown(self) -> None:
3435
# Restore _all_passes to original state before test.
3536
pass_utils.ALL_CADENCE_PASSES = self._all_passes_original
3637

37-
def get_filtered_passes(self, filter_):
38+
def get_filtered_passes(
39+
self, filter_: Callable[[ExportPass], bool]
40+
) -> Dict[ExportPass, CadencePassAttribute]:
3841
return {cls: attr for cls, attr in ALL_CADENCE_PASSES.items() if filter_(cls)}
3942

4043

4144
# Test pass registration
4245
class TestPassRegistration(TestBase):
43-
def test_register_cadence_pass(self):
46+
def test_register_cadence_pass(self) -> None:
4447
pass_attr_O0 = CadencePassAttribute(opt_level=0)
4548
pass_attr_debug = CadencePassAttribute(opt_level=None, debug_pass=True)
4649
pass_attr_O1_all_backends = CadencePassAttribute(
@@ -73,7 +76,7 @@ class DummyPass_Debug(ExportPass):
7376

7477
# Test pass filtering
7578
class TestPassFiltering(TestBase):
76-
def test_filter_none(self):
79+
def test_filter_none(self) -> None:
7780
pass_attr_O0 = CadencePassAttribute(opt_level=0)
7881
pass_attr_O1_debug = CadencePassAttribute(opt_level=1, debug_pass=True)
7982
pass_attr_O1_all_backends = CadencePassAttribute(
@@ -103,7 +106,7 @@ class DummyPass_O1_All_Backends(ExportPass):
103106
}
104107
self.assertEqual(O1_filter_passes, expected_passes)
105108

106-
def test_filter_debug(self):
109+
def test_filter_debug(self) -> None:
107110
pass_attr_O1_debug = CadencePassAttribute(opt_level=1, debug_pass=True)
108111
pass_attr_O2 = CadencePassAttribute(opt_level=2)
109112

@@ -122,7 +125,7 @@ class DummyPass_O2(ExportPass):
122125
# chooses debug=False.
123126
self.assertEqual(debug_filter_passes, {DummyPass_O2: pass_attr_O2})
124127

125-
def test_filter_all(self):
128+
def test_filter_all(self) -> None:
126129
@register_cadence_pass(CadencePassAttribute(opt_level=1))
127130
class DummyPass_O1(ExportPass):
128131
pass
@@ -138,7 +141,7 @@ class DummyPass_O2(ExportPass):
138141
# passes with opt_level <= 0
139142
self.assertEqual(debug_filter_passes, {})
140143

141-
def test_filter_opt_level_None(self):
144+
def test_filter_opt_level_None(self) -> None:
142145
pass_attr_O1 = CadencePassAttribute(opt_level=1)
143146
pass_attr_O2_debug = CadencePassAttribute(opt_level=2, debug_pass=True)
144147

0 commit comments

Comments
 (0)