Skip to content

Commit c8af8a9

Browse files
committed
Add tests & update xfail_sets
Signed-off-by: Zahid Wakeel <zahid.wakeel@multicorewareinc.com>
1 parent aff80fe commit c8af8a9

File tree

3 files changed

+52
-20
lines changed

3 files changed

+52
-20
lines changed

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,8 @@
840840
"ReflectionPad3dModuleRight_basic",
841841
"ReflectionPad3dModuleFront_basic",
842842
"ReflectionPad3dModuleBack_basic",
843+
"ReplicationPad1dModule_2DInput_basic",
844+
"ReplicationPad1dModule_3DInput_basic",
843845
"ReplicationPad2dModule_basic",
844846
"ReplicationPad2dModule_bottom0",
845847
"ReplicationPad2dModule_left0",
@@ -3896,6 +3898,8 @@
38963898
"ScaledDotProductAttentionSameDynamicModule_basic",
38973899
"ScaledDotProductAttentionSameModule_basic",
38983900
"ScaledDotProductAttentionGQAModule_basic",
3901+
"ReplicationPad1dModule_2DInput_basic",
3902+
"ReplicationPad1dModule_3DInput_basic",
38993903
}
39003904

39013905
ONNX_TOSA_CRASHING_SET = {
@@ -4725,6 +4729,8 @@
47254729
"ReshapeCollapseModule_basic",
47264730
"ReshapeDynamicModule_basic",
47274731
"ReshapeExpandModule_basic",
4732+
"ReplicationPad1dModule_2DInput_basic",
4733+
"ReplicationPad1dModule_3DInput_basic",
47284734
"RollModule_basic",
47294735
"RsubIntModule_noalpha_basic",
47304736
"ScalarConstantTupleModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -685,26 +685,6 @@ def ReplicationPad2dModule_left0(module, tu: TestUtils):
685685
# ==============================================================================
686686

687687

688-
class ReplicationPad1dModule(torch.nn.Module):
689-
def __init__(self):
690-
super().__init__()
691-
692-
@export
693-
@annotate_args(
694-
[
695-
None,
696-
([-1, -1, -1], torch.float32, True),
697-
]
698-
)
699-
def forward(self, x):
700-
return torch.ops.aten.replication_pad1d(x, [3, 5])
701-
702-
703-
@register_test_case(module_factory=lambda: ReplicationPad1dModule())
704-
def ReplicationPad1dModule_basic(module, tu: TestUtils):
705-
module.forward(tu.rand(1, 15, 20, low=-1))
706-
707-
708688
class ReplicationPad2dModule_right0_module(torch.nn.Module):
709689
def __init__(self):
710690
super().__init__()

projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,52 @@
1313
# ==============================================================================
1414

1515

16+
class ReplicationPad1dModule_3DInput(torch.nn.Module):
17+
def __init__(self):
18+
super().__init__()
19+
20+
@export
21+
@annotate_args(
22+
[
23+
None,
24+
([-1, -1, -1], torch.float32, True),
25+
]
26+
)
27+
def forward(self, x):
28+
return torch.ops.aten.replication_pad1d(x, [3, 5])
29+
30+
31+
@register_test_case(module_factory=lambda: ReplicationPad1dModule_3DInput())
32+
def ReplicationPad1dModule_3DInput_basic(module, tu: TestUtils):
33+
module.forward(tu.rand(1, 15, 20, low=-1))
34+
35+
36+
# ==============================================================================
37+
38+
39+
class ReplicationPad1dModule_2DInput(torch.nn.Module):
40+
def __init__(self):
41+
super().__init__()
42+
43+
@export
44+
@annotate_args(
45+
[
46+
None,
47+
([-1, -1], torch.float32, True),
48+
]
49+
)
50+
def forward(self, x):
51+
return torch.ops.aten.replication_pad1d(x, [2, 3])
52+
53+
54+
@register_test_case(module_factory=lambda: ReplicationPad1dModule_2DInput())
55+
def ReplicationPad1dModule_2DInput_basic(module, tu: TestUtils):
56+
module.forward(tu.rand(7, 12, low=-1))
57+
58+
59+
# ==============================================================================
60+
61+
1662
class ReflectionPad2dModule(torch.nn.Module):
1763
def __init__(self):
1864
super().__init__()

0 commit comments

Comments
 (0)