From 3f1b9fec7d1f67f55ad938a59faa98ebd3c1bcea Mon Sep 17 00:00:00 2001 From: Praveen G Date: Thu, 23 Jan 2025 15:54:02 +0000 Subject: [PATCH] Support batch and classes for nms lowering --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 332 +++++++++++------- 1 file changed, 214 insertions(+), 118 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 3db33aee1f1c..1f28280c1d64 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -3701,63 +3701,51 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return rewriter.notifyMatchFailure( binder.op, "expected center_point_box attribute to be 0 or 1"); - // TODO: Support multiple batches and classes - // Squeeze the boxes and scores tensor. - // In Onnx, the shape of boxes is [BxNx4] while the - // torchvision expects it to be of shape [Nx4]. Similarly, for - // the scores tensor shape in Onnx is [BxCxN] while the - // torchvision expects it to be of shape [N]. + Value cst0 = rewriter.create(loc, 0); + Value cst1 = rewriter.create(loc, 1); + Value cst2 = rewriter.create(loc, 2); + Value cst3 = rewriter.create(loc, 3); + Value cst4 = rewriter.create(loc, 4); + Value cst2F = rewriter.create( + loc, rewriter.getF64FloatAttr(2.0)); + Value cstNone = rewriter.create(loc); + Value cstTrue = rewriter.create( + loc, rewriter.getBoolAttr(true)); + + // In Onnx, the shape of boxes is [BxNx4] and that of scores is [BxCxN] Value boxes = operands[0], scores = operands[1]; - FailureOr squeezedBoxes = - Torch::squeezeTensor(rewriter, binder.op, loc, 0, boxes); - if (failed(squeezedBoxes)) - return rewriter.notifyMatchFailure(binder.op, - "failed to squeeze boxes tensor"); - FailureOr squeezedScores = - Torch::squeezeTensor(rewriter, binder.op, loc, 0, scores); - if (failed(squeezedScores)) - return rewriter.notifyMatchFailure(binder.op, - "failed to squeeze scores tensor"); - squeezedScores = Torch::squeezeTensor(rewriter, binder.op, loc, 0, - squeezedScores.value()); - if (failed(squeezedScores)) - return rewriter.notifyMatchFailure(binder.op, - "failed to squeeze scores tensor"); - boxes = squeezedBoxes.value(); - scores = squeezedScores.value(); + + auto boxesTensorType = cast(boxes.getType()); + auto scoreTensorType = cast(scores.getType()); + auto boxSlicedType = rewriter.getType( + boxesTensorType.getSizes().slice(1), boxesTensorType.getDtype()); + auto scoreSlicedType = rewriter.getType( + scoreTensorType.getSizes().slice(1), scoreTensorType.getDtype()); + if (centerPointBox == 1) { // When center_point_box is 1, the box data is supplied as // [[x_center, y_center, width, height], ...]. Slice it to // [[x_center, y_center], ...] and [[width, height], ...], // calculate the [[x1, y1], ...] and [[x2, y2], ...], and concatnate // to [[x1, y1, x2, y2], ...] - auto boxesTensorType = - dyn_cast(boxes.getType()); - Value const0 = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value const1 = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value const2 = rewriter.create( - loc, rewriter.getI64IntegerAttr(2)); - Value const4 = rewriter.create( - loc, rewriter.getI64IntegerAttr(4)); - Value const2F = rewriter.create( - loc, rewriter.getF64FloatAttr(2.0)); // extract scaled ranges for regions of interest - auto sliceShape = SmallVector{Torch::kUnknownSize, 2}; + auto sliceShape = + SmallVector{Torch::kUnknownSize, Torch::kUnknownSize, 2}; auto sliceTensorType = rewriter.getType( sliceShape, boxesTensorType.getDtype()); + + // Boxes have shape [BxNx4] Value centers = rewriter.create( - loc, sliceTensorType, boxes, const1, const0, const2, const1); + loc, sliceTensorType, boxes, cst2, cst0, cst2, cst1); Value sizes = rewriter.create( - loc, sliceTensorType, boxes, const1, const2, const4, const1); + loc, sliceTensorType, boxes, cst2, cst2, cst4, cst1); Value halfSizes = rewriter.create( - loc, sizes.getType(), sizes, const2F); + loc, sizes.getType(), sizes, cst2F); Value x1y1s = rewriter.create( - loc, centers.getType(), centers, halfSizes, const1); + loc, centers.getType(), centers, halfSizes, cst1); Value x2y2s = rewriter.create( - loc, centers.getType(), centers, halfSizes, const1); + loc, centers.getType(), centers, halfSizes, cst1); Type listElemType = boxesTensorType.getWithSizesAndDtype( /*optionalSizes=*/std::nullopt, @@ -3766,7 +3754,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value tensorList = rewriter.create( loc, listType, SmallVector{x1y1s, x2y2s}); boxes = rewriter.create(loc, boxesTensorType, - tensorList, const1); + tensorList, cst2); } // TODO: Support score_threshold input @@ -3792,10 +3780,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } // Get max_output_boxes_per_class and iou_threshold - Value cst0 = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value cst1 = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); Value maxOutputBoxesPerClass = cst0; Value iouThreshold = rewriter.create( loc, rewriter.getF64FloatAttr(0.0)); @@ -3810,87 +3794,199 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( loc, rewriter.getType(), operands[2]); } + // Since the shape of boxes is [BxNx4] in Onnx and torchvision expects + // it to be of shape [Nx4], loop over the batch dimension. Similarly, + // for the scores tensor which has shape [BxCxN] in Onnx and torchvision + // expects it to be of shape [N], loop over the class dimension too. + auto numBatches = + rewriter.create(loc, scores, cst0); + auto numClasses = + rewriter.create(loc, scores, cst1); + + // Create an empty tensor of shape (B*C*N, 3) to store the final result. + // We slice this to required elements at the end + + Value numResults = rewriter.create( + loc, numClasses.getType(), numBatches, numClasses); + numResults = rewriter.create( + loc, numClasses.getType(), numResults, maxOutputBoxesPerClass); + + auto intTy = rewriter.getType(); + auto intListTy = rewriter.getType(intTy); + + Value resultShapeList = rewriter.create( + loc, intListTy, SmallVector{numResults, cst3}); + Value finalResult = rewriter.create( + loc, resultType, resultShapeList, /*dtype=*/cst4, + /*layout=*/cstNone, + /*device=*/cstNone, /*pinMemory=*/cstNone, + /*memoryFormat=*/cstNone); + auto nmsTy = Torch::ValueTensorType::get( binder.op->getContext(), SmallVector{-1}, rewriter.getIntegerType(64, /*signed=*/true)); - Value result = rewriter.create( - loc, nmsTy, boxes, scores, iouThreshold); - // Slice the result if numOutputBoxes (N) > max_output_boxes_per_class - Value numOutputBoxes = - rewriter.create(loc, result, cst0); - Value boxesCond = rewriter.create( - loc, numOutputBoxes, maxOutputBoxesPerClass); + auto nmsBatchLoop = rewriter.create( + loc, TypeRange({finalResult.getType(), intTy}), numBatches, cstTrue, + ValueRange({finalResult, /*Index to finalResult*/ cst0})); - auto nmsResultTy = Torch::ValueTensorType::get( - binder.op->getContext(), - SmallVector{resultType.getSizes()[0]}, - rewriter.getIntegerType(64, /*signed=*/true)); - auto ifSlice = rewriter.create( - loc, TypeRange({nmsResultTy}), boxesCond); { + + // Batch loop body PatternRewriter::InsertionGuard guard(rewriter); - rewriter.createBlock(&ifSlice.getThenRegion(), - ifSlice.getThenRegion().begin()); + Block *batchLoopBody = rewriter.createBlock( + &nmsBatchLoop.getRegion(), nmsBatchLoop.getRegion().begin(), + TypeRange({intTy, finalResult.getType(), intTy}), + {loc, loc, loc}); + auto batchIV = batchLoopBody->getArgument(0); + auto currRes = batchLoopBody->getArgument(1); + auto finalResIdx = batchLoopBody->getArgument(2); + + auto boxValue = rewriter.create( + loc, boxSlicedType, boxes, cst0, batchIV); + + auto nmsClassLoop = rewriter.create( + loc, TypeRange({finalResult.getType(), intTy}), numClasses, + cstTrue, ValueRange({currRes, finalResIdx})); + + { + // Class loop body + PatternRewriter::InsertionGuard guard(rewriter); + Block *classLoopBody = rewriter.createBlock( + &nmsClassLoop.getRegion(), nmsClassLoop.getRegion().begin(), + TypeRange({intTy, finalResult.getType(), intTy}), + {loc, loc, loc}); + auto classIV = classLoopBody->getArgument(0); + auto currRes = classLoopBody->getArgument(1); + auto finalResIdx = classLoopBody->getArgument(2); + + auto scoreSelect = rewriter.create( + loc, scoreSlicedType, scores, cst0, batchIV); + + auto scoreSelectType = + dyn_cast(scoreSelect.getType()); + assert(scoreSelectType); + auto scoreValueType = rewriter.getType( + scoreSelectType.getSizes().slice(1), + scoreSelectType.getDtype()); + + auto scoreValue = rewriter.create( + loc, scoreValueType, scoreSelect, cst0, classIV); + + Value result = rewriter.create( + loc, nmsTy, boxValue, scoreValue, iouThreshold); + + // Slice the result if numOutputBoxes (N) > + // max_output_boxes_per_class + Value numOutputBoxes = + rewriter.create(loc, result, cst0); + Value boxesCond = rewriter.create( + loc, numOutputBoxes, maxOutputBoxesPerClass); + + auto nmsResultTy = Torch::ValueTensorType::get( + binder.op->getContext(), SmallVector{-1}, + rewriter.getIntegerType(64, /*signed=*/true)); + auto ifSlice = rewriter.create( + loc, TypeRange({nmsResultTy}), boxesCond); + { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.createBlock(&ifSlice.getThenRegion(), + ifSlice.getThenRegion().begin()); + + Value curResult = rewriter.create( + loc, nmsResultTy, result, /*dim=*/cst0, /*start=*/cst0, + /*end=*/maxOutputBoxesPerClass, /*step=*/cst1); + rewriter.create(loc, curResult); + } + { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.createBlock(&ifSlice.getElseRegion(), + ifSlice.getElseRegion().begin()); - Value curResult = rewriter.create( - loc, nmsResultTy, result, /*dim=*/cst0, /*start=*/cst0, - /*end=*/maxOutputBoxesPerClass, /*step=*/cst1); - rewriter.create(loc, curResult); + rewriter.create(loc, result); + } + result = ifSlice.getResult(0); + + // The result generated by torchvision.nms op is of shape [n], while + // the onnx expects it to be of shape [n, 3]. Hence, we unsqueeze + // the tensor and make it of shape [n, 1] and then concatenate it + // with batch and class values to make it shape [n, 3]. + FailureOr unsqueezedResult = + Torch::unsqueezeTensor(rewriter, binder.op, result, cst1); + if (failed(unsqueezedResult)) + return rewriter.notifyMatchFailure( + binder.op, "failed to unsqueeze result tensor"); + result = unsqueezedResult.value(); + + auto resultNmsType = cast(result.getType()); + numOutputBoxes = + rewriter.create(loc, result, cst0); + Value catList = rewriter.create( + loc, intListTy, SmallVector{numOutputBoxes, cst1}); + + Value resBatch = rewriter.create( + loc, resultNmsType, catList, /*dtype=*/cst4, + /*layout=*/cstNone, + /*device=*/cstNone, /*pinMemory=*/cstNone, + /*memoryFormat=*/cstNone); + auto batchVal = rewriter.create( + loc, resultNmsType, resBatch, batchIV); + + Value resClass = rewriter.create( + loc, resultNmsType, catList, /*dtype=*/cst4, + /*layout=*/cstNone, + /*device=*/cstNone, /*pinMemory=*/cstNone, + /*memoryFormat=*/cstNone); + auto classVal = rewriter.create( + loc, resultNmsType, resClass, classIV); + + Type listElemType = + cast(resultType) + .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, + /*optionalDtype=*/nullptr); + Type listType = Torch::ListType::get(listElemType); + + // c1 = concat (class, res), results in [n, 2] tensor + Value classResList = rewriter.create( + loc, listType, SmallVector{classVal, result}); + auto cat1Type = rewriter.getType( + ArrayRef{-1, 2}, resultNmsType.getDtype()); + auto cat1 = rewriter.create(loc, cat1Type, + classResList, cst1); + + // c2 = concat (batch, c1), results in [n, 3] tensor + auto cat2Type = rewriter.getType( + SmallVector{-1, 3}, resultNmsType.getDtype()); + Value batchClassResList = + rewriter.create( + loc, listType, SmallVector{batchVal, cat1}); + auto cat2 = rewriter.create( + loc, cat2Type, batchClassResList, cst1); + + // concat (finalResult, c2) along dim 0 + Value next = + rewriter.create(loc, finalResIdx, numOutputBoxes); + Value cstFalse = rewriter.create( + binder.getLoc(), rewriter.getBoolAttr(false)); + auto sliceFinal = rewriter.create( + loc, cat2.getType(), currRes, /*dim=*/cst0, /*start=*/finalResIdx, + /*end=*/next, /*step=*/cst1); + auto resCopy = rewriter.create( + loc, cat2.getType(), sliceFinal, cat2, + cstFalse); + auto scatterBatch = rewriter.create( + loc, finalResult.getType(), currRes, resCopy, cst0, finalResIdx, next, cst1); + rewriter.create( + loc, cstTrue, ValueRange({scatterBatch, next})); + } + rewriter.create( + loc, cstTrue, + ValueRange( + {nmsClassLoop.getResult(0), nmsClassLoop.getResult(1)})); } - { - PatternRewriter::InsertionGuard guard(rewriter); - rewriter.createBlock(&ifSlice.getElseRegion(), - ifSlice.getElseRegion().begin()); - - Value curResult = rewriter.create( - loc, nmsResultTy, result); - rewriter.create(loc, curResult); - } - result = ifSlice.getResult(0); - - // The result generated by torchvision.nms op is of shape [n], while the - // onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor - // and make it of shape [n, 1] and then concatenate it with a zero - // tensor of shape [n, 2] to make it of shape [n, 3]. - FailureOr unsqueezedResult = - Torch::unsqueezeTensor(rewriter, binder.op, result, cst1); - if (failed(unsqueezedResult)) - return rewriter.notifyMatchFailure( - binder.op, "failed to unsqueeze result tensor"); - result = unsqueezedResult.value(); - - numOutputBoxes = - rewriter.create(loc, result, cst0); - SmallVector zerosShapeValues{numOutputBoxes}; - zerosShapeValues.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(2))); - Value zerosShapeList = rewriter.create( - loc, - rewriter.getType( - rewriter.getType()), - zerosShapeValues); - std::optional> resultShape = - cast(result.getType()).getOptionalSizes(); - if (!resultShape.has_value()) - return rewriter.notifyMatchFailure( - binder.op, "expected result tensor to have shape"); - llvm::SmallVector zerosShape = {resultShape->front(), 2}; - auto zerosTy = Torch::ValueTensorType::get( - resultType.getContext(), zerosShape, resultType.getOptionalDtype()); - Value cstNone = rewriter.create(loc); - Value zeros = rewriter.create( - loc, zerosTy, zerosShapeList, cstNone, cstNone, cstNone, cstNone); - - Type listElemType = - cast(resultType) - .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, - /*optionalDtype=*/nullptr); - Type listType = Torch::ListType::get(listElemType); - Value tensorList = rewriter.create( - loc, listType, SmallVector{zeros, result}); - rewriter.replaceOpWithNewOp(binder.op, resultType, - tensorList, cst1); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, nmsBatchLoop.getResult(0)); return success(); }); }