Skip to content

Commit 089b217

Browse files
authored
[TORCH] Add support for aten.rms_norm Op (#4207)
- Decomposed rms_norm op into Aten ops. - The decomposition follows the formula: **rms(x)=sqrt(eps+mean(x^2))** - Added test cases in the e2e part. Signed-off-by: sharavana20 <sharavana.kumar@multicorewareinc.com>
1 parent 866786c commit 089b217

File tree

8 files changed

+260
-0
lines changed

8 files changed

+260
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7502,6 +7502,32 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
75027502
}];
75037503
}
75047504

7505+
def Torch_AtenRmsNormOp : Torch_Op<"aten.rms_norm", [
7506+
AllowsTypeRefinement,
7507+
HasValueSemantics,
7508+
ReadOnly
7509+
]> {
7510+
let summary = "Generated op for `aten::rms_norm : (Tensor, int[], Tensor?, float?) -> (Tensor)`";
7511+
let arguments = (ins
7512+
AnyTorchTensorType:$input,
7513+
AnyTorchListOfTorchIntType:$normalized_shape,
7514+
AnyTorchOptionalTensorType:$weight,
7515+
AnyTorchOptionalFloatType:$eps
7516+
);
7517+
let results = (outs
7518+
AnyTorchOptionalTensorType:$result
7519+
);
7520+
let hasCustomAssemblyFormat = 1;
7521+
let extraClassDefinition = [{
7522+
ParseResult AtenRmsNormOp::parse(OpAsmParser &parser, OperationState &result) {
7523+
return parseDefaultTorchOp(parser, result, 4, 1);
7524+
}
7525+
void AtenRmsNormOp::print(OpAsmPrinter &printer) {
7526+
printDefaultTorchOp(printer, *this, 4, 1);
7527+
}
7528+
}];
7529+
}
7530+
75057531
def Torch_AtenRenormOp : Torch_Op<"aten.renorm", [
75067532
AllowsTypeRefinement,
75077533
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7330,6 +7330,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
73307330
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
73317331
" return %0 : !torch.list<int>\n"
73327332
" }\n"
7333+
" func.func @\"__torch_mlir_shape_fn.aten.rms_norm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<float>) -> !torch.list<int> {\n"
7334+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
7335+
" return %0 : !torch.list<int>\n"
7336+
" }\n"
73337337
" func.func @\"__torch_mlir_shape_fn.aten._softmax_backward_data\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
73347338
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list<int>) -> !torch.list<int>\n"
73357339
" return %0 : !torch.list<int>\n"
@@ -12757,6 +12761,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1275712761
" }\n"
1275812762
" return %0#1 : !torch.int\n"
1275912763
" }\n"
12764+
" func.func @\"__torch_mlir_dtype_fn.aten.rms_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<float>) -> !torch.int {\n"
12765+
" %none = torch.constant.none\n"
12766+
" %str = torch.constant.str \"AssertionError: \"\n"
12767+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12768+
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
12769+
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
12770+
" torch.prim.If %2 -> () {\n"
12771+
" torch.prim.If.yield\n"
12772+
" } else {\n"
12773+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
12774+
" torch.prim.If.yield\n"
12775+
" }\n"
12776+
" return %0#1 : !torch.int\n"
12777+
" }\n"
1276012778
" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.bool) -> !torch.int {\n"
1276112779
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1276212780
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7582,6 +7582,85 @@ class DecomposeAtenNativeLayerNormOp
75827582
};
75837583
} // namespace
75847584

