Skip to content

[TORCH] Add support for aten.heaviside Op #4220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

sharavak
Copy link

@sharavak sharavak commented Jun 2, 2025

  • Decomposed heaviside op into Aten ops.
  • Added test cases in the e2e part.

This implementation addresses and closes #4211

@sharavak
Copy link
Author

sharavak commented Jun 2, 2025

@stellaraccident @vivekkhandelwal1 @penguin-wwy @zjgarvey @AmosLewis, I’d be grateful if any of you could take a look at this PR. Your feedback would be greatly appreciated!

Comment on lines +318 to +333
class ElementwiseHeavisideIntModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([-1, -1], torch.int32, True), ([-1], torch.int32, True)])
def forward(self, x, values):
return torch.heaviside(x, values)


@register_test_case(module_factory=lambda: ElementwiseHeavisideIntModule())
def ElementwiseHeavisideIntModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(5, 1, low=-100, high=1000).to(torch.int32),
tu.randint(1, low=-100, high=1000).to(torch.int32),
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you cast to int32 here?

Comment on lines +10997 to +11012
SmallVector<int64_t> broadcastShape;
SmallVector<Value> broadcastShapeValue;
computeBroadcastShape(rewriter, loc, input, value, broadcastShape,
broadcastShapeValue);

auto broadcastType = ValueTensorType::get(
op.getContext(), llvm::ArrayRef(broadcastShape), resultTy.getDtype());
auto boolBroadcastType = ValueTensorType::get(
op.getContext(), llvm::ArrayRef(broadcastShape), rewriter.getI1Type());
Value indexBroadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
broadcastShapeValue);
auto inputBroadcasted = rewriter.create<AtenBroadcastToOp>(
loc, broadcastType, input, indexBroadcastShapeTorchList);
auto valueBroadcasted = rewriter.create<AtenBroadcastToOp>(
loc, broadcastType, value, indexBroadcastShapeTorchList);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not needed. Since you are decomposing this op into elementwise ops, the broadcasting part will be handled during Torch->Linalg lowering.

Comment on lines +11026 to +11028
// Compute mask: isnan(input)
auto isNan =
rewriter.create<AtenIsnanOp>(loc, boolBroadcastType, inputBroadcasted);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not see the mention of this case here: https://docs.pytorch.org/docs/stable/generated/torch.heaviside.html. Can you please share any reference?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review, @vivekkhandelwal1.
I tested this behavior with PyTorch — if the input contains NaN values, they are replaced with 0.
To handle this explicitly, I used AtenIsnanOp to detect NaN values

input=torch.tensor([[0,float('nan')]])
values=torch.tensor([2],dtype=torch.float32)
torch.heaviside(input,values)

Output
tensor([[2., 0.]])

Ref: https://github.com/pytorch/pytorch/blob/main/torch/_refs/__init__.py#L1448

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[TORCH] Add support for aten.heaviside
2 participants