diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index cf6bae43fed6..d857f7427753 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -225,6 +225,10 @@ template <> struct DimensionTraits { static_assert(Dim == Dim); }; +template <> +struct DimensionTraits + : DimensionTraits {}; + template <> struct DimensionTraits { static constexpr int64_t Dim = 2; // unused const variable warning suppression: @@ -250,7 +254,8 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; static const bool withIndices = - llvm::is_one_of::value; private: @@ -1687,8 +1692,11 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + patterns.add>(typeConverter, + context); patterns.add>(typeConverter, context); patterns.add>(typeConverter, diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index 0970c9d9dd2a..915a58413e46 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -132,6 +132,29 @@ LogicalResult ConvertAtenOp::matchAndRewrite( stablehloPadding[stablehloPadding.size() - 1] = padding[0]; stablehloPadding[stablehloPadding.size() - 2] = padding[0]; + if (ceilMode) { + // Match PyTorch output shape with extra padding. See + // https://github.com/pytorch/pytorch/blob/c5de6ff079e3e5b453d6ff5190c90f02db458928/aten/src/ATen/native/Pool.h#L79 + const int64_t inputSize = inputShape[inputRank - 1]; + const int64_t numerator = + (inputSize + 2 * padding[0] - dilation[0] * (kernelSize[0] - 1) - 1); + const int64_t floor_output_size = (numerator) / stride[0] + 1; + const int64_t adj = (stride[0] - 1); + int64_t ceil_output_size = std::ceil((numerator + adj) / stride[0]) + 1; + + // Ensure last pooling starts inside input + if ((ceil_output_size - 1) * stride[0] >= inputSize + padding[0]) { + ceil_output_size--; + } + + // Add extra padding to make output size same as torch + if (ceil_output_size > floor_output_size) { + const int64_t sizeDiff = ceil_output_size - floor_output_size; + const int64_t extraPadding = sizeDiff * stride[0]; + stablehloPadding[stablehloPadding.size() - 1] += extraPadding; + } + } + Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 9c355f4ea4a8..8e8c7af4d132 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8395,6 +8395,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg5) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.max_pool1d_with_indices\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.tuple, list> {\n" +" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg5) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.bool) -> !torch.list\n" +" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %1 : !torch.tuple, list>\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.adaptive_avg_pool1d(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -12872,6 +12877,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max_pool1d_with_indices\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 3b5f651c8903..ce2add95dad0 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2995,6 +2995,8 @@ "LogCumsumExpModule_basic", "LogCumsumExpStaticNegativeDimModule_basic", "LogCumsumExpStaticFloat64DtypeModule_basic", + "MaxPool1dWithIndicesModule_basic", + "MaxPool1dWithIndicesCeilModeModule_basic", "MaxPool1dCeilModeTrueModule_basic", "MaxPool1dModule_basic", "MaxPool2dCeilModeTrueModule_basic", @@ -3724,6 +3726,8 @@ "LogCumsumExpStaticNegativeDimModule_basic", "LogCumsumExpStaticFloat64DtypeModule_basic", "MaskedScatterStaticBasic_basic", + "MaxPool1dWithIndicesModule_basic", + "MaxPool1dWithIndicesCeilModeModule_basic", "MaxPool1dCeilModeTrueModule_basic", "MaxPool1dModule_basic", "MaxPool2dCeilModeTrueModule_basic", @@ -4533,6 +4537,8 @@ "Matmul_4d", "Matmul_matvec", "Matmul_vecmat", + "MaxPool1dWithIndicesModule_basic", + "MaxPool1dWithIndicesCeilModeModule_basic", "MaxPool1dCeilModeTrueModule_basic", "MaxPool1dModule_basic", "MaxPool2dCeilModeTrueModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index a069550ec669..e46217fb7e05 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1258,6 +1258,10 @@ def aten〇avg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: L def aten〇max_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> List[int]: return pool1d(self, kernel_size, stride, padding, ceil_mode) +def aten〇max_pool1d_with_indices〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> Tuple[List[int], List[int]]: + maxpool1d = indices = pool1d(self, kernel_size, stride, padding, ceil_mode) + return maxpool1d, indices + def aten〇adaptive_avg_pool1d〡shape(self: List[int], output_size: List[int]) -> List[int]: return adaptive_avg_pool1d(self, output_size) @@ -3530,6 +3534,10 @@ def aten〇max_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: Lis self_rank, self_dtype = self_rank_dtype return self_dtype +def aten〇max_pool1d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + return self_dtype, torch.int64 + @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2])) def aten〇max_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index 9ef3cffb2193..dc4ae2efee81 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -180,6 +180,55 @@ def AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic(module, tu: TestUtils): # ============================================================================== +class MaxPool1dWithIndicesModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.max_pool1d_with_indices( + x, kernel_size=[6], stride=[2], padding=[3], dilation=2, ceil_mode=False + ) + + +@register_test_case(module_factory=lambda: MaxPool1dWithIndicesModule()) +def MaxPool1dWithIndicesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 64, 112, low=-1)) + + +class MaxPool1dWithIndicesCeilModeModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.max_pool1d_with_indices( + x, kernel_size=[4], stride=[2], padding=[2], dilation=2, ceil_mode=True + ) + + +@register_test_case(module_factory=lambda: MaxPool1dWithIndicesCeilModeModule()) +def MaxPool1dWithIndicesCeilModeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 25, 37, low=-1)) + + +# ============================================================================== + + class MaxPool1dModule(torch.nn.Module): def __init__(self):