Skip to content

Commit 53f85d5

Browse files
committed
update some notes about auc losses
1 parent 2fca2e6 commit 53f85d5

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

libauc/losses/auc.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@ class AUCMLoss(torch.nn.Module):
5353
5454
args:
5555
margin (float): margin for squared-hinge surrogate loss (default: ``1.0``).
56-
imratio (float, optional): the ratio of the number of positive samples to the number of total samples in the training dataset.
57-
If this value is not given, the mini-batch statistics will be used instead.
56+
imratio (float, optional): the ratio of the number of positive samples to the number of total samples in the training dataset, i.e., :math:`p` in the above formulation.
57+
If this value is not given, it will be automatically calculated with mini-batch samples.
58+
This value is ignored when ``version`` is set to ``'v2'``.
5859
version (str, optional): whether to include prior :math:`p` in the objective function (default: ``'v1'``).
5960
6061
@@ -65,13 +66,12 @@ class AUCMLoss(torch.nn.Module):
6566
>>> loss = loss_fn(preds, target)
6667
>>> loss.backward()
6768
68-
.. note::
69-
To use ``v2`` of AUCMLoss, plesae set ``version='v2'``. Otherwise, the default version is ``v1``. The ``v2`` version requires the use of :obj:`~libauc.sampler.DualSampler`.
7069
7170
.. note::
7271
Practial Tips:
7372
74-
- ``epoch_decay`` is a regularization parameter similar to `weight_decay` that can be tuned in the same range.
73+
- It is recommended to use ``v2`` of AUCMLoss by setting ``version='v2'`` to get better performance. The ``v2`` version requires the use of :obj:`~libauc.sampler.DualSampler`.
74+
- ``epoch_decay`` is a regularization parameter similar to `weight_decay` that can be tuned in the same range.
7575
- For complex tasks, it is recommended to use regular loss to pretrain the model, and then switch to AUCMLoss for finetuning with a smaller learning rate.
7676
7777
Reference:
@@ -140,7 +140,9 @@ class CompositionalAUCLoss(torch.nn.Module):
140140
141141
args:
142142
margin (float): margin for squared-hinge surrogate loss (default: ``1.0``).
143-
imratio (float, optional): the ratio of the number of positive samples to the number of total samples in the training dataset. If this value is not given, the mini-batch statistics will be used instead.
143+
imratio (float, optional): the ratio of the number of positive samples to the number of total samples in the training dataset, i.e., :math:`p` in the above formulation.
144+
If this value is not given, it will be automatically calculated with mini-batch samples.
145+
This value is ignored when ``version`` is set to ``'v2'``.
144146
k (int, optional): number of steps for inner updates. For example, when k is set to 2, the optimizer will alternately execute two steps optimizing :obj:`~libauc.losses.losses.CrossEntropyLoss` followed by a single step optimizing :obj:`~libauc.losses.auc.AUCMLoss` during training (default: ``1``).
145147
version (str, optional): whether to include prior :math:`p` in the objective function (default: ``'v1'``).
146148
@@ -152,7 +154,8 @@ class CompositionalAUCLoss(torch.nn.Module):
152154
>>> loss.backward()
153155
154156
.. note::
155-
As CompositionalAUCLoss is built on AUCMLoss, there are also two versions of CompositionalAUCLoss. To use ``v2`` version, plesae set ``version='v2'``. Otherwise, the default version is ``v1``.
157+
As CompositionalAUCLoss is built on AUCMLoss, there are also two versions of CompositionalAUCLoss.
158+
It is recommended to use ``v2`` version by setting ``version='v2'`` to get better performance.
156159
157160
.. note::
158161
@@ -173,8 +176,8 @@ def __init__(self,
173176
version='v1',
174177
imratio=None,
175178
backend='ce',
176-
l_avg=None, # todo: loss placeholder
177-
l_imb=None, # todo: loss placeholder
179+
# l_avg=None, # todo: loss placeholder
180+
# l_imb=None, # todo: loss placeholder
178181
device=None):
179182
super(CompositionalAUCLoss, self).__init__()
180183
if not device:
@@ -232,7 +235,7 @@ def forward(self, y_pred, y_true, k=None, auto=True, **kwargs):
232235

233236
class AveragePrecisionLoss(torch.nn.Module):
234237
r"""
235-
Average Precision loss with squared-hinge surrogate loss for optimizing AUPRC. The objective is defined as
238+
Average Precision loss with chosen surrogate loss for optimizing AUPRC. The objective is defined as
236239
237240
.. math::
238241
@@ -794,14 +797,14 @@ class meanAveragePrecisionLoss(torch.nn.Module):
794797
num_labels (int): number of unique labels(tasks) in the dataset.
795798
margin (float, optional): margin for the squared-hinge surrogate loss (default: ``1.0``).
796799
gamma (float, optional): parameter for the moving average estimator (default: ``0.9``).
797-
top_k (int, optional): If given, only top k items will be considered for optimizing mAP@k.
800+
top_k (int, optional): mAP@k optimization is activated if top_k > 0; top_k=-1 represents mAP (default: ``-1``).
798801
surr_loss (str, optional): type of surrogate loss to use. Choices are 'squared_hinge', 'squared',
799802
'logistic', 'barrier_hinge' (default: ``'squared_hinge'``).
800803
801804
This class is also aliased as :obj:`~libauc.losses.auc.mAPLoss`.
802805
803806
Example:
804-
>>> loss_fn = meanAveragePrecisionLoss(data_len=data_length, margin=1.0, num_labels=10, gamma=0.9)
807+
>>> loss_fn = meanAveragePrecisionLoss(data_len=data_length, margin=1.0, num_labels=10, gamma=0.9, top_k=-1)
805808
>>> y_pred = torch.randn((32,10), requires_grad=True)
806809
>>> y_true = torch.empty((32,10), dtype=torch.long).random_(2)
807810
>>> index = torch.randint(32, (32,), requires_grad=False)

0 commit comments

Comments
 (0)