7585+
// RMS normalization:
7586+
// rms(x) = sqrt(eps + mean(x^2))
7587+
// output = (x / rms(x)) * weight
7588+
namespace {
7589+
class DecomposeAtenRMSLayerNormOp : public OpRewritePattern<AtenRmsNormOp> {
7590+
using OpRewritePattern<AtenRmsNormOp>::OpRewritePattern;
7591+
7592+
LogicalResult matchAndRewrite(AtenRmsNormOp op,
7593+
PatternRewriter &rewriter) const override {
7594+
Location loc = op.getLoc();
7595+
auto context = op.getContext();
7596+
auto input = op.getInput();
7597+
auto inputTy = dyn_cast<ValueTensorType>(input.getType());
7598+
if (!inputTy || !inputTy.hasSizes() || !inputTy.hasDtype())
7599+
return rewriter.notifyMatchFailure(
7600+
op, "Expected input to be a tensor with sizes and a dtype");
7601+
7602+
auto outputTy = dyn_cast<ValueTensorType>(op.getType());
7603+
if (!outputTy.hasDtype())
7604+
return rewriter.notifyMatchFailure(op, "output should have a dtype.");
7605+
7606+
int64_t inputRank = inputTy.getSizes().size();
7607+
Value normalizedShape = op.getNormalizedShape();
7608+
SmallVector<Value> normalizedShapeSizesTorchInt;
7609+
if (!getListConstructElements(normalizedShape,
7610+
normalizedShapeSizesTorchInt))
7611+
return rewriter.notifyMatchFailure(op,
7612+
"should have constant shape values.");
7613+
7614+
int64_t normalize_from_idx =
7615+
inputRank - normalizedShapeSizesTorchInt.size();
7616+
auto reduceDimInts =
7617+
llvm::to_vector<4>(llvm::seq<int64_t>(normalize_from_idx, inputRank));
7618+
auto sizeListType = ListType::get(IntType::get(context));
7619+
7620+
SmallVector<Value> reduceDimVals;
7621+
for (int64_t dim : reduceDimInts)
7622+
reduceDimVals.push_back(rewriter.create<Torch::ConstantIntOp>(
7623+
loc, rewriter.getI64IntegerAttr(dim)));
7624+
Value reduceDimList =
7625+
rewriter.create<PrimListConstructOp>(loc, sizeListType, reduceDimVals);
7626+
7627+
auto inputShape = inputTy.getSizes();
7628+
SmallVector<int64_t> reducedShape(inputShape.begin(), inputShape.end());
7629+
for (int64_t i : reduceDimInts)
7630+
reducedShape[i] = 1;
7631+
auto reducedTy =
7632+
ValueTensorType::get(context, reducedShape, inputTy.getDtype());
7633+
// x^2
7634+
Value inputSquared = rewriter.create<AtenSquareOp>(loc, inputTy, input);
7635+
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
7636+
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
7637+
// mean(x^2)
7638+
Value mean = rewriter.create<AtenMeanDimOp>(loc, reducedTy, inputSquared,
7639+
reduceDimList, cstTrue, none);
7640+
// mean(x^2) + eps: Add eps if provided
7641+
if (!isa<Torch::NoneType>(op.getEps().getType())) {
7642+
Value one = rewriter.create<Torch::ConstantIntOp>(
7643+
loc, rewriter.getI64IntegerAttr(1));
7644+
mean = rewriter.create<AtenAddScalarOp>(loc, reducedTy, mean, op.getEps(),
7645+
one);
7646+
}
7647+
// rsqrt(mean(x^2) + eps)
7648+
Value invRMS = rewriter.create<AtenRsqrtOp>(loc, reducedTy, mean);
7649+
// rsqrt(mean(x^2) + eps) * x
7650+
Value normalized =
7651+
rewriter.create<AtenMulTensorOp>(loc, inputTy, input, invRMS);
7652+
// Optionally multiply by weight if provided
7653+
Value weight = op.getWeight();
7654+
if (!isa<Torch::NoneType>(weight.getType())) {
7655+
normalized =
7656+
rewriter.create<AtenMulTensorOp>(loc, outputTy, normalized, weight);
7657+
}
7658+
rewriter.replaceOp(op, normalized);
7659+
return success();
7660+
}
7661+
};
7662+
} // namespace
7663+
75857664
namespace {
75867665
// Decompose `aten.emptyLike` op into `aten.size` and `aten.empty` ops.
75877666
class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
@@ -12218,6 +12297,7 @@ class DecomposeComplexOpsPass
1221812297
addPatternIfTargetOpIsIllegal<DecomposeAtenInstanceNormOp>(patterns);
1221912298
addPatternIfTargetOpIsIllegal<DecomposeAtenLayerNormOp>(patterns);
1222012299
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeLayerNormOp>(patterns);
12300+
addPatternIfTargetOpIsIllegal<DecomposeAtenRMSLayerNormOp>(patterns);
1222112301
addPatternIfTargetOpIsIllegal<DecomposeAtenGroupNormOp>(patterns);
1222212302
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeGroupNormOp>(patterns);
1222312303
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeBatchNormOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
437437
target.addIllegalOp<AtenInstanceNormOp>();
438438
target.addIllegalOp<AtenLayerNormOp>();
439439
target.addIllegalOp<AtenNativeLayerNormOp>();
440+
target.addIllegalOp<AtenRmsNormOp>();
440441
target.addIllegalOp<AtenGroupNormOp>();
441442
target.addIllegalOp<AtenNativeGroupNormOp>();
442443
target.addIllegalOp<AtenNativeBatchNormOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,6 +1473,10 @@
14731473
"Rot90MultipleRotationsModule_basic",
14741474
"Rot90NegativeEvenRotationsModule_basic",
14751475
"Rot90NegativeOddRotationsModule_basic",
1476+
"RMSNormModule_basic",
1477+
"RMSNormWithoutEpsModule_basic",
1478+
"RMSNormWithoutWeightModule_basic",
1479+
"RMSNormAllNormalizeModule_basic",
14761480
"RsubInt0d_NumToTensor_Module_basic",
14771481
"ScalarConstantTupleModule_basic",
14781482
"ScalarImplicitFloatModule_basic",
@@ -2331,6 +2335,10 @@
23312335
"IscloseStaticModuleTrue_basic",
23322336
"IscloseStaticModule_basic",
23332337
"LayerNormNormalizeOverAllDimsModule_basic",
2338+
"RMSNormModule_basic",
2339+
"RMSNormWithoutEpsModule_basic",
2340+
"RMSNormWithoutWeightModule_basic",
2341+
"RMSNormAllNormalizeModule_basic",
23342342
"LeakyReluBackwardModule_basic",
23352343
"LeakyReluBackwardStaticModule_basic",
23362344
"LiftFreshCopyModule_basic",
@@ -3043,6 +3051,11 @@
30433051
"NativeGroupNormBackwardModule_basic",
30443052
"NativeGroupNormModule_basic",
30453053
"NativeLayerNormDynamicModule_basic",
3054+
"RMSNormModule_basic",
3055+
"RMSNormWithoutEpsModule_basic",
3056+
"RMSNormWithoutWeightModule_basic",
3057+
"RMSNormAllNormalizeModule_basic",
3058+
"RMSNormDynamicModule_basic",
30463059
"NeFloatIntModule_basic",
30473060
"NeIntModule_basic",
30483061
"NewEmptyStridedModuleDefaultDtype_basic",
@@ -4748,6 +4761,11 @@
47484761
"ReshapeCollapseModule_basic",
47494762
"ReshapeDynamicModule_basic",
47504763
"ReshapeExpandModule_basic",
4764+
"RMSNormModule_basic",
4765+
"RMSNormWithoutEpsModule_basic",
4766+
"RMSNormWithoutWeightModule_basic",
4767+
"RMSNormAllNormalizeModule_basic",
4768+
"RMSNormDynamicModule_basic",
47514769
"RollModule_basic",
47524770
"RsubIntModule_noalpha_basic",
47534771
"ScalarConstantTupleModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,9 @@ def aten〇gather〡shape(self: List[int], dim: int, index: List[int], sparse_gr
667667
def aten〇layer_norm〡shape(input: List[int], normalized_shape: List[int], weight: Optional[List[int]] = None, bias: Optional[List[int]] = None, eps: float = 1.0000000000000001e-05, cudnn_enable: bool = True) -> List[int]:
668668
return upstream_shape_functions.unary(input)
669669

670+
def aten〇rms_norm〡shape(input: List[int], normalized_shape: List[int], weight: Optional[List[int]] = None, eps: Optional[float] = None) -> List[int]:
671+
return upstream_shape_functions.unary(input)
672+
670673
def aten〇_softmax_backward_data〡shape(grad_output: List[int], output: List[int], dim: int, input_dtype: int) -> List[int]:
671674
return upstream_shape_functions.unary(output)
672675

@@ -3440,6 +3443,13 @@ def aten〇layer_norm〡dtype(input_rank_dtype: Tuple[int, int], normalized_shap
34403443
assert not is_integer_dtype(input_dtype)
34413444
return input_dtype
34423445

3446+
@check_dtype_function(_check_tensors_with_the_same_dtype(
3447+
num_of_tensors=1, error_types={*all_integer_dtypes()}, normalized_shape=[1]))
3448+
def aten〇rms_norm〡dtype(input_rank_dtype: Tuple[int, int], normalized_shape: List[int], weight_rank_dtype: Optional[Tuple[int, int]] = None, eps: Optional[float] = None) -> int:
3449+
input_rank, input_dtype = input_rank_dtype
3450+
assert not is_integer_dtype(input_dtype)
3451+
return input_dtype
3452+
34433453
@check_dtype_function(_check_two_tensor_op(negative_slope=0.1, self_is_result=False))
34443454
def aten〇leaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float, complex], self_is_result: bool) -> int:
34453455
grad_output_rank, grad_output_dtype = grad_output_rank_dtype

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,7 @@ def emit_with_mutating_variants(key, **kwargs):
643643
emit(
644644
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
645645
)
646+
emit("aten::rms_norm : (Tensor, int[], Tensor?, float?) -> (Tensor)")
646647
emit("aten::renorm : (Tensor, Scalar, int, Scalar) -> (Tensor)", has_verifier=True)
647648
emit("aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)", has_verifier=True)
648649
emit("aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)")

projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,112 @@ def AtenInstanceNormModule_basic(module, tu: TestUtils):
635635
module.forward(tu.rand(1, 2, 1, 3), tu.rand(2), tu.rand(2))
636636

637637

638+
# ==============================================================================
639+
class RMSNormModule(torch.nn.Module):
640+
def __init__(self):
641+
super().__init__()
642+
643+
@export
644+
@annotate_args(
645+
[
646+
None,
647+
([8, 9, 1, 2, 4], torch.float32, True),
648+
([1, 2, 4], torch.float32, True),
649+
]
650+
)
651+
def forward(self, x, weight):
652+
list = [1, 2, 4]
653+
return torch.ops.aten.rms_norm(x, list, weight, eps=0.5)
654+
655+
656+
@register_test_case(module_factory=lambda: RMSNormModule())
657+
def RMSNormModule_basic(module, tu: TestUtils):
658+
module.forward(tu.rand(8, 9, 1, 2, 4), tu.rand(1, 2, 4))
659+
660+
661+
class RMSNormWithoutEpsModule(torch.nn.Module):
662+
def __init__(self):
663+
super().__init__()
664+
665+
@export
666+
@annotate_args(
667+
[
668+
None,
669+
([2, 5, 2, 2, 3], torch.float32, True),
670+
([2, 2, 3], torch.float32, True),
671+
]
672+
)
673+
def forward(self, x, weight):
674+
list = [2, 2, 3]
675+
return torch.ops.aten.rms_norm(x, list, weight)
676+
677+
678+
@register_test_case(module_factory=lambda: RMSNormWithoutEpsModule())
679+
def RMSNormWithoutEpsModule_basic(module, tu: TestUtils):
680+
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3))
681+
682+
683+
class RMSNormWithoutWeightModule(torch.nn.Module):
684+
def __init__(self):
685+
super().__init__()
686+
687+
@export
688+
@annotate_args(
689+
[
690+
None,
691+
([1, 2, 3, 4], torch.float32, True),
692+
]
693+
)
694+
def forward(self, x):
695+
list = [4]
696+
return torch.ops.aten.rms_norm(x, list, eps=0.5)
697+
698+
699+
@register_test_case(module_factory=lambda: RMSNormWithoutWeightModule())
700+
def RMSNormWithoutWeightModule_basic(module, tu: TestUtils):
701+
module.forward(tu.rand(1, 2, 3, 4))
702+
703+
704+
class RMSNormAllNormalizeModule(torch.nn.Module):
705+
def __init__(self):
706+
super().__init__()
707+
708+
@export
709+
@annotate_args(
710+
[None, ([5, 6, 3], torch.float32, True), ([5, 6, 3], torch.float32, True)]
711+
)
712+
def forward(self, x, weight):
713+
list = [5, 6, 3]
714+
return torch.ops.aten.rms_norm(x, list, weight, eps=0.7)
715+
716+
717+
@register_test_case(module_factory=lambda: RMSNormAllNormalizeModule())
718+
def RMSNormAllNormalizeModule_basic(module, tu: TestUtils):
719+
module.forward(tu.rand(5, 6, 3), tu.rand(5, 6, 3))
720+
721+
722+
class RMSNormDynamicModule(torch.nn.Module):
723+
def __init__(self):
724+
super().__init__()
725+
726+
@export
727+
@annotate_args(
728+
[
729+
None,
730+
([-1, -1, -1, -1], torch.float32, True),
731+
([-1, -1, -1], torch.float32, True),
732+
]
733+
)
734+
def forward(self, x, weight):
735+
list = [2, 3, 4]
736+
return torch.ops.aten.rms_norm(x, list, weight, eps=0.8)
737+
738+
739+
@register_test_case(module_factory=lambda: RMSNormDynamicModule())
740+
def RMSNormDynamicModule_basic(module, tu: TestUtils):
741+
module.forward(tu.rand(1, 2, 3, 4), tu.rand(2, 3, 4))
742+
743+
638744
# ==============================================================================
639745
class RenormModuleFloat32(torch.nn.Module):
640746
def __init__(self):

0 commit comments

Comments
 (0)