Skip to content

Commit 06f0e4a

Browse files
committed
Add Kullback-Leibler divergence loss support
Signed-off-by: Zahid Wakeel <zahid.wakeel@multicorewareinc.com>
1 parent c675b2f commit 06f0e4a

File tree

8 files changed

+324
-0
lines changed

8 files changed

+324
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9241,6 +9241,32 @@ def Torch_AtenNllLossBackwardOp : Torch_Op<"aten.nll_loss_backward", [
92419241
}];
92429242
}
92439243

9244+
def Torch_AtenKlDivOp : Torch_Op<"aten.kl_div", [
9245+
AllowsTypeRefinement,
9246+
HasValueSemantics,
9247+
ReadOnly
9248+
]> {
9249+
let summary = "Generated op for `aten::kl_div : (Tensor, Tensor, int, bool) -> (Tensor)`";
9250+
let arguments = (ins
9251+
AnyTorchTensorType:$self,
9252+
AnyTorchTensorType:$target,
9253+
Torch_IntType:$reduction,
9254+
Torch_BoolType:$log_target
9255+
);
9256+
let results = (outs
9257+
AnyTorchOptionalTensorType:$result
9258+
);
9259+
let hasCustomAssemblyFormat = 1;
9260+
let extraClassDefinition = [{
9261+
ParseResult AtenKlDivOp::parse(OpAsmParser &parser, OperationState &result) {
9262+
return parseDefaultTorchOp(parser, result, 4, 1);
9263+
}
9264+
void AtenKlDivOp::print(OpAsmPrinter &printer) {
9265+
printDefaultTorchOp(printer, *this, 4, 1);
9266+
}
9267+
}];
9268+
}
9269+
92449270
def Torch_AtenBincountOp : Torch_Op<"aten.bincount", [
92459271
AllowsTypeRefinement,
92469272
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10635,6 +10635,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1063510635
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
1063610636
" return %0 : !torch.list<int>\n"
1063710637
" }\n"
10638+
" func.func @\"__torch_mlir_shape_fn.aten.kl_div\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.bool) -> !torch.list<int> {\n"
10639+
" %none = torch.constant.none\n"
10640+
" %str = torch.constant.str \"AssertionError: Invalid reduction value.\"\n"
10641+
" %int0 = torch.constant.int 0\n"
10642+
" %int1 = torch.constant.int 1\n"
10643+
" %int2 = torch.constant.int 2\n"
10644+
" %0 = torch.prim.Uninitialized : !torch.list<int>\n"
10645+
" %1 = torch.aten.eq.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
10646+
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
10647+
" %3 = func.call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
10648+
" torch.prim.If.yield %3 : !torch.list<int>\n"
10649+
" } else {\n"
10650+
" %3 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
10651+
" %4 = torch.aten.__contains__.int_list %3, %arg2 : !torch.list<int>, !torch.int -> !torch.bool\n"
10652+
" %5 = torch.prim.If %4 -> (!torch.list<int>) {\n"
10653+
" %6 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
10654+
" torch.prim.If.yield %6 : !torch.list<int>\n"
10655+
" } else {\n"
10656+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
10657+
" torch.prim.If.yield %0 : !torch.list<int>\n"
10658+
" }\n"
10659+
" torch.prim.If.yield %5 : !torch.list<int>\n"
10660+
" }\n"
10661+
" return %2 : !torch.list<int>\n"
10662+
" }\n"
1063810663
" func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
1063910664
" %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int) -> !torch.tuple<list<int>, list<int>>\n"
1064010665
" return %0 : !torch.tuple<list<int>, list<int>>\n"
@@ -14403,6 +14428,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1440314428
" }\n"
1440414429
" return %int3 : !torch.int\n"
1440514430
" }\n"
14431+
" func.func @\"__torch_mlir_dtype_fn.aten.kl_div\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.bool) -> !torch.int {\n"
14432+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
14433+
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
14434+
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
14435+
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
14436+
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
14437+
" return %4 : !torch.int\n"
14438+
" }\n"
1440614439
" func.func @\"__torch_mlir_dtype_fn.aten.mse_loss\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int) -> !torch.int {\n"
1440714440
" %none = torch.constant.none\n"
1440814441
" %str = torch.constant.str \"AssertionError: \"\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10283,6 +10283,83 @@ class DecomposeAtenNllLossForwardOp
1028310283
};
1028410284
} // namespace
1028510285

