@@ -852,35 +852,37 @@ func.func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch
852
852
// -----
853
853
854
854
// CHECK-LABEL: func.func @torch.aten.avg_pool2d$basic(
855
- // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,512,7,7],f32>) -> !torch.vtensor<[1,512,1,1],f32> {
855
+ // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,512,7,7],f32>) -> !torch.vtensor<[1,512,1,1],f32> {
856
856
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,7,7],f32> -> tensor<1x512x7x7xf32>
857
857
// CHECK: %[[VAL_2:.*]] = torch.constant.int 7
858
858
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
859
859
// CHECK: %[[VAL_4:.*]] = torch.constant.int 0
860
860
// CHECK: %[[VAL_5:.*]] = torch.constant.bool false
861
- // CHECK: %[[VAL_6:.*]] = torch.constant.none
862
- // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list<int>
863
- // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
864
- // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
865
- // CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
866
- // CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_10]] : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32>
867
- // CHECK: %[[VAL_12:.*]] = tosa.avg_pool2d %[[VAL_11]] {acc_type = f32, kernel = array<i64: 7, 7>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32>
868
- // CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
869
- // CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_12]], %[[VAL_13]] : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32>
870
- // CHECK: %[[VAL_15:.*]] = tensor.cast %[[VAL_14]] : tensor<1x512x1x1xf32> to tensor<1x512x1x1xf32>
871
- // CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<1x512x1x1xf32> -> !torch.vtensor<[1,512,1,1],f32>
872
- // CHECK: return %[[VAL_16]] : !torch.vtensor<[1,512,1,1],f32>
861
+ // CHECK: %[[VAL_6:.*]] = torch.constant.bool true
862
+ // CHECK: %[[VAL_7:.*]] = torch.constant.none
863
+ // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list<int>
864
+ // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
865
+ // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
866
+ // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
867
+ // CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_11]] : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32>
868
+ // CHECK: %[[VAL_13:.*]] = tosa.avg_pool2d %[[VAL_12]] {acc_type = f32, kernel = array<i64: 7, 7>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32>
869
+ // CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
870
+ // CHECK: %[[VAL_15:.*]] = tosa.transpose %[[VAL_13]], %[[VAL_14]] : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32>
871
+ // CHECK: %[[VAL_16:.*]] = tensor.cast %[[VAL_15]] : tensor<1x512x1x1xf32> to tensor<1x512x1x1xf32>
872
+ // CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<1x512x1x1xf32> -> !torch.vtensor<[1,512,1,1],f32>
873
+ // CHECK: return %[[VAL_17]] : !torch.vtensor<[1,512,1,1],f32>
873
874
// CHECK: }
874
875
func.func @torch.aten.avg_pool2d$basic (%arg0: !torch.vtensor <[1 ,512 ,7 ,7 ],f32 > ) -> !torch.vtensor <[1 ,512 ,1 ,1 ],f32 > {
875
876
%int7 = torch.constant.int 7
876
877
%int1 = torch.constant.int 1
877
878
%int0 = torch.constant.int 0
878
879
%false = torch.constant.bool false
880
+ %true = torch.constant.bool true
879
881
%none = torch.constant.none
880
882
%kernel = torch.prim.ListConstruct %int7 , %int7 : (!torch.int , !torch.int ) -> !torch.list <int >
881
883
%stride = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
882
884
%padding = torch.prim.ListConstruct %int0 , %int0 : (!torch.int , !torch.int ) -> !torch.list <int >
883
- %0 = torch.aten.avg_pool2d %arg0 , %kernel , %stride , %padding , %false , %false , %none : !torch.vtensor <[1 ,512 ,7 ,7 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool , !torch.none -> !torch.vtensor <[1 ,512 ,1 ,1 ],f32 >
885
+ %0 = torch.aten.avg_pool2d %arg0 , %kernel , %stride , %padding , %false , %true , %none : !torch.vtensor <[1 ,512 ,7 ,7 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool , !torch.none -> !torch.vtensor <[1 ,512 ,1 ,1 ],f32 >
884
886
return %0 : !torch.vtensor <[1 ,512 ,1 ,1 ],f32 >
885
887
}
886
888
@@ -1999,42 +2001,6 @@ func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso
1999
2001
2000
2002
// -----
2001
2003
2002
- func.func @torch.aten.avg_pool2d.count_include_pad_unsupported_value (%arg0: !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >) -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 > {
2003
- %int0 = torch.constant.int 0
2004
- %int1 = torch.constant.int 1
2005
- %int3 = torch.constant.int 3
2006
- %false = torch.constant.bool false
2007
- %count_include_pad = torch.constant.bool true
2008
- %divisor_override = torch.constant.none
2009
-
2010
- %0 = torch.prim.ListConstruct %int3 , %int3 : (!torch.int , !torch.int ) -> !torch.list <int >
2011
- %1 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
2012
- %2 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
2013
- // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}}
2014
- %3 = torch.aten.avg_pool2d %arg0 , %0 , %1 , %2 , %false , %count_include_pad , %divisor_override : !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool , !torch.none -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >
2015
- return %3 : !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >
2016
- }
2017
-
2018
- // -----
2019
-
2020
- func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value (%arg0: !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >) -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 > {
2021
- %int0 = torch.constant.int 0
2022
- %int1 = torch.constant.int 1
2023
- %int3 = torch.constant.int 3
2024
- %false = torch.constant.bool false
2025
- %count_include_pad = torch.constant.bool false
2026
- %divisor_override = torch.constant.int 9
2027
-
2028
- %0 = torch.prim.ListConstruct %int3 , %int3 : (!torch.int , !torch.int ) -> !torch.list <int >
2029
- %1 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
2030
- %2 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
2031
- // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}}
2032
- %3 = torch.aten.avg_pool2d %arg0 , %0 , %1 , %2 , %false , %count_include_pad , %divisor_override : !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool , !torch.int -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >
2033
- return %3 : !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >
2034
- }
2035
-
2036
- // -----
2037
-
2038
2004
// CHECK-LABEL: func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64> {
2039
2005
// CHECK: %[[VAL_0:.*]] = torch.constant.int 0
2040
2006
// CHECK: %[[VAL_1:.*]] = torch.constant.bool false
0 commit comments