4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- # pyre-unsafe
7
+ # pyre-strict
8
8
9
9
import math
10
10
import unittest
11
- from typing import cast , Optional
11
+ from typing import cast , List , Optional
12
12
13
13
import executorch .backends .cadence .aot .ops_registrations # noqa
14
14
import torch
@@ -224,11 +224,11 @@ def verify_nop_memory_alloc(self, graph_module: torch.fx.GraphModule) -> None:
224
224
# GenerateSliceAndSelectNopConstraints, and GenerateCatNopConstraints passes.
225
225
def run_memory_planning (
226
226
self ,
227
- original ,
228
- opt_level = 2 ,
229
- mem_algo = 1 , # greedy_by_size_for_offset_calculation_with_hierarchy
230
- alloc_graph_input = True ,
231
- alloc_graph_output = True ,
227
+ original : GraphModule ,
228
+ opt_level : int = 2 ,
229
+ mem_algo : int = 1 , # greedy_by_size_for_offset_calculation_with_hierarchy
230
+ alloc_graph_input : bool = True ,
231
+ alloc_graph_output : bool = True ,
232
232
memory_config : Optional [MemoryConfig ] = None ,
233
233
) -> GraphModule :
234
234
if memory_config is None :
@@ -242,6 +242,7 @@ def run_memory_planning(
242
242
alloc_graph_output = alloc_graph_output ,
243
243
)(graph_module ).graph_module
244
244
245
+ # pyre-ignore[56]
245
246
@parameterized .expand (
246
247
[
247
248
[
@@ -259,7 +260,11 @@ def run_memory_planning(
259
260
]
260
261
)
261
262
def test_optimize_cat_on_placeholders (
262
- self , x_shape , y_shape , concat_dim , alloc_graph_input
263
+ self ,
264
+ x_shape : List [int ],
265
+ y_shape : List [int ],
266
+ concat_dim : int ,
267
+ alloc_graph_input : bool ,
263
268
) -> None :
264
269
concat_shape = [x_shape [concat_dim ] + y_shape [concat_dim ], x_shape [1 ]]
265
270
builder = GraphBuilder ()
@@ -294,7 +299,12 @@ def test_optimize_cat_on_placeholders(
294
299
# "add_add_cat_model" : cat(x + 123, y + 456)
295
300
# "add_add_cat_add_model": cat(x + 123, y + 456) + 789
296
301
def get_graph_module (
297
- self , model_name , x_shape , y_shape , concated_shape , concat_dim
302
+ self ,
303
+ model_name : str ,
304
+ x_shape : List [int ],
305
+ y_shape : List [int ],
306
+ concated_shape : List [int ],
307
+ concat_dim : int ,
298
308
) -> GraphModule :
299
309
builder = GraphBuilder ()
300
310
x = builder .placeholder ("x" , torch .ones (* x_shape , dtype = torch .float32 ))
@@ -346,6 +356,7 @@ def get_graph_module(
346
356
347
357
raise ValueError (f"Unknown model name { model_name } " )
348
358
359
+ # pyre-ignore[56]
349
360
@parameterized .expand (
350
361
[
351
362
(
@@ -366,7 +377,12 @@ def get_graph_module(
366
377
name_func = lambda f , _ , param : f"{ f .__name__ } _{ param .args [0 ]} " ,
367
378
)
368
379
def test_cat_optimized (
369
- self , _ , x_shape , y_shape , concated_shape , concat_dim
380
+ self ,
381
+ _ ,
382
+ x_shape : List [int ],
383
+ y_shape : List [int ],
384
+ concated_shape : List [int ],
385
+ concat_dim : int ,
370
386
) -> None :
371
387
original = self .get_graph_module (
372
388
"add_add_cat_model" , x_shape , y_shape , concated_shape , concat_dim
@@ -379,6 +395,7 @@ def test_cat_optimized(
379
395
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
380
396
self .verify_nop_memory_alloc (graph_module )
381
397
398
+ # pyre-ignore[56]
382
399
@parameterized .expand (
383
400
[
384
401
(
@@ -392,7 +409,12 @@ def test_cat_optimized(
392
409
name_func = lambda f , _ , param : f"{ f .__name__ } _{ param .args [0 ]} " ,
393
410
)
394
411
def test_cat_not_optimized (
395
- self , _ , x_shape , y_shape , concated_shape , concat_dim
412
+ self ,
413
+ _ ,
414
+ x_shape : List [int ],
415
+ y_shape : List [int ],
416
+ concated_shape : List [int ],
417
+ concat_dim : int ,
396
418
) -> None :
397
419
original = self .get_graph_module (
398
420
"add_add_cat_model" , x_shape , y_shape , concated_shape , concat_dim
@@ -404,6 +426,7 @@ def test_cat_not_optimized(
404
426
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
405
427
self .verify_nop_memory_alloc (graph_module )
406
428
429
+ # pyre-ignore[56]
407
430
@parameterized .expand (
408
431
[
409
432
(
@@ -426,7 +449,13 @@ def test_cat_not_optimized(
426
449
name_func = lambda f , _ , param : f"{ f .__name__ } _{ param .args [0 ]} " ,
427
450
)
428
451
def test_cat_not_graph_output (
429
- self , _ , x_shape , y_shape , concated_shape , concat_dim , expected_cat_nodes
452
+ self ,
453
+ _ ,
454
+ x_shape : List [int ],
455
+ y_shape : List [int ],
456
+ concated_shape : List [int ],
457
+ concat_dim : int ,
458
+ expected_cat_nodes : int ,
430
459
) -> None :
431
460
original = self .get_graph_module (
432
461
"add_add_cat_add_model" , x_shape , y_shape , concated_shape , concat_dim
@@ -493,13 +522,14 @@ def test_optimize_cat_with_slice(self) -> None:
493
522
self .assertEqual (count_node (graph_module , exir_ops .edge .aten .slice .Tensor ), 1 )
494
523
self .verify_nop_memory_alloc (graph_module )
495
524
525
+ # pyre-ignore[56]
496
526
@parameterized .expand (
497
527
[
498
528
(True ,), # alloc_graph_input
499
529
(False ,), # alloc_graph_input
500
530
],
501
531
)
502
- def test_optimize_cat_with_slice_infeasible (self , alloc_graph_input ) -> None :
532
+ def test_optimize_cat_with_slice_infeasible (self , alloc_graph_input : bool ) -> None :
503
533
x_shape = [5 , 6 ]
504
534
y_shape = [3 , 6 ]
505
535
concated_shape = [8 , 6 ]
0 commit comments