From a24d531d1b21899ef65f5a6b9b542e8ceb2f78a9 Mon Sep 17 00:00:00 2001 From: sharavana20 Date: Mon, 26 May 2025 15:33:45 +0530 Subject: [PATCH] add the code for rms_norm op Signed-off-by: sharavana20 --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 26 +++++ .../Transforms/AbstractInterpLibrary.cpp | 18 +++ .../Torch/Transforms/DecomposeComplexOps.cpp | 80 +++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 18 +++ .../build_tools/abstract_interp_lib_gen.py | 10 ++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/norm_like.py | 106 ++++++++++++++++++ 8 files changed, 260 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 4f9bfc99a8ed..09e4eecb1a14 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7454,6 +7454,32 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [ }]; } +def Torch_AtenRmsNormOp : Torch_Op<"aten.rms_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::rms_norm : (Tensor, int[], Tensor?, float?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchListOfTorchIntType:$normalized_shape, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalFloatType:$eps + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRmsNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenRmsNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenRenormOp : Torch_Op<"aten.renorm", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index f1fa5d8a11e1..afeb0fd731bd 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7326,6 +7326,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.rms_norm\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten._softmax_backward_data\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -12732,6 +12736,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rms_norm\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number, %arg3: !torch.bool) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index ca66881d6eab..9f684046eb00 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -7485,6 +7485,85 @@ class DecomposeAtenNativeLayerNormOp }; } // namespace +// RMS normalization: +// rms(x) = sqrt(eps + mean(x^2)) +// output = (x / rms(x)) * weight +namespace { +class DecomposeAtenRMSLayerNormOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AtenRmsNormOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto context = op.getContext(); + auto input = op.getInput(); + auto inputTy = dyn_cast(input.getType()); + if (!inputTy || !inputTy.hasSizes() || !inputTy.hasDtype()) + return rewriter.notifyMatchFailure( + op, "Expected input to be a tensor with sizes and a dtype"); + + auto outputTy = dyn_cast(op.getType()); + if (!outputTy.hasDtype()) + return rewriter.notifyMatchFailure(op, "output should have a dtype."); + + int64_t inputRank = inputTy.getSizes().size(); + Value normalizedShape = op.getNormalizedShape(); + SmallVector normalizedShapeSizesTorchInt; + if (!getListConstructElements(normalizedShape, + normalizedShapeSizesTorchInt)) + return rewriter.notifyMatchFailure(op, + "should have constant shape values."); + + int64_t normalize_from_idx = + inputRank - normalizedShapeSizesTorchInt.size(); + auto reduceDimInts = + llvm::to_vector<4>(llvm::seq(normalize_from_idx, inputRank)); + auto sizeListType = ListType::get(IntType::get(context)); + + SmallVector reduceDimVals; + for (int64_t dim : reduceDimInts) + reduceDimVals.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(dim))); + Value reduceDimList = + rewriter.create(loc, sizeListType, reduceDimVals); + + auto inputShape = inputTy.getSizes(); + SmallVector reducedShape(inputShape.begin(), inputShape.end()); + for (int64_t i : reduceDimInts) + reducedShape[i] = 1; + auto reducedTy = + ValueTensorType::get(context, reducedShape, inputTy.getDtype()); + // x^2 + Value inputSquared = rewriter.create(loc, inputTy, input); + Value cstTrue = rewriter.create(loc, true); + Value none = rewriter.create(loc); + // mean(x^2) + Value mean = rewriter.create(loc, reducedTy, inputSquared, + reduceDimList, cstTrue, none); + // mean(x^2) + eps: Add eps if provided + if (!isa(op.getEps().getType())) { + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + mean = rewriter.create(loc, reducedTy, mean, op.getEps(), + one); + } + // rsqrt(mean(x^2) + eps) + Value invRMS = rewriter.create(loc, reducedTy, mean); + // rsqrt(mean(x^2) + eps) * x + Value normalized = + rewriter.create(loc, inputTy, input, invRMS); + // Optionally multiply by weight if provided + Value weight = op.getWeight(); + if (!isa(weight.getType())) { + normalized = + rewriter.create(loc, outputTy, normalized, weight); + } + rewriter.replaceOp(op, normalized); + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.emptyLike` op into `aten.size` and `aten.empty` ops. class DecomposeAtenEmptyLikeOp : public OpRewritePattern { @@ -12070,6 +12149,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index bea4b6c935b0..d4e96390a1d1 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -435,6 +435,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index d8ab65a06891..5da2e892bb78 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1473,6 +1473,10 @@ "Rot90MultipleRotationsModule_basic", "Rot90NegativeEvenRotationsModule_basic", "Rot90NegativeOddRotationsModule_basic", + "RMSNormModule_basic", + "RMSNormWithoutEpsModule_basic", + "RMSNormWithoutWeightModule_basic", + "RMSNormAllNormalizeModule_basic", "RsubInt0d_NumToTensor_Module_basic", "ScalarConstantTupleModule_basic", "ScalarImplicitFloatModule_basic", @@ -2331,6 +2335,10 @@ "IscloseStaticModuleTrue_basic", "IscloseStaticModule_basic", "LayerNormNormalizeOverAllDimsModule_basic", + "RMSNormModule_basic", + "RMSNormWithoutEpsModule_basic", + "RMSNormWithoutWeightModule_basic", + "RMSNormAllNormalizeModule_basic", "LeakyReluBackwardModule_basic", "LeakyReluBackwardStaticModule_basic", "LiftFreshCopyModule_basic", @@ -3037,6 +3045,11 @@ "NativeGroupNormBackwardModule_basic", "NativeGroupNormModule_basic", "NativeLayerNormDynamicModule_basic", + "RMSNormModule_basic", + "RMSNormWithoutEpsModule_basic", + "RMSNormWithoutWeightModule_basic", + "RMSNormAllNormalizeModule_basic", + "RMSNormDynamicModule_basic", "NeFloatIntModule_basic", "NeIntModule_basic", "NewEmptyStridedModuleDefaultDtype_basic", @@ -4725,6 +4738,11 @@ "ReshapeCollapseModule_basic", "ReshapeDynamicModule_basic", "ReshapeExpandModule_basic", + "RMSNormModule_basic", + "RMSNormWithoutEpsModule_basic", + "RMSNormWithoutWeightModule_basic", + "RMSNormAllNormalizeModule_basic", + "RMSNormDynamicModule_basic", "RollModule_basic", "RsubIntModule_noalpha_basic", "ScalarConstantTupleModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 67c1cd9bc986..769c1f650018 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -664,6 +664,9 @@ def aten〇gather〡shape(self: List[int], dim: int, index: List[int], sparse_gr def aten〇layer_norm〡shape(input: List[int], normalized_shape: List[int], weight: Optional[List[int]] = None, bias: Optional[List[int]] = None, eps: float = 1.0000000000000001e-05, cudnn_enable: bool = True) -> List[int]: return upstream_shape_functions.unary(input) +def aten〇rms_norm〡shape(input: List[int], normalized_shape: List[int], weight: Optional[List[int]] = None, eps: Optional[float] = None) -> List[int]: + return upstream_shape_functions.unary(input) + def aten〇_softmax_backward_data〡shape(grad_output: List[int], output: List[int], dim: int, input_dtype: int) -> List[int]: return upstream_shape_functions.unary(output) @@ -3420,6 +3423,13 @@ def aten〇layer_norm〡dtype(input_rank_dtype: Tuple[int, int], normalized_shap assert not is_integer_dtype(input_dtype) return input_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype( + num_of_tensors=1, error_types={*all_integer_dtypes()}, normalized_shape=[1])) +def aten〇rms_norm〡dtype(input_rank_dtype: Tuple[int, int], normalized_shape: List[int], weight_rank_dtype: Optional[Tuple[int, int]] = None, eps: Optional[float] = None) -> int: + input_rank, input_dtype = input_rank_dtype + assert not is_integer_dtype(input_dtype) + return input_dtype + @check_dtype_function(_check_two_tensor_op(negative_slope=0.1, self_is_result=False)) def aten〇leaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float, complex], self_is_result: bool) -> int: grad_output_rank, grad_output_dtype = grad_output_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 85880f585613..c7d74384d025 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -640,6 +640,7 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)" ) + emit("aten::rms_norm : (Tensor, int[], Tensor?, float?) -> (Tensor)") emit("aten::renorm : (Tensor, Scalar, int, Scalar) -> (Tensor)", has_verifier=True) emit("aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)", has_verifier=True) emit("aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py index 60c4ee144dfa..4f65394557d4 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py @@ -635,6 +635,112 @@ def AtenInstanceNormModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 2, 1, 3), tu.rand(2), tu.rand(2)) +# ============================================================================== +class RMSNormModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([8, 9, 1, 2, 4], torch.float32, True), + ([1, 2, 4], torch.float32, True), + ] + ) + def forward(self, x, weight): + list = [1, 2, 4] + return torch.ops.aten.rms_norm(x, list, weight, eps=0.5) + + +@register_test_case(module_factory=lambda: RMSNormModule()) +def RMSNormModule_basic(module, tu: TestUtils): + module.forward(tu.rand(8, 9, 1, 2, 4), tu.rand(1, 2, 4)) + + +class RMSNormWithoutEpsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 5, 2, 2, 3], torch.float32, True), + ([2, 2, 3], torch.float32, True), + ] + ) + def forward(self, x, weight): + list = [2, 2, 3] + return torch.ops.aten.rms_norm(x, list, weight) + + +@register_test_case(module_factory=lambda: RMSNormWithoutEpsModule()) +def RMSNormWithoutEpsModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3)) + + +class RMSNormWithoutWeightModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 2, 3, 4], torch.float32, True), + ] + ) + def forward(self, x): + list = [4] + return torch.ops.aten.rms_norm(x, list, eps=0.5) + + +@register_test_case(module_factory=lambda: RMSNormWithoutWeightModule()) +def RMSNormWithoutWeightModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 2, 3, 4)) + + +class RMSNormAllNormalizeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [None, ([5, 6, 3], torch.float32, True), ([5, 6, 3], torch.float32, True)] + ) + def forward(self, x, weight): + list = [5, 6, 3] + return torch.ops.aten.rms_norm(x, list, weight, eps=0.7) + + +@register_test_case(module_factory=lambda: RMSNormAllNormalizeModule()) +def RMSNormAllNormalizeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 6, 3), tu.rand(5, 6, 3)) + + +class RMSNormDynamicModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x, weight): + list = [2, 3, 4] + return torch.ops.aten.rms_norm(x, list, weight, eps=0.8) + + +@register_test_case(module_factory=lambda: RMSNormDynamicModule()) +def RMSNormDynamicModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 2, 3, 4), tu.rand(2, 3, 4)) + + # ============================================================================== class RenormModuleFloat32(torch.nn.Module): def __init__(self):