9
9
10
10
#include " PassDetail.h"
11
11
12
- #include " mlir/Dialect/Utils/StaticValueUtils.h"
13
12
#include " mlir/IR/BuiltinDialect.h"
14
13
#include " mlir/Transforms/DialectConversion.h"
15
14
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -400,9 +399,9 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
400
399
auto inputType = cast<ValueTensorType>(input.getType ());
401
400
auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength +
402
401
reduceDimsLength;
403
- SmallVector<OpFoldResult > inputShapeTensor;
402
+ SmallVector<Value > inputShapeTensor;
404
403
for (auto i = 0 ; i < inputRank; ++i) {
405
- inputShapeTensor.emplace_back (rewriter.createOrFold <AtenSizeIntOp>(
404
+ inputShapeTensor.emplace_back (rewriter.create <AtenSizeIntOp>(
406
405
loc, input,
407
406
rewriter.create <Torch::ConstantIntOp>(loc,
408
407
rewriter.getI64IntegerAttr (i))));
@@ -413,23 +412,13 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
413
412
rewriter.create <Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr (1 ));
414
413
auto dimOffset = 0 ;
415
414
416
- auto materializeIntFold = [&](OpFoldResult thing) {
417
- if (auto attr = dyn_cast<mlir::Attribute>(thing)) {
418
- Value result = rewriter.create <Torch::ConstantIntOp>(
419
- loc, cast<mlir::IntegerAttr>(attr));
420
- return result;
421
- }
422
- return cast<mlir::Value>(thing);
423
- };
424
-
425
415
auto appendDims = [&](int64_t dimLength) {
426
- OpFoldResult prod = getAsOpFoldResult ( constOne) ;
416
+ Value prod = constOne;
427
417
for (auto i = 0 ; i < dimLength; ++i) {
428
- prod = rewriter.createOrFold <AtenMulIntOp>(
429
- loc, materializeIntFold (prod),
430
- materializeIntFold (inputShapeTensor[i + dimOffset]));
418
+ prod = rewriter.create <AtenMulIntOp>(loc, prod,
419
+ inputShapeTensor[i + dimOffset]);
431
420
}
432
- outShapeTensor.emplace_back (materializeIntFold ( prod) );
421
+ outShapeTensor.emplace_back (prod);
433
422
dimOffset += dimLength;
434
423
};
435
424
@@ -581,32 +570,21 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
581
570
Type outputDType = lhsType.hasDtype () ? lhsType.getOptionalDtype ()
582
571
: rhsType.getOptionalDtype ();
583
572
584
- auto materializeIntFold = [&](OpFoldResult thing) {
585
- if (auto attr = dyn_cast<mlir::Attribute>(thing)) {
586
- Value result = rewriter.create <Torch::ConstantIntOp>(
587
- loc, cast<mlir::IntegerAttr>(attr));
588
- return result;
589
- }
590
- return cast<mlir::Value>(thing);
591
- };
592
-
593
573
llvm::SmallDenseMap<char , Value> lhsDimShapeMap;
594
574
for (size_t idx = 0 ; idx < lhsTokens.size (); ++idx) {
595
575
char d = lhsTokens[idx];
596
- OpFoldResult lhsFold = rewriter.createOrFold <AtenSizeIntOp>(
576
+ lhsDimShapeMap[d] = rewriter.create <AtenSizeIntOp>(
597
577
loc, lhs,
598
578
rewriter.create <Torch::ConstantIntOp>(loc,
599
579
rewriter.getI64IntegerAttr (idx)));
600
- lhsDimShapeMap[d] = materializeIntFold (lhsFold);
601
580
}
602
581
llvm::SmallDenseMap<char , Value> rhsDimShapeMap;
603
582
for (size_t idx = 0 ; idx < rhsTokens.size (); ++idx) {
604
583
char d = rhsTokens[idx];
605
- OpFoldResult rhsFold = rewriter.createOrFold <AtenSizeIntOp>(
584
+ rhsDimShapeMap[d] = rewriter.create <AtenSizeIntOp>(
606
585
loc, rhs,
607
586
rewriter.create <Torch::ConstantIntOp>(loc,
608
587
rewriter.getI64IntegerAttr (idx)));
609
- rhsDimShapeMap[d] = materializeIntFold (rhsFold);
610
588
}
611
589
612
590
// parse batch, contracting, other, reduce dims of lhs and rhs
@@ -626,9 +604,8 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
626
604
bool lhsContains = lhsDimShapeMap.count (d) > 0 ;
627
605
bool rhsContains = rhsDimShapeMap.count (d) > 0 ;
628
606
if (lhsContains && rhsContains) {
629
- OpFoldResult out = rewriter.createOrFold <Torch::PrimMaxIntOp>(
607
+ outDimShapeMap[d] = rewriter.create <Torch::PrimMaxIntOp>(
630
608
loc, lhsDimShapeMap[d], rhsDimShapeMap[d]);
631
- outDimShapeMap[d] = materializeIntFold (out);
632
609
} else if (lhsContains) {
633
610
outDimShapeMap[d] = lhsDimShapeMap[d];
634
611
} else if (rhsContains) {
@@ -1996,125 +1973,6 @@ class DecomposeAtenEinsumOp : public OpRewritePattern<AtenEinsumOp> {
1996
1973
};
1997
1974
} // namespace
1998
1975
1999
- namespace {
2000
- // Trilinear einstein sum, decomposed to:
2001
- // (i1.unsqueeze(expand1) * i2.unsqueeze(expand2) * i3.unsqueeze(expand3))
2002
- // .sum(sumdim)
2003
- // The unrollDim operand does not impact the output of the operation, so
2004
- // it is ignored.
2005
-
2006
- class DecomposeAten_TrilinearOp : public OpRewritePattern <Aten_TrilinearOp> {
2007
- public:
2008
- using OpRewritePattern::OpRewritePattern;
2009
- LogicalResult matchAndRewrite (Aten_TrilinearOp op,
2010
- PatternRewriter &rewriter) const override {
2011
-
2012
- Location loc = op.getLoc ();
2013
-
2014
- Value input1 = op.getI1 ();
2015
- Value input2 = op.getI2 ();
2016
- Value input3 = op.getI3 ();
2017
-
2018
- // Expansions
2019
- SmallVector<int64_t > expand1;
2020
- SmallVector<int64_t > expand2;
2021
- SmallVector<int64_t > expand3;
2022
- if (!matchPattern (op.getExpand1 (), m_TorchListOfConstantInts (expand1))) {
2023
- return rewriter.notifyMatchFailure (op, " expand1 should be constant" );
2024
- }
2025
- if (!matchPattern (op.getExpand2 (), m_TorchListOfConstantInts (expand2))) {
2026
- return rewriter.notifyMatchFailure (op, " expand2 should be constant" );
2027
- }
2028
- if (!matchPattern (op.getExpand3 (), m_TorchListOfConstantInts (expand3))) {
2029
- return rewriter.notifyMatchFailure (op, " expand3 should be constant" );
2030
- }
2031
-
2032
- SmallVector<int64_t > sumDim;
2033
- if (!matchPattern (op.getSumdim (), m_TorchListOfConstantInts (sumDim))) {
2034
- return rewriter.notifyMatchFailure (op, " sumDim should be constant" );
2035
- }
2036
-
2037
- // Check if there are any dimensions that intersect between expand1,
2038
- // expand2, and expand3.
2039
- int64_t totalDims =
2040
- cast<BaseTensorType>(input1.getType ()).getSizes ().size () +
2041
- expand1.size ();
2042
- if (sharedExpandDims (totalDims, expand1, expand2, expand3, sumDim)) {
2043
- // pytorch issue filed: https://github.com/pytorch/pytorch/issues/138353
2044
- // TODO: Remove warning when issue gets resolved.
2045
- op->emitWarning (" aten::_trilinear implementation in this case is "
2046
- " non-functional (returns an empty dimension). We will "
2047
- " intentionally deviate from this behavior." );
2048
- }
2049
-
2050
- // Apply unsqueeze to respective input tensors at the specified dimensions
2051
- SmallVector<int64_t > sortedExpand1 = expand1;
2052
- std::sort (sortedExpand1.begin (), sortedExpand1.end ());
2053
- for (auto expand : sortedExpand1) {
2054
- Value expandDim = rewriter.create <Torch::ConstantIntOp>(
2055
- loc, rewriter.getI64IntegerAttr (expand));
2056
- input1 = *unsqueezeTensor (rewriter, op, input1, expandDim);
2057
- }
2058
- SmallVector<int64_t > sortedExpand2 = expand2;
2059
- std::sort (sortedExpand2.begin (), sortedExpand2.end ());
2060
- for (auto expand : sortedExpand2) {
2061
- Value expandDim = rewriter.create <Torch::ConstantIntOp>(
2062
- loc, rewriter.getI64IntegerAttr (expand));
2063
- input2 = *unsqueezeTensor (rewriter, op, input2, expandDim);
2064
- }
2065
- SmallVector<int64_t > sortedExpand3 = expand3;
2066
- std::sort (sortedExpand3.begin (), sortedExpand3.end ());
2067
- for (auto expand : sortedExpand3) {
2068
- Value expandDim = rewriter.create <Torch::ConstantIntOp>(
2069
- loc, rewriter.getI64IntegerAttr (expand));
2070
- input3 = *unsqueezeTensor (rewriter, op, input3, expandDim);
2071
- }
2072
-
2073
- // Apply multiplication operation.
2074
- auto mul1 =
2075
- rewriter.create <AtenMulTensorOp>(loc, op.getType (), input1, input2);
2076
- auto mul2 =
2077
- rewriter.create <AtenMulTensorOp>(loc, op.getType (), mul1, input3);
2078
-
2079
- // Apply sum operation.
2080
- // Parse sumDim in descending order to avoid any issues with the
2081
- // dimensions being removed.
2082
- Value result = mul2;
2083
- SmallVector<int64_t > sortedSumDims = sumDim;
2084
- std::sort (sortedSumDims.rbegin (), sortedSumDims.rend ());
2085
- for (int64_t dim : sortedSumDims) {
2086
- Value dimValue = rewriter.create <Torch::ConstantIntOp>(
2087
- loc, rewriter.getI64IntegerAttr (dim));
2088
- result =
2089
- createSumAlongDimension (rewriter, loc, op, result, dimValue, false );
2090
- }
2091
-
2092
- rewriter.replaceOp (op, result);
2093
- return success ();
2094
- }
2095
-
2096
- private:
2097
- // Determine if there are any dimensions that intersect between expand1,
2098
- // expand2, and expand3.
2099
- bool sharedExpandDims (const int64_t &totalDims,
2100
- const SmallVector<int64_t > &expand1,
2101
- const SmallVector<int64_t > &expand2,
2102
- const SmallVector<int64_t > &expand3,
2103
- const SmallVector<int64_t > &sumDim) const {
2104
- for (int64_t i = 0 ; i < totalDims; ++i) {
2105
- if (!contains (sumDim, i) && contains (expand1, i) &&
2106
- contains (expand2, i) && contains (expand3, i)) {
2107
- return true ;
2108
- }
2109
- }
2110
- return false ;
2111
- }
2112
- bool contains (const SmallVector<int64_t > &vec, int64_t value) const {
2113
- return std::find (vec.begin (), vec.end (), value) != vec.end ();
2114
- }
2115
- };
2116
- } // namespace
2117
-
2118
1976
namespace {
2119
1977
// Calculate the trace of the input tensor as the sum over its diagonal
2120
1978
// elements. This computation is performed as:
@@ -10220,7 +10078,6 @@ class DecomposeComplexOpsPass
10220
10078
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast1dOp>(patterns);
10221
10079
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast2dOp>(patterns);
10222
10080
addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns);
10223
- addPatternIfTargetOpIsIllegal<DecomposeAten_TrilinearOp>(patterns);
10224
10081
addPatternIfTargetOpIsIllegal<DecomposeAtenTraceOp>(patterns);
10225
10082
addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns);
10226
10083
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftplusOp>(patterns);
0 commit comments