From ec43437e40c0b501e722633dbf549bcd19bef02b Mon Sep 17 00:00:00 2001 From: Gabriel Ferns Date: Wed, 19 Mar 2025 15:05:55 -0700 Subject: [PATCH] torch cond --- torchvision/transforms/_functional_tensor.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/_functional_tensor.py b/torchvision/transforms/_functional_tensor.py index 618bbfbab7c..57378258f59 100644 --- a/torchvision/transforms/_functional_tensor.py +++ b/torchvision/transforms/_functional_tensor.py @@ -919,8 +919,13 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool dtype = tensor.dtype mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) std = torch.as_tensor(std, dtype=dtype, device=tensor.device) - if (std == 0).any(): - raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.") + + def stdzero(): + raise ValueError( + f"std evaluated to zero after conversion to {dtype}, leading to division by zero." + ) + + torch.cond((std == 0).any(), stdzero, lambda: None) if mean.ndim == 1: mean = mean.view(-1, 1, 1) if std.ndim == 1: