Skip to content

Commit ebe05c5

Browse files
committed
Add Kullback-Leibler divergence loss support
Signed-off-by: Zahid Wakeel <zahid.wakeel@multicorewareinc.com>
1 parent 716303a commit ebe05c5

File tree

8 files changed

+324
-0
lines changed

8 files changed

+324
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9332,6 +9332,32 @@ def Torch_AtenNllLossBackwardOp : Torch_Op<"aten.nll_loss_backward", [
93329332
}];
93339333
}
93349334

9335+
def Torch_AtenKlDivOp : Torch_Op<"aten.kl_div", [
9336+
AllowsTypeRefinement,
9337+
HasValueSemantics,
9338+
ReadOnly
9339+
]> {
9340+
let summary = "Generated op for `aten::kl_div : (Tensor, Tensor, int, bool) -> (Tensor)`";
9341+
let arguments = (ins
9342+
AnyTorchTensorType:$self,
9343+
AnyTorchTensorType:$target,
9344+
Torch_IntType:$reduction,
9345+
Torch_BoolType:$log_target
9346+
);
9347+
let results = (outs
9348+
AnyTorchOptionalTensorType:$result
9349+
);
9350+
let hasCustomAssemblyFormat = 1;
9351+
let extraClassDefinition = [{
9352+
ParseResult AtenKlDivOp::parse(OpAsmParser &parser, OperationState &result) {
9353+
return parseDefaultTorchOp(parser, result, 4, 1);
9354+
}
9355+
void AtenKlDivOp::print(OpAsmPrinter &printer) {
9356+
printDefaultTorchOp(printer, *this, 4, 1);
9357+
}
9358+
}];
9359+
}
9360+
93359361
def Torch_AtenBincountOp : Torch_Op<"aten.bincount", [
93369362
AllowsTypeRefinement,
93379363
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10667,6 +10667,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1066710667
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
1066810668
" return %0 : !torch.list<int>\n"
1066910669
" }\n"
10670+
" func.func @\"__torch_mlir_shape_fn.aten.kl_div\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.bool) -> !torch.list<int> {\n"
10671+
" %none = torch.constant.none\n"
10672+
" %str = torch.constant.str \"AssertionError: Invalid reduction value.\"\n"
10673+
" %int0 = torch.constant.int 0\n"
10674+
" %int1 = torch.constant.int 1\n"
10675+
" %int2 = torch.constant.int 2\n"
10676+
" %0 = torch.prim.Uninitialized : !torch.list<int>\n"
10677+
" %1 = torch.aten.eq.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
10678+
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
10679+
" %3 = func.call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
10680+
" torch.prim.If.yield %3 : !torch.list<int>\n"
10681+
" } else {\n"
10682+
" %3 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
10683+
" %4 = torch.aten.__contains__.int_list %3, %arg2 : !torch.list<int>, !torch.int -> !torch.bool\n"
10684+
" %5 = torch.prim.If %4 -> (!torch.list<int>) {\n"
10685+
" %6 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
10686+
" torch.prim.If.yield %6 : !torch.list<int>\n"
10687+
" } else {\n"
10688+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
10689+
" torch.prim.If.yield %0 : !torch.list<int>\n"
10690+
" }\n"
10691+
" torch.prim.If.yield %5 : !torch.list<int>\n"
10692+
" }\n"
10693+
" return %2 : !torch.list<int>\n"
10694+
" }\n"
1067010695
" func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
1067110696
" %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int) -> !torch.tuple<list<int>, list<int>>\n"
1067210697
" return %0 : !torch.tuple<list<int>, list<int>>\n"
@@ -14447,6 +14472,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1444714472
" }\n"
1444814473
" return %int3 : !torch.int\n"
1444914474
" }\n"
14475+
" func.func @\"__torch_mlir_dtype_fn.aten.kl_div\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.bool) -> !torch.int {\n"
14476+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
14477+
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
14478+
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
14479+
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
14480+
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
14481+
" return %4 : !torch.int\n"
14482+
" }\n"
1445014483
" func.func @\"__torch_mlir_dtype_fn.aten.mse_loss\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int) -> !torch.int {\n"
1445114484
" %none = torch.constant.none\n"
1445214485
" %str = torch.constant.str \"AssertionError: \"\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10374,6 +10374,83 @@ class DecomposeAtenNllLossForwardOp
1037410374
};
1037510375
} // namespace
1037610376

