|
4 | 4 | from torchmetrics.functional import confusion_matrix
|
5 | 5 | import warnings
|
6 | 6 |
|
| 7 | + |
7 | 8 | class BalancedAccuracy(Metric):
|
8 |
| - def __init__( |
9 |
| - self, |
10 |
| - num_classes: int, |
11 |
| - task: str, |
12 |
| - adjusted: |
13 |
| - bool = False |
14 |
| - ): |
| 9 | + def __init__(self, num_classes: int, task: str, adjusted: bool = False): |
15 | 10 | """
|
16 |
| - Compute the balanced accuracy. |
| 11 | + Compute the balanced accuracy. |
| 12 | +
|
| 13 | + The balanced accuracy in binary, multiclass, and multilabel classification problems |
| 14 | + deals with imbalanced datasets. It is defined as the average of recall obtained on each class. |
| 15 | +
|
| 16 | + Parameters |
| 17 | + ---------- |
| 18 | + num_classes : int |
| 19 | + The number of classes in the target data. |
17 | 20 |
|
18 |
| - The balanced accuracy in binary, multiclass, and multilabel classification problems |
19 |
| - deals with imbalanced datasets. It is defined as the average of recall obtained on each class. |
| 21 | + task : str |
| 22 | + The type of classification task, should be one of 'binary' or 'multiclass' |
20 | 23 |
|
21 |
| - Parameters |
22 |
| - ---------- |
23 |
| - num_classes : int |
24 |
| - The number of classes in the target data. |
25 |
| - |
26 |
| - task : str |
27 |
| - The type of classification task, should be one of 'binary' or 'multiclass' |
28 |
| - |
29 |
| - adjusted : bool, optional (default=False) |
30 |
| - When true, the result is adjusted for chance, so that random performance would score 0, |
31 |
| - while keeping perfect performance at a score of 1. |
| 24 | + adjusted : bool, optional (default=False) |
| 25 | + When true, the result is adjusted for chance, so that random performance would score 0, |
| 26 | + while keeping perfect performance at a score of 1. |
32 | 27 |
|
33 |
| - Attributes |
34 |
| - ---------- |
35 |
| - confmat : torch.Tensor |
36 |
| - Confusion matrix to keep track of true positives, false positives, true negatives, and false negatives. |
| 28 | + Attributes |
| 29 | + ---------- |
| 30 | + confmat : torch.Tensor |
| 31 | + Confusion matrix to keep track of true positives, false positives, true negatives, and false negatives. |
37 | 32 |
|
38 |
| - Examples |
39 |
| - -------- |
40 |
| - >>> y_true = torch.tensor([0, 1, 0, 0, 1, 0]) |
41 |
| - >>> y_pred = torch.tensor([0, 1, 0, 0, 0, 1]) |
42 |
| - >>> metric = BalancedAccuracy(num_classes=2, task='binary') |
43 |
| - >>> metric(y_pred, y_true) |
44 |
| - 0.625 |
| 33 | + Examples |
| 34 | + -------- |
| 35 | + >>> y_true = torch.tensor([0, 1, 0, 0, 1, 0]) |
| 36 | + >>> y_pred = torch.tensor([0, 1, 0, 0, 0, 1]) |
| 37 | + >>> metric = BalancedAccuracy(num_classes=2, task='binary') |
| 38 | + >>> metric(y_pred, y_true) |
| 39 | + 0.625 |
45 | 40 | """
|
46 | 41 | super().__init__()
|
47 | 42 | self.num_classes = num_classes
|
48 | 43 | self.adjusted = adjusted
|
49 | 44 | self.task = task
|
50 |
| - self.add_state("confmat", default=torch.zeros((num_classes, num_classes)), dist_reduce_fx="sum") |
| 45 | + self.add_state( |
| 46 | + "confmat", |
| 47 | + default=torch.zeros((num_classes, num_classes)), |
| 48 | + dist_reduce_fx="sum", |
| 49 | + ) |
51 | 50 |
|
52 | 51 | def update(self, preds: torch.Tensor, target: torch.Tensor):
|
53 |
| - self.confmat += confusion_matrix(preds, target, num_classes=self.num_classes, task=self.task) |
| 52 | + self.confmat += confusion_matrix( |
| 53 | + preds, target, num_classes=self.num_classes, task=self.task |
| 54 | + ) |
54 | 55 |
|
55 | 56 | def compute(self):
|
56 | 57 | with torch.no_grad():
|
|
0 commit comments