Skip to content

Commit 1ac131c

Browse files
authored
[TORCH] Add support for logcumsumexp Op (#4187)
- Decomposed logcumsumexp op into Aten ops. - The decomposition follows the formula: **logcumsumexp(x) = log(cumsum(exp(x)))** - Added test cases in the e2e part. This implementation addresses and closes #4183 --------- Signed-off-by: sharavana20 <sharavana.kumar@multicorewareinc.com>
1 parent dcdb77e commit 1ac131c

File tree

8 files changed

+150
-0
lines changed

8 files changed

+150
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8764,6 +8764,30 @@ def Torch_AtenCumprodOp : Torch_Op<"aten.cumprod", [
87648764
}];
87658765
}
87668766

8767+
def Torch_AtenLogcumsumexpOp : Torch_Op<"aten.logcumsumexp", [
8768+
AllowsTypeRefinement,
8769+
HasValueSemantics,
8770+
ReadOnly
8771+
]> {
8772+
let summary = "Generated op for `aten::logcumsumexp : (Tensor, int) -> (Tensor)`";
8773+
let arguments = (ins
8774+
AnyTorchTensorType:$self,
8775+
Torch_IntType:$dim
8776+
);
8777+
let results = (outs
8778+
AnyTorchOptionalTensorType:$result
8779+
);
8780+
let hasCustomAssemblyFormat = 1;
8781+
let extraClassDefinition = [{
8782+
ParseResult AtenLogcumsumexpOp::parse(OpAsmParser &parser, OperationState &result) {
8783+
return parseDefaultTorchOp(parser, result, 2, 1);
8784+
}
8785+
void AtenLogcumsumexpOp::print(OpAsmPrinter &printer) {
8786+
printDefaultTorchOp(printer, *this, 2, 1);
8787+
}
8788+
}];
8789+
}
8790+
87678791
def Torch_AtenFloorDivideScalarOp : Torch_Op<"aten.floor_divide.Scalar", [
87688792
AllowsTypeRefinement,
87698793
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9383,6 +9383,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
93839383
" func.func @\"__torch_mlir_shape_fn.aten.cumprod\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
93849384
" return %arg0 : !torch.list<int>\n"
93859385
" }\n"
9386+
" func.func @\"__torch_mlir_shape_fn.aten.logcumsumexp\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
9387+
" return %arg0 : !torch.list<int>\n"
9388+
" }\n"
93869389
" func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
93879390
" return %arg0 : !torch.list<int>\n"
93889391
" }\n"
@@ -12561,6 +12564,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1256112564
" }\n"
1256212565
" return %1 : !torch.int\n"
1256312566
" }\n"
12567+
" func.func @\"__torch_mlir_dtype_fn.aten.logcumsumexp\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
12568+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12569+
" return %0#1 : !torch.int\n"
12570+
" }\n"
1256412571
" func.func @\"__torch_mlir_dtype_fn.aten.cumprod\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.int {\n"
1256512572
" %int4 = torch.constant.int 4\n"
1256612573
" %none = torch.constant.none\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1313
#include "mlir/IR/BuiltinDialect.h"
14+
#include "mlir/IR/BuiltinTypes.h"
1415
#include "mlir/Transforms/DialectConversion.h"
1516
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1617
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
@@ -2962,6 +2963,56 @@ class DecomposeAten_LogSoftmaxOp : public OpRewritePattern<Aten_LogSoftmaxOp> {
29622963
};
29632964
} // namespace
29642965

