-
Notifications
You must be signed in to change notification settings - Fork 466
BinaryAUROC
support for Masked Labels
#3268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
VijayVignesh1
wants to merge
9
commits into
Lightning-AI:master
Choose a base branch
from
VijayVignesh1:feature/maskedbinaryauroc
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 5 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
4a5619b
Initial commit of masked binary auroc
VijayVignesh1 129e656
Bebugging testcases and linting
VijayVignesh1 7e8d12f
Adding mask type and shape check along with testcase
VijayVignesh1 9628383
Merge branch 'master' into feature/maskedbinaryauroc
VijayVignesh1 45670a2
Skipping matplotlib test and adding masked binary to wrapper class
VijayVignesh1 8f311a7
Update src/torchmetrics/classification/auroc.py
VijayVignesh1 303c68e
Pushing _masked_binary_cases down in _inputs.py
VijayVignesh1 7f04bc5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 0a92ce4
Making mask parameter mandatory and removing the checks
VijayVignesh1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
from collections.abc import Sequence | ||
from typing import Any, Optional, Union | ||
|
||
import torch | ||
from torch import Tensor | ||
from typing_extensions import Literal | ||
|
||
|
@@ -38,7 +39,7 @@ | |
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE | ||
|
||
if not _MATPLOTLIB_AVAILABLE: | ||
__doctest_skip__ = ["BinaryAUROC.plot", "MulticlassAUROC.plot", "MultilabelAUROC.plot"] | ||
__doctest_skip__ = ["BinaryAUROC.plot", "MaskedBinaryAUROC.plot", "MulticlassAUROC.plot", "MultilabelAUROC.plot"] | ||
|
||
|
||
class BinaryAUROC(BinaryPrecisionRecallCurve): | ||
|
@@ -167,6 +168,124 @@ def plot( # type: ignore[override] | |
return self._plot(val, ax) | ||
|
||
|
||
class MaskedBinaryAUROC(BinaryAUROC): | ||
r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for binary tasks with masking. | ||
|
||
The Masked AUROC score summarizes the ROC curve into an single number that describes the performance of a model for | ||
multiple thresholds at the same time with an output mask. | ||
Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing. | ||
|
||
As input to ``forward`` and ``update`` the metric accepts the following input: | ||
|
||
- ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, ...)`` containing probabilities or logits for | ||
each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply | ||
sigmoid per element. | ||
- ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` containing ground truth labels, and | ||
therefore only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the | ||
positive class. | ||
- ``mask`` (:class:`~torch.Tensor`): A boolean tensor of shape ``(N, ...)`` indicating which elements to include | ||
in the metric computation. Elements with a value of `True` will be included, while elements with a value of | ||
`False` will be ignored. | ||
|
||
As output to ``forward`` and ``compute`` the metric returns the following output: | ||
|
||
- ``b_auroc`` (:class:`~torch.Tensor`): A single scalar with the auroc score of unmasked elements. | ||
|
||
Additional dimension ``...`` will be flattened into the batch dimension. | ||
|
||
The implementation both supports calculating the metric in a non-binned but accurate version and a | ||
binned version that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will | ||
activate the non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the | ||
`thresholds` argument to either an integer, list or a 1d tensor will use a binned version that uses memory of | ||
size :math:`\mathcal{O}(n_{thresholds})` (constant memory). | ||
|
||
Args: | ||
max_fpr: If not ``None``, calculates standardized partial AUC over the range ``[0, max_fpr]``. | ||
thresholds: | ||
Can be one of: | ||
|
||
- If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from | ||
all the data. Most accurate but also most memory consuming approach. | ||
- If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from | ||
0 to 1 as bins for the calculation. | ||
- If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation | ||
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as | ||
bins for the calculation. | ||
|
||
validate_args: bool indicating if input arguments and tensors should be validated for correctness. | ||
Set to ``False`` for faster computations. | ||
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. | ||
|
||
Example: | ||
>>> import torch | ||
>>> from torch import tensor | ||
>>> from torchmetrics.classification import MaskedBinaryAUROC | ||
>>> preds = tensor([0, 0.5, 0.7, 0.8]) | ||
>>> target = tensor([0, 1, 1, 0]) | ||
>>> mask = tensor([1, 1, 0, 1], dtype=torch.bool) | ||
>>> metric = MaskedBinaryAUROC(thresholds=None) | ||
>>> metric(preds, target, mask) | ||
tensor(0.5000) | ||
>>> b_auroc = MaskedBinaryAUROC(thresholds=5) | ||
>>> b_auroc(preds, target, mask) | ||
tensor(0.5000) | ||
|
||
""" | ||
|
||
def update(self, preds: Tensor, target: Tensor, mask: Tensor = None) -> None: | ||
VijayVignesh1 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
"""Update the state with the new data.""" | ||
if mask is not None: | ||
|
||
if mask.dtype != torch.bool: | ||
raise ValueError(f"Mask must be boolean, got {mask.dtype}") | ||
if mask.shape != preds.shape: | ||
raise ValueError(f"Mask shape {mask.shape} must match preds/target shape {preds.shape}") | ||
preds = preds[mask] | ||
target = target[mask] | ||
super().update(preds, target) # call the original BinaryAUROC update | ||
|
||
def plot( # type: ignore[override] | ||
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None | ||
) -> _PLOT_OUT_TYPE: | ||
"""Plot a single or multiple values from the metric. | ||
|
||
Args: | ||
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. | ||
If no value is provided, will automatically call `metric.compute` and plot that result. | ||
ax: An matplotlib axis object. If provided will add plot to that axis | ||
|
||
Returns: | ||
Figure and Axes object | ||
|
||
Raises: | ||
ModuleNotFoundError: | ||
If `matplotlib` is not installed | ||
|
||
.. plot:: | ||
:scale: 75 | ||
|
||
>>> # Example plotting a single | ||
>>> import torch | ||
>>> from torchmetrics.classification import MaskedBinaryAUROC | ||
>>> metric = MaskedBinaryAUROC() | ||
>>> metric.update(torch.rand(20,), torch.randint(2, (20,)), mask=torch.rand(20,) > 0.5) | ||
>>> fig_, ax_ = metric.plot() | ||
|
||
.. plot:: | ||
:scale: 75 | ||
|
||
>>> # Example plotting multiple values | ||
>>> import torch | ||
>>> from torchmetrics.classification import MaskedBinaryAUROC | ||
>>> metric = MaskedBinaryAUROC() | ||
>>> values = [ ] | ||
>>> for _ in range(10): | ||
... values.append(metric(torch.rand(20,), torch.randint(2, (20,)), mask=torch.rand(20,) > 0.5)) | ||
>>> fig_, ax_ = metric.plot(values) | ||
|
||
""" | ||
return self._plot(val, ax) | ||
|
||
|
||
class MulticlassAUROC(MulticlassPrecisionRecallCurve): | ||
r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multiclass tasks. | ||
|
||
|
@@ -482,10 +601,11 @@ class AUROC(_ClassificationTaskWrapper): | |
corresponds to random guessing. | ||
|
||
This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the | ||
``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of | ||
:class:`~torchmetrics.classification.BinaryAUROC`, :class:`~torchmetrics.classification.MulticlassAUROC` and | ||
:class:`~torchmetrics.classification.MultilabelAUROC` for the specific details of each argument influence and | ||
examples. | ||
``task`` argument to either ``'binary'``, ``'maskedbinary'``, ``'multiclass'`` or ``'multilabel'``. | ||
See the documentation of | ||
:class:`~torchmetrics.classification.BinaryAUROC`, :class:`~torchmetrics.classification.MaskedBinaryAUROC`, | ||
:class:`~torchmetrics.classification.MulticlassAUROC` and :class:`~torchmetrics.classification.MultilabelAUROC` | ||
for the specific details of each argument influence and examples. | ||
|
||
Legacy Example: | ||
>>> from torch import tensor | ||
|
@@ -509,7 +629,7 @@ class AUROC(_ClassificationTaskWrapper): | |
|
||
def __new__( # type: ignore[misc] | ||
cls: type["AUROC"], | ||
task: Literal["binary", "multiclass", "multilabel"], | ||
task: Literal["binary", "maskedbinary", "multiclass", "multilabel"], | ||
thresholds: Optional[Union[int, list[float], Tensor]] = None, | ||
num_classes: Optional[int] = None, | ||
num_labels: Optional[int] = None, | ||
|
@@ -524,6 +644,8 @@ def __new__( # type: ignore[misc] | |
kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) | ||
if task == ClassificationTask.BINARY: | ||
return BinaryAUROC(max_fpr, **kwargs) | ||
if task == ClassificationTask.MASKEDBINARY: | ||
return MaskedBinaryAUROC(max_fpr, **kwargs) | ||
if task == ClassificationTask.MULTICLASS: | ||
if not isinstance(num_classes, int): | ||
raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need a functional version of this as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need the functional version since I'm basically reusing the binaryauroc function. The only change will be in the way we update the preds and target. The rest remains the same.