diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index df3516a9e4a6..ebe4347a2aca 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -10214,6 +10214,30 @@ def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ }]; } +def Torch_AtenReplicationPad1dOp : Torch_Op<"aten.replication_pad1d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::replication_pad1d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$padding + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenReplicationPad1dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenReplicationPad1dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenReplicationPad2dOp : Torch_Op<"aten.replication_pad2d", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index 02853b14072a..132daafa5afb 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -116,6 +116,83 @@ class ConvertAtenConstantPadNdOp namespace { +class ConvertAtenReplicationPad1dOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtenReplicationPad1dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + Location loc = op.getLoc(); + Value input = adaptor.getSelf(); + auto inputType = llvm::cast(input.getType()); + int64_t inputRank = inputType.getRank(); + + if (inputRank < 2) + return rewriter.notifyMatchFailure(op, "input rank must be at least 2"); + + SmallVector padInts; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) + return rewriter.notifyMatchFailure( + op, "only support constant int pad ranges"); + + if (padInts.size() != 2) + return rewriter.notifyMatchFailure( + op, "pad range must have exactly two values"); + + int64_t leftPad = padInts[0]; + int64_t rightPad = padInts[1]; + + int64_t dimToPad = inputRank - 1; + Value one = rewriter.create(loc, 1); + + SmallVector inputShape = getTensorSizes(rewriter, loc, input); + Value widthSize = inputShape[dimToPad]; + Value widthMinusOne = rewriter.create(loc, widthSize, one); + + // Build offset and size arrays for slicing + SmallVector allOneStrides(inputRank, + rewriter.getIndexAttr(1)); + SmallVector leftOffsets(inputRank, rewriter.getIndexAttr(0)); + SmallVector rightOffsets(inputRank, rewriter.getIndexAttr(0)); + SmallVector sizes(inputRank, rewriter.getIndexAttr(0)); + for (int i = 0; i < inputRank; ++i) + sizes[i] = (i == dimToPad) ? rewriter.getIndexAttr(1) + : getAsOpFoldResult(inputShape[i]); + + rightOffsets[dimToPad] = getAsOpFoldResult(widthMinusOne); + + // Extract leftmost and rightmost slices + Value leftSlice = rewriter.create( + loc, input, leftOffsets, sizes, allOneStrides); + Value rightSlice = rewriter.create( + loc, input, rightOffsets, sizes, allOneStrides); + + // Aggregate slices to concat together + SmallVector resultParts; + resultParts.reserve(leftPad + rightPad + 1); + + resultParts.append(leftPad, leftSlice); + resultParts.push_back(input); + resultParts.append(rightPad, rightSlice); + + Value result = + rewriter.create(loc, dimToPad, resultParts); + Type resultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resultType, result); + + return success(); + } +}; + +} // namespace + +namespace { + // Lower aten.replication_pad2d operator into a sequence of // tensor.extract_slice and tensor.concat operations. @@ -621,6 +698,8 @@ void mlir::torch::torch_to_linalg:: MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 9d10867c095a..fc65f7f1653a 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10830,6 +10830,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " } : (!torch.int, !torch.bool) -> ()\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.replication_pad1d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %str = torch.constant.str \"AssertionError: padding size expected to be 2\"\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: \"\n" +" %int2 = torch.constant.int 2\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.replication_pad2d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %false = torch.constant.bool false\n" " %str = torch.constant.str \"AssertionError: padding size expected to be 4\"\n" @@ -10856,6 +10881,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" " return %4 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.replication_pad1d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.replication_pad2d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index bd15f9b6f1cb..16b8ee2ebca5 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -8490,12 +8490,6 @@ class DecomposeAtenPadOp : public OpRewritePattern { } } - // we don't have support for 1-D replicate pad, so pass it as 2d if - // possible. - // TODO: add support for AtenReplicatePad1dOp and remove this. - if (mode == "replicate" && usefulPadIndexEnd == 2 && padValues.size() >= 4) - usefulPadIndexEnd = 4; - // make a new list of padding ints if dimensionality reduction can be // performed if (usefulPadIndexEnd < padValues.size()) { @@ -8533,11 +8527,20 @@ class DecomposeAtenPadOp : public OpRewritePattern { } if (mode == "replicate") { - // only support for replication pad 2d - if (numPadDims != 2) - return failure(); - rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSelf(), usefulPads); + switch (numPadDims) { + case 1: + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), usefulPads); + break; + case 2: + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), usefulPads); + break; + default: + return rewriter.notifyMatchFailure( + op, "unsupported number of dims for 'reflect' mode: " + + std::to_string(numPadDims)); + } return success(); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index c9df40559e3d..e7833fd9ac33 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -840,6 +840,8 @@ "ReflectionPad3dModuleRight_basic", "ReflectionPad3dModuleFront_basic", "ReflectionPad3dModuleBack_basic", + "ReplicationPad1dModule_2DInput_basic", + "ReplicationPad1dModule_3DInput_basic", "ReplicationPad2dModule_basic", "ReplicationPad2dModule_bottom0", "ReplicationPad2dModule_left0", @@ -3927,6 +3929,8 @@ "IndexPutImpl1DFloatNonAccumulateModule_basic", "IndexPutImpl1DIntNonAccumulateModule_basic", "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", + "ReplicationPad1dModule_2DInput_basic", + "ReplicationPad1dModule_3DInput_basic", } ONNX_TOSA_CRASHING_SET = { @@ -4766,6 +4770,8 @@ "RMSNormWithoutWeightModule_basic", "RMSNormAllNormalizeModule_basic", "RMSNormDynamicModule_basic", + "ReplicationPad1dModule_2DInput_basic", + "ReplicationPad1dModule_3DInput_basic", "RollModule_basic", "RsubIntModule_noalpha_basic", "ScalarConstantTupleModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index cee28fbc072f..50ea52abdba9 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -2250,11 +2250,20 @@ def pad_shape_fn(input: List[int], pad: List[int], validate_pad : bool = False): def aten〇constant_pad_nd〡shape(self: List[int], pad: List[int], value: float = 0) -> List[int]: return pad_shape_fn(self, pad) +def aten〇replication_pad1d〡shape(self: List[int], padding: List[int]) -> List[int]: + assert len(self) >= 2 + assert len(padding) == 2, 'padding size expected to be 2' + return pad_shape_fn(self, padding) + def aten〇replication_pad2d〡shape(self: List[int], padding: List[int]) -> List[int]: assert len(self) >= 2 assert len(padding) == 4, 'padding size expected to be 4' return pad_shape_fn(self, padding) +def aten〇replication_pad1d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + def aten〇replication_pad2d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index a8799cc72522..6a173877b0b0 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -805,6 +805,7 @@ def emit_with_mutating_variants(key, **kwargs): # Misc tensor ops. emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") + emit("aten::replication_pad1d : (Tensor, int[]) -> (Tensor)") emit("aten::replication_pad2d : (Tensor, int[]) -> (Tensor)") emit("aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)") emit("aten::reflection_pad2d : (Tensor, int[]) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py index b9c58551d657..29578a59bc65 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py @@ -13,6 +13,52 @@ # ============================================================================== +class ReplicationPad1dModule_3DInput(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.replication_pad1d(x, [3, 5]) + + +@register_test_case(module_factory=lambda: ReplicationPad1dModule_3DInput()) +def ReplicationPad1dModule_3DInput_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 15, 20, low=-1)) + + +# ============================================================================== + + +class ReplicationPad1dModule_2DInput(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.replication_pad1d(x, [2, 3]) + + +@register_test_case(module_factory=lambda: ReplicationPad1dModule_2DInput()) +def ReplicationPad1dModule_2DInput_basic(module, tu: TestUtils): + module.forward(tu.rand(7, 12, low=-1)) + + +# ============================================================================== + + class ReflectionPad2dModule(torch.nn.Module): def __init__(self): super().__init__()