Skip to content

Commit 3850332

Browse files
authored
[TorchToLinalg] Support AtenReplicationPad1d with lowering to linalg backend (#4217)
This PR takes care of #4216. - Add `AtenReplicationPad1d` support in `torch` dialect. - Lower to **linalg** backend: - `tensor.extract_slice` ops for left & right slices - `tensor.concat` ops to **expand** each slice into right shape and join all slices in order - Update `AtenPadOp` decomposition to make use of **1d** replication pad instead of using 2d variant. - Add tests & update **failing** cases in `xfail` sets. --------- Signed-off-by: Zahid Wakeel <zahid.wakeel@multicorewareinc.com>
1 parent 089b217 commit 3850332

File tree

8 files changed

+208
-11
lines changed

8 files changed

+208
-11
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10214,6 +10214,30 @@ def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
1021410214
}];
1021510215
}
1021610216

10217+
def Torch_AtenReplicationPad1dOp : Torch_Op<"aten.replication_pad1d", [
10218+
AllowsTypeRefinement,
10219+
HasValueSemantics,
10220+
ReadOnly
10221+
]> {
10222+
let summary = "Generated op for `aten::replication_pad1d : (Tensor, int[]) -> (Tensor)`";
10223+
let arguments = (ins
10224+
AnyTorchTensorType:$self,
10225+
AnyTorchListOfTorchIntType:$padding
10226+
);
10227+
let results = (outs
10228+
AnyTorchOptionalTensorType:$result
10229+
);
10230+
let hasCustomAssemblyFormat = 1;
10231+
let extraClassDefinition = [{
10232+
ParseResult AtenReplicationPad1dOp::parse(OpAsmParser &parser, OperationState &result) {
10233+
return parseDefaultTorchOp(parser, result, 2, 1);
10234+
}
10235+
void AtenReplicationPad1dOp::print(OpAsmPrinter &printer) {
10236+
printDefaultTorchOp(printer, *this, 2, 1);
10237+
}
10238+
}];
10239+
}
10240+
1021710241
def Torch_AtenReplicationPad2dOp : Torch_Op<"aten.replication_pad2d", [
1021810242
AllowsTypeRefinement,
1021910243
HasValueSemantics,

lib/Conversion/TorchToLinalg/TensorConstructors.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,83 @@ class ConvertAtenConstantPadNdOp
116116

117117
namespace {
118118

119+
class ConvertAtenReplicationPad1dOp
120+
: public OpConversionPattern<AtenReplicationPad1dOp> {
121+
public:
122+
using OpConversionPattern::OpConversionPattern;
123+
124+
LogicalResult
125+
matchAndRewrite(AtenReplicationPad1dOp op, OpAdaptor adaptor,
126+
ConversionPatternRewriter &rewriter) const override {
127+
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
128+
return failure();
129+
130+
Location loc = op.getLoc();
131+
Value input = adaptor.getSelf();
132+
auto inputType = llvm::cast<RankedTensorType>(input.getType());
133+
int64_t inputRank = inputType.getRank();
134+
135+
if (inputRank < 2)
136+
return rewriter.notifyMatchFailure(op, "input rank must be at least 2");
137+
138+
SmallVector<int64_t> padInts;
139+
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts)))
140+
return rewriter.notifyMatchFailure(
141+
op, "only support constant int pad ranges");
142+
143+
if (padInts.size() != 2)
144+
return rewriter.notifyMatchFailure(
145+
op, "pad range must have exactly two values");
146+
147+
int64_t leftPad = padInts[0];
148+
int64_t rightPad = padInts[1];
149+
150+
int64_t dimToPad = inputRank - 1;
151+
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
152+
153+
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
154+
Value widthSize = inputShape[dimToPad];
155+
Value widthMinusOne = rewriter.create<arith::SubIOp>(loc, widthSize, one);
156+
157+
// Build offset and size arrays for slicing
158+
SmallVector<OpFoldResult> allOneStrides(inputRank,
159+
rewriter.getIndexAttr(1));
160+
SmallVector<OpFoldResult> leftOffsets(inputRank, rewriter.getIndexAttr(0));
161+
SmallVector<OpFoldResult> rightOffsets(inputRank, rewriter.getIndexAttr(0));
162+
SmallVector<OpFoldResult> sizes(inputRank, rewriter.getIndexAttr(0));
163+
for (int i = 0; i < inputRank; ++i)
164+
sizes[i] = (i == dimToPad) ? rewriter.getIndexAttr(1)
165+
: getAsOpFoldResult(inputShape[i]);
166+
167+
rightOffsets[dimToPad] = getAsOpFoldResult(widthMinusOne);
168+
169+
// Extract leftmost and rightmost slices
170+
Value leftSlice = rewriter.create<tensor::ExtractSliceOp>(
171+
loc, input, leftOffsets, sizes, allOneStrides);
172+
Value rightSlice = rewriter.create<tensor::ExtractSliceOp>(
173+
loc, input, rightOffsets, sizes, allOneStrides);
174+
175+
// Aggregate slices to concat together
176+
SmallVector<Value> resultParts;
177+
resultParts.reserve(leftPad + rightPad + 1);
178+
179+
resultParts.append(leftPad, leftSlice);
180+
resultParts.push_back(input);
181+
resultParts.append(rightPad, rightSlice);
182+
183+
Value result =
184+
rewriter.create<tensor::ConcatOp>(loc, dimToPad, resultParts);
185+
Type resultType = getTypeConverter()->convertType(op.getType());
186+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
187+
188+
return success();
189+
}
190+
};
191+
192+
} // namespace
193+
194+
namespace {
195+
119196
// Lower aten.replication_pad2d operator into a sequence of
120197
// tensor.extract_slice and tensor.concat operations.
121198

@@ -621,6 +698,8 @@ void mlir::torch::torch_to_linalg::
621698
MLIRContext *context = patterns.getContext();
622699
target.addIllegalOp<AtenReplicationPad2dOp>();
623700
patterns.add<ConvertAtenReplicationPad2dOp>(typeConverter, context);
701+
target.addIllegalOp<AtenReplicationPad1dOp>();
702+
patterns.add<ConvertAtenReplicationPad1dOp>(typeConverter, context);
624703
target.addIllegalOp<AtenConstantPadNdOp>();
625704
patterns.add<ConvertAtenConstantPadNdOp>(typeConverter, context);
626705
target.addIllegalOp<AtenZerosOp, AtenOnesOp>();

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10830,6 +10830,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1083010830
" } : (!torch.int, !torch.bool) -> ()\n"
1083110831
" return %arg0 : !torch.list<int>\n"
1083210832
" }\n"
10833+
" func.func @\"__torch_mlir_shape_fn.aten.replication_pad1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
10834+
" %false = torch.constant.bool false\n"
10835+
" %str = torch.constant.str \"AssertionError: padding size expected to be 2\"\n"
10836+
" %none = torch.constant.none\n"
10837+
" %str_0 = torch.constant.str \"AssertionError: \"\n"
10838+
" %int2 = torch.constant.int 2\n"
10839+
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
10840+
" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
10841+
" torch.prim.If %1 -> () {\n"
10842+
" torch.prim.If.yield\n"
10843+
" } else {\n"
10844+
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
10845+
" torch.prim.If.yield\n"
10846+
" }\n"
10847+
" %2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
10848+
" %3 = torch.aten.eq.int %2, %int2 : !torch.int, !torch.int -> !torch.bool\n"
10849+
" torch.prim.If %3 -> () {\n"
10850+
" torch.prim.If.yield\n"
10851+
" } else {\n"
10852+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
10853+
" torch.prim.If.yield\n"
10854+
" }\n"
10855+
" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
10856+
" return %4 : !torch.list<int>\n"
10857+
" }\n"
1083310858
" func.func @\"__torch_mlir_shape_fn.aten.replication_pad2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
1083410859
" %false = torch.constant.bool false\n"
1083510860
" %str = torch.constant.str \"AssertionError: padding size expected to be 4\"\n"
@@ -10856,6 +10881,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1085610881
" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
1085710882
" return %4 : !torch.list<int>\n"
1085810883
" }\n"
10884+
" func.func @\"__torch_mlir_dtype_fn.aten.replication_pad1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
10885+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
10886+
" return %0#1 : !torch.int\n"
10887+
" }\n"
1085910888
" func.func @\"__torch_mlir_dtype_fn.aten.replication_pad2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
1086010889
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1086110890
" return %0#1 : !torch.int\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8490,12 +8490,6 @@ class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
84908490
}
84918491
}
84928492

