Skip to content

Commit e44ea22

Browse files
Revert "support aten._trilinear and improve einsum decomposition (#3784)"
This reverts commit 9c1e3b8.
1 parent 3b25ba3 commit e44ea22

File tree

8 files changed

+14
-553
lines changed

8 files changed

+14
-553
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14248,36 +14248,6 @@ def Torch_AtenGridSamplerOp : Torch_Op<"aten.grid_sampler", [
1424814248
}];
1424914249
}
1425014250

14251-
def Torch_Aten_TrilinearOp : Torch_Op<"aten._trilinear", [
14252-
AllowsTypeRefinement,
14253-
HasValueSemantics,
14254-
ReadOnly
14255-
]> {
14256-
let summary = "Generated op for `aten::_trilinear : (Tensor, Tensor, Tensor, int[], int[], int[], int[], int) -> (Tensor)`";
14257-
let arguments = (ins
14258-
AnyTorchTensorType:$i1,
14259-
AnyTorchTensorType:$i2,
14260-
AnyTorchTensorType:$i3,
14261-
AnyTorchListOfTorchIntType:$expand1,
14262-
AnyTorchListOfTorchIntType:$expand2,
14263-
AnyTorchListOfTorchIntType:$expand3,
14264-
AnyTorchListOfTorchIntType:$sumdim,
14265-
Torch_IntType:$unroll_dim
14266-
);
14267-
let results = (outs
14268-
AnyTorchOptionalTensorType:$result
14269-
);
14270-
let hasCustomAssemblyFormat = 1;
14271-
let extraClassDefinition = [{
14272-
ParseResult Aten_TrilinearOp::parse(OpAsmParser &parser, OperationState &result) {
14273-
return parseDefaultTorchOp(parser, result, 8, 1);
14274-
}
14275-
void Aten_TrilinearOp::print(OpAsmPrinter &printer) {
14276-
printDefaultTorchOp(printer, *this, 8, 1);
14277-
}
14278-
}];
14279-
}
14280-
1428114251
def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
1428214252
AllowsTypeRefinement,
1428314253
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 0 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -8864,112 +8864,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
88648864
" %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.list<int>\n"
88658865
" return %0 : !torch.list<int>\n"
88668866
" }\n"
8867-
" func.func @\"__torch_mlir_shape_fn.aten._trilinear\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.list<int>, %arg7: !torch.int) -> !torch.list<int> {\n"
8868-
" %int3 = torch.constant.int 3\n"
8869-
" %int-1 = torch.constant.int -1\n"
8870-
" %str = torch.constant.str \"AssertionError: number of dimensions must match\"\n"
8871-
" %str_0 = torch.constant.str \"expand dimension {} is out of bounds for input of shape {}\"\n"
8872-
" %true = torch.constant.bool true\n"
8873-
" %none = torch.constant.none\n"
8874-
" %str_1 = torch.constant.str \"AssertionError: \"\n"
8875-
" %str_2 = torch.constant.str \"unroll_dim must be in [0, {}]\"\n"
8876-
" %false = torch.constant.bool false\n"
8877-
" %int0 = torch.constant.int 0\n"
8878-
" %int1 = torch.constant.int 1\n"
8879-
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
8880-
" %1 = torch.aten.len.t %arg3 : !torch.list<int> -> !torch.int\n"
8881-
" %2 = torch.aten.add.int %0, %1 : !torch.int, !torch.int -> !torch.int\n"
8882-
" %3 = torch.aten.ge.int %arg7, %int0 : !torch.int, !torch.int -> !torch.bool\n"
8883-
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
8884-
" %23 = torch.aten.lt.int %arg7, %2 : !torch.int, !torch.int -> !torch.bool\n"
8885-
" torch.prim.If.yield %23 : !torch.bool\n"
8886-
" } else {\n"
8887-
" torch.prim.If.yield %false : !torch.bool\n"
8888-
" }\n"
8889-
" torch.prim.If %4 -> () {\n"
8890-
" torch.prim.If.yield\n"
8891-
" } else {\n"
8892-
" %23 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int\n"
8893-
" %24 = torch.aten.format(%str_2, %23) : !torch.str, !torch.int -> !torch.str\n"
8894-
" %25 = torch.aten.add.str %str_1, %24 : !torch.str, !torch.str -> !torch.str\n"
8895-
" torch.prim.RaiseException %25, %none : !torch.str, !torch.none\n"
8896-
" torch.prim.If.yield\n"
8897-
" }\n"
8898-
" %5 = call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
8899-
" %6 = call @__torch__.torch.jit._shape_functions._copy(%arg1) : (!torch.list<int>) -> !torch.list<int>\n"
8900-
" %7 = call @__torch__.torch.jit._shape_functions._copy(%arg2) : (!torch.list<int>) -> !torch.list<int>\n"
8901-
" %8 = torch.prim.ListConstruct %5, %6, %7 : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<list<int>>\n"
8902-
" %9 = torch.prim.ListConstruct %arg3, %arg4, %arg5 : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<list<int>>\n"
8903-
" torch.prim.Loop %int3, %true, init() {\n"
8904-
" ^bb0(%arg8: !torch.int):\n"
8905-
" %23 = torch.aten.__getitem__.t %9, %arg8 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
8906-
" %24 = torch.aten.__getitem__.t %8, %arg8 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
8907-
" %25 = torch.aten.len.t %24 : !torch.list<int> -> !torch.int\n"
8908-
" %26 = torch.aten.len.t %23 : !torch.list<int> -> !torch.int\n"
8909-
" torch.prim.Loop %26, %true, init() {\n"
8910-
" ^bb0(%arg9: !torch.int):\n"
8911-
" %27 = torch.aten.__getitem__.t %23, %arg9 : !torch.list<int>, !torch.int -> !torch.int\n"
8912-
" %28 = torch.aten.le.int %27, %25 : !torch.int, !torch.int -> !torch.bool\n"
8913-
" torch.prim.If %28 -> () {\n"
8914-
" torch.prim.If.yield\n"
8915-
" } else {\n"
8916-
" %30 = torch.aten.__getitem__.t %8, %arg8 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
8917-
" %31 = torch.aten.format(%str_0, %27, %30) : !torch.str, !torch.int, !torch.list<int> -> !torch.str\n"
8918-
" %32 = torch.aten.add.str %str_1, %31 : !torch.str, !torch.str -> !torch.str\n"
8919-
" torch.prim.RaiseException %32, %none : !torch.str, !torch.none\n"
8920-
" torch.prim.If.yield\n"
8921-
" }\n"
8922-
" %29 = torch.aten.__getitem__.t %8, %arg8 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
8923-
" torch.aten.insert.t %29, %27, %int1 : !torch.list<int>, !torch.int, !torch.int\n"
8924-
" torch.prim.Loop.condition %true, iter()\n"
8925-
" } : (!torch.int, !torch.bool) -> ()\n"
8926-
" torch.prim.Loop.condition %true, iter()\n"
8927-
" } : (!torch.int, !torch.bool) -> ()\n"
8928-
" %10 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int\n"
8929-
" %11 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int\n"
8930-
" %12 = torch.aten.eq.int %10, %11 : !torch.int, !torch.int -> !torch.bool\n"
8931-
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
8932-
" %23 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int\n"
8933-
" %24 = torch.aten.len.t %7 : !torch.list<int> -> !torch.int\n"
8934-
" %25 = torch.aten.eq.int %23, %24 : !torch.int, !torch.int -> !torch.bool\n"
8935-
" torch.prim.If.yield %25 : !torch.bool\n"
8936-
" } else {\n"
8937-
" torch.prim.If.yield %false : !torch.bool\n"
8938-
" }\n"
8939-
" torch.prim.If %13 -> () {\n"
8940-
" torch.prim.If.yield\n"
8941-
" } else {\n"
8942-
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
8943-
" torch.prim.If.yield\n"
8944-
" }\n"
8945-
" %14 = call @__torch__.torch.jit._shape_functions.broadcast_three(%5, %6, %7) : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
8946-
" %15 = torch.prim.ListConstruct %false : (!torch.bool) -> !torch.list<bool>\n"
8947-
" %16 = torch.aten.len.t %14 : !torch.list<int> -> !torch.int\n"
8948-
" %17 = torch.operator \"aten.mul.left_t\"(%15, %16) : (!torch.list<bool>, !torch.int) -> !torch.list<bool> \n"
8949-
" %18 = torch.aten.len.t %arg6 : !torch.list<int> -> !torch.int\n"
8950-
" torch.prim.Loop %18, %true, init() {\n"
8951-
" ^bb0(%arg8: !torch.int):\n"
8952-
" %23 = torch.aten.__getitem__.t %arg6, %arg8 : !torch.list<int>, !torch.int -> !torch.int\n"
8953-
" %24 = torch.aten._set_item.t %17, %23, %true : !torch.list<bool>, !torch.int, !torch.bool -> !torch.list<bool>\n"
8954-
" torch.prim.Loop.condition %true, iter()\n"
8955-
" } : (!torch.int, !torch.bool) -> ()\n"
8956-
" %19 = torch.aten.len.t %14 : !torch.list<int> -> !torch.int\n"
8957-
" %20 = torch.aten.sub.int %19, %int1 : !torch.int, !torch.int -> !torch.int\n"
8958-
" %21 = torch.aten.__range_length %20, %int-1, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
8959-
" %22 = torch.prim.Loop %21, %true, init(%14) {\n"
8960-
" ^bb0(%arg8: !torch.int, %arg9: !torch.list<int>):\n"
8961-
" %23 = torch.aten.__derive_index %arg8, %20, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
8962-
" %24 = torch.aten.__getitem__.t %17, %23 : !torch.list<bool>, !torch.int -> !torch.bool\n"
8963-
" %25 = torch.prim.If %24 -> (!torch.list<int>) {\n"
8964-
" %26 = func.call @__torch__.torch.jit._shape_functions._reduce_along_dim(%arg9, %23, %false) : (!torch.list<int>, !torch.int, !torch.bool) -> !torch.list<int>\n"
8965-
" torch.prim.If.yield %26 : !torch.list<int>\n"
8966-
" } else {\n"
8967-
" torch.prim.If.yield %arg9 : !torch.list<int>\n"
8968-
" }\n"
8969-
" torch.prim.Loop.condition %true, iter(%25 : !torch.list<int>)\n"
8970-
" } : (!torch.int, !torch.bool, !torch.list<int>) -> !torch.list<int>\n"
8971-
" return %22 : !torch.list<int>\n"
8972-
" }\n"
89738867
" func.func @\"__torch_mlir_shape_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional<float>, %arg7: !torch.bool) -> !torch.list<int> {\n"
89748868
" %int-1 = torch.constant.int -1\n"
89758869
" %0 = torch.aten.__getitem__.t %arg2, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
@@ -15400,15 +15294,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1540015294
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
1540115295
" return %4 : !torch.int\n"
1540215296
" }\n"
15403-
" func.func @\"__torch_mlir_dtype_fn.aten._trilinear\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.list<int>, %arg7: !torch.int) -> !torch.int {\n"
15404-
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15405-
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15406-
" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15407-
" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<optional<int>>\n"
15408-
" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
15409-
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
15410-
" return %5 : !torch.int\n"
15411-
" }\n"
1541215297
" func.func @\"__torch_mlir_dtype_fn.aten.cat\"(%arg0: !torch.list<tuple<int, int>>, %arg1: !torch.int) -> !torch.int {\n"
1541315298
" %true = torch.constant.bool true\n"
1541415299
" %none = torch.constant.none\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 9 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
#include "PassDetail.h"
1111

