Skip to content

Commit 6a341f8

Browse files
Revert "[Torch] support AtenExp2Op (#3832)"
This reverts commit 9ce2a69.
1 parent 9a89cd0 commit 6a341f8

File tree

8 files changed

+94
-152
lines changed

8 files changed

+94
-152
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -996,51 +996,6 @@ def Torch_AtenExp_Op : Torch_Op<"aten.exp_", [
996996
}];
997997
}
998998

999-
def Torch_AtenExp2Op : Torch_Op<"aten.exp2", [
1000-
AllowsTypeRefinement,
1001-
HasValueSemantics,
1002-
ReadOnly
1003-
]> {
1004-
let summary = "Generated op for `aten::exp2 : (Tensor) -> (Tensor)`";
1005-
let arguments = (ins
1006-
AnyTorchTensorType:$self
1007-
);
1008-
let results = (outs
1009-
AnyTorchOptionalTensorType:$result
1010-
);
1011-
let hasCustomAssemblyFormat = 1;
1012-
let extraClassDefinition = [{
1013-
ParseResult AtenExp2Op::parse(OpAsmParser &parser, OperationState &result) {
1014-
return parseDefaultTorchOp(parser, result, 1, 1);
1015-
}
1016-
void AtenExp2Op::print(OpAsmPrinter &printer) {
1017-
printDefaultTorchOp(printer, *this, 1, 1);
1018-
}
1019-
}];
1020-
}
1021-
1022-
def Torch_AtenExp2_Op : Torch_Op<"aten.exp2_", [
1023-
IsTrailingUnderscoreInplaceVariant,
1024-
AllowsTypeRefinement
1025-
]> {
1026-
let summary = "Generated op for `aten::exp2_ : (Tensor) -> (Tensor)`";
1027-
let arguments = (ins
1028-
Torch_NonValueTensorType:$self
1029-
);
1030-
let results = (outs
1031-
AnyTorchOptionalNonValueTensorType:$result
1032-
);
1033-
let hasCustomAssemblyFormat = 1;
1034-
let extraClassDefinition = [{
1035-
ParseResult AtenExp2_Op::parse(OpAsmParser &parser, OperationState &result) {
1036-
return parseDefaultTorchOp(parser, result, 1, 1);
1037-
}
1038-
void AtenExp2_Op::print(OpAsmPrinter &printer) {
1039-
printDefaultTorchOp(printer, *this, 1, 1);
1040-
}
1041-
}];
1042-
}
1043-
1044999
def Torch_AtenExpm1Op : Torch_Op<"aten.expm1", [
10451000
AllowsTypeRefinement,
10461001
HasValueSemantics,

lib/Conversion/TorchToStablehlo/Basic.cpp

Lines changed: 94 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -931,49 +931,79 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
931931
return success();
932932
}
933933