8493-
// we don't have support for 1-D replicate pad, so pass it as 2d if
8494-
// possible.
8495-
// TODO: add support for AtenReplicatePad1dOp and remove this.
8496-
if (mode == "replicate" && usefulPadIndexEnd == 2 && padValues.size() >= 4)
8497-
usefulPadIndexEnd = 4;
8498-
84998493
// make a new list of padding ints if dimensionality reduction can be
85008494
// performed
85018495
if (usefulPadIndexEnd < padValues.size()) {
@@ -8533,11 +8527,20 @@ class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
85338527
}
85348528

85358529
if (mode == "replicate") {
8536-
// only support for replication pad 2d
8537-
if (numPadDims != 2)
8538-
return failure();
8539-
rewriter.replaceOpWithNewOp<AtenReplicationPad2dOp>(
8540-
op, op.getType(), op.getSelf(), usefulPads);
8530+
switch (numPadDims) {
8531+
case 1:
8532+
rewriter.replaceOpWithNewOp<AtenReplicationPad1dOp>(
8533+
op, op.getType(), op.getSelf(), usefulPads);
8534+
break;
8535+
case 2:
8536+
rewriter.replaceOpWithNewOp<AtenReplicationPad2dOp>(
8537+
op, op.getType(), op.getSelf(), usefulPads);
8538+
break;
8539+
default:
8540+
return rewriter.notifyMatchFailure(
8541+
op, "unsupported number of dims for 'reflect' mode: " +
8542+
std::to_string(numPadDims));
8543+
}
85418544
return success();
85428545
}
85438546

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,8 @@
840840
"ReflectionPad3dModuleRight_basic",
841841
"ReflectionPad3dModuleFront_basic",
842842
"ReflectionPad3dModuleBack_basic",
843+
"ReplicationPad1dModule_2DInput_basic",
844+
"ReplicationPad1dModule_3DInput_basic",
843845
"ReplicationPad2dModule_basic",
844846
"ReplicationPad2dModule_bottom0",
845847
"ReplicationPad2dModule_left0",
@@ -3927,6 +3929,8 @@
39273929
"IndexPutImpl1DFloatNonAccumulateModule_basic",
39283930
"IndexPutImpl1DIntNonAccumulateModule_basic",
39293931
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
3932+
"ReplicationPad1dModule_2DInput_basic",
3933+
"ReplicationPad1dModule_3DInput_basic",
39303934
}
39313935