12-
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1312
#include "mlir/IR/BuiltinDialect.h"
1413
#include "mlir/Transforms/DialectConversion.h"
1514
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -400,9 +399,9 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
400399
auto inputType = cast<ValueTensorType>(input.getType());
401400
auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength +
402401
reduceDimsLength;
403-
SmallVector<OpFoldResult> inputShapeTensor;
402+
SmallVector<Value> inputShapeTensor;
404403
for (auto i = 0; i < inputRank; ++i) {
405-
inputShapeTensor.emplace_back(rewriter.createOrFold<AtenSizeIntOp>(
404+
inputShapeTensor.emplace_back(rewriter.create<AtenSizeIntOp>(
406405
loc, input,
407406
rewriter.create<Torch::ConstantIntOp>(loc,
408407
rewriter.getI64IntegerAttr(i))));
@@ -413,23 +412,13 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
413412
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
414413
auto dimOffset = 0;
415414

416-
auto materializeIntFold = [&](OpFoldResult thing) {
417-
if (auto attr = dyn_cast<mlir::Attribute>(thing)) {
418-
Value result = rewriter.create<Torch::ConstantIntOp>(
419-
loc, cast<mlir::IntegerAttr>(attr));
420-
return result;
421-
}
422-
return cast<mlir::Value>(thing);
423-
};
424-
425415
auto appendDims = [&](int64_t dimLength) {
426-
OpFoldResult prod = getAsOpFoldResult(constOne);
416+
Value prod = constOne;
427417
for (auto i = 0; i < dimLength; ++i) {
428-
prod = rewriter.createOrFold<AtenMulIntOp>(
429-
loc, materializeIntFold(prod),
430-
materializeIntFold(inputShapeTensor[i + dimOffset]));
418+
prod = rewriter.create<AtenMulIntOp>(loc, prod,
419+
inputShapeTensor[i + dimOffset]);
431420
}
432-
outShapeTensor.emplace_back(materializeIntFold(prod));
421+
outShapeTensor.emplace_back(prod);
433422
dimOffset += dimLength;
434423
};
435424

@@ -581,32 +570,21 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
581570
Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype()
582571
: rhsType.getOptionalDtype();
583572

584-
auto materializeIntFold = [&](OpFoldResult thing) {
585-
if (auto attr = dyn_cast<mlir::Attribute>(thing)) {
586-
Value result = rewriter.create<Torch::ConstantIntOp>(
587-
loc, cast<mlir::IntegerAttr>(attr));
588-
return result;
589-
}
590-
return cast<mlir::Value>(thing);
591-
};
592-
593573
llvm::SmallDenseMap<char, Value> lhsDimShapeMap;
594574
for (size_t idx = 0; idx < lhsTokens.size(); ++idx) {
595575
char d = lhsTokens[idx];
596-
OpFoldResult lhsFold = rewriter.createOrFold<AtenSizeIntOp>(
576+
lhsDimShapeMap[d] = rewriter.create<AtenSizeIntOp>(
597577
loc, lhs,
598578
rewriter.create<Torch::ConstantIntOp>(loc,
599579
rewriter.getI64IntegerAttr(idx)));
600-
lhsDimShapeMap[d] = materializeIntFold(lhsFold);
601580
}
602581
llvm::SmallDenseMap<char, Value> rhsDimShapeMap;
603582
for (size_t idx = 0; idx < rhsTokens.size(); ++idx) {
604583
char d = rhsTokens[idx];
605-
OpFoldResult rhsFold = rewriter.createOrFold<AtenSizeIntOp>(
584+
rhsDimShapeMap[d] = rewriter.create<AtenSizeIntOp>(
606585
loc, rhs,
607586
rewriter.create<Torch::ConstantIntOp>(loc,
608587
rewriter.getI64IntegerAttr(idx)));
609-
rhsDimShapeMap[d] = materializeIntFold(rhsFold);
610588
}
611589

612590
// parse batch, contracting, other, reduce dims of lhs and rhs
@@ -626,9 +604,8 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
626604
bool lhsContains = lhsDimShapeMap.count(d) > 0;
627605
bool rhsContains = rhsDimShapeMap.count(d) > 0;
628606
if (lhsContains && rhsContains) {
629-
OpFoldResult out = rewriter.createOrFold<Torch::PrimMaxIntOp>(
607+
outDimShapeMap[d] = rewriter.create<Torch::PrimMaxIntOp>(
630608
loc, lhsDimShapeMap[d], rhsDimShapeMap[d]);
631-
outDimShapeMap[d] = materializeIntFold(out);
632609
} else if (lhsContains) {
633610
outDimShapeMap[d] = lhsDimShapeMap[d];
634611
} else if (rhsContains) {
@@ -1996,125 +1973,6 @@ class DecomposeAtenEinsumOp : public OpRewritePattern<AtenEinsumOp> {
19961973
};
19971974
} // namespace
19981975

1999-
namespace {
2000-
// Trilinear einstein sum, decomposed to:
2001-
// (i1.unsqueeze(expand1) * i2.unsqueeze(expand2) * i3.unsqueeze(expand3))
2002-
// .sum(sumdim)
2003-
// The unrollDim operand does not impact the output of the operation, so
2004-
// it is ignored.
2005-
2006-
class DecomposeAten_TrilinearOp : public OpRewritePattern<Aten_TrilinearOp> {
2007-
public:
2008-
using OpRewritePattern::OpRewritePattern;
2009-
LogicalResult matchAndRewrite(Aten_TrilinearOp op,
2010-
PatternRewriter &rewriter) const override {
2011-
2012-
Location loc = op.getLoc();
2013-
2014-
Value input1 = op.getI1();
2015-
Value input2 = op.getI2();
2016-
Value input3 = op.getI3();
2017-
2018-
// Expansions
2019-
SmallVector<int64_t> expand1;
2020-
SmallVector<int64_t> expand2;
2021-
SmallVector<int64_t> expand3;
2022-
if (!matchPattern(op.getExpand1(), m_TorchListOfConstantInts(expand1))) {
2023-
return rewriter.notifyMatchFailure(op, "expand1 should be constant");
2024-
}
2025-
if (!matchPattern(op.getExpand2(), m_TorchListOfConstantInts(expand2))) {
2026-
return rewriter.notifyMatchFailure(op, "expand2 should be constant");
2027-
}
2028-
if (!matchPattern(op.getExpand3(), m_TorchListOfConstantInts(expand3))) {
2029-
return rewriter.notifyMatchFailure(op, "expand3 should be constant");
2030-
}
2031-
2032-
SmallVector<int64_t> sumDim;
2033-
if (!matchPattern(op.getSumdim(), m_TorchListOfConstantInts(sumDim))) {
2034-
return rewriter.notifyMatchFailure(op, "sumDim should be constant");
2035-
}
2036-
2037-
// Check if there are any dimensions that intersect between expand1,
2038-
// expand2, and expand3.
2039-
int64_t totalDims =
2040-
cast<BaseTensorType>(input1.getType()).getSizes().size() +
2041-
expand1.size();
2042-
if (sharedExpandDims(totalDims, expand1, expand2, expand3, sumDim)) {
2043-
// pytorch issue filed: https://github.com/pytorch/pytorch/issues/138353
2044-
// TODO: Remove warning when issue gets resolved.
2045-
op->emitWarning("aten::_trilinear implementation in this case is "
2046-
"non-functional (returns an empty dimension). We will "
2047-
"intentionally deviate from this behavior.");
2048-
}
2049-
2050-
// Apply unsqueeze to respective input tensors at the specified dimensions
2051-
SmallVector<int64_t> sortedExpand1 = expand1;
2052-
std::sort(sortedExpand1.begin(), sortedExpand1.end());
2053-
for (auto expand : sortedExpand1) {
2054-
Value expandDim = rewriter.create<Torch::ConstantIntOp>(
2055-
loc, rewriter.getI64IntegerAttr(expand));
2056-
input1 = *unsqueezeTensor(rewriter, op, input1, expandDim);
2057-
}
2058-
SmallVector<int64_t> sortedExpand2 = expand2;
2059-
std::sort(sortedExpand2.begin(), sortedExpand2.end());
2060-
for (auto expand : sortedExpand2) {
2061-
Value expandDim = rewriter.create<Torch::ConstantIntOp>(
2062-
loc, rewriter.getI64IntegerAttr(expand));
2063-
input2 = *unsqueezeTensor(rewriter, op, input2, expandDim);
2064-
}
2065-
SmallVector<int64_t> sortedExpand3 = expand3;
2066-
std::sort(sortedExpand3.begin(), sortedExpand3.end());
2067-
for (auto expand : sortedExpand3) {
2068-
Value expandDim = rewriter.create<Torch::ConstantIntOp>(
2069-
loc, rewriter.getI64IntegerAttr(expand));
2070-
input3 = *unsqueezeTensor(rewriter, op, input3, expandDim);
2071-
}
2072-
2073-
// Apply multiplication operation.
2074-
auto mul1 =
2075-
rewriter.create<AtenMulTensorOp>(loc, op.getType(), input1, input2);
2076-
auto mul2 =
2077-
rewriter.create<AtenMulTensorOp>(loc, op.getType(), mul1, input3);
2078-
2079-
// Apply sum operation.
2080-
// Parse sumDim in descending order to avoid any issues with the
2081-
// dimensions being removed.
2082-
Value result = mul2;
2083-
SmallVector<int64_t> sortedSumDims = sumDim;
2084-
std::sort(sortedSumDims.rbegin(), sortedSumDims.rend());
2085-
for (int64_t dim : sortedSumDims) {
2086-
Value dimValue = rewriter.create<Torch::ConstantIntOp>(
2087-
loc, rewriter.getI64IntegerAttr(dim));
2088-
result =
2089-
createSumAlongDimension(rewriter, loc, op, result, dimValue, false);
2090-
}
2091-
2092-
rewriter.replaceOp(op, result);
2093-
return success();
2094-
}
2095-
2096-
private:
2097-
// Determine if there are any dimensions that intersect between expand1,
2098-
// expand2, and expand3.
2099-
bool sharedExpandDims(const int64_t &totalDims,
2100-
const SmallVector<int64_t> &expand1,
2101-
const SmallVector<int64_t> &expand2,
2102-
const SmallVector<int64_t> &expand3,
2103-
const SmallVector<int64_t> &sumDim) const {
2104-
for (int64_t i = 0; i < totalDims; ++i) {
2105-
if (!contains(sumDim, i) && contains(expand1, i) &&
2106-
contains(expand2, i) && contains(expand3, i)) {
2107-
return true;
2108-
}
2109-
}
2110-
return false;
2111-
}
2112-
bool contains(const SmallVector<int64_t> &vec, int64_t value) const {
2113-
return std::find(vec.begin(), vec.end(), value) != vec.end();
2114-
}
2115-
};
2116-
} // namespace
2117-
21181976
namespace {
21191977
// Calculate the trace of the input tensor as the sum over its diagonal
21201978
// elements. This computation is performed as:
@@ -10220,7 +10078,6 @@ class DecomposeComplexOpsPass
1022010078
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast1dOp>(patterns);
1022110079
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast2dOp>(patterns);
1022210080
addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns);
10223-
addPatternIfTargetOpIsIllegal<DecomposeAten_TrilinearOp>(patterns);
1022410081
addPatternIfTargetOpIsIllegal<DecomposeAtenTraceOp>(patterns);
1022510082
addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns);
1022610083
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftplusOp>(patterns);

0 commit comments

Comments
 (0)