934-
namespace {
935-
template <typename AtenOpT>
936-
class ConvertAtenPowOp : public OpConversionPattern<AtenOpT> {
937-
public:
938-
using OpConversionPattern<AtenOpT>::OpConversionPattern;
939-
using OpAdaptor = typename AtenOpT::Adaptor;
940-
LogicalResult
941-
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
942-
ConversionPatternRewriter &rewriter) const override {
943-
auto outType = cast<TensorType>(
944-
OpConversionPattern<AtenPowScalarOp>::getTypeConverter()->convertType(
945-
op.getType()));
934+
// AtenPowTensorScalarOp
935+
template <>
936+
LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
937+
AtenPowTensorScalarOp op, OpAdaptor adaptor,
938+
ConversionPatternRewriter &rewriter) const {
939+
Value lhs = adaptor.getSelf();
940+
auto lhsType = dyn_cast<TensorType>(lhs.getType());
941+
Value rhs = adaptor.getExponent();
942+
TensorType rhsType = dyn_cast<TensorType>(rhs.getType());
946943

947-
Type outElemTy = outType.getElementType();
948-
if (!outElemTy.isIntOrFloat()) {
949-
return op.emitError(
950-
"only floating-point or integer datatype legalization supported");
951-
}
944+
if (!lhsType)
945+
return op.emitError("only Tensor types supported in StableHLO");
952946

953-
Value lhs = adaptor.getSelf();
954-
auto lhsType = dyn_cast<TensorType>(lhs.getType());
955-
Value rhs = adaptor.getExponent();
956-
auto rhsType = dyn_cast<TensorType>(rhs.getType());
947+
auto outType = cast<TensorType>(
948+
OpConversionPattern<AtenPowTensorScalarOp>::getTypeConverter()
949+
->convertType(op.getType()));
957950

958-
if (!lhsType && !rhsType) {
959-
return op.emitError("only Tensor types supported in StableHLO");
960-
}
961-
if (!lhsType) {
962-
lhs = hlo::scalarToStablehloTensor(rewriter, op, lhs, outElemTy);
963-
}
964-
if (!rhsType) {
965-
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
966-
}
951+
Type outElemTy = outType.getElementType();
952+
if (!outElemTy.isIntOrFloat()) {
953+
return op.emitError(
954+
"only floating-point or integer datatype legalization supported");
955+
}
967956

968-
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
969-
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
970-
DenseI64ArrayAttr bcastDimensions;
971-
rewriter.replaceOpWithNewOp<chlo::BroadcastPowOp>(op, outType, lhs, rhs,
972-
bcastDimensions);
973-
return success();
957+
if (!rhsType) {
958+
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
974959
}
975-
};
976-
} // namespace
960+
DenseI64ArrayAttr bcastDimensions;
961+
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
962+
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
963+
auto loc = op.getLoc();
964+
Value result = rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs,
965+
bcastDimensions);
966+
967+
rewriter.replaceOp(op, result);
968+
return success();
969+
}
970+
971+
// AtenPowScalarOp
972+
template <>
973+
LogicalResult ConvertAtenOp<AtenPowScalarOp>::matchAndRewrite(
974+
AtenPowScalarOp op, OpAdaptor adaptor,
975+
ConversionPatternRewriter &rewriter) const {
976+
Value lhs = adaptor.getSelf();
977+
auto lhsType = dyn_cast<TensorType>(lhs.getType());
978+
Value rhs = adaptor.getExponent();
979+
auto rhsType = dyn_cast<TensorType>(rhs.getType());
980+
981+
if (!rhsType)
982+
return op.emitError("only Tensor types supported in StableHLO");
983+
984+
auto outType = cast<TensorType>(
985+
OpConversionPattern<AtenPowScalarOp>::getTypeConverter()->convertType(
986+
op.getType()));
987+
988+
Type outElemTy = outType.getElementType();
989+
if (!outElemTy.isIntOrFloat()) {
990+
return op.emitError(
991+
"only floating-point or integer datatype legalization supported");
992+
}
993+
994+
if (!lhsType) {
995+
lhs = hlo::scalarToStablehloTensor(rewriter, op, lhs, outElemTy);
996+
}
997+
DenseI64ArrayAttr bcastDimensions;
998+
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
999+
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
1000+
auto loc = op.getLoc();
1001+
Value result = rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs,
1002+
bcastDimensions);
1003+
1004+
rewriter.replaceOp(op, result);
1005+
return success();
1006+
}
9771007

9781008
// PrimNumToTensorScalarOp
9791009
template <>
@@ -1767,6 +1797,29 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
17671797
return success();
17681798
}
17691799

