Skip to content

Commit 9188846

Browse files
Revert "Fix torchToTosa lowering for avgpool2d to handle unsupported parameters (#3822)"
This reverts commit 7f9f99c.
1 parent 604d9a6 commit 9188846

File tree

2 files changed

+16
-72
lines changed

2 files changed

+16
-72
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5466,28 +5466,6 @@ class ConvertAtenAvgPool2dOp
54665466
DenseI64ArrayAttr &kernel,
54675467
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
54685468
Type &outputTy) const override {
5469-
5470-
// Currently, we can not represent `count_include_pad` with the existing
5471-
// TOSA AvgPool2d specification. Without the below check, we produce silent
5472-
// wrong answers (SWA) when the `count_include_pad` value is `true.`
5473-
bool countIncludePad;
5474-
if (!matchPattern(op.getCountIncludePad(),
5475-
m_TorchConstantBool(&countIncludePad)) ||
5476-
countIncludePad) {
5477-
return rewriter.notifyMatchFailure(
5478-
op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp "
5479-
"`count_include_pad` value should be `False`.");
5480-
}
5481-
5482-
// Currently, we can not represent `divisor_override` with the existing TOSA
5483-
// AvgPool2d specification. Without the below check, we produce silent wrong
5484-
// answers (SWA) when the `divisor_override` value is other than `None.`
5485-
if (!isa<Torch::NoneType>(op.getDivisorOverride().getType())) {
5486-
return rewriter.notifyMatchFailure(
5487-
op, "Unsupported `divisor_override` value, for tosa AvgPool2dOp "
5488-
"`divisor_override` value should be `None`.");
5489-
}
5490-
54915469
SmallVector<int64_t, 2> dilationArray{1, 1};
54925470
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool2dOp,
54935471
tosa::AvgPool2dOp>(

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 16 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -852,35 +852,37 @@ func.func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch
852852
// -----
853853

854854
// CHECK-LABEL: func.func @torch.aten.avg_pool2d$basic(
855-
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,512,7,7],f32>) -> !torch.vtensor<[1,512,1,1],f32> {
855+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,512,7,7],f32>) -> !torch.vtensor<[1,512,1,1],f32> {
856856
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,7,7],f32> -> tensor<1x512x7x7xf32>
857857
// CHECK: %[[VAL_2:.*]] = torch.constant.int 7
858858
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
859859
// CHECK: %[[VAL_4:.*]] = torch.constant.int 0
860860
// CHECK: %[[VAL_5:.*]] = torch.constant.bool false
861-
// CHECK: %[[VAL_6:.*]] = torch.constant.none
862-
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list<int>
863-
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
864-
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
865-
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
866-
// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_10]] : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32>
867-
// CHECK: %[[VAL_12:.*]] = tosa.avg_pool2d %[[VAL_11]] {acc_type = f32, kernel = array<i64: 7, 7>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32>
868-
// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
869-
// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_12]], %[[VAL_13]] : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32>
870-
// CHECK: %[[VAL_15:.*]] = tensor.cast %[[VAL_14]] : tensor<1x512x1x1xf32> to tensor<1x512x1x1xf32>
871-
// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<1x512x1x1xf32> -> !torch.vtensor<[1,512,1,1],f32>
872-
// CHECK: return %[[VAL_16]] : !torch.vtensor<[1,512,1,1],f32>
861+
// CHECK: %[[VAL_6:.*]] = torch.constant.bool true
862+
// CHECK: %[[VAL_7:.*]] = torch.constant.none
863+
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list<int>
864+
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
865+
// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
866+
// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
867+
// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_11]] : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32>
868+
// CHECK: %[[VAL_13:.*]] = tosa.avg_pool2d %[[VAL_12]] {acc_type = f32, kernel = array<i64: 7, 7>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32>
869+
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
870+
// CHECK: %[[VAL_15:.*]] = tosa.transpose %[[VAL_13]], %[[VAL_14]] : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32>
871+
// CHECK: %[[VAL_16:.*]] = tensor.cast %[[VAL_15]] : tensor<1x512x1x1xf32> to tensor<1x512x1x1xf32>
872+
// CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<1x512x1x1xf32> -> !torch.vtensor<[1,512,1,1],f32>
873+
// CHECK: return %[[VAL_17]] : !torch.vtensor<[1,512,1,1],f32>
873874
// CHECK: }
874875
func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> ) -> !torch.vtensor<[1,512,1,1],f32> {
875876
%int7 = torch.constant.int 7
876877
%int1 = torch.constant.int 1
877878
%int0 = torch.constant.int 0
878879
%false = torch.constant.bool false
880+
%true = torch.constant.bool true
879881
%none = torch.constant.none
880882
%kernel = torch.prim.ListConstruct %int7, %int7 : (!torch.int, !torch.int) -> !torch.list<int>
881883
%stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
882884
%padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
883-
%0 = torch.aten.avg_pool2d %arg0, %kernel, %stride, %padding, %false, %false, %none : !torch.vtensor<[1,512,7,7],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,512,1,1],f32>
885+
%0 = torch.aten.avg_pool2d %arg0, %kernel, %stride, %padding, %false, %true, %none : !torch.vtensor<[1,512,7,7],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,512,1,1],f32>
884886
return %0 : !torch.vtensor<[1,512,1,1],f32>
885887
}
886888

@@ -1999,42 +2001,6 @@ func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso
19992001

20002002
// -----
20012003

2002-
func.func @torch.aten.avg_pool2d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
2003-
%int0 = torch.constant.int 0
2004-
%int1 = torch.constant.int 1
2005-
%int3 = torch.constant.int 3
2006-
%false= torch.constant.bool false
2007-
%count_include_pad = torch.constant.bool true
2008-
%divisor_override = torch.constant.none
2009-
2010-
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
2011-
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
2012-
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
2013-
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}}
2014-
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,192,35,35],f32>
2015-
return %3 : !torch.vtensor<[1,192,35,35],f32>
2016-
}
2017-
2018-
// -----
2019-
2020-
func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
2021-
%int0 = torch.constant.int 0
2022-
%int1 = torch.constant.int 1
2023-
%int3 = torch.constant.int 3
2024-
%false= torch.constant.bool false
2025-
%count_include_pad = torch.constant.bool false
2026-
%divisor_override = torch.constant.int 9
2027-
2028-
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
2029-
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
2030-
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
2031-
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}}
2032-
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.int -> !torch.vtensor<[1,192,35,35],f32>
2033-
return %3 : !torch.vtensor<[1,192,35,35],f32>
2034-
}
2035-
2036-
// -----
2037-
20382004
// CHECK-LABEL: func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64> {
20392005
// CHECK: %[[VAL_0:.*]] = torch.constant.int 0
20402006
// CHECK: %[[VAL_1:.*]] = torch.constant.bool false

0 commit comments

Comments
 (0)