From e8a7ddf1ecec4bcf8e0b3321cec6f9df667b154f Mon Sep 17 00:00:00 2001 From: Zahid Wakeel Date: Mon, 2 Jun 2025 13:15:09 +0530 Subject: [PATCH 1/5] Lower to torch dialect without expansion Signed-off-by: Zahid Wakeel --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 1a730a0475ed..06f79a1b35db 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1606,6 +1606,67 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( /* cudnn enabled */ boolFalse); return success(); }); + patterns.onOp( + "MeanVarianceNormalization", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + SmallVector axes; + + if (binder.tensorOperand(input) || + binder.s64IntegerArrayAttr(axes, "axes", + llvm::SmallVector({0, 2, 3})) || + binder.tensorResultType(resultType)) { + return failure(); + } + Location loc = binder.getLoc(); + Value keepDim = rewriter.create(loc, true); + Value unBiased = rewriter.create(loc, false); + Value none = rewriter.create(loc); + + ArrayRef input_shape = resultType.getSizes(); + SmallVector reduced_shape(input_shape); + for (int64_t i : axes) { + reduced_shape[i] = 1; + } + + Torch::ValueTensorType meanOutTy = Torch::ValueTensorType::get( + resultType.getContext(), reduced_shape, resultType.getDtype()); + SmallVector cstAxes; + for (int64_t i : axes) { + cstAxes.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + } + Value axes_list = rewriter.create( + loc, + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstAxes); + Value mean = rewriter.create( + loc, meanOutTy, input, axes_list, keepDim, none); + + Value variance = rewriter.create( + loc, meanOutTy, input, axes_list, unBiased, keepDim); + + Value cstOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value cstEps = rewriter.create( + loc, rewriter.getF64FloatAttr(1e-9)); + variance = rewriter.create( + loc, meanOutTy, variance, cstEps, cstOne); + + Value sqrt = + rewriter.create(loc, meanOutTy, variance); + + Value subValue = rewriter.create( + loc, resultType, input, mean, cstOne); + + Value meanVarNorm = rewriter.create( + loc, resultType, subValue, sqrt); + + rewriter.replaceOp(binder.op, meanVarNorm); + + return success(); + }); patterns.onOp( "Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; From 75bbd1ca892cb2a24c779f7a80345dd36be3ab81 Mon Sep 17 00:00:00 2001 From: Zahid Wakeel Date: Mon, 2 Jun 2025 15:32:50 +0530 Subject: [PATCH 2/5] Remove MeanVarNorm from allowlists to avoid expansion Signed-off-by: Zahid Wakeel --- python/torch_mlir/extras/onnx_importer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index 9e5bc373ebd3..defbb9bdda2a 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -103,9 +103,7 @@ class Config: function_expansion_allowlists_by_domain: Optional[Dict[str, set[str]]] = field( default_factory=lambda: { # Default domain (ONNX built-in ops) - "": { - "MeanVarianceNormalization", - } + "": {} } ) From fb80dc1e010b62cb0f916cf026e4ea6b618f09ed Mon Sep 17 00:00:00 2001 From: Zahid Wakeel Date: Mon, 2 Jun 2025 15:52:18 +0530 Subject: [PATCH 3/5] Add MeanVarNorm lit tests Signed-off-by: Zahid Wakeel --- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index a336d78f55dd..63f3c9b4ccaa 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1595,6 +1595,57 @@ func.func @test_mod_int64_no_fmod(%arg0: !torch.vtensor<[6],si64>, %arg1: !torch // ----- +// CHECK-LABEL: func.func @test_meanvarnorm( +func.func @test_meanvarnorm(%arg0: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_0:.*]] = torch.constant.bool true + // CHECK: %[[VAL_1:.*]] = torch.constant.bool false + // CHECK: %[[VAL_2:.*]] = torch.constant.none + // CHECK: %[[VAL_3:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_4:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_5:.*]] = torch.constant.int 3 + // CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_7:.*]] = torch.aten.mean.dim %[[ARG0]], %[[VAL_6]], %[[VAL_0]], %[[VAL_2]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,5,1,1],f32> + // CHECK: %[[VAL_8:.*]] = torch.aten.var.dim %[[ARG0]], %[[VAL_6]], %[[VAL_1]], %[[VAL_0]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,5,1,1],f32> + // CHECK: %[[VAL_9:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_10:.*]] = torch.constant.float 1.000000e-09 + // CHECK: %[[VAL_11:.*]] = torch.aten.add.Scalar %[[VAL_8]], %[[VAL_10]], %[[VAL_9]] : !torch.vtensor<[1,5,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,5,1,1],f32> + // CHECK: %[[VAL_12:.*]] = torch.aten.sqrt %[[VAL_11]] : !torch.vtensor<[1,5,1,1],f32> -> !torch.vtensor<[1,5,1,1],f32> + // CHECK: %[[VAL_13:.*]] = torch.aten.sub.Tensor %[[ARG0]], %[[VAL_7]], %[[VAL_9]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[1,5,1,1],f32>, !torch.int -> !torch.vtensor<[3,5,2,2],f32> + // CHECK: %[[VAL_14:.*]] = torch.aten.div.Tensor %[[VAL_13]], %[[VAL_12]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[1,5,1,1],f32> -> !torch.vtensor<[3,5,2,2],f32> + // CHECK: return %[[VAL_14]] : !torch.vtensor<[3,5,2,2],f32> + // CHECK: } + %0 = torch.operator "onnx.MeanVarianceNormalization"(%arg0) : (!torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> + return %0 : !torch.vtensor<[3,5,2,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_meanvarnorm_axes( +func.func @test_meanvarnorm_axes(%arg0: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_0:.*]] = torch.constant.bool true + // CHECK: %[[VAL_1:.*]] = torch.constant.bool false + // CHECK: %[[VAL_2:.*]] = torch.constant.none + // CHECK: %[[VAL_3:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_4:.*]] = torch.constant.int 3 + // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_6:.*]] = torch.aten.mean.dim %[[ARG0]], %[[VAL_5]], %[[VAL_0]], %[[VAL_2]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2,1],f32> + // CHECK: %[[VAL_7:.*]] = torch.aten.var.dim %[[ARG0]], %[[VAL_5]], %[[VAL_1]], %[[VAL_0]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[3,1,2,1],f32> + // CHECK: %[[VAL_8:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_9:.*]] = torch.constant.float 1.000000e-09 + // CHECK: %[[VAL_10:.*]] = torch.aten.add.Scalar %[[VAL_7]], %[[VAL_9]], %[[VAL_8]] : !torch.vtensor<[3,1,2,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[3,1,2,1],f32> + // CHECK: %[[VAL_11:.*]] = torch.aten.sqrt %[[VAL_10]] : !torch.vtensor<[3,1,2,1],f32> -> !torch.vtensor<[3,1,2,1],f32> + // CHECK: %[[VAL_12:.*]] = torch.aten.sub.Tensor %[[ARG0]], %[[VAL_6]], %[[VAL_8]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[3,1,2,1],f32>, !torch.int -> !torch.vtensor<[3,5,2,2],f32> + // CHECK: %[[VAL_13:.*]] = torch.aten.div.Tensor %[[VAL_12]], %[[VAL_11]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[3,1,2,1],f32> -> !torch.vtensor<[3,5,2,2],f32> + // CHECK: return %[[VAL_13]] : !torch.vtensor<[3,5,2,2],f32> + // CHECK: } + %0 = torch.operator "onnx.MeanVarianceNormalization"(%arg0) {torch.onnx.axes = [1 : si64, 3 : si64]} : (!torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> + return %0 : !torch.vtensor<[3,5,2,2],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_not_2d func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1> From c9fe057f18411e346bf1a8d73c7d267495d37aac Mon Sep 17 00:00:00 2001 From: Zahid Wakeel Date: Tue, 1 Jul 2025 12:35:01 +0530 Subject: [PATCH 4/5] Add type checks & allow MVN expansion by default Signed-off-by: Zahid Wakeel --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 41 +++++++++++-------- python/torch_mlir/extras/onnx_importer.py | 4 +- 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 06f79a1b35db..52e92a705305 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1619,18 +1619,31 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.tensorResultType(resultType)) { return failure(); } + if (!resultType.hasSizes() || !resultType.hasDtype()) { + return failure(); + } + auto inputTy = cast(input.getType()); + if (!inputTy || !inputTy.hasSizes()) { + return failure(); + } + int64_t inputRank = inputTy.getSizes().size(); + Location loc = binder.getLoc(); Value keepDim = rewriter.create(loc, true); Value unBiased = rewriter.create(loc, false); Value none = rewriter.create(loc); - ArrayRef input_shape = resultType.getSizes(); - SmallVector reduced_shape(input_shape); + ArrayRef output_shape = resultType.getSizes(); + SmallVector reduced_shape(output_shape); + for (int64_t i : axes) { + int64_t dim = Torch::toPositiveDim(i, inputRank); + if (!Torch::isValidDim(dim, inputRank)) { + return failure(); + } reduced_shape[i] = 1; } - - Torch::ValueTensorType meanOutTy = Torch::ValueTensorType::get( + Torch::ValueTensorType reducedOutTy = Torch::ValueTensorType::get( resultType.getContext(), reduced_shape, resultType.getDtype()); SmallVector cstAxes; for (int64_t i : axes) { @@ -1642,29 +1655,23 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstAxes); Value mean = rewriter.create( - loc, meanOutTy, input, axes_list, keepDim, none); - + loc, reducedOutTy, input, axes_list, keepDim, none); Value variance = rewriter.create( - loc, meanOutTy, input, axes_list, unBiased, keepDim); - + loc, reducedOutTy, input, axes_list, unBiased, keepDim); Value cstOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value cstEps = rewriter.create( loc, rewriter.getF64FloatAttr(1e-9)); variance = rewriter.create( - loc, meanOutTy, variance, cstEps, cstOne); - - Value sqrt = - rewriter.create(loc, meanOutTy, variance); - - Value subValue = rewriter.create( + loc, reducedOutTy, variance, cstEps, cstOne); + Value sqrtVar = + rewriter.create(loc, reducedOutTy, variance); + Value inputMinusMean = rewriter.create( loc, resultType, input, mean, cstOne); - Value meanVarNorm = rewriter.create( - loc, resultType, subValue, sqrt); + loc, resultType, inputMinusMean, sqrtVar); rewriter.replaceOp(binder.op, meanVarNorm); - return success(); }); patterns.onOp( diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index defbb9bdda2a..9e5bc373ebd3 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -103,7 +103,9 @@ class Config: function_expansion_allowlists_by_domain: Optional[Dict[str, set[str]]] = field( default_factory=lambda: { # Default domain (ONNX built-in ops) - "": {} + "": { + "MeanVarianceNormalization", + } } ) From 801ba0d0a454d275d24b0f11a73219f98e0eb520 Mon Sep 17 00:00:00 2001 From: Zahid Wakeel Date: Tue, 15 Jul 2025 22:52:44 +0530 Subject: [PATCH 5/5] Minor fix for negative axes Add negative axes lit tests for MVN Signed-off-by: Zahid Wakeel --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 2 +- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 25 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 52e92a705305..e6be0304f64d 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1641,7 +1641,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( if (!Torch::isValidDim(dim, inputRank)) { return failure(); } - reduced_shape[i] = 1; + reduced_shape[dim] = 1; } Torch::ValueTensorType reducedOutTy = Torch::ValueTensorType::get( resultType.getContext(), reduced_shape, resultType.getDtype()); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 63f3c9b4ccaa..8ff895bd71cc 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1646,6 +1646,31 @@ func.func @test_meanvarnorm_axes(%arg0: !torch.vtensor<[3,5,2,2],f32>) -> !torch // ----- +// CHECK-LABEL: func.func @test_meanvarnorm_neg_axes( +func.func @test_meanvarnorm_neg_axes(%arg0: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_0:.*]] = torch.constant.bool true + // CHECK: %[[VAL_1:.*]] = torch.constant.bool false + // CHECK: %[[VAL_2:.*]] = torch.constant.none + // CHECK: %[[VAL_3:.*]] = torch.constant.int -1 + // CHECK: %[[VAL_4:.*]] = torch.constant.int -3 + // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_6:.*]] = torch.aten.mean.dim %[[ARG0]], %[[VAL_5]], %[[VAL_0]], %[[VAL_2]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2,1],f32> + // CHECK: %[[VAL_7:.*]] = torch.aten.var.dim %[[ARG0]], %[[VAL_5]], %[[VAL_1]], %[[VAL_0]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[3,1,2,1],f32> + // CHECK: %[[VAL_8:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_9:.*]] = torch.constant.float 1.000000e-09 + // CHECK: %[[VAL_10:.*]] = torch.aten.add.Scalar %[[VAL_7]], %[[VAL_9]], %[[VAL_8]] : !torch.vtensor<[3,1,2,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[3,1,2,1],f32> + // CHECK: %[[VAL_11:.*]] = torch.aten.sqrt %[[VAL_10]] : !torch.vtensor<[3,1,2,1],f32> -> !torch.vtensor<[3,1,2,1],f32> + // CHECK: %[[VAL_12:.*]] = torch.aten.sub.Tensor %[[ARG0]], %[[VAL_6]], %[[VAL_8]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[3,1,2,1],f32>, !torch.int -> !torch.vtensor<[3,5,2,2],f32> + // CHECK: %[[VAL_13:.*]] = torch.aten.div.Tensor %[[VAL_12]], %[[VAL_11]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[3,1,2,1],f32> -> !torch.vtensor<[3,5,2,2],f32> + // CHECK: return %[[VAL_13]] : !torch.vtensor<[3,5,2,2],f32> + // CHECK: } + %0 = torch.operator "onnx.MeanVarianceNormalization"(%arg0) {torch.onnx.axes = [-1 : si64, -3 : si64]} : (!torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> + return %0 : !torch.vtensor<[3,5,2,2],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_not_2d func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1>