10377+
namespace {
10378+
class DecomposeAtenKlDivOp : public OpRewritePattern<AtenKlDivOp> {
10379+
using OpRewritePattern::OpRewritePattern;
10380+
LogicalResult matchAndRewrite(AtenKlDivOp op,
10381+
PatternRewriter &rewriter) const override {
10382+
Location loc = op.getLoc();
10383+
Value self = op.getSelf();
10384+
Value target = op.getTarget();
10385+
Value reductionValue = op.getReduction();
10386+
Value logTargetValue = op.getLogTarget();
10387+
10388+
auto selfTy = dyn_cast<ValueTensorType>(self.getType());
10389+
auto targetTy = dyn_cast<ValueTensorType>(target.getType());
10390+
auto outTy = dyn_cast<ValueTensorType>(op.getType());
10391+
10392+
if (!selfTy.hasSizes() || !targetTy.hasSizes() || !outTy.hasSizes()) {
10393+
return rewriter.notifyMatchFailure(
10394+
op, "require self, target and output having sizes!");
10395+
}
10396+
10397+
if (!selfTy.hasDtype() || !targetTy.hasDtype() || !outTy.hasDtype()) {
10398+
return rewriter.notifyMatchFailure(
10399+
op, "require self, target and output having dtype!");
10400+
}
10401+
10402+
// Extract boolean value from logTarget argument
10403+
bool logTargetBool;
10404+
if (!matchPattern(logTargetValue, m_TorchConstantBool(&logTargetBool)))
10405+
return rewriter.notifyMatchFailure(
10406+
op, "Expected a constant boolean value for logTargetBool");
10407+
10408+
Value logOfTarget;
10409+
// Default: target tensor is not in log space
10410+
if (!logTargetBool) {
10411+
logOfTarget = rewriter.create<AtenLogOp>(loc, targetTy, target);
10412+
} else {
10413+
logOfTarget = target;
10414+
}
10415+
10416+
Value constOne =
10417+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
10418+
Value subValue = rewriter.create<AtenSubTensorOp>(loc, selfTy, logOfTarget,
10419+
self, constOne);
10420+
10421+
// target tensor is already in log space
10422+
if (logTargetBool) {
10423+
target = rewriter.create<AtenExpOp>(loc, targetTy, target);
10424+
}
10425+
Value lossPointwise =
10426+
rewriter.create<AtenMulTensorOp>(loc, targetTy, target, subValue);
10427+
10428+
// Extract reduction int value from reduction argument
10429+
int64_t reduction;
10430+
if (!matchPattern(reductionValue, m_TorchConstantInt(&reduction))) {
10431+
return rewriter.notifyMatchFailure(op,
10432+
"reduction should be a constant int!");
10433+
}
10434+
Value loss;
10435+
Value none = rewriter.create<ConstantNoneOp>(loc);
10436+
// reduction: mean
10437+
if (reduction == 1) {
10438+
loss = rewriter.create<AtenMeanOp>(loc, outTy, lossPointwise, none);
10439+
} else if (reduction == 2) {
10440+
// reduction: sum
10441+
loss = rewriter.create<AtenSumOp>(loc, outTy, lossPointwise, none);
10442+
} else {
10443+
// reduction: none
10444+
loss = lossPointwise;
10445+
}
10446+
10447+
rewriter.replaceOp(op, loss);
10448+
10449+
return success();
10450+
}
10451+
};
10452+
} // namespace
10453+
1037710454
namespace {
1037810455
class DecomposeAtenBinaryCrossEntropyWithLogitsOp
1037910456
: public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
@@ -12238,6 +12315,7 @@ class DecomposeComplexOpsPass
1223812315
addPatternIfTargetOpIsIllegal<DecomposeAtenNllLossForwardOp>(patterns);
1223912316
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
1224012317
patterns);
12318+
addPatternIfTargetOpIsIllegal<DecomposeAtenKlDivOp>(patterns);
1224112319
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
1224212320
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
1224312321
addPatternIfTargetOpIsIllegal<DecomposeAtenArgsortOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
581581
target.addIllegalOp<AtenSpecialExpm1Op>();
582582
target.addIllegalOp<AtenFliplrOp>();
583583
target.addIllegalOp<AtenFlipudOp>();
584+
target.addIllegalOp<AtenKlDivOp>();
584585

