From 8153c6e871e47e00e79c26f768f866eda1ebf345 Mon Sep 17 00:00:00 2001 From: Vinit Deodhar Date: Thu, 24 Jul 2025 09:29:45 -0400 Subject: [PATCH] Refactor replication pad 1d/2d/3d to share implementation --- .../TorchToLinalg/TensorConstructors.cpp | 376 ++++-------------- 1 file changed, 76 insertions(+), 300 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index cdc4afde332f..2bdce2ab2d8c 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -26,6 +26,74 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +enum sliceLoc { START = 0, END = 1 }; + +static Value extractSlice(ConversionPatternRewriter &rewriter, Location loc, + Value input, int64_t dimension, sliceLoc sliceLoc) { + auto inputType = llvm::cast(input.getType()); + int64_t inputRank = inputType.getRank(); + SmallVector inputShape = getTensorSizes(rewriter, loc, input); + + SmallVector offsets(inputRank, rewriter.getIndexAttr(0)); + if (sliceLoc == END) { + Value dimSize = inputShape[dimension]; + Value one = rewriter.create(loc, 1); + Value endIdx = rewriter.create(loc, dimSize, one); + offsets[dimension] = getAsOpFoldResult(endIdx); + } + + SmallVector allOneStrides(inputRank, rewriter.getIndexAttr(1)); + SmallVector sizes(inputRank, rewriter.getIndexAttr(0)); + for (int i = 0; i < inputRank; ++i) + sizes[i] = (i == dimension) ? rewriter.getIndexAttr(1) + : getAsOpFoldResult(inputShape[i]); + + Value extractedSlice = rewriter.create( + loc, input, offsets, sizes, allOneStrides); + return extractedSlice; +} + +static Value createTile(ConversionPatternRewriter &rewriter, Location loc, + Value slice, int64_t tileWidth, int64_t dimension) { + SmallVector slices(tileWidth, slice); + if (tileWidth == 1) + return slice; + return rewriter.create(loc, dimension, slices); +} + +static Value replicationPad(ConversionPatternRewriter &rewriter, Location loc, + Value input, SmallVector &padInts, + int64_t numDims) { + auto inputType = llvm::cast(input.getType()); + int64_t inputRank = inputType.getRank(); + + Value res = input; + int64_t padIdx = 0; + for (int64_t dim = inputRank - 1; dim >= inputRank - numDims; dim--) { + int64_t startTileWidth = padInts[padIdx++]; + int64_t endTileWidth = padInts[padIdx++]; + + SmallVector resultParts; + if (startTileWidth > 0) { + Value slice = extractSlice(rewriter, loc, res, dim, sliceLoc::START); + Value tile = createTile(rewriter, loc, slice, startTileWidth, dim); + resultParts.push_back(tile); + } + + resultParts.push_back(res); + + if (endTileWidth > 0) { + Value slice = extractSlice(rewriter, loc, res, dim, sliceLoc::END); + Value tile = createTile(rewriter, loc, slice, endTileWidth, dim); + resultParts.push_back(tile); + } + + if (resultParts.size() > 1) + res = rewriter.create(loc, dim, resultParts); + } + return res; +} + namespace { class ConvertAtenConstantPadNdOp : public OpConversionPattern { @@ -144,44 +212,8 @@ class ConvertAtenReplicationPad1dOp return rewriter.notifyMatchFailure( op, "pad range must have exactly two values"); - int64_t leftPad = padInts[0]; - int64_t rightPad = padInts[1]; - - int64_t dimToPad = inputRank - 1; - Value one = rewriter.create(loc, 1); - - SmallVector inputShape = getTensorSizes(rewriter, loc, input); - Value widthSize = inputShape[dimToPad]; - Value widthMinusOne = rewriter.create(loc, widthSize, one); - - // Build offset and size arrays for slicing - SmallVector allOneStrides(inputRank, - rewriter.getIndexAttr(1)); - SmallVector leftOffsets(inputRank, rewriter.getIndexAttr(0)); - SmallVector rightOffsets(inputRank, rewriter.getIndexAttr(0)); - SmallVector sizes(inputRank, rewriter.getIndexAttr(0)); - for (int i = 0; i < inputRank; ++i) - sizes[i] = (i == dimToPad) ? rewriter.getIndexAttr(1) - : getAsOpFoldResult(inputShape[i]); - - rightOffsets[dimToPad] = getAsOpFoldResult(widthMinusOne); - - // Extract leftmost and rightmost slices - Value leftSlice = rewriter.create( - loc, input, leftOffsets, sizes, allOneStrides); - Value rightSlice = rewriter.create( - loc, input, rightOffsets, sizes, allOneStrides); - - // Aggregate slices to concat together - SmallVector resultParts; - resultParts.reserve(leftPad + rightPad + 1); - - resultParts.append(leftPad, leftSlice); - resultParts.push_back(input); - resultParts.append(rightPad, rightSlice); - - Value result = - rewriter.create(loc, dimToPad, resultParts); + int64_t numDims = 1; + Value result = replicationPad(rewriter, loc, input, padInts, numDims); Type resultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, resultType, result); @@ -210,8 +242,7 @@ class ConvertAtenReplicationPad2dOp Value input = adaptor.getSelf(); auto inputType = llvm::cast(input.getType()); int64_t inputRank = inputType.getRank(); - unsigned numDims = inputType.getRank(); - assert(numDims >= 2 && "Not enough input dimensions"); + assert(inputRank >= 2 && "Not enough input dimensions"); SmallVector padInts; if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) @@ -223,202 +254,8 @@ class ConvertAtenReplicationPad2dOp if (inputRank < 0 || padRank > (uint64_t)inputRank) return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank"); - SmallVector inputShape = getTensorSizes(rewriter, loc, input); - int64_t hDim = numDims - 1; - int64_t vDim = numDims - 2; - Value hDimSize = inputShape[hDim]; - Value vDimSize = inputShape[vDim]; - - enum tileHLoc { LEFT = 0, HCENTER = 1, RIGHT = 2 }; - enum tileVLoc { - TOP = 0, - VCENTER = 2, - BOTTOM = 1, - }; - // vTile denotes the vertical size of the tile - // hTile denotes the horizontal size of the tile - // The padding results are composed of following tiles: - // vTile[TOP]hTile[LEFT], vTile[TOP]hTile[HCENTER], vTile[TOP]hTile[RIGHT] - // vTile[VCENTER]hTile[LEFT], vTile[VCENTER]hTile[HCENTER], - // vTile[VCENTER]hTile[RIGHT] vTile[BOTTOM]hTile[LEFT], - // vTile[BOTTOM]hTile[HCENTER], vTile[BOTTOM]hTile[RIGHT] - // vTile[VCENTER]hTile[HCENTER] is the original input tensor - Type indexType = rewriter.getIndexType(); - Value vTile[3]; - Value hTile[3]; - vTile[VCENTER] = vDimSize; - hTile[HCENTER] = hDimSize; - vTile[TOP] = getConstant(rewriter, loc, padInts[2], indexType); - vTile[BOTTOM] = getConstant(rewriter, loc, padInts[3], indexType); - hTile[LEFT] = getConstant(rewriter, loc, padInts[0], indexType); - hTile[RIGHT] = getConstant(rewriter, loc, padInts[1], indexType); - - bool hasLeftPadding = false; - bool hasRightPadding = false; - bool hasTopPadding = false; - bool hasBottomPadding = false; - - for (auto i : {TOP, VCENTER, BOTTOM}) { - for (auto j : {LEFT, HCENTER, RIGHT}) { - auto constVtile{dyn_cast_or_null( - mlir::dyn_cast(vTile[i].getDefiningOp()) - .getValue())}; - - auto constHtile{dyn_cast_or_null( - mlir::dyn_cast(hTile[j].getDefiningOp()) - .getValue())}; - auto vSize = constVtile.getInt(); - auto hSize = constHtile.getInt(); - - if ((i == TOP) && (vSize > 0)) - hasTopPadding = true; - if ((i == BOTTOM) && (vSize > 0)) - hasBottomPadding = true; - if ((j == LEFT) && (hSize > 0)) - hasLeftPadding = true; - if ((j == RIGHT) && (hSize > 0)) - hasRightPadding = true; - } - } - - auto createSub = [&](Value x, Value y) { - return rewriter.create(loc, x, y); - }; - - // Extract left and right pad tiles. - Value zero = getConstant(rewriter, loc, 0, indexType); - Value one = getConstant(rewriter, loc, 1, indexType); - Value hDimSizeMinusOne = createSub(hDimSize, one); - Value vDimSizeMinusOne = createSub(vDimSize, one); - SmallVector allOneStridesVal(numDims, one); - SmallVector allOneStrides = - getAsOpFoldResult(allOneStridesVal); - - SmallVector extractOffsetsLTVal(numDims, zero); - extractOffsetsLTVal[hDim] = zero; - extractOffsetsLTVal[vDim] = zero; - SmallVector extractOffsetsLT = - getAsOpFoldResult(extractOffsetsLTVal); - SmallVector extractShapeLRVal(numDims, one); - extractShapeLRVal[hDim] = one; - extractShapeLRVal[vDim] = vDimSize; - SmallVector extractShapeLR = - getAsOpFoldResult(extractShapeLRVal); - - SmallVector extractOffsetsRightVal(numDims, zero); - extractOffsetsRightVal[hDim] = hDimSizeMinusOne; - extractOffsetsRightVal[vDim] = zero; - SmallVector extractOffsetsRight = - getAsOpFoldResult(extractOffsetsRightVal); - - SmallVector extractOffsetsBottomVal(numDims, zero); - extractOffsetsBottomVal[hDim] = zero; - extractOffsetsBottomVal[vDim] = vDimSizeMinusOne; - SmallVector extractOffsetsBottom = - getAsOpFoldResult(extractOffsetsBottomVal); - - SmallVector extractShapeTBVal(numDims, one); - extractShapeTBVal[hDim] = hDimSize; - extractShapeTBVal[vDim] = one; - SmallVector extractShapeTB = - getAsOpFoldResult(extractShapeTBVal); - - SmallVector tensorsLeft; - SmallVector tensorsRight; - SmallVector tensorsCenter; - Value centerTile; - SmallVector tensorsRes; - - if (hasLeftPadding) { - Value vCenterLeftSlice = rewriter.create( - loc, input, extractOffsetsLT, extractShapeLR, allOneStrides); - Value vLeftSlice = vCenterLeftSlice; - SmallVector extractIndices(numDims, zero); - if (hasTopPadding) { - Value topLeftValue = - rewriter.create(loc, input, extractIndices); - // pad vCenterLeftSlice on the top - SmallVector lowPadding(numDims, 0); - SmallVector highPadding(numDims, 0); - lowPadding[vDim] = padInts[2]; - vLeftSlice = torch_to_linalg::getPaddedTensor( - op, rewriter, vLeftSlice, lowPadding, highPadding, topLeftValue); - } - if (hasBottomPadding) { - extractIndices[vDim] = vDimSizeMinusOne; - Value bottomLeftValue = - rewriter.create(loc, input, extractIndices); - - // pad vLeftSlice at the bottom - SmallVector lowPadding(numDims, 0); - SmallVector highPadding(numDims, 0); - highPadding[vDim] = padInts[3]; - vLeftSlice = torch_to_linalg::getPaddedTensor( - op, rewriter, vLeftSlice, lowPadding, highPadding, bottomLeftValue); - } - for (auto i = 0; i < padInts[0]; ++i) { - tensorsLeft.push_back(vLeftSlice); - } - Value leftPadTile = - rewriter.create(loc, hDim, tensorsLeft); - tensorsRes.push_back(leftPadTile); - } - if (hasTopPadding) { - Value topHcenterSlice = rewriter.create( - loc, input, extractOffsetsLT, extractShapeTB, allOneStrides); - for (auto i = 0; i < padInts[2]; ++i) { - tensorsCenter.push_back(topHcenterSlice); - } - } - tensorsCenter.push_back(input); - if (hasBottomPadding) { - Value bottomHcenterSlice = rewriter.create( - loc, input, extractOffsetsBottom, extractShapeTB, allOneStrides); - for (auto i = 0; i < padInts[3]; ++i) { - tensorsCenter.push_back(bottomHcenterSlice); - } - } - centerTile = rewriter.create(loc, vDim, tensorsCenter); - tensorsRes.push_back(centerTile); - - if (hasRightPadding) { - Value vCenterRightSlice = rewriter.create( - loc, input, extractOffsetsRight, extractShapeLR, allOneStrides); - Value vRightSlice = vCenterRightSlice; - SmallVector extractIndices(numDims, zero); - extractIndices[hDim] = hDimSizeMinusOne; - if (hasTopPadding) { - Value topRightValue = rewriter.create( - loc, input, ValueRange{zero, zero, zero, hDimSizeMinusOne}); - - // pad vCenterRightSlice on the top - SmallVector lowPadding(numDims, 0); - SmallVector highPadding(numDims, 0); - lowPadding[vDim] = padInts[2]; - vRightSlice = torch_to_linalg::getPaddedTensor( - op, rewriter, vRightSlice, lowPadding, highPadding, topRightValue); - } - if (hasBottomPadding) { - extractIndices[vDim] = vDimSizeMinusOne; - Value bottomRightValue = - rewriter.create(loc, input, extractIndices); - - // Pad vCenterRightSlice or vRightTopPaddedSlice at the bottom. - SmallVector lowPadding(numDims, 0); - SmallVector highPadding(numDims, 0); - highPadding[vDim] = padInts[3]; - vRightSlice = torch_to_linalg::getPaddedTensor( - op, rewriter, vRightSlice, lowPadding, highPadding, - bottomRightValue); - } - for (auto i = 0; i < padInts[1]; ++i) { - tensorsRight.push_back(vRightSlice); - } - Value rightPadTile = - rewriter.create(loc, hDim, tensorsRight); - tensorsRes.push_back(rightPadTile); - } - Value resTensor = rewriter.create(loc, hDim, tensorsRes); + int64_t numDims = 2; + Value resTensor = replicationPad(rewriter, loc, input, padInts, numDims); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, resTensor); return success(); @@ -433,43 +270,6 @@ namespace { class ConvertAtenReplicationPad3dOp : public OpConversionPattern { -private: - enum sliceLoc { START = 0, END = 1 }; - - Value extractSlice(ConversionPatternRewriter &rewriter, Location loc, - Value input, int64_t dimension, sliceLoc sliceLoc) const { - auto inputType = llvm::cast(input.getType()); - int64_t inputRank = inputType.getRank(); - SmallVector inputShape = getTensorSizes(rewriter, loc, input); - - SmallVector offsets(inputRank, rewriter.getIndexAttr(0)); - if (sliceLoc == END) { - Value dimSize = inputShape[dimension]; - Value one = rewriter.create(loc, 1); - Value endIdx = rewriter.create(loc, dimSize, one); - offsets[dimension] = getAsOpFoldResult(endIdx); - } - - SmallVector allOneStrides(inputRank, - rewriter.getIndexAttr(1)); - SmallVector sizes(inputRank, rewriter.getIndexAttr(0)); - for (int i = 0; i < inputRank; ++i) - sizes[i] = (i == dimension) ? rewriter.getIndexAttr(1) - : getAsOpFoldResult(inputShape[i]); - - Value extractedSlice = rewriter.create( - loc, input, offsets, sizes, allOneStrides); - return extractedSlice; - } - - Value createTile(ConversionPatternRewriter &rewriter, Location loc, - Value slice, int64_t tileWidth, int64_t dimension) const { - SmallVector slices(tileWidth, slice); - if (tileWidth == 1) - return slice; - return rewriter.create(loc, dimension, slices); - } - public: using OpConversionPattern::OpConversionPattern; @@ -483,8 +283,7 @@ class ConvertAtenReplicationPad3dOp Value input = adaptor.getSelf(); auto inputType = llvm::cast(input.getType()); int64_t inputRank = inputType.getRank(); - unsigned numDims = inputType.getRank(); - assert(numDims >= 2 && "Not enough input dimensions"); + assert(inputRank >= 2 && "Not enough input dimensions"); SmallVector padInts; if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) @@ -495,31 +294,8 @@ class ConvertAtenReplicationPad3dOp return rewriter.notifyMatchFailure( op, "pad range must have exactly six values"); - Value res = input; - int64_t padIdx = 0; - for (int64_t dim = inputRank - 1; dim >= inputRank - 3; dim--) { - int64_t startTileWidth = padInts[padIdx++]; - int64_t endTileWidth = padInts[padIdx++]; - - SmallVector resultParts; - if (startTileWidth > 0) { - Value slice = extractSlice(rewriter, loc, res, dim, sliceLoc::START); - Value tile = createTile(rewriter, loc, slice, startTileWidth, dim); - resultParts.push_back(tile); - } - - resultParts.push_back(res); - - if (endTileWidth > 0) { - Value slice = extractSlice(rewriter, loc, res, dim, sliceLoc::END); - Value tile = createTile(rewriter, loc, slice, endTileWidth, dim); - resultParts.push_back(tile); - } - - if (resultParts.size() > 1) - res = rewriter.create(loc, dim, resultParts); - } - + int64_t numDims = 3; + Value res = replicationPad(rewriter, loc, input, padInts, numDims); Type resultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, resultType, res); return success();