1800+
template <>
1801+
LogicalResult ConvertAtenOp<AtenPowTensorTensorOp>::matchAndRewrite(
1802+
AtenPowTensorTensorOp op, OpAdaptor adaptor,
1803+
ConversionPatternRewriter &rewriter) const {
1804+
Value lhs = adaptor.getSelf();
1805+
auto lhsTy = cast<TensorType>(lhs.getType());
1806+
Value rhs = adaptor.getExponent();
1807+
auto rhsTy = cast<TensorType>(rhs.getType());
1808+
1809+
if (!lhsTy || !rhsTy)
1810+
return op.emitError("only Tensor types supported");
1811+
1812+
auto outTy =
1813+
cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));
1814+
1815+
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy.getElementType());
1816+
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy.getElementType());
1817+
1818+
rewriter.replaceOpWithNewOp<chlo::BroadcastPowOp>(op, outTy, lhs, rhs,
1819+
/*broadcast_attr*/ nullptr);
1820+
return success();
1821+
}
1822+
17701823
// Converts `aten.empty.memory_format` to `tensor.empty` op.
17711824
template <>
17721825
LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
@@ -2197,14 +2250,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
21972250

21982251
#undef INSERT_BINARY_LOGICAL_PATTERN
21992252

2200-
#define INSERT_BINARY_POW_PATTERN(AtenOp) \
2201-
target.addIllegalOp<AtenOp>(); \
2202-
patterns.add<ConvertAtenPowOp<AtenOp>>(typeConverter, context)
2203-
INSERT_BINARY_POW_PATTERN(AtenPowTensorScalarOp);
2204-
INSERT_BINARY_POW_PATTERN(AtenPowTensorTensorOp);
2205-
INSERT_BINARY_POW_PATTERN(AtenPowScalarOp);
2206-
#undef INSERT_BINARY_ADDSUB_PATTERN
2207-
22082253
#define INSERT_ATENOP_PATTERN(AtenOp) \
22092254
target.addIllegalOp<AtenOp>(); \
22102255
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
@@ -2215,6 +2260,8 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
22152260
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
22162261
INSERT_ATENOP_PATTERN(AtenTensorIntOp);
22172262
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
2263+
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
2264+
INSERT_ATENOP_PATTERN(AtenPowScalarOp);
22182265
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
22192266
INSERT_ATENOP_PATTERN(AtenScalarImplicitOp);
22202267
INSERT_ATENOP_PATTERN(AtenContiguousOp);
@@ -2238,6 +2285,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
22382285
INSERT_ATENOP_PATTERN(AtenSizeIntOp);
22392286
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
22402287
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
2288+
INSERT_ATENOP_PATTERN(AtenPowTensorTensorOp);
22412289

22422290
INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp);
22432291
INSERT_ATENOP_PATTERN(AtenFillScalarOp);

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6487,10 +6487,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
64876487
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
64886488
" return %0 : !torch.list<int>\n"
64896489
" }\n"
6490-
" func.func @\"__torch_mlir_shape_fn.aten.exp2\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
6491-
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
6492-
" return %0 : !torch.list<int>\n"
6493-
" }\n"
64946490
" func.func @\"__torch_mlir_shape_fn.aten.expm1\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
64956491
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
64966492
" return %0 : !torch.list<int>\n"
@@ -11260,11 +11256,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1126011256
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
1126111257
" return %1 : !torch.int\n"
1126211258
" }\n"
11263-
" func.func @\"__torch_mlir_dtype_fn.aten.exp2\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
11264-
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
11265-
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
11266-
" return %1 : !torch.int\n"
11267-
" }\n"
1126811259
" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
1126911260
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1127011261
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9008,24 +9008,6 @@ class DecomposeAtenBinaryCrossEntropyWithLogitsOp
90089008
};
90099009
} // namespace
90109010