2966+
// Decompose AtenLogCumsumExpOp into: AtenExpOp,
2967+
// AtenCumsumOp and AtenLogOp
2968+
// logcumsumexp(x)[i][j] = log(sum_{k=0}^{j} exp(x[i][k]))
2969+
namespace {
2970+
class DecomposeAtenLogCumsumExpOp
2971+
: public OpRewritePattern<AtenLogcumsumexpOp> {
2972+
public:
2973+
using OpRewritePattern<AtenLogcumsumexpOp>::OpRewritePattern;
2974+
LogicalResult matchAndRewrite(AtenLogcumsumexpOp op,
2975+
PatternRewriter &rewriter) const override {
2976+
Location loc = op.getLoc();
2977+
Value input = op.getSelf();
2978+
2979+
auto inputType = dyn_cast<BaseTensorType>(input.getType());
2980+
auto resultType = dyn_cast<BaseTensorType>(op.getType());
2981+
2982+
if (!inputType || !inputType.hasDtype())
2983+
return rewriter.notifyMatchFailure(op, "input should have dtype.");
2984+
2985+
if (isa<mlir::IntegerType>(inputType.getDtype()))
2986+
return rewriter.notifyMatchFailure(op, "integer dtype is not allowed.");
2987+
2988+
// TODO: support complex type in future.
2989+
if (isa<mlir::ComplexType>(inputType.getDtype()))
2990+
return rewriter.notifyMatchFailure(op,
2991+
"doesn't support complex type now");
2992+
2993+
if (!inputType.hasSizes())
2994+
return rewriter.notifyMatchFailure(op, "input should have known size.");
2995+
2996+
int64_t inputRank = inputType.getSizes().size();
2997+
int64_t dim;
2998+
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
2999+
return rewriter.notifyMatchFailure(
3000+
op, "Unimplemented: Only constant dim value is supported.");
3001+
dim = toPositiveDim(dim, inputRank);
3002+
if (!isValidDim(dim, inputRank))
3003+
return rewriter.notifyMatchFailure(op, "invalid dim.");
3004+
3005+
Value dtypeVal =
3006+
getDtypeIntValueForType(rewriter, loc, inputType.getDtype());
3007+
Value expInput = rewriter.create<AtenExpOp>(loc, resultType, input);
3008+
Value cumsum = rewriter.create<AtenCumsumOp>(loc, resultType, expInput,
3009+
op.getDim(), dtypeVal);
3010+
rewriter.replaceOpWithNewOp<AtenLogOp>(op, resultType, cumsum);
3011+
return success();
3012+
}
3013+
};
3014+
} // namespace
3015+
29653016
namespace {
29663017
class DecomposeAtenLogSigmoidOp : public OpRewritePattern<AtenLogSigmoidOp> {
29673018
public:
@@ -12114,6 +12165,7 @@ class DecomposeComplexOpsPass
1211412165
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxOp>(patterns);
1211512166
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSoftmaxIntOp>(patterns);
1211612167
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSigmoidOp>(patterns);
12168+
addPatternIfTargetOpIsIllegal<DecomposeAtenLogCumsumExpOp>(patterns);
1211712169
addPatternIfTargetOpIsIllegal<DecomposeAtenLogAddExpOp>(patterns);
1211812170
addPatternIfTargetOpIsIllegal<DecomposeAtenLogAddExp2Op>(patterns);
1211912171
addPatternIfTargetOpIsIllegal<DecomposeAtenHardshrinkOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
375375
target.addIllegalOp<Aten_LogSoftmaxOp>();
376376
target.addIllegalOp<AtenLogSoftmaxIntOp>();
377377
target.addIllegalOp<AtenLogSigmoidOp>();
378+
target.addIllegalOp<AtenLogcumsumexpOp>();
378379
target.addIllegalOp<AtenHardshrinkOp>();
379380
target.addIllegalOp<AtenSoftshrinkOp>();
380381
target.addIllegalOp<AtenEmptyLikeOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2992,6 +2992,9 @@
29922992
"LinalgNormKeepDimComplexModule_basic",
29932993
"LinalgVectorNormComplexModule_basic",
29942994
"LogSoftmaxBackwardModule_basic",
2995+
"LogCumsumExpModule_basic",
2996+
"LogCumsumExpStaticNegativeDimModule_basic",
2997+
"LogCumsumExpStaticFloat64DtypeModule_basic",
29952998
"MaxPool1dCeilModeTrueModule_basic",
29962999
"MaxPool1dModule_basic",
29973000
"MaxPool2dCeilModeTrueModule_basic",
@@ -3717,6 +3720,9 @@
37173720
"LinalgNormKeepDimComplexModule_basic",
37183721
"LinalgVectorNormComplexModule_basic",
37193722
"LinspaceEmptyModule_basic",
3723+
"LogCumsumExpModule_basic",
3724+
"LogCumsumExpStaticNegativeDimModule_basic",
3725+
"LogCumsumExpStaticFloat64DtypeModule_basic",
37203726
"MaskedScatterStaticBasic_basic",
37213727
"MaxPool1dCeilModeTrueModule_basic",
37223728
"MaxPool1dModule_basic",
@@ -4517,6 +4523,9 @@
45174523
"LinalgVectorNormComplexModule_basic",
45184524
"LogSoftmaxBackwardModule_basic",
45194525
"LogSoftmaxIntModule_basic",
4526+
"logCumsumExpModule_basic",
4527+
"LogCumsumExpStaticNegativeDimModule_basic",
4528+
"LogCumsumExpStaticFloat64DtypeModule_basic",
45204529
"MaskedFillTensorFloatValueModule_basic",
45214530
"MatmulBroadcastBatchDim_basic",
45224531
"MatmulSingleDynamicBatchDim_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1552,6 +1552,9 @@ def aten〇cumsum〡shape(self: List[int], dim: int, dtype: Optional[int] = None
15521552
def aten〇cumprod〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
15531553
return self
15541554

1555+
def aten〇logcumsumexp〡shape(self: List[int], dim: int) -> List[int]:
1556+
return self
1557+
15551558
def aten〇rand_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
15561559
return self
15571560

@@ -3245,6 +3248,11 @@ def aten〇cumsum〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Opt
32453248
return torch.int64
32463249
return self_dtype
32473250

3251+
@check_dtype_function(
3252+
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0))
3253+
def aten〇logcumsumexp〡dtype(self_rank_dtype: Tuple[int, int], dim: int) -> int:
3254+
self_rank, self_dtype = self_rank_dtype
3255+
return self_dtype
32483256

32493257
@check_dtype_function(
32503258
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) +

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,7 @@ def emit_with_mutating_variants(key, **kwargs):
723723
emit("aten::bmm : (Tensor, Tensor) -> (Tensor)")
724724
emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)")
725725
emit("aten::cumprod : (Tensor, int, int?) -> (Tensor)")
726+
emit("aten::logcumsumexp : (Tensor, int) -> (Tensor)")
726727
emit("aten::floor_divide.Scalar : (Tensor, Scalar) -> (Tensor)")
727728
emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)")
728729
emit("aten::logaddexp : (Tensor, Tensor) -> (Tensor)")

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5141,6 +5141,54 @@ def CumsumWithDtypeModule_basic(module, tu: TestUtils):
51415141
# ==============================================================================
51425142