585586
for (auto &opName : backendLegalOpsSet) {
586587
target.addLegalOp(

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2156,6 +2156,14 @@ def aten〇tril_indices〡shape(row: int, col: int, offset: int = 0, dtype: Opti
21562156
def aten〇deg2rad〡shape(self: List[int]) -> List[int]:
21572157
return upstream_shape_functions.unary(self)
21582158

2159+
def aten〇kl_div〡shape(self: List[int], target: List[int], reduction: int = 1, log_target: bool = False) -> List[int]:
2160+
if reduction == 0:
2161+
return upstream_shape_functions.unary(self)
2162+
elif reduction in [1, 2]:
2163+
return []
2164+
else:
2165+
assert False, "Invalid reduction value."
2166+
21592167
@check_shape_function([
21602168
Invocation(TensorOfShape(2, 3), LongTensorOfShape(2), None, 1, -100), # Basic case.
21612169
Invocation(TensorOfShape(3), LongTensorOfShape(), None, 1, -100), # No batch dim.
@@ -4485,6 +4493,14 @@ def aten〇_int_mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tu
44854493
assert mat2_dtype == torch.int8
44864494
return torch.int32
44874495

4496+
def aten〇kl_div〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1, log_target: bool = False) -> int:
4497+
self_rank, self_dtype = self_rank_dtype
4498+
target_rank, target_dtype = target_rank_dtype
4499+
ranks: List[Optional[int]] = [self_rank, target_rank]
4500+
dtypes = [self_dtype, target_dtype]
4501+
promoted_dtype = promote_dtypes(ranks, dtypes)
4502+
return promoted_dtype
4503+
44884504
@check_dtype_function(_check_two_tensor_op(
44894505
output_error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}))
44904506
def aten〇mse_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1) -> int:

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,7 @@ def emit_with_mutating_variants(key, **kwargs):
754754
emit(
755755
"aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)"
756756
)
757+
emit("aten::kl_div : (Tensor, Tensor, int, bool) -> (Tensor)")
757758
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
758759
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
759760
emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)")

projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,4 @@ def register_all_tests():
6262
from . import gridsampler
6363
from . import meshgrid
6464
from . import timeout
65+
from . import kl_div_loss
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
# Also available under a BSD-style license. See LICENSE.
5+
6+
import functorch
7+
import torch
8+
9+
from torch_mlir_e2e_test.framework import TestUtils
10+
from torch_mlir_e2e_test.registry import register_test_case
11+
from torch_mlir_e2e_test.annotations import annotate_args, export
12+
13+
# ==============================================================================
14+
15+
16+
class KlDivLossModule_default(torch.nn.Module):
17+
def __init__(self):
18+
super().__init__()
19+
20+
@export
21+
@annotate_args(
22+
[
23+
None,
24+
([-1, -1, -1], torch.float32, True),
25+
([-1, -1, -1], torch.float32, True),
26+
]
27+
)
28+
def forward(self, x, y):
29+
return torch.ops.aten.kl_div(x, y)
30+
31+
32+
@register_test_case(module_factory=lambda: KlDivLossModule_default())
33+
def KlDivLossModule_default_basic(module, tu: TestUtils):
34+
module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2))
35+
36+
37+
# ==============================================================================
38+
39+
40+
class KlDivLossModule_reduction_is_none(torch.nn.Module):
41+
def __init__(self):
42+
super().__init__()
43+
44+
@export
45+
@annotate_args(
46+
[
47+
None,
48+
([-1, -1, -1], torch.float32, True),
49+
([-1, -1, -1], torch.float32, True),
50+
]
51+
)
52+
def forward(self, x, y):
53+
return torch.ops.aten.kl_div(x, y, reduction=0)
54+
55+
56+
@register_test_case(module_factory=lambda: KlDivLossModule_reduction_is_none())
57+
def KlDivLossModule_reduction_is_none_basic(module, tu: TestUtils):
58+
module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2))
59+
60+
61+
# ==============================================================================
62+
63+
64+
class KlDivLossModule_reduction_is_none_log_target_is_true(torch.nn.Module):
65+
def __init__(self):
66+
super().__init__()
67+
68+
@export
69+
@annotate_args(
70+
[
71+
None,
72+
([-1, -1, -1], torch.float32, True),
73+
([-1, -1, -1], torch.float32, True),
74+
]
75+
)
76+
def forward(self, x, y):
77+
return torch.ops.aten.kl_div(x, y, reduction=0, log_target=True)
78+
79+
80+
@register_test_case(
81+
module_factory=lambda: KlDivLossModule_reduction_is_none_log_target_is_true()
82+
)
83+
def KlDivLossModule_reduction_is_none_log_target_is_true_basic(module, tu: TestUtils):
84+
module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2))
85+
86+
87+
# ==============================================================================
88+
89+
90+
class KlDivLossModule_mean_reduction(torch.nn.Module):
91+
def __init__(self):
92+
super().__init__()
93+
94+
@export
95+
@annotate_args(
96+
[
97+
None,
98+
([-1, -1, -1], torch.float32, True),
99+
([-1, -1, -1], torch.float32, True),
100+
]
101+
)
102+
def forward(self, x, y):
103+
return torch.ops.aten.kl_div(x, y, reduction=1)
104+
105+
106+
@register_test_case(module_factory=lambda: KlDivLossModule_mean_reduction())
107+
def KlDivLossModule_mean_reduction_basic(module, tu: TestUtils):
108+
module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2))
109+
110+
111+
# ==============================================================================
112+
113+
114+
class KlDivLossModule_sum_reduction(torch.nn.Module):
115+
def __init__(self):
116+
super().__init__()
117+
118+
@export
119+
@annotate_args(
120+
[
121+
None,
122+
([-1, -1, -1], torch.float32, True),
123+
([-1, -1, -1], torch.float32, True),
124+
]
125+
)
126+
def forward(self, x, y):
127+
return torch.ops.aten.kl_div(x, y, reduction=2)
128+
129+
130+
@register_test_case(module_factory=lambda: KlDivLossModule_sum_reduction())
131+
def KlDivLossModule_sum_reduction_basic(module, tu: TestUtils):
132+
module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2))
133+
134+
135+
# ==============================================================================
136+
137+
138+
class KlDivLossModule_batchmean_reduction(torch.nn.Module):
139+
def __init__(self):
140+
super().__init__()
141+
142+
@export
143+
@annotate_args(
144+
[
145+
None,
146+
([-1, -1, -1], torch.float32, True),
147+
([-1, -1, -1], torch.float32, True),
148+
]
149+
)
150+
def forward(self, input, target):
151+
# torch.ops.aten.kl_div has no direct way to pass batchmean as reduction mode.
152+
# https://github.com/pytorch/pytorch/blob/53ecb8159aa28b3c015917acaa89604cfae0d2c6/torch/nn/_reduction.py#L8-L24
153+
# F.kl_div(input, target, reduction="batchmean"):
154+
# out = torch.kl_div(input, target, reduction="sum")
155+
# batch_size = input.shape[0]
156+
# out = out / batch_size
157+
# https://github.com/pytorch/pytorch/blob/53ecb8159aa28b3c015917acaa89604cfae0d2c6/torch/nn/functional.py#L3379-L3381
158+
loss = torch.ops.aten.kl_div(input, target, reduction=2)
159+
batch_size = input.shape[0]
160+
return torch.ops.aten.div.Scalar(loss, batch_size)
161+
162+
163+
@register_test_case(module_factory=lambda: KlDivLossModule_batchmean_reduction())
164+
def KlDivLossModule_batchmean_reduction_basic(module, tu: TestUtils):
165+
module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2))
166+
167+
168+
# ==============================================================================

0 commit comments

Comments
 (0)