From 1c8bb2ddb2316c6466d3c3f7a72ba3fab4da2565 Mon Sep 17 00:00:00 2001 From: Zahid Wakeel Date: Mon, 2 Jun 2025 00:53:24 +0530 Subject: [PATCH 1/6] Register replication pad 1d in torch dialect spec Signed-off-by: Zahid Wakeel --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 24 +++++++++++++++++++ .../build_tools/torch_ods_gen.py | 1 + 2 files changed, 25 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index df3516a9e4a6..ebe4347a2aca 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -10214,6 +10214,30 @@ def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ }]; } +def Torch_AtenReplicationPad1dOp : Torch_Op<"aten.replication_pad1d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::replication_pad1d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$padding + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenReplicationPad1dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenReplicationPad1dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenReplicationPad2dOp : Torch_Op<"aten.replication_pad2d", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index a8799cc72522..6a173877b0b0 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -805,6 +805,7 @@ def emit_with_mutating_variants(key, **kwargs): # Misc tensor ops. emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") + emit("aten::replication_pad1d : (Tensor, int[]) -> (Tensor)") emit("aten::replication_pad2d : (Tensor, int[]) -> (Tensor)") emit("aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)") emit("aten::reflection_pad2d : (Tensor, int[]) -> (Tensor)") From 5fbd750a62e10fe34fa036f645d27827fa90f7f5 Mon Sep 17 00:00:00 2001 From: Zahid Wakeel Date: Mon, 2 Jun 2025 00:53:59 +0530 Subject: [PATCH 2/6] Shape & dtype inference Signed-off-by: Zahid Wakeel --- .../Transforms/AbstractInterpLibrary.cpp | 29 +++++++++++++++++++ .../build_tools/abstract_interp_lib_gen.py | 9 ++++++ 2 files changed, 38 insertions(+) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 9d10867c095a..fc65f7f1653a 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10830,6 +10830,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " } : (!torch.int, !torch.bool) -> ()\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.replication_pad1d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %str = torch.constant.str \"AssertionError: padding size expected to be 2\"\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: \"\n" +" %int2 = torch.constant.int 2\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.replication_pad2d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %false = torch.constant.bool false\n" " %str = torch.constant.str \"AssertionError: padding size expected to be 4\"\n" @@ -10856,6 +10881,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" " return %4 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.replication_pad1d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.replication_pad2d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index cee28fbc072f..50ea52abdba9 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -2250,11 +2250,20 @@ def pad_shape_fn(input: List[int], pad: List[int], validate_pad : bool = False): def aten〇constant_pad_nd〡shape(self: List[int], pad: List[int], value: float = 0) -> List[int]: return pad_shape_fn(self, pad) +def aten〇replication_pad1d〡shape(self: List[int], padding: List[int]) -> List[int]: + assert len(self) >= 2 + assert len(padding) == 2, 'padding size expected to be 2' + return pad_shape_fn(self, padding) + def aten〇replication_pad2d〡shape(self: List[int], padding: List[int]) -> List[int]: assert len(self) >= 2 assert len(padding) == 4, 'padding size expected to be 4' return pad_shape_fn(self, padding) +def aten〇replication_pad1d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + def aten〇replication_pad2d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype From 05a4ccf103b34961e23da4d34cf842421b9e4f16 Mon Sep 17 00:00:00 2001 From: Zahid Wakeel Date: Mon, 2 Jun 2025 00:55:04 +0530 Subject: [PATCH 3/6] Lower replication pad 1d to linalg backend Signed-off-by: Zahid Wakeel --- .../TorchToLinalg/TensorConstructors.cpp | 90 +++++++++++++++++++ .../torch_mlir_e2e_test/test_suite/basic.py | 20 +++++ 2 files changed, 110 insertions(+) diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index 02853b14072a..28c5c0e15459 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -116,6 +116,94 @@ class ConvertAtenConstantPadNdOp namespace { +class ConvertAtenReplicationPad1dOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtenReplicationPad1dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + Location loc = op.getLoc(); + Value input = adaptor.getSelf(); + auto inputType = llvm::cast(input.getType()); + int64_t inputRank = inputType.getRank(); + + if (inputRank < 2) + return rewriter.notifyMatchFailure(op, "input rank must be at least 2"); + + SmallVector padInts; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) + return rewriter.notifyMatchFailure( + op, "only support constant int pad ranges"); + + if (padInts.size() != 2) + return rewriter.notifyMatchFailure( + op, "pad range must have exactly two values"); + + int64_t leftPad = padInts[0]; + int64_t rightPad = padInts[1]; + + int64_t dimToPad = inputRank - 1; + Value one = rewriter.create(loc, 1); + + SmallVector inputShape = getTensorSizes(rewriter, loc, input); + Value widthSize = inputShape[dimToPad]; + Value widthMinusOne = rewriter.create(loc, widthSize, one); + + // Build offset and size arrays for slicing + SmallVector allOneStrides(inputRank, + rewriter.getIndexAttr(1)); + SmallVector leftOffsets(inputRank, rewriter.getIndexAttr(0)); + SmallVector rightOffsets(inputRank, rewriter.getIndexAttr(0)); + SmallVector sizes(inputRank, rewriter.getIndexAttr(0)); + for (int i = 0; i < inputRank; ++i) + sizes[i] = (i == dimToPad) ? rewriter.getIndexAttr(1) + : getAsOpFoldResult(inputShape[i]); + + rightOffsets[dimToPad] = getAsOpFoldResult(widthMinusOne); + + // Extract leftmost and rightmost slices + Value leftSlice = rewriter.create( + loc, input, leftOffsets, sizes, allOneStrides); + Value rightSlice = rewriter.create( + loc, input, rightOffsets, sizes, allOneStrides); + + // Create repeated tiles + SmallVector resultParts; + + if (leftPad > 0) { + SmallVector leftTiles(leftPad, leftSlice); + Value leftConcat = + rewriter.create(loc, dimToPad, leftTiles); + resultParts.push_back(leftConcat); + } + + resultParts.push_back(input); + + if (rightPad > 0) { + SmallVector rightTiles(rightPad, rightSlice); + Value rightConcat = + rewriter.create(loc, dimToPad, rightTiles); + resultParts.push_back(rightConcat); + } + + Value result = + rewriter.create(loc, dimToPad, resultParts); + Type resultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resultType, result); + + return success(); + } +}; + +} // namespace + +namespace { + // Lower aten.replication_pad2d operator into a sequence of // tensor.extract_slice and tensor.concat operations. @@ -621,6 +709,8 @@ void mlir::torch::torch_to_linalg:: MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 1ad698db9cc1..9674539fdefc 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -685,6 +685,26 @@ def ReplicationPad2dModule_left0(module, tu: TestUtils): # ============================================================================== +class ReplicationPad1dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.replication_pad1d(x, [3, 5]) + + +@register_test_case(module_factory=lambda: ReplicationPad1dModule()) +def ReplicationPad1dModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 15, 20, low=-1)) + + class ReplicationPad2dModule_right0_module(torch.nn.Module): def __init__(self): super().__init__() From 667098e231ec6ff071daaf960fffd6fd6c92f542 Mon Sep 17 00:00:00 2001 From: Zahid Wakeel Date: Mon, 2 Jun 2025 01:42:27 +0530 Subject: [PATCH 4/6] Update AtenPadOp decomposition Signed-off-by: Zahid Wakeel --- .../Torch/Transforms/DecomposeComplexOps.cpp | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index bd15f9b6f1cb..16b8ee2ebca5 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -8490,12 +8490,6 @@ class DecomposeAtenPadOp : public OpRewritePattern { } } - // we don't have support for 1-D replicate pad, so pass it as 2d if - // possible. - // TODO: add support for AtenReplicatePad1dOp and remove this. - if (mode == "replicate" && usefulPadIndexEnd == 2 && padValues.size() >= 4) - usefulPadIndexEnd = 4; - // make a new list of padding ints if dimensionality reduction can be // performed if (usefulPadIndexEnd < padValues.size()) { @@ -8533,11 +8527,20 @@ class DecomposeAtenPadOp : public OpRewritePattern { } if (mode == "replicate") { - // only support for replication pad 2d - if (numPadDims != 2) - return failure(); - rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSelf(), usefulPads); + switch (numPadDims) { + case 1: + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), usefulPads); + break; + case 2: + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), usefulPads); + break; + default: + return rewriter.notifyMatchFailure( + op, "unsupported number of dims for 'reflect' mode: " + + std::to_string(numPadDims)); + } return success(); } From 4a9bc26ed0e9a21e453578cda23ab6cca8bef914 Mon Sep 17 00:00:00 2001 From: Zahid Wakeel Date: Mon, 2 Jun 2025 02:09:33 +0530 Subject: [PATCH 5/6] Add tests & update xfail_sets Signed-off-by: Zahid Wakeel --- projects/pt1/e2e_testing/xfail_sets.py | 6 +++ .../torch_mlir_e2e_test/test_suite/basic.py | 20 -------- .../torch_mlir_e2e_test/test_suite/padding.py | 46 +++++++++++++++++++ 3 files changed, 52 insertions(+), 20 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index c9df40559e3d..e7833fd9ac33 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -840,6 +840,8 @@ "ReflectionPad3dModuleRight_basic", "ReflectionPad3dModuleFront_basic", "ReflectionPad3dModuleBack_basic", + "ReplicationPad1dModule_2DInput_basic", + "ReplicationPad1dModule_3DInput_basic", "ReplicationPad2dModule_basic", "ReplicationPad2dModule_bottom0", "ReplicationPad2dModule_left0", @@ -3927,6 +3929,8 @@ "IndexPutImpl1DFloatNonAccumulateModule_basic", "IndexPutImpl1DIntNonAccumulateModule_basic", "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", + "ReplicationPad1dModule_2DInput_basic", + "ReplicationPad1dModule_3DInput_basic", } ONNX_TOSA_CRASHING_SET = { @@ -4766,6 +4770,8 @@ "RMSNormWithoutWeightModule_basic", "RMSNormAllNormalizeModule_basic", "RMSNormDynamicModule_basic", + "ReplicationPad1dModule_2DInput_basic", + "ReplicationPad1dModule_3DInput_basic", "RollModule_basic", "RsubIntModule_noalpha_basic", "ScalarConstantTupleModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 9674539fdefc..1ad698db9cc1 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -685,26 +685,6 @@ def ReplicationPad2dModule_left0(module, tu: TestUtils): # ============================================================================== -class ReplicationPad1dModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([-1, -1, -1], torch.float32, True), - ] - ) - def forward(self, x): - return torch.ops.aten.replication_pad1d(x, [3, 5]) - - -@register_test_case(module_factory=lambda: ReplicationPad1dModule()) -def ReplicationPad1dModule_basic(module, tu: TestUtils): - module.forward(tu.rand(1, 15, 20, low=-1)) - - class ReplicationPad2dModule_right0_module(torch.nn.Module): def __init__(self): super().__init__() diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py index b9c58551d657..29578a59bc65 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py @@ -13,6 +13,52 @@ # ============================================================================== +class ReplicationPad1dModule_3DInput(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.replication_pad1d(x, [3, 5]) + + +@register_test_case(module_factory=lambda: ReplicationPad1dModule_3DInput()) +def ReplicationPad1dModule_3DInput_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 15, 20, low=-1)) + + +# ============================================================================== + + +class ReplicationPad1dModule_2DInput(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.replication_pad1d(x, [2, 3]) + + +@register_test_case(module_factory=lambda: ReplicationPad1dModule_2DInput()) +def ReplicationPad1dModule_2DInput_basic(module, tu: TestUtils): + module.forward(tu.rand(7, 12, low=-1)) + + +# ============================================================================== + + class ReflectionPad2dModule(torch.nn.Module): def __init__(self): super().__init__() From 691eb589ca63b278dedd7d5166e99d823a4ab7e0 Mon Sep 17 00:00:00 2001 From: Zahid Wakeel Date: Sat, 14 Jun 2025 16:29:06 +0530 Subject: [PATCH 6/6] Aggregate slices to concat together Signed-off-by: Zahid Wakeel --- .../TorchToLinalg/TensorConstructors.cpp | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index 28c5c0e15459..132daafa5afb 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -172,24 +172,13 @@ class ConvertAtenReplicationPad1dOp Value rightSlice = rewriter.create( loc, input, rightOffsets, sizes, allOneStrides); - // Create repeated tiles + // Aggregate slices to concat together SmallVector resultParts; + resultParts.reserve(leftPad + rightPad + 1); - if (leftPad > 0) { - SmallVector leftTiles(leftPad, leftSlice); - Value leftConcat = - rewriter.create(loc, dimToPad, leftTiles); - resultParts.push_back(leftConcat); - } - + resultParts.append(leftPad, leftSlice); resultParts.push_back(input); - - if (rightPad > 0) { - SmallVector rightTiles(rightPad, rightSlice); - Value rightConcat = - rewriter.create(loc, dimToPad, rightTiles); - resultParts.push_back(rightConcat); - } + resultParts.append(rightPad, rightSlice); Value result = rewriter.create(loc, dimToPad, resultParts);