9011-
namespace {
9012-
class DecomposeAtenExp2Op : public OpRewritePattern<AtenExp2Op> {
9013-
using OpRewritePattern<AtenExp2Op>::OpRewritePattern;
9014-
LogicalResult matchAndRewrite(AtenExp2Op op,
9015-
PatternRewriter &rewriter) const override {
9016-
Location loc = op.getLoc();
9017-
Value self = op.getSelf();
9018-
9019-
auto two =
9020-
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(2));
9021-
rewriter.replaceOpWithNewOp<AtenPowScalarOp>(op, op.getType(), two, self);
9022-
9023-
return success();
9024-
}
9025-
};
9026-
9027-
} // namespace
9028-
90299011
namespace {
90309012
class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
90319013
using OpRewritePattern<AtenOneHotOp>::OpRewritePattern;
@@ -10164,7 +10146,6 @@ class DecomposeComplexOpsPass
1016410146
addPatternIfTargetOpIsIllegal<DecomposePrimTolistOp>(patterns);
1016510147
addPatternIfTargetOpIsIllegal<DecomposePrimsSqueezeOp>(patterns);
1016610148
addPatternIfTargetOpIsIllegal<DecomposeAtenMovedimIntOp>(patterns);
10167-
addPatternIfTargetOpIsIllegal<DecomposeAtenExp2Op>(patterns);
1016810149
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
1016910150
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
1017010151
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2707,7 +2707,6 @@
27072707
"ElementwiseLog2IntModule_basic",
27082708
"ElementwiseFminModule_basic",
27092709
"ElementwiseFmaxModule_basic",
2710-
"Exp2StaticModule_basic",
27112710
"MultinomialModule2D_basic",
27122711
"MultinomialModule2D_F32",
27132712
"PixelShuffleModuleStaticRank4Float32_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,6 @@ def aten〇silu〡shape(self: List[int]) -> List[int]:
216216
def aten〇exp〡shape(self: List[int]) -> List[int]:
217217
return upstream_shape_functions.unary(self)
218218

219-
def aten〇exp2〡shape(self: List[int]) -> List[int]:
220-
return upstream_shape_functions.unary(self)
221-
222219
def aten〇expm1〡shape(self: List[int]) -> List[int]:
223220
return upstream_shape_functions.unary(self)
224221

@@ -2570,11 +2567,6 @@ def aten〇exp〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
25702567
self_rank, self_dtype = self_rank_dtype
25712568
return _get_dtype_of_floating_point_op(self_dtype)
25722569

2573-
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
2574-
def aten〇exp2〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
2575-
self_rank, self_dtype = self_rank_dtype
2576-
return _get_dtype_of_floating_point_op(self_dtype)
2577-
25782570
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
25792571
def aten〇expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
25802572
self_rank, self_dtype = self_rank_dtype

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,6 @@ def emit_with_mutating_variants(key, **kwargs):
317317
"aten::asin : (Tensor) -> (Tensor)",
318318
"aten::asinh : (Tensor) -> (Tensor)",
319319
"aten::exp : (Tensor) -> (Tensor)",
320-
"aten::exp2 : (Tensor) -> (Tensor)",
321320
"aten::expm1 : (Tensor) -> (Tensor)",
322321
"aten::cos : (Tensor) -> (Tensor)",
323322
"aten::cosh : (Tensor) -> (Tensor)",

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2881,29 +2881,6 @@ def ElementwiseSgnModule_basic(module, tu: TestUtils):
28812881
# ==============================================================================
28822882

28832883

2884-
class Exp2StaticModule(torch.nn.Module):
2885-
def __init__(self):
2886-
super().__init__()
2887-
2888-
@export
2889-
@annotate_args(
2890-
[
2891-
None,
2892-
([3, 2], torch.float32, True),
2893-
]
2894-
)
2895-
def forward(self, x):
2896-
return torch.ops.aten.exp2(x)
2897-
2898-
2899-
@register_test_case(module_factory=lambda: Exp2StaticModule())
2900-
def Exp2StaticModule_basic(module, tu: TestUtils):
2901-
module.forward(tu.rand(3, 2))
2902-
2903-
2904-
# ==============================================================================
2905-
2906-
29072884
class ElementwisePowModule(torch.nn.Module):
29082885
def __init__(self):
29092886
super().__init__()

0 commit comments

Comments
 (0)