39323936
ONNX_TOSA_CRASHING_SET = {
@@ -4766,6 +4770,8 @@
47664770
"RMSNormWithoutWeightModule_basic",
47674771
"RMSNormAllNormalizeModule_basic",
47684772
"RMSNormDynamicModule_basic",
4773+
"ReplicationPad1dModule_2DInput_basic",
4774+
"ReplicationPad1dModule_3DInput_basic",
47694775
"RollModule_basic",
47704776
"RsubIntModule_noalpha_basic",
47714777
"ScalarConstantTupleModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2250,11 +2250,20 @@ def pad_shape_fn(input: List[int], pad: List[int], validate_pad : bool = False):
22502250
def aten〇constant_pad_nd〡shape(self: List[int], pad: List[int], value: float = 0) -> List[int]:
22512251
return pad_shape_fn(self, pad)
22522252

2253+
def aten〇replication_pad1d〡shape(self: List[int], padding: List[int]) -> List[int]:
2254+
assert len(self) >= 2
2255+
assert len(padding) == 2, 'padding size expected to be 2'
2256+
return pad_shape_fn(self, padding)
2257+
22532258
def aten〇replication_pad2d〡shape(self: List[int], padding: List[int]) -> List[int]:
22542259
assert len(self) >= 2
22552260
assert len(padding) == 4, 'padding size expected to be 4'
22562261
return pad_shape_fn(self, padding)
22572262

2263+
def aten〇replication_pad1d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int:
2264+
self_rank, self_dtype = self_rank_dtype
2265+
return self_dtype
2266+
22582267
def aten〇replication_pad2d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int:
22592268
self_rank, self_dtype = self_rank_dtype
22602269
return self_dtype

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,7 @@ def emit_with_mutating_variants(key, **kwargs):
805805

806806
# Misc tensor ops.
807807
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
808+
emit("aten::replication_pad1d : (Tensor, int[]) -> (Tensor)")
808809
emit("aten::replication_pad2d : (Tensor, int[]) -> (Tensor)")
809810
emit("aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)")
810811
emit("aten::reflection_pad2d : (Tensor, int[]) -> (Tensor)")

projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,52 @@
1313
# ==============================================================================
1414

1515

16+
class ReplicationPad1dModule_3DInput(torch.nn.Module):
17+
def __init__(self):
18+
super().__init__()
19+
20+
@export
21+
@annotate_args(
22+
[
23+
None,
24+
([-1, -1, -1], torch.float32, True),
25+
]
26+
)
27+
def forward(self, x):
28+
return torch.ops.aten.replication_pad1d(x, [3, 5])
29+
30+
31+
@register_test_case(module_factory=lambda: ReplicationPad1dModule_3DInput())
32+
def ReplicationPad1dModule_3DInput_basic(module, tu: TestUtils):
33+
module.forward(tu.rand(1, 15, 20, low=-1))
34+
35+
36+
# ==============================================================================
37+
38+
39+
class ReplicationPad1dModule_2DInput(torch.nn.Module):
40+
def __init__(self):
41+
super().__init__()
42+
43+
@export
44+
@annotate_args(
45+
[
46+
None,
47+
([-1, -1], torch.float32, True),
48+
]
49+
)
50+
def forward(self, x):
51+
return torch.ops.aten.replication_pad1d(x, [2, 3])
52+
53+
54+
@register_test_case(module_factory=lambda: ReplicationPad1dModule_2DInput())
55+
def ReplicationPad1dModule_2DInput_basic(module, tu: TestUtils):
56+
module.forward(tu.rand(7, 12, low=-1))
57+
58+
59+
# ==============================================================================
60+
61+
1662
class ReflectionPad2dModule(torch.nn.Module):
1763
def __init__(self):
1864
super().__init__()

0 commit comments

Comments
 (0)