diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index 0970c9d9dd2a..7e71c8cfbd2f 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -468,6 +468,7 @@ class ConvertAtenMaxPoolOp : public ConvertAtenOp { SmallVector stablehloDilation(inputRank, 1); SmallVector stablehloKernelSize(inputRank, 1); SmallVector stablehloPadding(inputRank * 2, 0); + SmallVector ceilModePadding(inputRank * 2, 0); std::copy(dilation.begin(), dilation.end(), stablehloDilation.begin() + inputRank - Dim); std::copy(stride.begin(), stride.end(), @@ -520,6 +521,8 @@ class ConvertAtenMaxPoolOp : public ConvertAtenOp { const int64_t extraPadding = sizeDiff * stride[i]; stablehloPadding[frontPadIdx] += extraPadding / 2; stablehloPadding[backPadIdx] += extraPadding - extraPadding / 2; + ceilModePadding[frontPadIdx] += extraPadding / 2; + ceilModePadding[backPadIdx] += extraPadding - extraPadding / 2; } } } @@ -539,6 +542,19 @@ class ConvertAtenMaxPoolOp : public ConvertAtenOp { op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, baseDilations, windowDilations, pad); + // *** ADD THE 'ceil_mode' ATTRIBUTE *** + // The 'ceilMode' boolean variable was extracted from the PyTorch op + // earlier. + mlir::BoolAttr ceilModeAttr = rewriter.getBoolAttr(ceilMode); + reduceWindowOp->setAttr(llvm::StringRef("ceil_mode"), ceilModeAttr); + DenseIntElementsAttr ceilModePad = DenseIntElementsAttr::get( + RankedTensorType::get( + {static_cast(inputRank), static_cast(2)}, + rewriter.getI64Type()), + stablehloPadding); + // Also add the 'ceil_mode_padding' attribute to the op to distinguish + // original padding from the extra padding added for ceil_mode. + reduceWindowOp->setAttr(llvm::StringRef("ceil_mode_padding"), ceilModePad); Block &block = reduceWindowOp.getBody().emplaceBlock(); // Add bb argument