diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c851f82dba72..cd0bf6206850 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -9452,6 +9452,32 @@ def Torch_AtenNllLossBackwardOp : Torch_Op<"aten.nll_loss_backward", [ }]; } +def Torch_AtenKlDivOp : Torch_Op<"aten.kl_div", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::kl_div : (Tensor, Tensor, int, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$target, + Torch_IntType:$reduction, + Torch_BoolType:$log_target + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenKlDivOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenKlDivOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenBincountOp : Torch_Op<"aten.bincount", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 9c355f4ea4a8..5b9c2beecd2c 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10688,6 +10688,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.kl_div\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.bool) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: Invalid reduction value.\"\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %0 = torch.prim.Uninitialized : !torch.list\n" +" %1 = torch.aten.eq.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" %3 = func.call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" torch.prim.If.yield %3 : !torch.list\n" +" } else {\n" +" %3 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = torch.aten.__contains__.int_list %3, %arg2 : !torch.list, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.list) {\n" +" %6 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.If.yield %6 : !torch.list\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.list\n" +" }\n" +" torch.prim.If.yield %5 : !torch.list\n" +" }\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple, list> {\n" " %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list, !torch.list, !torch.optional>, !torch.int) -> !torch.tuple, list>\n" " return %0 : !torch.tuple, list>\n" @@ -14517,6 +14542,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %int3 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.kl_div\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mse_loss\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 62ce02df50a6..74f4813d1910 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -10471,6 +10471,84 @@ class DecomposeAtenNllLossForwardOp }; } // namespace +namespace { +class DecomposeAtenKlDivOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenKlDivOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value target = op.getTarget(); + Value reductionValue = op.getReduction(); + Value logTargetValue = op.getLogTarget(); + + auto selfTy = cast(self.getType()); + auto targetTy = cast(target.getType()); + auto outTy = cast(op.getType()); + + if (!selfTy.hasSizes() || !targetTy.hasSizes() || !outTy.hasSizes()) { + return rewriter.notifyMatchFailure( + op, "require self, target and output having sizes!"); + } + + if (!selfTy.hasDtype() || !targetTy.hasDtype() || !outTy.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "require self, target and output having dtype!"); + } + + // Extract boolean value from logTarget argument + bool logTargetBool; + if (!matchPattern(logTargetValue, m_TorchConstantBool(&logTargetBool))) + return rewriter.notifyMatchFailure( + op, "Expected a constant boolean value for logTargetBool"); + + // Default: target tensor is not in log space + Value logOfTarget; + if (!logTargetBool) { + logOfTarget = rewriter.create(loc, targetTy, target); + } else { + logOfTarget = target; + } + + Value constOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value subValue = rewriter.create(loc, selfTy, logOfTarget, + self, constOne); + + // if target tensor is already in log space + if (logTargetBool) { + target = rewriter.create(loc, targetTy, target); + } + Value lossPointwise = + rewriter.create(loc, targetTy, target, subValue); + + // Extract reduction int value from reduction argument + int64_t reduction; + if (!matchPattern(reductionValue, m_TorchConstantInt(&reduction))) { + return rewriter.notifyMatchFailure(op, + "reduction should be a constant int!"); + } + + Value loss; + Value none = rewriter.create(loc); + // reduction: mean + if (reduction == 1) { + loss = rewriter.create(loc, outTy, lossPointwise, none); + } else if (reduction == 2) { + // reduction: sum + loss = rewriter.create(loc, outTy, lossPointwise, none); + } else { + // reduction: none + loss = lossPointwise; + } + + rewriter.replaceOp(op, loss); + + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenBinaryCrossEntropyWithLogitsOp : public OpRewritePattern { @@ -12386,6 +12464,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 13d5a1f2ab8b..6bccc1dff4ef 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -585,6 +585,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); for (auto &opName : backendLegalOpsSet) { target.addLegalOp( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 3b5f651c8903..311cc7657d84 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -39,6 +39,8 @@ "AtenSymConstrainRange_basic", "AtenSymConstrainRangeForSize_basic", "Aten_AssertScalar_basic", + # RuntimeError: attribute lookup is not defined on builtin: + "KlDivLossModule_batchmean_reduction_basic", } if torch_version_for_comparison() < version.parse("2.5.0.dev"): @@ -386,6 +388,12 @@ "MaxPool3dStaticModule_basic", # Looks like incorrect fx graph conversion "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", + # error: failed to legalize operation 'torch.aten.xlogy.Tensor' + "KlDivLossModule_default_basic", + "KlDivLossModule_reduction_is_none_basic", + "KlDivLossModule_mean_reduction_basic", + "KlDivLossModule_sum_reduction_basic", + "KlDivLossModule_batchmean_reduction_basic", } FX_IMPORTER_XFAIL_SET = { @@ -3068,6 +3076,7 @@ "NllLossStaticModule_mean_basic", "NllLossModule_sum_basic", "NllLossStaticModule_sum_basic", + "KlDivLossModule_batchmean_reduction_basic", "NormScalarComplexModule_basic", "NormScalarModule_basic", "NormScalarOptDimKeepDimComplexModule_basic", @@ -3953,6 +3962,12 @@ "NllLossStaticModule_mean_basic", "NllLossStaticModule_sum_basic", "NllLossStaticModule_weight_basic", + "KlDivLossModule_default_basic", + "KlDivLossModule_reduction_is_none_basic", + "KlDivLossModule_reduction_is_none_log_target_is_true_basic", + "KlDivLossModule_mean_reduction_basic", + "KlDivLossModule_sum_reduction_basic", + "KlDivLossModule_batchmean_reduction_basic", "Exp2StaticModule_basic", "ElementwiseRreluWithNoiseEvalModule_basic", "ElementwiseRreluWithNoiseEvalStaticModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index a069550ec669..f3000cb8fe11 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -2171,6 +2171,14 @@ def aten〇tril_indices〡shape(row: int, col: int, offset: int = 0, dtype: Opti def aten〇deg2rad〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇kl_div〡shape(self: List[int], target: List[int], reduction: int = 1, log_target: bool = False) -> List[int]: + if reduction == 0: + return upstream_shape_functions.unary(self) + elif reduction in [1, 2]: + return [] + else: + assert False, "Invalid reduction value." + @check_shape_function([ Invocation(TensorOfShape(2, 3), LongTensorOfShape(2), None, 1, -100), # Basic case. Invocation(TensorOfShape(3), LongTensorOfShape(), None, 1, -100), # No batch dim. @@ -4523,6 +4531,14 @@ def aten〇_int_mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tu assert mat2_dtype == torch.int8 return torch.int32 +def aten〇kl_div〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1, log_target: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + target_rank, target_dtype = target_rank_dtype + ranks: List[Optional[int]] = [self_rank, target_rank] + dtypes = [self_dtype, target_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + return promoted_dtype + @check_dtype_function(_check_two_tensor_op( output_error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64})) def aten〇mse_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1) -> int: diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 6817d285faea..28b97ac0eedc 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -760,6 +760,7 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)" ) + emit("aten::kl_div : (Tensor, Tensor, int, bool) -> (Tensor)") emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)") emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)") emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py index 8166562b0527..bd82cd1c11b0 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -62,3 +62,4 @@ def register_all_tests(): from . import gridsampler from . import meshgrid from . import timeout + from . import kl_div_loss diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/kl_div_loss.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/kl_div_loss.py new file mode 100644 index 000000000000..2b963751ddf6 --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/kl_div_loss.py @@ -0,0 +1,168 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import functorch +import torch + +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export + +# ============================================================================== + + +class KlDivLossModule_default(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.kl_div(x, y) + + +@register_test_case(module_factory=lambda: KlDivLossModule_default()) +def KlDivLossModule_default_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2)) + + +# ============================================================================== + + +class KlDivLossModule_reduction_is_none(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.kl_div(x, y, reduction=0) + + +@register_test_case(module_factory=lambda: KlDivLossModule_reduction_is_none()) +def KlDivLossModule_reduction_is_none_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2)) + + +# ============================================================================== + + +class KlDivLossModule_reduction_is_none_log_target_is_true(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.kl_div(x, y, reduction=0, log_target=True) + + +@register_test_case( + module_factory=lambda: KlDivLossModule_reduction_is_none_log_target_is_true() +) +def KlDivLossModule_reduction_is_none_log_target_is_true_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2)) + + +# ============================================================================== + + +class KlDivLossModule_mean_reduction(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.kl_div(x, y, reduction=1) + + +@register_test_case(module_factory=lambda: KlDivLossModule_mean_reduction()) +def KlDivLossModule_mean_reduction_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2)) + + +# ============================================================================== + + +class KlDivLossModule_sum_reduction(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.kl_div(x, y, reduction=2) + + +@register_test_case(module_factory=lambda: KlDivLossModule_sum_reduction()) +def KlDivLossModule_sum_reduction_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2)) + + +# ============================================================================== + + +class KlDivLossModule_batchmean_reduction(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, input, target): + # torch.ops.aten.kl_div has no direct way to pass batchmean as reduction mode. + # https://github.com/pytorch/pytorch/blob/53ecb8159aa28b3c015917acaa89604cfae0d2c6/torch/nn/_reduction.py#L8-L24 + # F.kl_div(input, target, reduction="batchmean"): + # out = torch.kl_div(input, target, reduction="sum") + # batch_size = input.shape[0] + # out = out / batch_size + # https://github.com/pytorch/pytorch/blob/53ecb8159aa28b3c015917acaa89604cfae0d2c6/torch/nn/functional.py#L3379-L3381 + loss = torch.ops.aten.kl_div(input, target, reduction=2) + batch_size = input.shape[0] + return torch.ops.aten.div.Scalar(loss, batch_size) + + +@register_test_case(module_factory=lambda: KlDivLossModule_batchmean_reduction()) +def KlDivLossModule_batchmean_reduction_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2)) + + +# ==============================================================================