Skip to content

[TorchToLinalg] Enable lowering of AtenMaxPool1dWithIndicesOp to linalg dialect #4215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@ template <> struct DimensionTraits<AtenMaxPool1dOp> {
static_assert(Dim == Dim);
};

template <>
struct DimensionTraits<AtenMaxPool1dWithIndicesOp>
: DimensionTraits<AtenMaxPool1dOp> {};

template <> struct DimensionTraits<AtenMaxPool2dOp> {
static constexpr int64_t Dim = 2;
// unused const variable warning suppression:
Expand All @@ -250,7 +254,8 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;

static const bool withIndices =
llvm::is_one_of<OpTy, AtenMaxPool2dWithIndicesOp,
llvm::is_one_of<OpTy, AtenMaxPool1dWithIndicesOp,
AtenMaxPool2dWithIndicesOp,
AtenMaxPool3dWithIndicesOp>::value;

private:
Expand Down Expand Up @@ -1687,8 +1692,11 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool2dOp>>(typeConverter, context);
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool3dOp>>(typeConverter, context);

target.addIllegalOp<AtenMaxPool1dWithIndicesOp>();
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
target.addIllegalOp<AtenMaxPool3dWithIndicesOp>();
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool1dWithIndicesOp>>(typeConverter,
context);
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
context);
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool3dWithIndicesOp>>(typeConverter,
Expand Down
23 changes: 23 additions & 0 deletions lib/Conversion/TorchToStablehlo/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,29 @@ LogicalResult ConvertAtenOp<AtenMaxPool1dWithIndicesOp>::matchAndRewrite(
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
stablehloPadding[stablehloPadding.size() - 2] = padding[0];

if (ceilMode) {
// Match PyTorch output shape with extra padding. See
// https://github.com/pytorch/pytorch/blob/c5de6ff079e3e5b453d6ff5190c90f02db458928/aten/src/ATen/native/Pool.h#L79
const int64_t inputSize = inputShape[inputRank - 1];
const int64_t numerator =
(inputSize + 2 * padding[0] - dilation[0] * (kernelSize[0] - 1) - 1);
const int64_t floor_output_size = (numerator) / stride[0] + 1;
const int64_t adj = (stride[0] - 1);
int64_t ceil_output_size = std::ceil((numerator + adj) / stride[0]) + 1;

// Ensure last pooling starts inside input
if ((ceil_output_size - 1) * stride[0] >= inputSize + padding[0]) {
ceil_output_size--;
}

// Add extra padding to make output size same as torch
if (ceil_output_size > floor_output_size) {
const int64_t sizeDiff = ceil_output_size - floor_output_size;
const int64_t extraPadding = sizeDiff * stride[0];
stablehloPadding[stablehloPadding.size() - 1] += extraPadding;
}
}

Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);

auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
Expand Down
11 changes: 11 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8395,6 +8395,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.max_pool1d_with_indices\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %1 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.adaptive_avg_pool1d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -12872,6 +12877,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.max_pool1d_with_indices\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.tuple<int, int> {\n"
" %int4 = torch.constant.int 4\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %1 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
6 changes: 6 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2995,6 +2995,8 @@
"LogCumsumExpModule_basic",
"LogCumsumExpStaticNegativeDimModule_basic",
"LogCumsumExpStaticFloat64DtypeModule_basic",
"MaxPool1dWithIndicesModule_basic",
"MaxPool1dWithIndicesCeilModeModule_basic",
"MaxPool1dCeilModeTrueModule_basic",
"MaxPool1dModule_basic",
"MaxPool2dCeilModeTrueModule_basic",
Expand Down Expand Up @@ -3724,6 +3726,8 @@
"LogCumsumExpStaticNegativeDimModule_basic",
"LogCumsumExpStaticFloat64DtypeModule_basic",
"MaskedScatterStaticBasic_basic",
"MaxPool1dWithIndicesModule_basic",
"MaxPool1dWithIndicesCeilModeModule_basic",
"MaxPool1dCeilModeTrueModule_basic",
"MaxPool1dModule_basic",
"MaxPool2dCeilModeTrueModule_basic",
Expand Down Expand Up @@ -4533,6 +4537,8 @@
"Matmul_4d",
"Matmul_matvec",
"Matmul_vecmat",
"MaxPool1dWithIndicesModule_basic",
"MaxPool1dWithIndicesCeilModeModule_basic",
"MaxPool1dCeilModeTrueModule_basic",
"MaxPool1dModule_basic",
"MaxPool2dCeilModeTrueModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1258,6 +1258,10 @@ def aten〇avg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: L
def aten〇max_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> List[int]:
return pool1d(self, kernel_size, stride, padding, ceil_mode)

def aten〇max_pool1d_with_indices〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> Tuple[List[int], List[int]]:
maxpool1d = indices = pool1d(self, kernel_size, stride, padding, ceil_mode)
return maxpool1d, indices

def aten〇adaptive_avg_pool1d〡shape(self: List[int], output_size: List[int]) -> List[int]:
return adaptive_avg_pool1d(self, output_size)

Expand Down Expand Up @@ -3530,6 +3534,10 @@ def aten〇max_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: Lis
self_rank, self_dtype = self_rank_dtype
return self_dtype

def aten〇max_pool1d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> Tuple[int, int]:
self_rank, self_dtype = self_rank_dtype
return self_dtype, torch.int64

@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2]))
def aten〇max_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
49 changes: 49 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,55 @@ def AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic(module, tu: TestUtils):
# ==============================================================================


class MaxPool1dWithIndicesModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.max_pool1d_with_indices(
x, kernel_size=[6], stride=[2], padding=[3], dilation=2, ceil_mode=False
)


@register_test_case(module_factory=lambda: MaxPool1dWithIndicesModule())
def MaxPool1dWithIndicesModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 64, 112, low=-1))


class MaxPool1dWithIndicesCeilModeModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.max_pool1d_with_indices(
x, kernel_size=[4], stride=[2], padding=[2], dilation=2, ceil_mode=True
)


@register_test_case(module_factory=lambda: MaxPool1dWithIndicesCeilModeModule())
def MaxPool1dWithIndicesCeilModeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 25, 37, low=-1))


# ==============================================================================


class MaxPool1dModule(torch.nn.Module):

def __init__(self):
Expand Down
Loading