Skip to content

Commit 3b25ba3

Browse files
Revert "[linalg] Fix torch.aten.add of torch.bool (#3820)"
This reverts commit 5aa323d.
1 parent f2c3191 commit 3b25ba3

File tree

2 files changed

+0
-32
lines changed

2 files changed

+0
-32
lines changed

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -827,9 +827,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
827827
if (isa<mlir::FloatType>(dtype)) {
828828
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
829829
return b.create<arith::AddFOp>(loc, lhs, scaled);
830-
} else if (dtype.isInteger(1)) {
831-
Value scaled = b.create<arith::MulIOp>(loc, rhs, alpha);
832-
return b.create<arith::OrIOp>(loc, lhs, scaled);
833830
} else {
834831
Value scaled = b.create<arith::MulIOp>(loc, rhs, alpha);
835832
return b.create<arith::AddIOp>(loc, lhs, scaled);

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -685,35 +685,6 @@ def ElementwiseAddModule_basic(module, tu: TestUtils):
685685
# ==============================================================================
686686

687687

688-
# Addition is an interesting special case of a binary op, because under the hood
689-
# it carries a third scalar "alpha" parameter, which needs special handling.
690-
class ElementwiseAddBoolModule(torch.nn.Module):
691-
def __init__(self):
692-
super().__init__()
693-
694-
@export
695-
@annotate_args(
696-
[
697-
None,
698-
([4], torch.bool, True),
699-
([4], torch.bool, True),
700-
]
701-
)
702-
def forward(self, a, b):
703-
return a + b
704-
705-
706-
@register_test_case(module_factory=lambda: ElementwiseAddBoolModule())
707-
def ElementwiseAddBoolModule_basic(module, tu: TestUtils):
708-
module.forward(
709-
torch.tensor([False, False, True, True]),
710-
torch.tensor([False, True, False, False]),
711-
)
712-
713-
714-
# ==============================================================================
715-
716-
717688
class ElementwiseUnsqueezeBroadcastModule(torch.nn.Module):
718689
def __init__(self):
719690
super().__init__()

0 commit comments

Comments
 (0)