From 5d2f12429ce39f1d402ee5e83216481aab2fdb37 Mon Sep 17 00:00:00 2001 From: Zahid Wakeel Date: Fri, 30 May 2025 00:53:42 +0530 Subject: [PATCH] Enable onnx MaxPool1D with indices lowering Signed-off-by: Zahid Wakeel --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 10 +++-- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 42 +++++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 542df9ee4c7b..1a730a0475ed 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1276,9 +1276,13 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( if (binder.tensorResultTypeAtIndex(resultTypeIndices, 1)) return failure(); - if (rank == 3) - return rewriter.notifyMatchFailure( - binder.op, "Unimplemented: AtenMaxPool1dWithIndicesOp"); + if (rank == 3) { + rewriter.replaceOpWithNewOp( + binder.op, resultTypeOut, resultTypeIndices, operand, + kernelSizeList, stridesList, paddingList, dilationsList, + cstCeilMode); + return success(); + } if (rank == 4) { rewriter.replaceOpWithNewOp( diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index ad18724df52a..a336d78f55dd 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -659,6 +659,48 @@ func.func @test_multinomial_dtype_double_samplenum_4(%arg0: !torch.vtensor<[3,5] // ----- +// CHECK-LABEL: func.func @test_maxpool_1d_indices_default +func.func @test_maxpool_1d_indices_default(%arg0: !torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { +// CHECK: %[[VAL_0:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_1:.*]] = torch.prim.ListConstruct %[[VAL_0]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.constant.bool false +// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]] = torch.aten.max_pool1d_with_indices %[[ARG0]], %[[VAL_1]], %[[VAL_5]], %[[VAL_3]], %[[VAL_7]], %[[VAL_8]] : !torch.vtensor<[1,3,32],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,3,31],f32>, !torch.vtensor<[93],ui64> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[1,3,31],f32> +// CHECK: } + %0:2 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64]} : (!torch.vtensor<[1,3,32],f32>) -> (!torch.vtensor<[1,3,31],f32>, !torch.vtensor<[93], ui64>) + return %0#0 : !torch.vtensor<[1,3,31],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_maxpool_1d_indices_ceil_pad_stride( +func.func @test_maxpool_1d_indices_ceil_pad_stride(%arg0: !torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,16],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { + // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,16],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { + // CHECK: %[[VAL_0:.*]] = torch.constant.int 5 + // CHECK: %[[VAL_1:.*]] = torch.prim.ListConstruct %[[VAL_0]] : (!torch.int) -> !torch.list + // CHECK: %[[VAL_2:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list + // CHECK: %[[VAL_4:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list + // CHECK: %[[VAL_6:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]] : (!torch.int) -> !torch.list + // CHECK: %[[VAL_8:.*]] = torch.constant.bool true + // CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]] = torch.aten.max_pool1d_with_indices %[[ARG0]], %[[VAL_1]], %[[VAL_5]], %[[VAL_3]], %[[VAL_7]], %[[VAL_8]] : !torch.vtensor<[1,3,32],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,3,16],f32>, !torch.vtensor<[48],ui64> + // CHECK: return %[[VAL_9]] : !torch.vtensor<[1,3,16],f32> + // CHECK: } + %0:2 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.ceil_mode = 1 : si64, torch.onnx.kernel_shape = [5 : si64], torch.onnx.pads = [2 : si64, 2: si64], torch.onnx.strides = [2 : si64]} : (!torch.vtensor<[1,3,32],f32>) -> (!torch.vtensor<[1,3,16],f32>, !torch.vtensor<[48], ui64>) + return %0#0 : !torch.vtensor<[1,3,16],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_maxpool_2d_default func.func @test_maxpool_2d_default(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { // CHECK: %[[I2:.*]] = torch.constant.int 2