Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions src/torchmetrics/utilities/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,32 @@ def normalize_logits_if_needed(tensor: Tensor, normalization: Optional[Literal["
# decrease sigmoid on cpu .
if tensor.device == torch.device("cpu"):
if not torch.all((tensor >= 0) * (tensor <= 1)):
tensor = tensor.sigmoid() if normalization == "sigmoid" else torch.softmax(tensor, dim=1)
if normalization == "sigmoid":
# Apply numerically stable sigmoid by subtracting max to prevent overflow
# For large positive logits (>16.7 for float32, >36.7 for float64), sigmoid(x) overflows to 1.0
# Only apply stabilization when min value is also large (indicating all values will overflow)
# This avoids the issue where subtracting max creates artificial ties for widely spread values
Comment on lines +240 to +241
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code comment says "Only apply stabilization when min value is also large (indicating all values will overflow)" but this should also note the limitation: mixed-range logits with some large values won't be stabilized.

Consider expanding the comment to be clearer about the trade-off:
"Only apply stabilization when min value is also large (indicating all values will overflow).
Note: This means mixed-range logits (e.g., [0, 50, 100]) won't be stabilized, and large values
may still overflow to 1.0. However, stabilizing mixed ranges would create artificial ties for
smaller values, which is worse for ranking-based metrics."

Copilot uses AI. Check for mistakes.
min_val = tensor.min()
max_val = tensor.max()
if min_val > 15: # All values are large enough to potentially overflow
Comment on lines +243 to +244
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CPU path always computes min_val and max_val (lines 242-243) even when they might not be needed. If min_val <= 15, then max_val is computed but never used (since we take the else branch on line 247).

For better performance, consider restructuring to only compute max_val when it's actually needed (i.e., when min_val > 15). This would look like:

min_val = tensor.min()
if min_val > 15:
    max_val = tensor.max()
    tensor = (tensor - max_val).sigmoid()
else:
    tensor = tensor.sigmoid()
Suggested change
max_val = tensor.max()
if min_val > 15: # All values are large enough to potentially overflow
if min_val > 15: # All values are large enough to potentially overflow
max_val = tensor.max()

Copilot uses AI. Check for mistakes.
tensor = (tensor - max_val).sigmoid()
else:
tensor = tensor.sigmoid()
else:
tensor = torch.softmax(tensor, dim=1)
return tensor

# decrease device-host sync on device .
condition = ((tensor < 0) | (tensor > 1)).any()
return torch.where(
condition,
torch.sigmoid(tensor) if normalization == "sigmoid" else torch.softmax(tensor, dim=1),
tensor,
)
if normalization == "sigmoid":
# Apply numerically stable sigmoid by subtracting max to prevent overflow
# Only stabilize when all values are large to avoid creating artificial ties
min_val = tensor.min()
max_val = tensor.max()
# Use stable sigmoid only when minimum value is also large (all values will overflow)
needs_stabilization = min_val > 15
Comment on lines +257 to +260
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The device path logic has a critical correctness issue. The function computes min_val and max_val on lines 257-258, which will trigger device-to-host synchronization (defeating the purpose of the "decrease device-host sync on device" comment on line 252).

The original code on line 253 computes condition which is already a scalar boolean that causes device-to-host sync. The new code adds two more synchronizations (min and max), making the device path potentially much slower than intended.

The issue is that tensor.min() and tensor.max() return scalar tensors, and comparing them with > 15 requires transferring the value to the host. This contradicts the optimization goal stated in the comment.

Consider either:

  1. Accepting the performance cost for correctness, or
  2. Finding a way to check the stabilization condition without device-to-host sync, or
  3. Using a different stabilization approach that doesn't require checking min/max values

Copilot uses AI. Check for mistakes.
Comment on lines +244 to +260
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The threshold value of 15 for detecting when stabilization is needed appears arbitrary and may not align with the documented overflow thresholds. According to the PR description and comments in the code:

  • float32 sigmoid overflows at x > 16.7
  • float64 sigmoid overflows at x > 36.7

Using min_val > 15 means stabilization kicks in when the minimum value is just above 15, which is reasonable for float32 but very conservative for float64.

However, there's a more fundamental issue: this logic doesn't account for the tensor's dtype. A float64 tensor with min=20 doesn't need stabilization (since 20 < 36.7), but the current code would apply it anyway. Consider checking tensor.dtype and using dtype-specific thresholds, or use a single conservative threshold with clear documentation explaining why.

Copilot uses AI. Check for mistakes.
if needs_stabilization:
tensor_stable = tensor - max_val
Comment on lines +242 to +262
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment states "Preserves ranking: Subtracts max value before sigmoid, maintaining relative ordering since sigmoid is monotonic" in the PR description. However, the code computes min_val and max_val but then only uses max_val for stabilization.

The min_val computation on lines 242 and 257 is only used for the threshold check (line 244/260), not for the actual stabilization. While this approach is correct (subtracting max preserves ranking because sigmoid is monotonic), the code could be slightly clearer by adding a comment explaining that we check min but subtract max.

For clarity, consider adding an inline comment on lines 245 and 262 explaining: "Subtract max (not min) to preserve ranking while preventing overflow, since sigmoid is monotonic."

Copilot uses AI. Check for mistakes.
return torch.where(condition, tensor_stable.sigmoid(), tensor)
Comment on lines +245 to +263
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CPU and device code paths handle tensor modification differently, which could lead to subtle bugs.

In the CPU path (line 245), the tensor is modified in-place: tensor = (tensor - max_val).sigmoid().

In the device path (line 262-263), a new tensor is created: tensor_stable = tensor - max_val and then returned via torch.where.

This inconsistency means the behavior differs between CPU and GPU. The CPU path will modify the input tensor (if needs_stabilization is true), while the GPU path preserves the original tensor. For consistency and to avoid potential side effects, both paths should either modify in-place or create new tensors.

Copilot uses AI. Check for mistakes.
return torch.where(condition, tensor.sigmoid(), tensor)
return torch.where(condition, torch.softmax(tensor, dim=1), tensor)
50 changes: 50 additions & 0 deletions tests/unittests/classification/test_auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,56 @@ def test_binary_auroc_threshold_arg(self, inputs, threshold_fn):
assert torch.allclose(ap1, ap2)


def test_binary_auroc_large_logits():
"""Test that large logits don't cause numerical overflow in sigmoid.

Regression test for issue where very large logits (>16.7 for float32) cause naive sigmoid to overflow to 1.0 for all
values, losing ranking information needed for AUROC.

"""
# Test case from the issue: all logits in range 97-100
preds = torch.tensor([
98.0950,
98.4612,
98.1145,
98.1506,
97.6037,
98.9425,
99.2644,
99.5014,
99.7280,
99.6595,
99.6931,
99.4667,
99.9623,
99.8949,
99.8768,
])
target = torch.tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

result = binary_auroc(preds, target, thresholds=None)

# Expected AUROC is 0.9286 (as computed by sklearn)
# The ranking is preserved: lowest value (97.6037) corresponds to label 0,
# all others are higher and correspond to label 1
expected_sklearn = sk_roc_auc_score(target.numpy(), preds.numpy())
Comment on lines +143 to +175
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test doesn't cover the edge case where values are right at the overflow boundary. Consider adding test cases with:

  • Values just below the threshold (e.g., min=14.5) to verify normal sigmoid is used
  • Values at exactly the threshold (e.g., min=15.0) to verify behavior at the boundary
  • Mixed positive and negative large logits (e.g., [-100, -99, ..., 99, 100]) to ensure stabilization doesn't break when min < 15 but max is very large

These edge cases would help ensure the threshold logic works correctly at boundary conditions.

Copilot uses AI. Check for mistakes.
assert torch.allclose(result, torch.tensor(expected_sklearn), atol=1e-4)
Comment on lines +143 to +176
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test function is defined at the module level but should follow the pattern of the test class above it. All other binary AUROC tests in this file are methods of the TestBinaryAUROC class (lines 47-141).

Defining this test as a standalone function means it won't be grouped with related tests and may not benefit from shared test infrastructure. Consider either:

  1. Adding it as a method to the TestBinaryAUROC class, or
  2. If it must be standalone, adding a comment explaining why it's separate from the class.

Copilot uses AI. Check for mistakes.

# Test with even larger logits
preds_huge = torch.tensor([200.0, 201.0, 202.0, 203.0])
target_huge = torch.tensor([0, 0, 1, 1])
result_huge = binary_auroc(preds_huge, target_huge, thresholds=None)
expected_huge = sk_roc_auc_score(target_huge.numpy(), preds_huge.numpy())
assert torch.allclose(result_huge, torch.tensor(expected_huge), atol=1e-4)

# Test with mixed large and normal logits
preds_mixed = torch.tensor([-5.0, 0.0, 5.0, 50.0, 100.0])
target_mixed = torch.tensor([0, 0, 1, 1, 1])
result_mixed = binary_auroc(preds_mixed, target_mixed, thresholds=None)
expected_mixed = sk_roc_auc_score(target_mixed.numpy(), preds_mixed.numpy())
assert torch.allclose(result_mixed, torch.tensor(expected_mixed), atol=1e-4)


def _reference_sklearn_auroc_multiclass(preds, target, average="macro", ignore_index=None):
preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1]))
target = target.numpy().flatten()
Expand Down
Loading