From e0d6f022b32c23a1e4e75885ed7e08a2e806ea4b Mon Sep 17 00:00:00 2001 From: Muzammiluddin Syed Date: Fri, 7 Mar 2025 12:47:52 -0600 Subject: [PATCH 1/6] Initial commit does not build --- lib/Conversion/TorchToLinalg/Pooling.cpp | 217 +++++++++++++++++++++++ 1 file changed, 217 insertions(+) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 3c971354783a..0b80ee1ffd03 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -1616,6 +1616,223 @@ class ConvertAtenAdaptivePoolOp : public OpConversionPattern { }; } // namespace +namespace { +template +class ConvertRoiAlignOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + Location loc = op->getLoc(); + const TypeConverter *typeConverter = this->getTypeConverter(); + Value result = op.getResult(); + + uint64_t pooledHeight = + cast(op.getPooledHeight().getDefiningOp()).getValue(); + uint64_t pooledWidth = + cast(op.getPooledWidth().getDefiningOp()).getValue(); + uint64_t samplingRatio = + cast(op.getSamplingRatio().getDefiningOp()).getValue(); + Value pooledH = op.getPooledHeight(); + Value pooledW = op.getPooledWidth(); + Value spatialScaleVal = op.getSpatialScale(); + llvm::APFloat spatialScale = + cast(op.getSpatialScale().getDefiningOp()).getValue(); + Value rois = op.getRois(); + Value input = op.getInput(); + // RankedTensorType inputType = input.getType(); + Value offset = + rewriter.create(loc, b.getF32FloatAttr(0.0)); + Type resultType = cast(result.getType()); + Type resultElementType = resultType.getElementType(); + if (!op.getAligned()) { + offset = rewriter.create(loc, b.getF32FloatAttr(0.5)); + } + + Value lb = rewriter.create(loc, 0); + Value ub0 = rewriter.create(loc, rois, 0); + Value ub1 = rewriter.create(loc, input, 1); + Value step = rewriter.create(loc, 1); + SmallVector finalOutputShape = {ub0, ub1, pooledH, pooledW}; + Value finalOutputTensor = rewriter.create( + loc, getAsOpFoldResult(finalOutputShape), resultElementType); + auto forLoop = rewriter.create( + loc, lb, ub0, step, ValueRange{}, + [&](OpBuilder &b1, Location loc, Value iv0, ValueRange args) { + auto forLoop = b1.create( + loc, lb, ub1, step, ValueRange{}, + [&](OpBuilder &b, Location loc, Value iv1, ValueRange args) { + // Step 1: Extract bounds for region of interest (roi) + OpFoldResult zeroAttr = b.getI64IntegerAttr(0); + OpFoldResult oneAttr = b.getI64IntegerAttr(1); + OpFoldResult twoAttr = b.getI64IntegerAttr(2); + OpFoldResult threeAttr = b.getI64IntegerAttr(3); + OpFoldResult fourAttr = b.getI64IntegerAttr(4); + OpFoldResult fiveAttr = b.getI64IntegerAttr(5); + // SmallVector offsetVals{iv0, zeroAttr}; + // SmallVector sizeVals{oneAttr, fiveAttr}; + SmallVector strideVals{oneAttr, oneAttr, oneAttr, + oneAttr}; + // Value extractRoiBounds = b.create( + // loc, rois, offsetVals, sizeVals, strideVals); + Value lowY = b.create( + loc, rois, ValueRange{iv0, oneAttr}); + Value lowX = b.create( + loc, rois, ValueRange{iv0, twoAttr}); + Value highY = b.create( + loc, rois, ValueRange{iv0, threeAttr}); + Value highX = b.create( + loc, rois, ValueRange{iv0, fourAttr}); + + lowY = b.create(loc, lowY, spatialScaleVal); + lowX = b.create(loc, lowX, spatialScaleVal); + highY = b.create(loc, highY, spatialScaleVal); + highX = b.create(loc, highX, spatialScaleVal); + + lowY = b.create(loc, lowY, offset); + lowX = b.create(loc, lowX, offset); + highY = b.create(loc, highY, offset); + highX = b.create(loc, highX, offset); + + // Step 2: Extract region of interest using bounds + Value lowY_int = b.create(loc, lowY); + Value lowX_int = b.create(loc, lowX); + Value highY_int = b.create(loc, highY); + Value highX_int = b.create(loc, highX); + lowY_int = + b.create(loc, b.getI64Type(), lowY_int); + lowX_int = + b.create(loc, b.getI64Type(), lowX_int); + highY_int = + b.create(loc, b.getI64Type(), highY_int); + highX_int = + b.create(loc, b.getI64Type(), highX_int); + + Value roiHeight = + b.create(loc, highY_int, lowY_int); + Value roiWidth = + b.create(loc, highX_int, lowX_int); + + SmallVector roiOffsetVals{zeroAttr, iv1, lowY_int, + lowX_int}; + SmallVector roiSizeVals{oneAttr, oneAttr, roiHeight, + roiWidth}; + + Value extractRoi = b.create( + loc, input, roiOffsetVals, roiSizeVals, strideVals); + + // Step 3: Perform bilinear interpolation over roi + Value roiBinH = b.create(loc, highY, lowY); + Value roiBinW = b.create(loc, highX, lowX); + Value scaleH = b.create(loc, roiBinH, pooledH); + Value scaleW = b.create(loc, roiBinW, pooledW); + scaleH = b.create(loc, scaleH); + scaleW = b.create(loc, scaleW); + scaleH = b.create(loc, b.getI64Type(), scaleH); + scaleW = b.create(loc, b.getI64Type(), scaleW); + + Value roiSampleHeight = + b.create(loc, pooledH, scaleH); + Value roiSampleWidth = + b.create(loc, pooledW, scaleW); + + SmallVector outputSizeIntValues = {roiSampleHeight, + roiSampleWidth}; + SmallVector dims = + getTensorSizesUntilDim(b, loc, extractRoi, 1); + for (unsigned i = 2; i < inputRank; i++) { + dims.push_back( + castIntToIndex(b, loc, outputSizeIntValues[i - 2])); + } + SmallVector inputSizes; + auto inputType = cast(extractRoi.getType()); + auto inputRank = inputType.getRank(); + for (unsigned i = 2; i < inputRank; i++) { + Value inputSize = getDimOp(b, loc, extractRoi, i); + inputSizes.push_back(b.create( + loc, b.getIntegerType(64), roiSizeVals[i])); + } + Value outTensor = b.create( + loc, getAsOpFoldResult(dims), inputType.getElementType()); + AffineMap idMap = b.getMultiDimIdentityMap(inputRank); + SmallVector iteratorTypes( + inputRank, utils::IteratorType::parallel); + Value bilinearInterpolatedRoi = + b.create( + loc, outTensor.getType(), ValueRange{}, outTensor, + /*indexingMaps=*/idMap, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value retVal = bilinearInterpolate( + b, op, loc, outputSizeIntValues, extractRoi, + inputSizes, ValueRange{}, "bilinear"); + b.create(loc, retVal); + }) + .getResult(0); + + // Step 4: Sum pool over interpolated values + Value sumPool, paddedInput; + SmallVector kernelSizeIntValues = {oneAttr, oneAttr, + scaleH, scaleW}; + SmallVector strideInts = {scaleH, scaleW}; + SmallVector paddingInts = {zeroAttr, zeroAttr}; + SmallVector dilationInts(oneAttr, 2); + SmallVector outTensorShape; + if (failed(createPoolingOp( + op, b, self, /*supportNonFPInput=*/true, false, + /*dimensionality=*/2, kernelSizeIntValues, strideInts, + paddingInts, dilationInts, + b.getZeroAttr(resultElementType), outTensorShape, + paddedInput, sumPool))) + return b.notifyMatchFailure(op, "unable to compute sumpool"); + + // Step 5: elementwise division by number of sampling points + // to compute avg pool + Value outputTensor = b.create( + loc, getAsOpFoldResult(outTensorShape), resultElementType); + Value divisor = b.create(loc, scaleH, scaleW); + Value avgPool = + b.create( + loc, outputTensor.getType(), sumPool, outputTensor, + /*indexingMaps=*/indexingMapsAvg, + /*iteratorTypes=*/iteratorTypesAvg, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value avg; + if (isa(resultElementType)) + avg = b.create(loc, args[0], + divisor); + else if (isa(resultElementType)) + avg = + b.create(loc, args[0], divisor); + b.create(loc, avg); + }) + .getResult(0); + + SmallVector finalStrides(inputRank, oneAttr); + SmallVector finalOffsets = { + getAsOpFoldResult(iv0), getAsOpFoldResult(iv1), zeroAttr, + zeroAttr}; + SmallVector finalSizes = { + oneAttr, oneAttr, getAsOpFoldResult(pooledH), + getAsOpFoldResult(pooledW)}; + SmallVector diagStrides(inputRank, oneAttr); + finalOutputTensor = b.create( + loc, finalOutputTensor, avgPool, finalOffsets, finalSizes, + finalStrides); + }); + }); + + Type resultType = typeConverter->convertType(op.getType()); + b.replaceOp(op, finalOutputTensor); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { From 31140fb5760076bcee6fbea66ca6414684877e7c Mon Sep 17 00:00:00 2001 From: Muzammiluddin Syed Date: Mon, 10 Mar 2025 01:06:38 -0500 Subject: [PATCH 2/6] Fixing issues --- lib/Conversion/TorchToLinalg/Pooling.cpp | 412 +++++++++++++++++++++++ 1 file changed, 412 insertions(+) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 0b80ee1ffd03..ce41c25ddfc3 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -18,6 +18,8 @@ #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Math/IR/Math.h" #include using namespace mlir; @@ -1833,6 +1835,414 @@ class ConvertRoiAlignOp : public OpConversionPattern { }; } // namespace +namespace { +template +class ConvertRoiAlignOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static SmallVector coordinateTransform( + OpBuilder &b, OpTy op, Location loc, SmallVector outputSizes, + Value input, SmallVector inputSizes, + SmallVector scaleValues, std::string coordStr, + bool alignCornersBool, SmallVector indices, bool clip) { + + unsigned dimOffset = 2; + auto inputType = cast(input.getType()); + auto inputRank = inputType.getRank(); + + Value cstOneFloat = + b.create(loc, b.getF32FloatAttr(1.0)); + Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); + Value zero = b.create(loc, b.getF32FloatAttr(0.0)); + + SmallVector proj; + for (unsigned i = 0; i < inputRank - dimOffset; i++) { + // length_original + Value inputFP = + b.create(loc, b.getF32Type(), inputSizes[i]); + // length_resized + Value outputSizeFP = + b.create(loc, b.getF32Type(), outputSizes[i]); + // scale = length_resized/length_original + Value scale; + if (alignCornersBool) { + // x_original = x_resized * (length_original - 1) / (length_resized - 1) + Value inputSubOne = b.create(loc, inputFP, cstOneFloat); + Value outputSizeSubOne = + b.create(loc, outputSizeFP, cstOneFloat); + Value cmp = b.create(loc, arith::CmpFPredicate::UEQ, + outputSizeSubOne, zero); + scale = b.create(loc, inputSubOne, outputSizeSubOne); + scale = b.create(loc, cmp, zero, scale); + coordStr = "_align_corners"; + } else if (scaleValues.empty()) + scale = b.create(loc, outputSizeFP, inputFP); + else + scale = scaleValues[i]; + // y_resized + Value outInt = b.create(loc, b.getI64Type(), + indices[i + dimOffset]); + Value outFP = b.create(loc, b.getF32Type(), outInt); + Value preClip; + if (coordStr == "_align_corners") { + preClip = b.create(loc, outFP, scale); + } + if (coordStr == "_asymmetric") { + preClip = b.create(loc, outFP, scale); + } + if (coordStr == "_pytorch_half_pixel" || coordStr == "" || + coordStr == "_half_pixel_symmetric") { + // half-pixel modes + // y_resized + 0.5 + Value outPlusHalf = b.create(loc, outFP, cstHalf); + // (y_resized + 0.5) / scale + Value outDivScale = b.create(loc, outPlusHalf, scale); + // _ - 0.5 + preClip = b.create(loc, outDivScale, cstHalf); + } + // for half_pixel_symmetric, need to compute offset from raw scales + if (coordStr == "_half_pixel_symmetric" && !scaleValues.empty()) { + Value outputSizeFromScale = + b.create(loc, inputFP, scale); + Value adjustment = + b.create(loc, outputSizeFP, outputSizeFromScale); + Value cstTwo = b.create(loc, b.getF32FloatAttr(2.0)); + Value center = b.create(loc, inputFP, cstTwo); + Value oneMAdjustment = + b.create(loc, cstOneFloat, adjustment); + Value offset = b.create(loc, center, oneMAdjustment); + preClip = b.create(loc, offset, preClip); + } + // for pytorch half pixel , special case for length_resized == 1: + if (coordStr == "_pytorch_half_pixel") { + Value cmp = b.create(loc, arith::CmpFPredicate::UEQ, + outputSizeFP, cstOneFloat); + preClip = b.create(loc, cmp, zero, preClip); + } + if (clip) { + // preClip is the fp position inside the input image to extract from. + // clip to [0,inf) + Value max = b.create(loc, preClip, zero); + Value inputSubOne = b.create(loc, inputFP, cstOneFloat); + // clip to [0,length_original - 1]. + // proj is properly within the input image. + proj.push_back(b.create(loc, max, inputSubOne)); + } else { + proj.push_back(preClip); + } + } + return proj; + } + + static Value bilinearInterpolate(OpBuilder &b, OpTy op, Location loc, + SmallVector outputSizes, Value input, + SmallVector inputSizes, + SmallVector scaleValues, + std::string coordStr) { + unsigned dimOffset = 2; + auto inputType = cast(input.getType()); + auto inputRank = inputType.getRank(); + + Value cstOneFloat = + b.create(loc, b.getF32FloatAttr(1.0)); + + SmallVector indices; + for (unsigned i = 0; i < inputRank; i++) { + indices.push_back(b.create(loc, i)); + } + + SmallVector proj, high, low, highFP, lowFP; + proj = coordinateTransform(b, op, loc, outputSizes, input, inputSizes, + scaleValues, coordStr, false, indices, true); + for (unsigned i = 0; i < inputRank - dimOffset; i++) { + // length_original + Value inputFP = + b.create(loc, b.getF32Type(), inputSizes[i]); + Value inputSubOne = b.create(loc, inputFP, cstOneFloat); + + // for bilinear interpolation, we look for the nearest indices below and + // above proj + lowFP.push_back(b.create(loc, proj[i])); + Value projPlusOne = b.create(loc, cstOneFloat, proj[i]); + highFP.push_back(b.create(loc, projPlusOne)); + + Value lowInt = b.create(loc, b.getI64Type(), lowFP[i]); + low.push_back( + b.create(loc, b.getIndexType(), lowInt)); + + // highFP could be out-of-bounds, so make sure to clip it down before + // extracting. If highFP actually gets clipped here, then high[i] will + // extract at the last pixel, but will treat it as if it were extracted + // from one further position when computing the interpolation weights. + Value highExtract = + b.create(loc, projPlusOne, inputSubOne); + highExtract = b.create(loc, b.getI64Type(), highExtract); + high.push_back( + b.create(loc, b.getIndexType(), highExtract)); + } + + indices[dimOffset] = low[0]; + indices[dimOffset + 1] = low[1]; + Value p00 = b.create(loc, input, indices); + + indices[dimOffset] = low[0]; + indices[dimOffset + 1] = high[1]; + Value p01 = b.create(loc, input, indices); + + indices[dimOffset] = high[0]; + indices[dimOffset + 1] = low[1]; + Value p10 = b.create(loc, input, indices); + + indices[dimOffset] = high[0]; + indices[dimOffset + 1] = high[1]; + Value p11 = b.create(loc, input, indices); + + // Let Aij := area rect((yProj,xProj) <-> (y_i*,x_j*)), + // where i* = i+1 mod 2 and x_0 = xLow, x_1 = xHigh etc. + // We interpolate via the weighted average of pij by weights Aij + // the formula is retval = Sum(pij*Aij for i and j in range(2)) + // Note: we do not need to divide by total rect area == 1 + + // lengths : Aij == dyi*dxj + Value dy0 = b.create(loc, highFP[0], proj[0]); + Value dy1 = b.create(loc, proj[0], lowFP[0]); + Value dx0 = b.create(loc, highFP[1], proj[1]); + Value dx1 = b.create(loc, proj[1], lowFP[1]); + + // left = A00*p00 + A01*p01 = dy0(dx0p00 + dx1p01) + Value dx0p00 = b.create(loc, dx0, p00); + Value dx1p01 = b.create(loc, dx1, p01); + Value sum = b.create(loc, dx0p00, dx1p01); + Value left = b.create(loc, dy0, sum); + // right = A10*p10 + A11*p11 = dy1(dx0p10 + dx1p11) + Value dx0p10 = b.create(loc, dx0, p10); + Value dx1p11 = b.create(loc, dx1, p11); + sum = b.create(loc, dx0p10, dx1p11); + Value right = b.create(loc, dy1, sum); + + return b.create(loc, left, right); + } + LogicalResult + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + Location loc = op->getLoc(); + Value result = op.getResult(); + + uint64_t samplingRatio = + cast(op.getSamplingRatio().getDefiningOp()).getValue(); + int64_t samplingRatioInt = static_cast(samplingRatio); + Value pooledH = op.getPooledHeight(); + Value pooledW = op.getPooledWidth(); + Value spatialScaleVal = op.getSpatialScale(); + llvm::APFloat spatialScale = + cast(op.getSpatialScale().getDefiningOp()).getValue(); + Value rois = op.getRois(); + Value input = op.getInput(); + unsigned inputRank = cast(input.getType()).getRank(); + Value offset = + rewriter.create(loc, rewriter.getF32FloatAttr(0.0)); + RankedTensorType resultType = cast(result.getType()); + Type resultElementType = resultType.getElementType(); + if (!op.getAligned()) { + offset = rewriter.create( + loc, rewriter.getF32FloatAttr(0.5)); + } + + Value lb = rewriter.create(loc, 0); + Value ub0 = rewriter.create(loc, rois, 0); + Value ub1 = rewriter.create(loc, input, 1); + Value step = rewriter.create(loc, 1); + SmallVector finalOutputShape = {ub0, ub1, pooledH, pooledW}; + Value finalOutputTensor = rewriter.create( + loc, getAsOpFoldResult(finalOutputShape), resultElementType); + rewriter.create( + loc, lb, ub0, step, ValueRange{}, + [&](OpBuilder &b, Location loc, Value iv0, ValueRange args) { + b.create( + loc, lb, ub1, step, ValueRange{}, + [&](OpBuilder &b, Location loc, Value iv1, ValueRange args) { + // Step 1: Extract bounds for region of interest (roi) + OpFoldResult zeroAttr = b.getI64IntegerAttr(0); + OpFoldResult oneAttr = b.getI64IntegerAttr(1); + + Value cstZero = rewriter.create(loc, 0); + Value cstOne = rewriter.create(loc, 1); + Value cstTwo = rewriter.create(loc, 2); + Value cstThree = + rewriter.create(loc, 3); + Value cstFour = rewriter.create(loc, 4); + SmallVector strideVals{oneAttr, oneAttr, oneAttr, + oneAttr}; + SmallVector lowYIndices = {iv0, cstOne}; + Value lowY = + b.create(loc, rois, lowYIndices); + SmallVector lowXIndices = {iv0, cstTwo}; + Value lowX = + b.create(loc, rois, lowXIndices); + SmallVector highYIndices = {iv0, cstThree}; + Value highY = + b.create(loc, rois, highYIndices); + SmallVector highXIndices = {iv0, cstFour}; + Value highX = + b.create(loc, rois, highXIndices); + + lowY = b.create(loc, lowY, spatialScaleVal); + lowX = b.create(loc, lowX, spatialScaleVal); + highY = b.create(loc, highY, spatialScaleVal); + highX = b.create(loc, highX, spatialScaleVal); + + lowY = b.create(loc, lowY, offset); + lowX = b.create(loc, lowX, offset); + highY = b.create(loc, highY, offset); + highX = b.create(loc, highX, offset); + + // Step 2: Extract region of interest using bounds + Value lowY_int = b.create(loc, lowY); + Value lowX_int = b.create(loc, lowX); + Value highY_int = b.create(loc, highY); + Value highX_int = b.create(loc, highX); + lowY_int = + b.create(loc, b.getI64Type(), lowY_int); + lowX_int = + b.create(loc, b.getI64Type(), lowX_int); + highY_int = + b.create(loc, b.getI64Type(), highY_int); + highX_int = + b.create(loc, b.getI64Type(), highX_int); + + Value roiHeight = + b.create(loc, highY_int, lowY_int); + Value roiWidth = + b.create(loc, highX_int, lowX_int); + + SmallVector roiOffsetVals = { + getAsOpFoldResult(cstZero), getAsOpFoldResult(iv1), + getAsOpFoldResult(lowY_int), getAsOpFoldResult(lowX_int)}; + SmallVector roiSizeVals = {cstOne, cstOne, roiHeight, + roiWidth}; + + Value extractRoi = b.create( + loc, input, ValueRange{cstZero, iv1, lowY_int, lowX_int}, + ValueRange{cstOne, cstOne, roiHeight, roiWidth}, + ValueRange{cstOne, cstOne, cstOne, cstOne}); + + // Step 3: Perform bilinear interpolation over roi + Value roiBinH = b.create(loc, highY, lowY); + Value roiBinW = b.create(loc, highX, lowX); + Value scaleH = b.create(loc, roiBinH, pooledH); + Value scaleW = b.create(loc, roiBinW, pooledW); + scaleH = b.create(loc, scaleH); + scaleW = b.create(loc, scaleW); + scaleH = b.create(loc, b.getI64Type(), scaleH); + scaleW = b.create(loc, b.getI64Type(), scaleW); + if (samplingRatio > 0) { + scaleH = b.create( + loc, rewriter.getI64IntegerAttr(samplingRatio)); + scaleW = b.create( + loc, rewriter.getI64IntegerAttr(samplingRatio)); + } + + Value roiSampleHeight = + b.create(loc, pooledH, scaleH); + Value roiSampleWidth = + b.create(loc, pooledW, scaleW); + + SmallVector outputSizeIntValues = {roiSampleHeight, + roiSampleWidth}; + SmallVector dims = + getTensorSizesUntilDim(b, loc, extractRoi, 1); + for (unsigned i = 2; i < inputRank; i++) { + dims.push_back( + castIntToIndex(b, loc, outputSizeIntValues[i - 2])); + } + SmallVector inputSizes; + auto inputType = cast(extractRoi.getType()); + auto inputRank = inputType.getRank(); + for (unsigned i = 2; i < inputRank; i++) { + inputSizes.push_back(b.create( + loc, b.getIntegerType(64), roiSizeVals[i])); + } + Value outTensor = b.create( + loc, getAsOpFoldResult(dims), inputType.getElementType()); + AffineMap idMap = b.getMultiDimIdentityMap(inputRank); + SmallVector iteratorTypes( + inputRank, utils::IteratorType::parallel); + Value bilinearInterpolatedRoi = + b.create( + loc, outTensor.getType(), ValueRange{}, outTensor, + /*indexingMaps=*/idMap, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value retVal = bilinearInterpolate( + b, op, loc, outputSizeIntValues, extractRoi, + inputSizes, ValueRange{}, "bilinear"); + b.create(loc, retVal); + }) + .getResult(0); + + // Step 4: Sum pool over interpolated values + Value sumPool, paddedInput; + SmallVector kernelSizeIntValues = {cstOne, cstOne, + scaleH, scaleW}; + SmallVector strideInts = {samplingRatioInt, + samplingRatioInt}; + SmallVector paddingInts = {0, 0}; + SmallVector dilationInts(2, 1); + SmallVector outTensorShape; + if (failed(createPoolingOp( + op, rewriter, bilinearInterpolatedRoi, + /*supportNonFPInput=*/true, false, + /*dimensionality=*/2, kernelSizeIntValues, strideInts, + paddingInts, dilationInts, + b.getZeroAttr(resultElementType), outTensorShape, + paddedInput, sumPool))) + op.emitError("unable to compute sumpool"); + + // Step 5: elementwise division by number of sampling points + // to compute avg pool + Value outputTensor = b.create( + loc, getAsOpFoldResult(outTensorShape), resultElementType); + Value divisor = b.create(loc, scaleH, scaleW); + Value avgPool = + b.create( + loc, outputTensor.getType(), sumPool, outputTensor, + /*indexingMaps=*/idMap, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value avg; + if (isa(resultElementType)) + avg = b.create(loc, args[0], + divisor); + else if (isa(resultElementType)) + avg = + b.create(loc, args[0], divisor); + b.create(loc, avg); + }) + .getResult(0); + + SmallVector finalStrides(inputRank, oneAttr); + SmallVector finalOffsets = { + getAsOpFoldResult(iv0), getAsOpFoldResult(iv1), zeroAttr, + zeroAttr}; + SmallVector finalSizes = { + oneAttr, oneAttr, getAsOpFoldResult(pooledH), + getAsOpFoldResult(pooledW)}; + SmallVector diagStrides(inputRank, oneAttr); + finalOutputTensor = b.create( + loc, finalOutputTensor, avgPool, finalOffsets, finalSizes, + finalStrides); + }); + }); + rewriter.replaceOp(op, finalOutputTensor); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -1882,4 +2292,6 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( typeConverter, context); patterns.add>( typeConverter, context); + patterns.add>(typeConverter, + context); } From 059c443ebdea9d232b612730c1f21fc0d0e11bea Mon Sep 17 00:00:00 2001 From: Muzammiluddin Syed Date: Mon, 10 Mar 2025 18:10:25 -0500 Subject: [PATCH 3/6] debug statements remove later --- lib/Conversion/TorchToLinalg/Pooling.cpp | 299 ++++------------------- 1 file changed, 51 insertions(+), 248 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index ce41c25ddfc3..ea9c2e07e5fd 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -20,6 +20,7 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "llvm/Support/Debug.h" #include using namespace mlir; @@ -1616,229 +1617,9 @@ class ConvertAtenAdaptivePoolOp : public OpConversionPattern { return success(); } }; -} // namespace - -namespace { -template -class ConvertRoiAlignOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (failed(verifyLinalgCompatibleTypes(op, rewriter))) - return failure(); - - Location loc = op->getLoc(); - const TypeConverter *typeConverter = this->getTypeConverter(); - Value result = op.getResult(); - - uint64_t pooledHeight = - cast(op.getPooledHeight().getDefiningOp()).getValue(); - uint64_t pooledWidth = - cast(op.getPooledWidth().getDefiningOp()).getValue(); - uint64_t samplingRatio = - cast(op.getSamplingRatio().getDefiningOp()).getValue(); - Value pooledH = op.getPooledHeight(); - Value pooledW = op.getPooledWidth(); - Value spatialScaleVal = op.getSpatialScale(); - llvm::APFloat spatialScale = - cast(op.getSpatialScale().getDefiningOp()).getValue(); - Value rois = op.getRois(); - Value input = op.getInput(); - // RankedTensorType inputType = input.getType(); - Value offset = - rewriter.create(loc, b.getF32FloatAttr(0.0)); - Type resultType = cast(result.getType()); - Type resultElementType = resultType.getElementType(); - if (!op.getAligned()) { - offset = rewriter.create(loc, b.getF32FloatAttr(0.5)); - } - - Value lb = rewriter.create(loc, 0); - Value ub0 = rewriter.create(loc, rois, 0); - Value ub1 = rewriter.create(loc, input, 1); - Value step = rewriter.create(loc, 1); - SmallVector finalOutputShape = {ub0, ub1, pooledH, pooledW}; - Value finalOutputTensor = rewriter.create( - loc, getAsOpFoldResult(finalOutputShape), resultElementType); - auto forLoop = rewriter.create( - loc, lb, ub0, step, ValueRange{}, - [&](OpBuilder &b1, Location loc, Value iv0, ValueRange args) { - auto forLoop = b1.create( - loc, lb, ub1, step, ValueRange{}, - [&](OpBuilder &b, Location loc, Value iv1, ValueRange args) { - // Step 1: Extract bounds for region of interest (roi) - OpFoldResult zeroAttr = b.getI64IntegerAttr(0); - OpFoldResult oneAttr = b.getI64IntegerAttr(1); - OpFoldResult twoAttr = b.getI64IntegerAttr(2); - OpFoldResult threeAttr = b.getI64IntegerAttr(3); - OpFoldResult fourAttr = b.getI64IntegerAttr(4); - OpFoldResult fiveAttr = b.getI64IntegerAttr(5); - // SmallVector offsetVals{iv0, zeroAttr}; - // SmallVector sizeVals{oneAttr, fiveAttr}; - SmallVector strideVals{oneAttr, oneAttr, oneAttr, - oneAttr}; - // Value extractRoiBounds = b.create( - // loc, rois, offsetVals, sizeVals, strideVals); - Value lowY = b.create( - loc, rois, ValueRange{iv0, oneAttr}); - Value lowX = b.create( - loc, rois, ValueRange{iv0, twoAttr}); - Value highY = b.create( - loc, rois, ValueRange{iv0, threeAttr}); - Value highX = b.create( - loc, rois, ValueRange{iv0, fourAttr}); - - lowY = b.create(loc, lowY, spatialScaleVal); - lowX = b.create(loc, lowX, spatialScaleVal); - highY = b.create(loc, highY, spatialScaleVal); - highX = b.create(loc, highX, spatialScaleVal); - - lowY = b.create(loc, lowY, offset); - lowX = b.create(loc, lowX, offset); - highY = b.create(loc, highY, offset); - highX = b.create(loc, highX, offset); - - // Step 2: Extract region of interest using bounds - Value lowY_int = b.create(loc, lowY); - Value lowX_int = b.create(loc, lowX); - Value highY_int = b.create(loc, highY); - Value highX_int = b.create(loc, highX); - lowY_int = - b.create(loc, b.getI64Type(), lowY_int); - lowX_int = - b.create(loc, b.getI64Type(), lowX_int); - highY_int = - b.create(loc, b.getI64Type(), highY_int); - highX_int = - b.create(loc, b.getI64Type(), highX_int); - - Value roiHeight = - b.create(loc, highY_int, lowY_int); - Value roiWidth = - b.create(loc, highX_int, lowX_int); - - SmallVector roiOffsetVals{zeroAttr, iv1, lowY_int, - lowX_int}; - SmallVector roiSizeVals{oneAttr, oneAttr, roiHeight, - roiWidth}; - - Value extractRoi = b.create( - loc, input, roiOffsetVals, roiSizeVals, strideVals); - - // Step 3: Perform bilinear interpolation over roi - Value roiBinH = b.create(loc, highY, lowY); - Value roiBinW = b.create(loc, highX, lowX); - Value scaleH = b.create(loc, roiBinH, pooledH); - Value scaleW = b.create(loc, roiBinW, pooledW); - scaleH = b.create(loc, scaleH); - scaleW = b.create(loc, scaleW); - scaleH = b.create(loc, b.getI64Type(), scaleH); - scaleW = b.create(loc, b.getI64Type(), scaleW); - - Value roiSampleHeight = - b.create(loc, pooledH, scaleH); - Value roiSampleWidth = - b.create(loc, pooledW, scaleW); - - SmallVector outputSizeIntValues = {roiSampleHeight, - roiSampleWidth}; - SmallVector dims = - getTensorSizesUntilDim(b, loc, extractRoi, 1); - for (unsigned i = 2; i < inputRank; i++) { - dims.push_back( - castIntToIndex(b, loc, outputSizeIntValues[i - 2])); - } - SmallVector inputSizes; - auto inputType = cast(extractRoi.getType()); - auto inputRank = inputType.getRank(); - for (unsigned i = 2; i < inputRank; i++) { - Value inputSize = getDimOp(b, loc, extractRoi, i); - inputSizes.push_back(b.create( - loc, b.getIntegerType(64), roiSizeVals[i])); - } - Value outTensor = b.create( - loc, getAsOpFoldResult(dims), inputType.getElementType()); - AffineMap idMap = b.getMultiDimIdentityMap(inputRank); - SmallVector iteratorTypes( - inputRank, utils::IteratorType::parallel); - Value bilinearInterpolatedRoi = - b.create( - loc, outTensor.getType(), ValueRange{}, outTensor, - /*indexingMaps=*/idMap, - /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value retVal = bilinearInterpolate( - b, op, loc, outputSizeIntValues, extractRoi, - inputSizes, ValueRange{}, "bilinear"); - b.create(loc, retVal); - }) - .getResult(0); - - // Step 4: Sum pool over interpolated values - Value sumPool, paddedInput; - SmallVector kernelSizeIntValues = {oneAttr, oneAttr, - scaleH, scaleW}; - SmallVector strideInts = {scaleH, scaleW}; - SmallVector paddingInts = {zeroAttr, zeroAttr}; - SmallVector dilationInts(oneAttr, 2); - SmallVector outTensorShape; - if (failed(createPoolingOp( - op, b, self, /*supportNonFPInput=*/true, false, - /*dimensionality=*/2, kernelSizeIntValues, strideInts, - paddingInts, dilationInts, - b.getZeroAttr(resultElementType), outTensorShape, - paddedInput, sumPool))) - return b.notifyMatchFailure(op, "unable to compute sumpool"); - - // Step 5: elementwise division by number of sampling points - // to compute avg pool - Value outputTensor = b.create( - loc, getAsOpFoldResult(outTensorShape), resultElementType); - Value divisor = b.create(loc, scaleH, scaleW); - Value avgPool = - b.create( - loc, outputTensor.getType(), sumPool, outputTensor, - /*indexingMaps=*/indexingMapsAvg, - /*iteratorTypes=*/iteratorTypesAvg, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value avg; - if (isa(resultElementType)) - avg = b.create(loc, args[0], - divisor); - else if (isa(resultElementType)) - avg = - b.create(loc, args[0], divisor); - b.create(loc, avg); - }) - .getResult(0); - - SmallVector finalStrides(inputRank, oneAttr); - SmallVector finalOffsets = { - getAsOpFoldResult(iv0), getAsOpFoldResult(iv1), zeroAttr, - zeroAttr}; - SmallVector finalSizes = { - oneAttr, oneAttr, getAsOpFoldResult(pooledH), - getAsOpFoldResult(pooledW)}; - SmallVector diagStrides(inputRank, oneAttr); - finalOutputTensor = b.create( - loc, finalOutputTensor, avgPool, finalOffsets, finalSizes, - finalStrides); - }); - }); - - Type resultType = typeConverter->convertType(op.getType()); - b.replaceOp(op, finalOutputTensor); - return success(); - } -}; -} // namespace -namespace { template -class ConvertRoiAlignOp : public OpConversionPattern { -public: +struct ConvertRoiAlignOp : final OpConversionPattern { using OpConversionPattern::OpConversionPattern; static SmallVector coordinateTransform( @@ -2001,7 +1782,7 @@ class ConvertRoiAlignOp : public OpConversionPattern { // Let Aij := area rect((yProj,xProj) <-> (y_i*,x_j*)), // where i* = i+1 mod 2 and x_0 = xLow, x_1 = xHigh etc. // We interpolate via the weighted average of pij by weights Aij - // the formula is retval = Sum(pij*Aij for i and j in range(2)) + // the formula is retval = Sum(pij*Aij for i and j in range(2)). // Note: we do not need to divide by total rect area == 1 // lengths : Aij == dyi*dxj @@ -2042,16 +1823,28 @@ class ConvertRoiAlignOp : public OpConversionPattern { cast(op.getSpatialScale().getDefiningOp()).getValue(); Value rois = op.getRois(); Value input = op.getInput(); - unsigned inputRank = cast(input.getType()).getRank(); + RankedTensorType inputType = dyn_cast_or_null( + this->getTypeConverter()->convertType(input.getType())); + llvm::dbgs() << "input"; + if (inputType == nullptr) { + op.emitError("Cannot determine input shape"); + } + + unsigned inputRank = inputType.getRank(); Value offset = rewriter.create(loc, rewriter.getF32FloatAttr(0.0)); - RankedTensorType resultType = cast(result.getType()); + RankedTensorType resultType = dyn_cast_or_null( + this->getTypeConverter()->convertType(result.getType())); + if (resultType == nullptr) { + op.emitError("Cannot determine result shape"); + } + llvm::dbgs() << "that\n"; Type resultElementType = resultType.getElementType(); if (!op.getAligned()) { offset = rewriter.create( loc, rewriter.getF32FloatAttr(0.5)); } - + llvm::dbgs() << "1\n"; Value lb = rewriter.create(loc, 0); Value ub0 = rewriter.create(loc, rois, 0); Value ub1 = rewriter.create(loc, input, 1); @@ -2065,68 +1858,79 @@ class ConvertRoiAlignOp : public OpConversionPattern { b.create( loc, lb, ub1, step, ValueRange{}, [&](OpBuilder &b, Location loc, Value iv1, ValueRange args) { - // Step 1: Extract bounds for region of interest (roi) + llvm::dbgs() << "2\n"; + // Step 1: Extract bounds for region of interest (roi). OpFoldResult zeroAttr = b.getI64IntegerAttr(0); OpFoldResult oneAttr = b.getI64IntegerAttr(1); - + llvm::dbgs() << "2.1\n"; Value cstZero = rewriter.create(loc, 0); Value cstOne = rewriter.create(loc, 1); Value cstTwo = rewriter.create(loc, 2); Value cstThree = rewriter.create(loc, 3); Value cstFour = rewriter.create(loc, 4); + llvm::dbgs() << "2.2\n"; SmallVector strideVals{oneAttr, oneAttr, oneAttr, oneAttr}; + llvm::dbgs() << "2.21\n"; SmallVector lowYIndices = {iv0, cstOne}; + llvm::dbgs() << "2.211\n"; + llvm::dbgs() << rois << "\n"; Value lowY = b.create(loc, rois, lowYIndices); + // Value lowY = b.create(loc, b.getF32FloatAttr(0.0)); + llvm::dbgs() << "2.212\n"; SmallVector lowXIndices = {iv0, cstTwo}; + llvm::dbgs() << "2.213\n"; Value lowX = b.create(loc, rois, lowXIndices); + llvm::dbgs() << "2.214\n"; SmallVector highYIndices = {iv0, cstThree}; + llvm::dbgs() << "2.22\n"; Value highY = b.create(loc, rois, highYIndices); SmallVector highXIndices = {iv0, cstFour}; + llvm::dbgs() << "2.23\n"; Value highX = b.create(loc, rois, highXIndices); - + llvm::dbgs() << "2.5\n"; lowY = b.create(loc, lowY, spatialScaleVal); lowX = b.create(loc, lowX, spatialScaleVal); highY = b.create(loc, highY, spatialScaleVal); highX = b.create(loc, highX, spatialScaleVal); - + llvm::dbgs() << "3\n"; lowY = b.create(loc, lowY, offset); lowX = b.create(loc, lowX, offset); highY = b.create(loc, highY, offset); highX = b.create(loc, highX, offset); // Step 2: Extract region of interest using bounds - Value lowY_int = b.create(loc, lowY); - Value lowX_int = b.create(loc, lowX); - Value highY_int = b.create(loc, highY); - Value highX_int = b.create(loc, highX); - lowY_int = - b.create(loc, b.getI64Type(), lowY_int); - lowX_int = - b.create(loc, b.getI64Type(), lowX_int); - highY_int = - b.create(loc, b.getI64Type(), highY_int); - highX_int = - b.create(loc, b.getI64Type(), highX_int); + Value lowYInt = b.create(loc, lowY); + Value lowXInt = b.create(loc, lowX); + Value highYInt = b.create(loc, highY); + Value highXInt = b.create(loc, highX); + lowYInt = + b.create(loc, b.getI64Type(), lowYInt); + lowXInt = + b.create(loc, b.getI64Type(), lowXInt); + highYInt = + b.create(loc, b.getI64Type(), highYInt); + highXInt = + b.create(loc, b.getI64Type(), highXInt); Value roiHeight = - b.create(loc, highY_int, lowY_int); + b.create(loc, highYInt, lowYInt); Value roiWidth = - b.create(loc, highX_int, lowX_int); + b.create(loc, highXInt, lowXInt); SmallVector roiOffsetVals = { getAsOpFoldResult(cstZero), getAsOpFoldResult(iv1), - getAsOpFoldResult(lowY_int), getAsOpFoldResult(lowX_int)}; + getAsOpFoldResult(lowYInt), getAsOpFoldResult(lowXInt)}; SmallVector roiSizeVals = {cstOne, cstOne, roiHeight, roiWidth}; Value extractRoi = b.create( - loc, input, ValueRange{cstZero, iv1, lowY_int, lowX_int}, + loc, input, ValueRange{cstZero, iv1, lowYInt, lowXInt}, ValueRange{cstOne, cstOne, roiHeight, roiWidth}, ValueRange{cstOne, cstOne, cstOne, cstOne}); @@ -2160,8 +1964,6 @@ class ConvertRoiAlignOp : public OpConversionPattern { castIntToIndex(b, loc, outputSizeIntValues[i - 2])); } SmallVector inputSizes; - auto inputType = cast(extractRoi.getType()); - auto inputRank = inputType.getRank(); for (unsigned i = 2; i < inputRank; i++) { inputSizes.push_back(b.create( loc, b.getIntegerType(64), roiSizeVals[i])); @@ -2223,7 +2025,7 @@ class ConvertRoiAlignOp : public OpConversionPattern { b.create(loc, avg); }) .getResult(0); - + llvm::dbgs() << "4\n"; SmallVector finalStrides(inputRank, oneAttr); SmallVector finalOffsets = { getAsOpFoldResult(iv0), getAsOpFoldResult(iv1), zeroAttr, @@ -2237,6 +2039,7 @@ class ConvertRoiAlignOp : public OpConversionPattern { finalStrides); }); }); + llvm::dbgs() << "5\n"; rewriter.replaceOp(op, finalOutputTensor); return success(); } From ca5df5522f9d5ade566634534911f94f63758d30 Mon Sep 17 00:00:00 2001 From: Muzammiluddin Syed Date: Mon, 10 Mar 2025 19:28:14 -0500 Subject: [PATCH 4/6] removed debug statements addressed comments --- lib/Conversion/TorchToLinalg/Pooling.cpp | 161 +++++++++++++---------- 1 file changed, 95 insertions(+), 66 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index ea9c2e07e5fd..7594b82b5465 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -153,26 +153,30 @@ static LogicalResult createPoolingOp( SmallVectorImpl &dilationInts, Attribute initValueAttr, SmallVectorImpl &outTensorShape, Value &paddedInput, Value &result) { Location loc = op->getLoc(); + + Type elementType = cast(self.getType()).getElementType(); if (!isa(elementType) && !supportNonFPInput) return op->emitError("unimplemented: non-floating point type"); - + Value initValue = rewriter.create(loc, cast(initValueAttr)); paddedInput = padInputTensor(op, rewriter, self, ceilMode, dimensionality, strideInts, paddingInts, initValue); - + auto outTensorInitialized = computeOutputTensor( op, rewriter, self, dimensionality, ceilMode, strideInts, paddingInts, dilationInts, kernelSizeIntValues, outTensorShape, initValue); - + + auto stridesAttr = rewriter.getI64VectorAttr(strideInts); auto dilationAttr = rewriter.getI64VectorAttr(dilationInts); auto shape = castIntVectorToIndexVector(rewriter, loc, kernelSizeIntValues); + Value windowTensor = rewriter.create( loc, getAsOpFoldResult(shape), elementType); - + Value permutedInput = paddedInput, permutedOutput = outTensorInitialized; if (dimensionality == 3) { // Permute input and output tensor as follows: @@ -190,7 +194,7 @@ static LogicalResult createPoolingOp( return rewriter.notifyMatchFailure( op, "failed to perform permutation of tensor"); } - + Value poolingResult = rewriter .create(loc, permutedOutput.getType(), @@ -1618,15 +1622,17 @@ class ConvertAtenAdaptivePoolOp : public OpConversionPattern { } }; -template -struct ConvertRoiAlignOp : final OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ConvertRoiAlignOp final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - static SmallVector coordinateTransform( - OpBuilder &b, OpTy op, Location loc, SmallVector outputSizes, - Value input, SmallVector inputSizes, - SmallVector scaleValues, std::string coordStr, - bool alignCornersBool, SmallVector indices, bool clip) { + static SmallVector + coordinateTransform(OpBuilder &b, Torch::TorchvisionRoiAlignOp op, + Location loc, SmallVector outputSizes, Value input, + SmallVector inputSizes, + SmallVector scaleValues, std::string coordStr, + bool alignCornersBool, SmallVector indices, + bool clip) { unsigned dimOffset = 2; auto inputType = cast(input.getType()); @@ -1647,6 +1653,7 @@ struct ConvertRoiAlignOp : final OpConversionPattern { b.create(loc, b.getF32Type(), outputSizes[i]); // scale = length_resized/length_original Value scale; + if (alignCornersBool) { // x_original = x_resized * (length_original - 1) / (length_resized - 1) Value inputSubOne = b.create(loc, inputFP, cstOneFloat); @@ -1695,32 +1702,43 @@ struct ConvertRoiAlignOp : final OpConversionPattern { Value offset = b.create(loc, center, oneMAdjustment); preClip = b.create(loc, offset, preClip); } + // for pytorch half pixel , special case for length_resized == 1: if (coordStr == "_pytorch_half_pixel") { + Value cmp = b.create(loc, arith::CmpFPredicate::UEQ, outputSizeFP, cstOneFloat); + preClip = b.create(loc, cmp, zero, preClip); } if (clip) { // preClip is the fp position inside the input image to extract from. // clip to [0,inf) + Value max = b.create(loc, preClip, zero); + Value inputSubOne = b.create(loc, inputFP, cstOneFloat); // clip to [0,length_original - 1]. // proj is properly within the input image. + proj.push_back(b.create(loc, max, inputSubOne)); + } else { + proj.push_back(preClip); } } + return proj; } - static Value bilinearInterpolate(OpBuilder &b, OpTy op, Location loc, - SmallVector outputSizes, Value input, - SmallVector inputSizes, + static Value bilinearInterpolate(OpBuilder &b, + Torch::TorchvisionRoiAlignOp op, + Location loc, SmallVector outputSizes, + Value input, SmallVector inputSizes, SmallVector scaleValues, std::string coordStr) { + unsigned dimOffset = 2; auto inputType = cast(input.getType()); auto inputRank = inputType.getRank(); @@ -1729,21 +1747,22 @@ struct ConvertRoiAlignOp : final OpConversionPattern { b.create(loc, b.getF32FloatAttr(1.0)); SmallVector indices; - for (unsigned i = 0; i < inputRank; i++) { + for (unsigned i = 0; i < inputRank; ++i) { indices.push_back(b.create(loc, i)); } SmallVector proj, high, low, highFP, lowFP; + proj = coordinateTransform(b, op, loc, outputSizes, input, inputSizes, scaleValues, coordStr, false, indices, true); - for (unsigned i = 0; i < inputRank - dimOffset; i++) { + for (unsigned i = 0; i < inputRank - dimOffset; ++i) { // length_original Value inputFP = b.create(loc, b.getF32Type(), inputSizes[i]); Value inputSubOne = b.create(loc, inputFP, cstOneFloat); // for bilinear interpolation, we look for the nearest indices below and - // above proj + // above proj. lowFP.push_back(b.create(loc, proj[i])); Value projPlusOne = b.create(loc, cstOneFloat, proj[i]); highFP.push_back(b.create(loc, projPlusOne)); @@ -1759,6 +1778,7 @@ struct ConvertRoiAlignOp : final OpConversionPattern { Value highExtract = b.create(loc, projPlusOne, inputSubOne); highExtract = b.create(loc, b.getI64Type(), highExtract); + high.push_back( b.create(loc, b.getIndexType(), highExtract)); } @@ -1783,7 +1803,7 @@ struct ConvertRoiAlignOp : final OpConversionPattern { // where i* = i+1 mod 2 and x_0 = xLow, x_1 = xHigh etc. // We interpolate via the weighted average of pij by weights Aij // the formula is retval = Sum(pij*Aij for i and j in range(2)). - // Note: we do not need to divide by total rect area == 1 + // Note: we do not need to divide by total rect area == 1. // lengths : Aij == dyi*dxj Value dy0 = b.create(loc, highFP[0], proj[0]); @@ -1797,6 +1817,7 @@ struct ConvertRoiAlignOp : final OpConversionPattern { Value sum = b.create(loc, dx0p00, dx1p01); Value left = b.create(loc, dy0, sum); // right = A10*p10 + A11*p11 = dy1(dx0p10 + dx1p11) + Value dx0p10 = b.create(loc, dx0, p10); Value dx1p11 = b.create(loc, dx1, p11); sum = b.create(loc, dx0p10, dx1p11); @@ -1805,7 +1826,8 @@ struct ConvertRoiAlignOp : final OpConversionPattern { return b.create(loc, left, right); } LogicalResult - matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, + matchAndRewrite(Torch::TorchvisionRoiAlignOp op, + typename Torch::TorchvisionRoiAlignOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); @@ -1818,33 +1840,33 @@ struct ConvertRoiAlignOp : final OpConversionPattern { int64_t samplingRatioInt = static_cast(samplingRatio); Value pooledH = op.getPooledHeight(); Value pooledW = op.getPooledWidth(); - Value spatialScaleVal = op.getSpatialScale(); + Value spatialScaleVal = adaptor.getSpatialScale(); llvm::APFloat spatialScale = cast(op.getSpatialScale().getDefiningOp()).getValue(); - Value rois = op.getRois(); - Value input = op.getInput(); + Value rois = adaptor.getRois(); + Value input = adaptor.getInput(); RankedTensorType inputType = dyn_cast_or_null( - this->getTypeConverter()->convertType(input.getType())); - llvm::dbgs() << "input"; + this->getTypeConverter()->convertType(op.getInput().getType())); + if (inputType == nullptr) { op.emitError("Cannot determine input shape"); } - + unsigned inputRank = inputType.getRank(); Value offset = rewriter.create(loc, rewriter.getF32FloatAttr(0.0)); RankedTensorType resultType = dyn_cast_or_null( - this->getTypeConverter()->convertType(result.getType())); + this->getTypeConverter()->convertType(result.getType())); if (resultType == nullptr) { op.emitError("Cannot determine result shape"); } - llvm::dbgs() << "that\n"; + Type resultElementType = resultType.getElementType(); if (!op.getAligned()) { offset = rewriter.create( loc, rewriter.getF32FloatAttr(0.5)); } - llvm::dbgs() << "1\n"; + Value lb = rewriter.create(loc, 0); Value ub0 = rewriter.create(loc, rois, 0); Value ub1 = rewriter.create(loc, input, 1); @@ -1858,47 +1880,43 @@ struct ConvertRoiAlignOp : final OpConversionPattern { b.create( loc, lb, ub1, step, ValueRange{}, [&](OpBuilder &b, Location loc, Value iv1, ValueRange args) { - llvm::dbgs() << "2\n"; // Step 1: Extract bounds for region of interest (roi). OpFoldResult zeroAttr = b.getI64IntegerAttr(0); OpFoldResult oneAttr = b.getI64IntegerAttr(1); - llvm::dbgs() << "2.1\n"; + Value cstZero = rewriter.create(loc, 0); Value cstOne = rewriter.create(loc, 1); Value cstTwo = rewriter.create(loc, 2); Value cstThree = rewriter.create(loc, 3); Value cstFour = rewriter.create(loc, 4); - llvm::dbgs() << "2.2\n"; + SmallVector strideVals{oneAttr, oneAttr, oneAttr, oneAttr}; - llvm::dbgs() << "2.21\n"; + SmallVector lowYIndices = {iv0, cstOne}; - llvm::dbgs() << "2.211\n"; - llvm::dbgs() << rois << "\n"; - Value lowY = - b.create(loc, rois, lowYIndices); - // Value lowY = b.create(loc, b.getF32FloatAttr(0.0)); - llvm::dbgs() << "2.212\n"; + Value lowY = b.create(loc, b.getF32Type(), + rois, lowYIndices); + SmallVector lowXIndices = {iv0, cstTwo}; - llvm::dbgs() << "2.213\n"; - Value lowX = - b.create(loc, rois, lowXIndices); - llvm::dbgs() << "2.214\n"; + + Value lowX = b.create(loc, b.getF32Type(), + rois, lowXIndices); + SmallVector highYIndices = {iv0, cstThree}; - llvm::dbgs() << "2.22\n"; - Value highY = - b.create(loc, rois, highYIndices); + + Value highY = b.create(loc, b.getF32Type(), + rois, highYIndices); SmallVector highXIndices = {iv0, cstFour}; - llvm::dbgs() << "2.23\n"; - Value highX = - b.create(loc, rois, highXIndices); - llvm::dbgs() << "2.5\n"; + + Value highX = b.create(loc, b.getF32Type(), + rois, highXIndices); + lowY = b.create(loc, lowY, spatialScaleVal); lowX = b.create(loc, lowX, spatialScaleVal); highY = b.create(loc, highY, spatialScaleVal); highX = b.create(loc, highX, spatialScaleVal); - llvm::dbgs() << "3\n"; + lowY = b.create(loc, lowY, offset); lowX = b.create(loc, lowX, offset); highY = b.create(loc, highY, offset); @@ -1934,7 +1952,7 @@ struct ConvertRoiAlignOp : final OpConversionPattern { ValueRange{cstOne, cstOne, roiHeight, roiWidth}, ValueRange{cstOne, cstOne, cstOne, cstOne}); - // Step 3: Perform bilinear interpolation over roi + // Step 3: Perform bilinear interpolation over roi. Value roiBinH = b.create(loc, highY, lowY); Value roiBinW = b.create(loc, highX, lowX); Value scaleH = b.create(loc, roiBinH, pooledH); @@ -1943,6 +1961,7 @@ struct ConvertRoiAlignOp : final OpConversionPattern { scaleW = b.create(loc, scaleW); scaleH = b.create(loc, b.getI64Type(), scaleH); scaleW = b.create(loc, b.getI64Type(), scaleW); + if (samplingRatio > 0) { scaleH = b.create( loc, rewriter.getI64IntegerAttr(samplingRatio)); @@ -1959,20 +1978,25 @@ struct ConvertRoiAlignOp : final OpConversionPattern { roiSampleWidth}; SmallVector dims = getTensorSizesUntilDim(b, loc, extractRoi, 1); - for (unsigned i = 2; i < inputRank; i++) { - dims.push_back( - castIntToIndex(b, loc, outputSizeIntValues[i - 2])); + + for (unsigned i = 2; i < inputRank; ++i) { + auto dim = b.create( + loc, b.getIndexType(), outputSizeIntValues[i - 2]); + dims.push_back(dim); } + SmallVector inputSizes; - for (unsigned i = 2; i < inputRank; i++) { + for (unsigned i = 2; i < inputRank; ++i) { inputSizes.push_back(b.create( loc, b.getIntegerType(64), roiSizeVals[i])); } + Value outTensor = b.create( loc, getAsOpFoldResult(dims), inputType.getElementType()); AffineMap idMap = b.getMultiDimIdentityMap(inputRank); SmallVector iteratorTypes( inputRank, utils::IteratorType::parallel); + Value bilinearInterpolatedRoi = b.create( loc, outTensor.getType(), ValueRange{}, outTensor, @@ -1981,20 +2005,25 @@ struct ConvertRoiAlignOp : final OpConversionPattern { [&](OpBuilder &b, Location loc, ValueRange args) { Value retVal = bilinearInterpolate( b, op, loc, outputSizeIntValues, extractRoi, - inputSizes, ValueRange{}, "bilinear"); + inputSizes, ValueRange{}, ""); + b.create(loc, retVal); }) .getResult(0); - // Step 4: Sum pool over interpolated values + // Step 4: Sum pool over interpolated values. + Value sumPool, paddedInput; - SmallVector kernelSizeIntValues = {cstOne, cstOne, + Value oneInt = + b.create(loc, b.getI64IntegerAttr(1)); + SmallVector kernelSizeIntValues = {oneInt, oneInt, scaleH, scaleW}; SmallVector strideInts = {samplingRatioInt, samplingRatioInt}; SmallVector paddingInts = {0, 0}; - SmallVector dilationInts(2, 1); + SmallVector dilationInts = {1, 1}; SmallVector outTensorShape; + if (failed(createPoolingOp( op, rewriter, bilinearInterpolatedRoi, /*supportNonFPInput=*/true, false, @@ -2005,10 +2034,11 @@ struct ConvertRoiAlignOp : final OpConversionPattern { op.emitError("unable to compute sumpool"); // Step 5: elementwise division by number of sampling points - // to compute avg pool + // to compute avg pool. Value outputTensor = b.create( loc, getAsOpFoldResult(outTensorShape), resultElementType); Value divisor = b.create(loc, scaleH, scaleW); + Value avgPool = b.create( loc, outputTensor.getType(), sumPool, outputTensor, @@ -2025,7 +2055,7 @@ struct ConvertRoiAlignOp : final OpConversionPattern { b.create(loc, avg); }) .getResult(0); - llvm::dbgs() << "4\n"; + SmallVector finalStrides(inputRank, oneAttr); SmallVector finalOffsets = { getAsOpFoldResult(iv0), getAsOpFoldResult(iv1), zeroAttr, @@ -2039,7 +2069,7 @@ struct ConvertRoiAlignOp : final OpConversionPattern { finalStrides); }); }); - llvm::dbgs() << "5\n"; + rewriter.replaceOp(op, finalOutputTensor); return success(); } @@ -2095,6 +2125,5 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( typeConverter, context); patterns.add>( typeConverter, context); - patterns.add>(typeConverter, - context); + patterns.add(typeConverter, context); } From e4a50e5bab6394c923744ce38d00bc06e68e5d37 Mon Sep 17 00:00:00 2001 From: Muzammiluddin Syed Date: Wed, 12 Mar 2025 01:09:17 -0500 Subject: [PATCH 5/6] not ready for review, fixing bugs, will drop commit later --- lib/Conversion/TorchToLinalg/Pooling.cpp | 121 ++++++++++++++--------- 1 file changed, 74 insertions(+), 47 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 7594b82b5465..421ee91f9e3f 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -168,7 +168,7 @@ static LogicalResult createPoolingOp( auto outTensorInitialized = computeOutputTensor( op, rewriter, self, dimensionality, ceilMode, strideInts, paddingInts, dilationInts, kernelSizeIntValues, outTensorShape, initValue); - + llvm::dbgs() << outTensorInitialized << " [][][][][][][][][]\n"; auto stridesAttr = rewriter.getI64VectorAttr(strideInts); auto dilationAttr = rewriter.getI64VectorAttr(dilationInts); @@ -201,6 +201,7 @@ static LogicalResult createPoolingOp( ValueRange{permutedInput, windowTensor}, permutedOutput, stridesAttr, dilationAttr) .getResult(0); + llvm::dbgs() << poolingResult << "{}{}{}{}{}{}\n"; result = poolingResult; if (dimensionality == 3) { @@ -1671,6 +1672,7 @@ struct ConvertRoiAlignOp final // y_resized Value outInt = b.create(loc, b.getI64Type(), indices[i + dimOffset]); +llvm::dbgs() << outInt << " HERE <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n"; Value outFP = b.create(loc, b.getF32Type(), outInt); Value preClip; if (coordStr == "_align_corners") { @@ -1738,7 +1740,7 @@ struct ConvertRoiAlignOp final Value input, SmallVector inputSizes, SmallVector scaleValues, std::string coordStr) { - + llvm::dbgs() << "12A\n"; unsigned dimOffset = 2; auto inputType = cast(input.getType()); auto inputRank = inputType.getRank(); @@ -1750,11 +1752,12 @@ struct ConvertRoiAlignOp final for (unsigned i = 0; i < inputRank; ++i) { indices.push_back(b.create(loc, i)); } - + llvm::dbgs() << "12A1\n"; SmallVector proj, high, low, highFP, lowFP; proj = coordinateTransform(b, op, loc, outputSizes, input, inputSizes, scaleValues, coordStr, false, indices, true); + llvm::dbgs() << "12B\n"; for (unsigned i = 0; i < inputRank - dimOffset; ++i) { // length_original Value inputFP = @@ -1782,7 +1785,7 @@ struct ConvertRoiAlignOp final high.push_back( b.create(loc, b.getIndexType(), highExtract)); } - + llvm::dbgs() << "12B1\n"; indices[dimOffset] = low[0]; indices[dimOffset + 1] = low[1]; Value p00 = b.create(loc, input, indices); @@ -1823,12 +1826,14 @@ struct ConvertRoiAlignOp final sum = b.create(loc, dx0p10, dx1p11); Value right = b.create(loc, dy1, sum); + llvm::dbgs() << "12C\n"; return b.create(loc, left, right); } LogicalResult matchAndRewrite(Torch::TorchvisionRoiAlignOp op, typename Torch::TorchvisionRoiAlignOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); @@ -1838,11 +1843,16 @@ struct ConvertRoiAlignOp final uint64_t samplingRatio = cast(op.getSamplingRatio().getDefiningOp()).getValue(); int64_t samplingRatioInt = static_cast(samplingRatio); - Value pooledH = op.getPooledHeight(); - Value pooledW = op.getPooledWidth(); - Value spatialScaleVal = adaptor.getSpatialScale(); + Value pooledH = adaptor.getPooledHeight(); + Value pooledW = adaptor.getPooledWidth(); + Value pooledHFp = rewriter.create(loc, rewriter.getF32Type(), pooledH); + Value pooledWFp = rewriter.create(loc, rewriter.getF32Type(), pooledW); + + // Value spatialScaleVal = adaptor.getSpatialScale(); llvm::APFloat spatialScale = cast(op.getSpatialScale().getDefiningOp()).getValue(); + Value spatialScaleVal = rewriter.create( + loc, rewriter.getF32FloatAttr(spatialScale.convertToDouble())); Value rois = adaptor.getRois(); Value input = adaptor.getInput(); RankedTensorType inputType = dyn_cast_or_null( @@ -1871,7 +1881,11 @@ struct ConvertRoiAlignOp final Value ub0 = rewriter.create(loc, rois, 0); Value ub1 = rewriter.create(loc, input, 1); Value step = rewriter.create(loc, 1); - SmallVector finalOutputShape = {ub0, ub1, pooledH, pooledW}; + auto pooledHIdx = rewriter.create( + loc, rewriter.getIndexType(), pooledH); + auto pooledWIdx = rewriter.create( + loc, rewriter.getIndexType(), pooledW); + SmallVector finalOutputShape = {ub0, ub1, pooledHIdx, pooledWIdx}; Value finalOutputTensor = rewriter.create( loc, getAsOpFoldResult(finalOutputShape), resultElementType); rewriter.create( @@ -1883,9 +1897,12 @@ struct ConvertRoiAlignOp final // Step 1: Extract bounds for region of interest (roi). OpFoldResult zeroAttr = b.getI64IntegerAttr(0); OpFoldResult oneAttr = b.getI64IntegerAttr(1); - - Value cstZero = rewriter.create(loc, 0); - Value cstOne = rewriter.create(loc, 1); + Value intOne = + b.create(loc, b.getI64IntegerAttr(1)); + // Value intZero = + // b.create(loc, b.getI64IntegerAttr(0)); + Value idxZero = rewriter.create(loc, 0); + Value idxOne = rewriter.create(loc, 1); Value cstTwo = rewriter.create(loc, 2); Value cstThree = rewriter.create(loc, 3); @@ -1894,7 +1911,7 @@ struct ConvertRoiAlignOp final SmallVector strideVals{oneAttr, oneAttr, oneAttr, oneAttr}; - SmallVector lowYIndices = {iv0, cstOne}; + SmallVector lowYIndices = {iv0, idxOne}; Value lowY = b.create(loc, b.getF32Type(), rois, lowYIndices); @@ -1911,22 +1928,27 @@ struct ConvertRoiAlignOp final Value highX = b.create(loc, b.getF32Type(), rois, highXIndices); - + llvm::dbgs() << "7A" << "\n"; + llvm::dbgs() << " LOL..\n"; lowY = b.create(loc, lowY, spatialScaleVal); + llvm::dbgs() << " LOLA\n"; lowX = b.create(loc, lowX, spatialScaleVal); + llvm::dbgs() << " LOLB\n"; highY = b.create(loc, highY, spatialScaleVal); + llvm::dbgs() << " LOLC\n"; highX = b.create(loc, highX, spatialScaleVal); - + llvm::dbgs() << " LOL1\n"; lowY = b.create(loc, lowY, offset); lowX = b.create(loc, lowX, offset); highY = b.create(loc, highY, offset); highX = b.create(loc, highX, offset); - + llvm::dbgs() << " LOL2\n"; // Step 2: Extract region of interest using bounds Value lowYInt = b.create(loc, lowY); Value lowXInt = b.create(loc, lowX); Value highYInt = b.create(loc, highY); Value highXInt = b.create(loc, highX); + llvm::dbgs() << " LOL3\n"; lowYInt = b.create(loc, b.getI64Type(), lowYInt); lowXInt = @@ -1935,28 +1957,33 @@ struct ConvertRoiAlignOp final b.create(loc, b.getI64Type(), highYInt); highXInt = b.create(loc, b.getI64Type(), highXInt); - + Value lowYIdx = b.create(loc, b.getIndexType(), lowYInt); + Value lowXIdx = b.create(loc, b.getIndexType(), lowXInt); + llvm::dbgs() << " LOL4\n"; + llvm::dbgs() << lowYIdx << "\n^^^\n\n"; Value roiHeight = b.create(loc, highYInt, lowYInt); Value roiWidth = b.create(loc, highXInt, lowXInt); + Value roiHIdx = b.create(loc, b.getIndexType(), roiHeight); + Value roiWIdx = b.create(loc, b.getIndexType(), roiWidth); SmallVector roiOffsetVals = { - getAsOpFoldResult(cstZero), getAsOpFoldResult(iv1), + getAsOpFoldResult(idxZero), getAsOpFoldResult(iv1), getAsOpFoldResult(lowYInt), getAsOpFoldResult(lowXInt)}; - SmallVector roiSizeVals = {cstOne, cstOne, roiHeight, + SmallVector roiSizeVals = {intOne, intOne, roiHeight, roiWidth}; Value extractRoi = b.create( - loc, input, ValueRange{cstZero, iv1, lowYInt, lowXInt}, - ValueRange{cstOne, cstOne, roiHeight, roiWidth}, - ValueRange{cstOne, cstOne, cstOne, cstOne}); + loc, input, ValueRange{idxZero, iv1, lowYIdx, lowXIdx}, + ValueRange{idxOne, idxOne, roiHIdx, roiWIdx}, + ValueRange{idxOne, idxOne, idxOne, idxOne}); // Step 3: Perform bilinear interpolation over roi. Value roiBinH = b.create(loc, highY, lowY); Value roiBinW = b.create(loc, highX, lowX); - Value scaleH = b.create(loc, roiBinH, pooledH); - Value scaleW = b.create(loc, roiBinW, pooledW); + Value scaleH = b.create(loc, roiBinH, pooledHFp); + Value scaleW = b.create(loc, roiBinW, pooledWFp); scaleH = b.create(loc, scaleH); scaleW = b.create(loc, scaleW); scaleH = b.create(loc, b.getI64Type(), scaleH); @@ -1987,22 +2014,28 @@ struct ConvertRoiAlignOp final SmallVector inputSizes; for (unsigned i = 2; i < inputRank; ++i) { - inputSizes.push_back(b.create( - loc, b.getIntegerType(64), roiSizeVals[i])); + inputSizes.push_back(roiSizeVals[i]); } Value outTensor = b.create( loc, getAsOpFoldResult(dims), inputType.getElementType()); - AffineMap idMap = b.getMultiDimIdentityMap(inputRank); - SmallVector iteratorTypes( - inputRank, utils::IteratorType::parallel); - + auto iteratorTypes = + SmallVector(inputRank, utils::IteratorType::parallel); + iteratorTypes.append(inputRank, utils::IteratorType::parallel); + SmallVector idMap(2, b.getMultiDimIdentityMap(inputRank)); + //AffineMap idMap = b.getMultiDimIdentityMap(inputRank); + // SmallVector iteratorTypes( + // inputRank, utils::IteratorType::parallel); + + llvm::dbgs() << "5A" << "\n"; + llvm::dbgs() << "4.99A" << "\n"; Value bilinearInterpolatedRoi = b.create( - loc, outTensor.getType(), ValueRange{}, outTensor, + loc, outTensor.getType(), extractRoi, outTensor, /*indexingMaps=*/idMap, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { + llvm::dbgs() << "4.9A" << "\n"; Value retVal = bilinearInterpolate( b, op, loc, outputSizeIntValues, extractRoi, inputSizes, ValueRange{}, ""); @@ -2010,20 +2043,18 @@ struct ConvertRoiAlignOp final b.create(loc, retVal); }) .getResult(0); - + llvm::dbgs() << "4A" << "\n"; // Step 4: Sum pool over interpolated values. - Value sumPool, paddedInput; - Value oneInt = - b.create(loc, b.getI64IntegerAttr(1)); - SmallVector kernelSizeIntValues = {oneInt, oneInt, + + SmallVector kernelSizeIntValues = {/*intOne, intOne,*/ scaleH, scaleW}; SmallVector strideInts = {samplingRatioInt, samplingRatioInt}; SmallVector paddingInts = {0, 0}; SmallVector dilationInts = {1, 1}; SmallVector outTensorShape; - + llvm::dbgs() << "3A" << "\n"; if (failed(createPoolingOp( op, rewriter, bilinearInterpolatedRoi, /*supportNonFPInput=*/true, false, @@ -2038,24 +2069,20 @@ struct ConvertRoiAlignOp final Value outputTensor = b.create( loc, getAsOpFoldResult(outTensorShape), resultElementType); Value divisor = b.create(loc, scaleH, scaleW); - + divisor = rewriter.create(loc, rewriter.getF32Type(), divisor); + llvm::dbgs() << "2A" << "\n"; Value avgPool = b.create( loc, outputTensor.getType(), sumPool, outputTensor, /*indexingMaps=*/idMap, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { - Value avg; - if (isa(resultElementType)) - avg = b.create(loc, args[0], - divisor); - else if (isa(resultElementType)) - avg = - b.create(loc, args[0], divisor); - b.create(loc, avg); + Value res = b.create(loc, args[0], divisor); + b.create(loc, res); }) .getResult(0); - + llvm::dbgs() << avgPool << " <------------------ \n"; + llvm::dbgs() << "1" << "\n"; SmallVector finalStrides(inputRank, oneAttr); SmallVector finalOffsets = { getAsOpFoldResult(iv0), getAsOpFoldResult(iv1), zeroAttr, @@ -2069,7 +2096,7 @@ struct ConvertRoiAlignOp final finalStrides); }); }); - + llvm::dbgs() << "0" << "\n"; rewriter.replaceOp(op, finalOutputTensor); return success(); } From d035744812b33aa082b0973448d169d14902b39f Mon Sep 17 00:00:00 2001 From: Muzammiluddin Syed Date: Wed, 12 Mar 2025 13:59:21 -0500 Subject: [PATCH 6/6] last commit before hiatus --- lib/Conversion/TorchToLinalg/Pooling.cpp | 58 ++++++------------------ 1 file changed, 13 insertions(+), 45 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 421ee91f9e3f..60fbbf675a28 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -168,7 +168,6 @@ static LogicalResult createPoolingOp( auto outTensorInitialized = computeOutputTensor( op, rewriter, self, dimensionality, ceilMode, strideInts, paddingInts, dilationInts, kernelSizeIntValues, outTensorShape, initValue); - llvm::dbgs() << outTensorInitialized << " [][][][][][][][][]\n"; auto stridesAttr = rewriter.getI64VectorAttr(strideInts); auto dilationAttr = rewriter.getI64VectorAttr(dilationInts); @@ -201,7 +200,6 @@ static LogicalResult createPoolingOp( ValueRange{permutedInput, windowTensor}, permutedOutput, stridesAttr, dilationAttr) .getResult(0); - llvm::dbgs() << poolingResult << "{}{}{}{}{}{}\n"; result = poolingResult; if (dimensionality == 3) { @@ -1672,7 +1670,6 @@ struct ConvertRoiAlignOp final // y_resized Value outInt = b.create(loc, b.getI64Type(), indices[i + dimOffset]); -llvm::dbgs() << outInt << " HERE <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n"; Value outFP = b.create(loc, b.getF32Type(), outInt); Value preClip; if (coordStr == "_align_corners") { @@ -1740,7 +1737,6 @@ llvm::dbgs() << outInt << " HERE <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n"; Value input, SmallVector inputSizes, SmallVector scaleValues, std::string coordStr) { - llvm::dbgs() << "12A\n"; unsigned dimOffset = 2; auto inputType = cast(input.getType()); auto inputRank = inputType.getRank(); @@ -1752,12 +1748,10 @@ llvm::dbgs() << outInt << " HERE <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n"; for (unsigned i = 0; i < inputRank; ++i) { indices.push_back(b.create(loc, i)); } - llvm::dbgs() << "12A1\n"; SmallVector proj, high, low, highFP, lowFP; proj = coordinateTransform(b, op, loc, outputSizes, input, inputSizes, scaleValues, coordStr, false, indices, true); - llvm::dbgs() << "12B\n"; for (unsigned i = 0; i < inputRank - dimOffset; ++i) { // length_original Value inputFP = @@ -1785,7 +1779,6 @@ llvm::dbgs() << outInt << " HERE <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n"; high.push_back( b.create(loc, b.getIndexType(), highExtract)); } - llvm::dbgs() << "12B1\n"; indices[dimOffset] = low[0]; indices[dimOffset + 1] = low[1]; Value p00 = b.create(loc, input, indices); @@ -1826,7 +1819,6 @@ llvm::dbgs() << outInt << " HERE <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n"; sum = b.create(loc, dx0p10, dx1p11); Value right = b.create(loc, dy1, sum); - llvm::dbgs() << "12C\n"; return b.create(loc, left, right); } LogicalResult @@ -1888,19 +1880,17 @@ llvm::dbgs() << outInt << " HERE <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n"; SmallVector finalOutputShape = {ub0, ub1, pooledHIdx, pooledWIdx}; Value finalOutputTensor = rewriter.create( loc, getAsOpFoldResult(finalOutputShape), resultElementType); - rewriter.create( - loc, lb, ub0, step, ValueRange{}, - [&](OpBuilder &b, Location loc, Value iv0, ValueRange args) { - b.create( - loc, lb, ub1, step, ValueRange{}, + auto resOut = rewriter.create( + loc, lb, ub0, step, ValueRange{finalOutputTensor}, + [&](OpBuilder &b, Location loc, Value iv0, ValueRange args0) { + auto res = b.create( + loc, lb, ub1, step, ValueRange{args0[0]}, [&](OpBuilder &b, Location loc, Value iv1, ValueRange args) { // Step 1: Extract bounds for region of interest (roi). OpFoldResult zeroAttr = b.getI64IntegerAttr(0); OpFoldResult oneAttr = b.getI64IntegerAttr(1); Value intOne = b.create(loc, b.getI64IntegerAttr(1)); - // Value intZero = - // b.create(loc, b.getI64IntegerAttr(0)); Value idxZero = rewriter.create(loc, 0); Value idxOne = rewriter.create(loc, 1); Value cstTwo = rewriter.create(loc, 2); @@ -1928,27 +1918,20 @@ llvm::dbgs() << outInt << " HERE <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n"; Value highX = b.create(loc, b.getF32Type(), rois, highXIndices); - llvm::dbgs() << "7A" << "\n"; - llvm::dbgs() << " LOL..\n"; lowY = b.create(loc, lowY, spatialScaleVal); - llvm::dbgs() << " LOLA\n"; lowX = b.create(loc, lowX, spatialScaleVal); - llvm::dbgs() << " LOLB\n"; highY = b.create(loc, highY, spatialScaleVal); - llvm::dbgs() << " LOLC\n"; highX = b.create(loc, highX, spatialScaleVal); - llvm::dbgs() << " LOL1\n"; lowY = b.create(loc, lowY, offset); lowX = b.create(loc, lowX, offset); highY = b.create(loc, highY, offset); highX = b.create(loc, highX, offset); - llvm::dbgs() << " LOL2\n"; + // Step 2: Extract region of interest using bounds Value lowYInt = b.create(loc, lowY); Value lowXInt = b.create(loc, lowX); Value highYInt = b.create(loc, highY); Value highXInt = b.create(loc, highX); - llvm::dbgs() << " LOL3\n"; lowYInt = b.create(loc, b.getI64Type(), lowYInt); lowXInt = @@ -1959,8 +1942,6 @@ llvm::dbgs() << outInt << " HERE <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n"; b.create(loc, b.getI64Type(), highXInt); Value lowYIdx = b.create(loc, b.getIndexType(), lowYInt); Value lowXIdx = b.create(loc, b.getIndexType(), lowXInt); - llvm::dbgs() << " LOL4\n"; - llvm::dbgs() << lowYIdx << "\n^^^\n\n"; Value roiHeight = b.create(loc, highYInt, lowYInt); Value roiWidth = @@ -2019,23 +2000,15 @@ llvm::dbgs() << outInt << " HERE <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n"; Value outTensor = b.create( loc, getAsOpFoldResult(dims), inputType.getElementType()); - auto iteratorTypes = + auto iteratorTypes = SmallVector(inputRank, utils::IteratorType::parallel); - iteratorTypes.append(inputRank, utils::IteratorType::parallel); SmallVector idMap(2, b.getMultiDimIdentityMap(inputRank)); - //AffineMap idMap = b.getMultiDimIdentityMap(inputRank); - // SmallVector iteratorTypes( - // inputRank, utils::IteratorType::parallel); - - llvm::dbgs() << "5A" << "\n"; - llvm::dbgs() << "4.99A" << "\n"; Value bilinearInterpolatedRoi = b.create( loc, outTensor.getType(), extractRoi, outTensor, /*indexingMaps=*/idMap, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { - llvm::dbgs() << "4.9A" << "\n"; Value retVal = bilinearInterpolate( b, op, loc, outputSizeIntValues, extractRoi, inputSizes, ValueRange{}, ""); @@ -2043,7 +2016,6 @@ llvm::dbgs() << outInt << " HERE <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n"; b.create(loc, retVal); }) .getResult(0); - llvm::dbgs() << "4A" << "\n"; // Step 4: Sum pool over interpolated values. Value sumPool, paddedInput; @@ -2054,7 +2026,6 @@ llvm::dbgs() << outInt << " HERE <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n"; SmallVector paddingInts = {0, 0}; SmallVector dilationInts = {1, 1}; SmallVector outTensorShape; - llvm::dbgs() << "3A" << "\n"; if (failed(createPoolingOp( op, rewriter, bilinearInterpolatedRoi, /*supportNonFPInput=*/true, false, @@ -2070,7 +2041,6 @@ llvm::dbgs() << outInt << " HERE <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n"; loc, getAsOpFoldResult(outTensorShape), resultElementType); Value divisor = b.create(loc, scaleH, scaleW); divisor = rewriter.create(loc, rewriter.getF32Type(), divisor); - llvm::dbgs() << "2A" << "\n"; Value avgPool = b.create( loc, outputTensor.getType(), sumPool, outputTensor, @@ -2081,23 +2051,21 @@ llvm::dbgs() << outInt << " HERE <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<\n"; b.create(loc, res); }) .getResult(0); - llvm::dbgs() << avgPool << " <------------------ \n"; - llvm::dbgs() << "1" << "\n"; SmallVector finalStrides(inputRank, oneAttr); SmallVector finalOffsets = { getAsOpFoldResult(iv0), getAsOpFoldResult(iv1), zeroAttr, zeroAttr}; SmallVector finalSizes = { - oneAttr, oneAttr, getAsOpFoldResult(pooledH), - getAsOpFoldResult(pooledW)}; + idxOne, idxOne, getAsOpFoldResult(pooledHIdx), getAsOpFoldResult(pooledWIdx)}; SmallVector diagStrides(inputRank, oneAttr); - finalOutputTensor = b.create( - loc, finalOutputTensor, avgPool, finalOffsets, finalSizes, + auto insert = b.create( + loc, avgPool, args[0], finalOffsets, finalSizes, finalStrides); + b.create(loc, insert.getResult()); }); + b.create(loc, res.getResult(0)); }); - llvm::dbgs() << "0" << "\n"; - rewriter.replaceOp(op, finalOutputTensor); + rewriter.replaceOpWithNewOp(op, resultType, resOut.getResult(0)); return success(); } };