-
Notifications
You must be signed in to change notification settings - Fork 477
Fix sigmoid overflow for large logits causing incorrect AUROC results #3283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
93eeddb
bd8a86a
02195a6
41f6d23
944f63c
d3cb84e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -234,13 +234,32 @@ def normalize_logits_if_needed(tensor: Tensor, normalization: Optional[Literal[" | |||||||||
| # decrease sigmoid on cpu . | ||||||||||
| if tensor.device == torch.device("cpu"): | ||||||||||
| if not torch.all((tensor >= 0) * (tensor <= 1)): | ||||||||||
| tensor = tensor.sigmoid() if normalization == "sigmoid" else torch.softmax(tensor, dim=1) | ||||||||||
| if normalization == "sigmoid": | ||||||||||
| # Apply numerically stable sigmoid by subtracting max to prevent overflow | ||||||||||
| # For large positive logits (>16.7 for float32, >36.7 for float64), sigmoid(x) overflows to 1.0 | ||||||||||
| # Only apply stabilization when min value is also large (indicating all values will overflow) | ||||||||||
| # This avoids the issue where subtracting max creates artificial ties for widely spread values | ||||||||||
| min_val = tensor.min() | ||||||||||
| max_val = tensor.max() | ||||||||||
| if min_val > 15: # All values are large enough to potentially overflow | ||||||||||
|
Comment on lines
+243
to
+244
|
||||||||||
| max_val = tensor.max() | |
| if min_val > 15: # All values are large enough to potentially overflow | |
| if min_val > 15: # All values are large enough to potentially overflow | |
| max_val = tensor.max() |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The device path logic has a critical correctness issue. The function computes min_val and max_val on lines 257-258, which will trigger device-to-host synchronization (defeating the purpose of the "decrease device-host sync on device" comment on line 252).
The original code on line 253 computes condition which is already a scalar boolean that causes device-to-host sync. The new code adds two more synchronizations (min and max), making the device path potentially much slower than intended.
The issue is that tensor.min() and tensor.max() return scalar tensors, and comparing them with > 15 requires transferring the value to the host. This contradicts the optimization goal stated in the comment.
Consider either:
- Accepting the performance cost for correctness, or
- Finding a way to check the stabilization condition without device-to-host sync, or
- Using a different stabilization approach that doesn't require checking min/max values
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The threshold value of 15 for detecting when stabilization is needed appears arbitrary and may not align with the documented overflow thresholds. According to the PR description and comments in the code:
- float32 sigmoid overflows at x > 16.7
- float64 sigmoid overflows at x > 36.7
Using min_val > 15 means stabilization kicks in when the minimum value is just above 15, which is reasonable for float32 but very conservative for float64.
However, there's a more fundamental issue: this logic doesn't account for the tensor's dtype. A float64 tensor with min=20 doesn't need stabilization (since 20 < 36.7), but the current code would apply it anyway. Consider checking tensor.dtype and using dtype-specific thresholds, or use a single conservative threshold with clear documentation explaining why.
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment states "Preserves ranking: Subtracts max value before sigmoid, maintaining relative ordering since sigmoid is monotonic" in the PR description. However, the code computes min_val and max_val but then only uses max_val for stabilization.
The min_val computation on lines 242 and 257 is only used for the threshold check (line 244/260), not for the actual stabilization. While this approach is correct (subtracting max preserves ranking because sigmoid is monotonic), the code could be slightly clearer by adding a comment explaining that we check min but subtract max.
For clarity, consider adding an inline comment on lines 245 and 262 explaining: "Subtract max (not min) to preserve ranking while preventing overflow, since sigmoid is monotonic."
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The CPU and device code paths handle tensor modification differently, which could lead to subtle bugs.
In the CPU path (line 245), the tensor is modified in-place: tensor = (tensor - max_val).sigmoid().
In the device path (line 262-263), a new tensor is created: tensor_stable = tensor - max_val and then returned via torch.where.
This inconsistency means the behavior differs between CPU and GPU. The CPU path will modify the input tensor (if needs_stabilization is true), while the GPU path preserves the original tensor. For consistency and to avoid potential side effects, both paths should either modify in-place or create new tensors.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -140,6 +140,56 @@ def test_binary_auroc_threshold_arg(self, inputs, threshold_fn): | |
| assert torch.allclose(ap1, ap2) | ||
|
|
||
|
|
||
| def test_binary_auroc_large_logits(): | ||
| """Test that large logits don't cause numerical overflow in sigmoid. | ||
|
|
||
| Regression test for issue where very large logits (>16.7 for float32) cause naive sigmoid to overflow to 1.0 for all | ||
| values, losing ranking information needed for AUROC. | ||
|
|
||
| """ | ||
| # Test case from the issue: all logits in range 97-100 | ||
| preds = torch.tensor([ | ||
| 98.0950, | ||
| 98.4612, | ||
| 98.1145, | ||
| 98.1506, | ||
| 97.6037, | ||
| 98.9425, | ||
| 99.2644, | ||
| 99.5014, | ||
| 99.7280, | ||
| 99.6595, | ||
| 99.6931, | ||
| 99.4667, | ||
| 99.9623, | ||
| 99.8949, | ||
| 99.8768, | ||
| ]) | ||
| target = torch.tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) | ||
|
|
||
| result = binary_auroc(preds, target, thresholds=None) | ||
|
|
||
| # Expected AUROC is 0.9286 (as computed by sklearn) | ||
| # The ranking is preserved: lowest value (97.6037) corresponds to label 0, | ||
| # all others are higher and correspond to label 1 | ||
| expected_sklearn = sk_roc_auc_score(target.numpy(), preds.numpy()) | ||
|
Comment on lines
+143
to
+175
|
||
| assert torch.allclose(result, torch.tensor(expected_sklearn), atol=1e-4) | ||
|
Comment on lines
+143
to
+176
|
||
|
|
||
| # Test with even larger logits | ||
| preds_huge = torch.tensor([200.0, 201.0, 202.0, 203.0]) | ||
| target_huge = torch.tensor([0, 0, 1, 1]) | ||
| result_huge = binary_auroc(preds_huge, target_huge, thresholds=None) | ||
| expected_huge = sk_roc_auc_score(target_huge.numpy(), preds_huge.numpy()) | ||
| assert torch.allclose(result_huge, torch.tensor(expected_huge), atol=1e-4) | ||
|
|
||
| # Test with mixed large and normal logits | ||
| preds_mixed = torch.tensor([-5.0, 0.0, 5.0, 50.0, 100.0]) | ||
| target_mixed = torch.tensor([0, 0, 1, 1, 1]) | ||
| result_mixed = binary_auroc(preds_mixed, target_mixed, thresholds=None) | ||
| expected_mixed = sk_roc_auc_score(target_mixed.numpy(), preds_mixed.numpy()) | ||
| assert torch.allclose(result_mixed, torch.tensor(expected_mixed), atol=1e-4) | ||
|
|
||
|
|
||
| def _reference_sklearn_auroc_multiclass(preds, target, average="macro", ignore_index=None): | ||
| preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) | ||
| target = target.numpy().flatten() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code comment says "Only apply stabilization when min value is also large (indicating all values will overflow)" but this should also note the limitation: mixed-range logits with some large values won't be stabilized.
Consider expanding the comment to be clearer about the trade-off:
"Only apply stabilization when min value is also large (indicating all values will overflow).
Note: This means mixed-range logits (e.g., [0, 50, 100]) won't be stabilized, and large values
may still overflow to 1.0. However, stabilizing mixed ranges would create artificial ties for
smaller values, which is worse for ranking-based metrics."