10286+
namespace {
10287+
class DecomposeAtenKlDivOp : public OpRewritePattern<AtenKlDivOp> {
10288+
using OpRewritePattern::OpRewritePattern;
10289+
LogicalResult matchAndRewrite(AtenKlDivOp op,
10290+
PatternRewriter &rewriter) const override {
10291+
Location loc = op.getLoc();
10292+
Value self = op.getSelf();
10293+
Value target = op.getTarget();
10294+
Value reductionValue = op.getReduction();
10295+
Value logTargetValue = op.getLogTarget();
10296+
10297+
auto selfTy = dyn_cast<ValueTensorType>(self.getType());
10298+
auto targetTy = dyn_cast<ValueTensorType>(target.getType());
10299+
auto outTy = dyn_cast<ValueTensorType>(op.getType());
10300+
10301+
if (!selfTy.hasSizes() || !targetTy.hasSizes() || !outTy.hasSizes()) {
10302+
return rewriter.notifyMatchFailure(
10303+
op, "require self, target and output having sizes!");
10304+
}
10305+
10306+
if (!selfTy.hasDtype() || !targetTy.hasDtype() || !outTy.hasDtype()) {
10307+
return rewriter.notifyMatchFailure(
10308+
op, "require self, target and output having dtype!");
10309+
}
10310+
10311+
// Extract boolean value from logTarget argument
10312+
bool logTargetBool;
10313+
if (!matchPattern(logTargetValue, m_TorchConstantBool(&logTargetBool)))
10314+
return rewriter.notifyMatchFailure(
10315+
op, "Expected a constant boolean value for logTargetBool");
10316+
10317+
Value logOfTarget;
10318+
// Default: target tensor is not in log space
10319+
if (!logTargetBool) {
10320+
logOfTarget = rewriter.create<AtenLogOp>(loc, targetTy, target);
10321+
} else {
10322+
logOfTarget = target;
10323+
}
10324+
10325+
Value constOne =
10326+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
10327+
Value subValue = rewriter.create<AtenSubTensorOp>(loc, selfTy, logOfTarget,
10328+
self, constOne);
10329+
10330+
// target tensor is already in log space
10331+
if (logTargetBool) {
10332+
target = rewriter.create<AtenExpOp>(loc, targetTy, target);
10333+
}
10334+
Value lossPointwise =
10335+
rewriter.create<AtenMulTensorOp>(loc, targetTy, target, subValue);
10336+
10337+
// Extract reduction int value from reduction argument
10338+
int64_t reduction;
10339+
if (!matchPattern(reductionValue, m_TorchConstantInt(&reduction))) {
10340+
return rewriter.notifyMatchFailure(op,
10341+
"reduction should be a constant int!");
10342+
}
10343+
Value loss;
10344+
Value none = rewriter.create<ConstantNoneOp>(loc);
10345+
// reduction: mean
10346+
if (reduction == 1) {
10347+
loss = rewriter.create<AtenMeanOp>(loc, outTy, lossPointwise, none);
10348+
} else if (reduction == 2) {
10349+
// reduction: sum
10350+
loss = rewriter.create<AtenSumOp>(loc, outTy, lossPointwise, none);
10351+
} else {
10352+
// reduction: none
10353+
loss = lossPointwise;
10354+
}
10355+
10356+
rewriter.replaceOp(op, loss);
10357+
10358+
return success();
10359+
}
10360+
};
10361+
} // namespace
10362+
1028610363
namespace {
1028710364
class DecomposeAtenBinaryCrossEntropyWithLogitsOp
1028810365
: public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
@@ -12144,6 +12221,7 @@ class DecomposeComplexOpsPass
1214412221
addPatternIfTargetOpIsIllegal<DecomposeAtenNllLossForwardOp>(patterns);
1214512222
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
1214612223
patterns);
12224+
addPatternIfTargetOpIsIllegal<DecomposeAtenKlDivOp>(patterns);
1214712225
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
1214812226
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
1214912227
addPatternIfTargetOpIsIllegal<DecomposeAtenArgsortOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
578578
target.addIllegalOp<AtenFminOp>();
579579
target.addIllegalOp<AtenFmaxOp>();
580580
target.addIllegalOp<AtenSpecialExpm1Op>();
581+
target.addIllegalOp<AtenKlDivOp>();
581582

582583
for (auto &opName : backendLegalOpsSet) {
583584
target.addLegalOp(

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2145,6 +2145,14 @@ def aten〇tril_indices〡shape(row: int, col: int, offset: int = 0, dtype: Opti
21452145
def aten〇deg2rad〡shape(self: List[int]) -> List[int]:
21462146
return upstream_shape_functions.unary(self)
21472147

2148+
def aten〇kl_div〡shape(self: List[int], target: List[int], reduction: int = 1, log_target: bool = False) -> List[int]:
2149+
if reduction == 0:
2150+
return upstream_shape_functions.unary(self)
2151+
elif reduction in [1, 2]:
2152+
return []
2153+
else:
2154+
assert False, "Invalid reduction value."
2155+
21482156
@check_shape_function([
21492157
Invocation(TensorOfShape(2, 3), LongTensorOfShape(2), None, 1, -100), # Basic case.
21502158
Invocation(TensorOfShape(3), LongTensorOfShape(), None, 1, -100), # No batch dim.
@@ -4459,6 +4467,14 @@ def aten〇_int_mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tu
44594467
assert mat2_dtype == torch.int8
44604468
return torch.int32
44614469

4470+
def aten〇kl_div〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1, log_target: bool = False) -> int:
4471+
self_rank, self_dtype = self_rank_dtype
4472+
target_rank, target_dtype = target_rank_dtype
4473+
ranks: List[Optional[int]] = [self_rank, target_rank]
4474+
dtypes = [self_dtype, target_dtype]
4475+
promoted_dtype = promote_dtypes(ranks, dtypes)
4476+
return promoted_dtype
4477+
44624478
@check_dtype_function(_check_two_tensor_op(
44634479
output_error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}))
44644480
def aten〇mse_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1) -> int:

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,7 @@ def emit_with_mutating_variants(key, **kwargs):
751751
emit(
752752
"aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)"
753753
)
754+
emit("aten::kl_div : (Tensor, Tensor, int, bool) -> (Tensor)")
754755
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
755756
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
756757
emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)")

projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,4 @@ def register_all_tests():
6262
from . import gridsampler
6363
from . import meshgrid
6464
from . import timeout
65+
from . import kl_div_loss
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
# Also available under a BSD-style license. See LICENSE.
5+
6+
import functorch
7+
import torch
8+
9+
from torch_mlir_e2e_test.framework import TestUtils
10+
from torch_mlir_e2e_test.registry import register_test_case
11+
from torch_mlir_e2e_test.annotations import annotate_args, export
12+
13+
# ==============================================================================
14+
15+
16+
class KlDivLossModule_default(torch.nn.Module):
17+
def __init__(self):
18+
super().__init__()
19+
20+
@export
21+
@annotate_args(
22+
[
23+
None,
24+
([-1, -1, -1], torch.float32, True),
25+
([-1, -1, -1], torch.float32, True),
26+
]
27+
)
28+
def forward(self, x, y):
29+
return torch.ops.aten.kl_div(x, y)
30+
31+
32+
@register_test_case(module_factory=lambda: KlDivLossModule_default())
33+
def KlDivLossModule_default_basic(module, tu: TestUtils):
34+
module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2))
35+
36+
37+
# ==============================================================================
38+
39+
40+
class KlDivLossModule_reduction_is_none(torch.nn.Module):
41+
def __init__(self):
42+
super().__init__()
43+
44+
@export
45+
@annotate_args(
46+
[
47+
None,
48+
([-1, -1, -1], torch.float32, True),
49+
([-1, -1, -1], torch.float32, True),
50+
]
51+
)
52+
def forward(self, x, y):
53+
return torch.ops.aten.kl_div(x, y, reduction=0)
54+
55+
56+
@register_test_case(module_factory=lambda: KlDivLossModule_reduction_is_none())
57+
def KlDivLossModule_reduction_is_none_basic(module, tu: TestUtils):
58+
module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2))
59+
60+
61+
# ==============================================================================
62+
63+
64+
class KlDivLossModule_reduction_is_none_log_target_is_true(torch.nn.Module):
65+
def __init__(self):
66+
super().__init__()
67+
68+
@export
69+
@annotate_args(
70+
[
71+
None,
72+
([-1, -1, -1], torch.float32, True),
73+
([-1, -1, -1], torch.float32, True),
74+
]
75+
)
76+
def forward(self, x, y):
77+
return torch.ops.aten.kl_div(x, y, reduction=0, log_target=True)
78+
79+
80+
@register_test_case(
81+
module_factory=lambda: KlDivLossModule_reduction_is_none_log_target_is_true()
82+
)
83+
def KlDivLossModule_reduction_is_none_log_target_is_true_basic(module, tu: TestUtils):
84+
module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2))
85+
86+
87+
# ==============================================================================
88+
89+
90+
class KlDivLossModule_mean_reduction(torch.nn.Module):
91+
def __init__(self):
92+
super().__init__()
93+
94+
@export
95+
@annotate_args(
96+
[
97+
None,
98+
([-1, -1, -1], torch.float32, True),
99+
([-1, -1, -1], torch.float32, True),
100+
]
101+
)
102+
def forward(self, x, y):
103+
return torch.ops.aten.kl_div(x, y, reduction=1)
104+
105+
106+
@register_test_case(module_factory=lambda: KlDivLossModule_mean_reduction())
107+
def KlDivLossModule_mean_reduction_basic(module, tu: TestUtils):
108+
module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2))
109+
110+
111+
# ==============================================================================
112+
113+
114+
class KlDivLossModule_sum_reduction(torch.nn.Module):
115+
def __init__(self):
116+
super().__init__()
117+
118+
@export
119+
@annotate_args(
120+
[
121+
None,
122+
([-1, -1, -1], torch.float32, True),
123+
([-1, -1, -1], torch.float32, True),
124+
]
125+
)
126+
def forward(self, x, y):
127+
return torch.ops.aten.kl_div(x, y, reduction=2)
128+
129+
130+
@register_test_case(module_factory=lambda: KlDivLossModule_sum_reduction())
131+
def KlDivLossModule_sum_reduction_basic(module, tu: TestUtils):
132+
module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2))
133+
134+
135+
# ==============================================================================
136+
137+
138+
class KlDivLossModule_batchmean_reduction(torch.nn.Module):
139+
def __init__(self):
140+
super().__init__()
141+
142+
@export
143+
@annotate_args(
144+
[
145+
None,
146+
([-1, -1, -1], torch.float32, True),
147+
([-1, -1, -1], torch.float32, True),
148+
]
149+
)
150+
def forward(self, input, target):
151+
# torch.ops.aten.kl_div has no direct way to pass batchmean as reduction mode.
152+
# https://github.com/pytorch/pytorch/blob/53ecb8159aa28b3c015917acaa89604cfae0d2c6/torch/nn/_reduction.py#L8-L24
153+
# F.kl_div(input, target, reduction="batchmean"):
154+
# out = torch.kl_div(input, target, reduction="sum")
155+
# batch_size = input.shape[0]
156+
# out = out / batch_size
157+
# https://github.com/pytorch/pytorch/blob/53ecb8159aa28b3c015917acaa89604cfae0d2c6/torch/nn/functional.py#L3379-L3381
158+
loss = torch.ops.aten.kl_div(input, target, reduction=2)
159+
batch_size = input.shape[0]
160+
return torch.ops.aten.div.Scalar(loss, batch_size)
161+
162+
163+
@register_test_case(module_factory=lambda: KlDivLossModule_batchmean_reduction())
164+
def KlDivLossModule_batchmean_reduction_basic(module, tu: TestUtils):
165+
module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2))
166+
167+
168+
# ==============================================================================

0 commit comments

Comments
 (0)