51435143

5144+
class LogCumsumExpModule(torch.nn.Module):
5145+
def __init__(self):
5146+
super().__init__()
5147+
5148+
@export
5149+
@annotate_args([None, ([-1, -1, -1], torch.float32, True)])
5150+
def forward(self, x):
5151+
return torch.ops.aten.logcumsumexp(x, dim=1)
5152+
5153+
5154+
@register_test_case(module_factory=lambda: LogCumsumExpModule())
5155+
def LogCumsumExpModule_basic(module, tu: TestUtils):
5156+
module.forward(tu.rand(1, 2, 3))
5157+
5158+
5159+
class LogCumsumExpStaticNegativeDimModule(torch.nn.Module):
5160+
def __init__(self):
5161+
super().__init__()
5162+
5163+
@export
5164+
@annotate_args([None, ([8, 5, 6], torch.float32, True)])
5165+
def forward(self, x):
5166+
return torch.ops.aten.logcumsumexp(x, dim=-2)
5167+
5168+
5169+
@register_test_case(module_factory=lambda: LogCumsumExpStaticNegativeDimModule())
5170+
def LogCumsumExpStaticNegativeDimModule_basic(module, tu: TestUtils):
5171+
module.forward(tu.rand(8, 5, 6))
5172+
5173+
5174+
class LogCumsumExpStaticFloat64DtypeModule(torch.nn.Module):
5175+
def __init__(self):
5176+
super().__init__()
5177+
5178+
@export
5179+
@annotate_args([None, ([5, 3, 6, 9], torch.float64, True)])
5180+
def forward(self, x):
5181+
return torch.ops.aten.logcumsumexp(x, dim=1)
5182+
5183+
5184+
@register_test_case(module_factory=lambda: LogCumsumExpStaticFloat64DtypeModule())
5185+
def LogCumsumExpStaticFloat64DtypeModule_basic(module, tu: TestUtils):
5186+
module.forward(tu.rand(5, 3, 6, 9).to(torch.float64))
5187+
5188+
5189+
# ==============================================================================
5190+
5191+
51445192
class CumprodModule(torch.nn.Module):
51455193
def __init__(self):
51465194
super().__init__()

0 commit comments

Comments
 (0)