Skip to content

v2.9.0 #747

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/imgs/smooth_ap_approx_equation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/imgs/smooth_ap_loss_equation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/imgs/smooth_ap_sigmoid_equation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
31 changes: 31 additions & 0 deletions docs/losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,37 @@ losses.SignalToNoiseRatioContrastiveLoss(pos_margin=0, neg_margin=1, **kwargs):
* **pos_loss**: The loss per positive pair in the batch. Reduction type is ```"pos_pair"```.
* **neg_loss**: The loss per negative pair in the batch. Reduction type is ```"neg_pair"```.

## SmoothAPLoss
[Smooth-AP: Smoothing the Path Towards Large-Scale Image Retrieval](https://arxiv.org/abs/2007.12163){target=_blank}

```python
losses.SmoothAPLoss(
margin=0.01,
**kwargs
)
```

**Equations**:

![smooth_ap_loss_equation1](imgs/smooth_ap_sigmoid_equation.png){: style="height:100px"}
![smooth_ap_loss_equation2](imgs/smooth_ap_approx_equation.png){: style="height:100px"}
![smooth_ap_loss_equation3](imgs/smooth_ap_loss_equation.png){: style="height:100px"}


**Parameters**:

* **temperature**: The desired temperature for scaling the sigmoid function. This is denoted by $\tau$ in the first and second equations.


**Other info**:

* The loss requires the same number of number of elements for each class in the batch labels. An example of valid labels is: `[1, 1, 2, 2, 3, 3]`. An example of invalid labels is `[1, 1, 1, 2, 2, 3, 3]` because there are `3` elements with the value `1`. This can be achieved by using `samplers.MPerClassSampler` and setting the `batch_size` and `m` hyperparameters.

**Default distance**:

- [```CosineSimilarity()```](distances.md#cosinesimilarity)
- This is the only compatible distance.

## SoftTripleLoss
[SoftTriple Loss: Deep Metric Learning Without Triplet Sampling](http://openaccess.thecvf.com/content_ICCV_2019/papers/Qian_SoftTriple_Loss_Deep_Metric_Learning_Without_Triplet_Sampling_ICCV_2019_paper.pdf){target=_blank}
```python
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def __init__(self, **kwargs):
assert self.is_inverted

def compute_mat(self, query_emb, ref_emb):
return torch.matmul(query_emb, ref_emb.t())
return torch.matmul(query_emb, ref_emb.transpose(-1, -2))

def pairwise_distance(self, query_emb, ref_emb):
return torch.sum(query_emb * ref_emb, dim=1)
1 change: 1 addition & 0 deletions src/pytorch_metric_learning/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .ranked_list_loss import RankedListLoss
from .self_supervised_loss import SelfSupervisedLoss
from .signal_to_noise_ratio_losses import SignalToNoiseRatioContrastiveLoss
from .smooth_ap import SmoothAPLoss
from .soft_triple_loss import SoftTripleLoss
from .sphereface_loss import SphereFaceLoss
from .subcenter_arcface_loss import SubCenterArcFaceLoss
Expand Down
5 changes: 5 additions & 0 deletions src/pytorch_metric_learning/losses/generic_pair_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def mat_based_loss(self, mat, indices_tuple):
pos_mask, neg_mask = torch.zeros_like(mat), torch.zeros_like(mat)
pos_mask[a1, p] = 1
neg_mask[a2, n] = 1
self._assert_either_pos_or_neg(pos_mask, neg_mask)
return self._compute_loss(mat, pos_mask, neg_mask)

def pair_based_loss(self, mat, indices_tuple):
Expand All @@ -38,3 +39,7 @@ def pair_based_loss(self, mat, indices_tuple):
if len(a2) > 0:
neg_pair = mat[a2, n]
return self._compute_loss(pos_pair, neg_pair, indices_tuple)

@staticmethod
def _assert_either_pos_or_neg(pos_mask, neg_mask):
assert not torch.any((pos_mask != 0) & (neg_mask != 0)), "Each pair should be either be positive or negative"
103 changes: 103 additions & 0 deletions src/pytorch_metric_learning/losses/smooth_ap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch
import torch.nn.functional as F

from ..distances import CosineSimilarity
from ..utils import common_functions as c_f
from ..utils import loss_and_miner_utils as lmu
from .base_metric_loss_function import BaseMetricLossFunction


class SmoothAPLoss(BaseMetricLossFunction):
"""
Implementation of the SmoothAP loss: https://arxiv.org/abs/2007.12163
"""

def __init__(self, temperature=0.01, **kwargs):
super().__init__(**kwargs)
c_f.assert_distance_type(self, CosineSimilarity)
self.temperature = temperature

def get_default_distance(self):
return CosineSimilarity()

# Implementation is based on the original repository:
# https://github.com/Andrew-Brown1/Smooth_AP/blob/master/src/Smooth_AP_loss.py#L87
def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
# The loss expects labels such that there is the same number of elements for each class
# The number of classes is not important, nor their order, but the number of elements must be the same, eg.
#
# The following label is valid:
# [ A,A,A, B,B,B, C,C,C ]
# The following label is NOT valid:
# [ B,B,B A,A,A,A, C,C,C ]
#
c_f.labels_required(labels)
c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels)

counts = torch.bincount(labels)
nonzero_indices = torch.nonzero(counts, as_tuple=True)[0]
nonzero_counts = counts[nonzero_indices]
if nonzero_counts.unique().size(0) != 1:
raise ValueError(
"All classes must have the same number of elements in the labels.\n"
"The given labels have the following number of elements: {}.\n"
"You can achieve this using the samplers.MPerClassSampler class and setting the batch_size and m.".format(
nonzero_counts.cpu().tolist()
)
)

batch_size = embeddings.size(0)
num_classes_batch = batch_size // torch.unique(labels).size(0)

mask = 1.0 - torch.eye(batch_size)
mask = mask.unsqueeze(dim=0).repeat(batch_size, 1, 1)

sims = self.distance(embeddings)

sims_repeat = sims.unsqueeze(dim=1).repeat(1, batch_size, 1)
sims_diff = sims_repeat - sims_repeat.permute(0, 2, 1)
sims_sigm = F.sigmoid(sims_diff / self.temperature) * mask.to(sims_diff.device)
sims_ranks = torch.sum(sims_sigm, dim=-1) + 1

xs = embeddings.view(
num_classes_batch, batch_size // num_classes_batch, embeddings.size(-1)
)
pos_mask = 1.0 - torch.eye(batch_size // num_classes_batch)
pos_mask = (
pos_mask.unsqueeze(dim=0)
.unsqueeze(dim=0)
.repeat(num_classes_batch, batch_size // num_classes_batch, 1, 1)
)

# Circumvent the shape check in forward method
xs_norm = self.distance.maybe_normalize(xs, dim=-1)
sims_pos = self.distance.compute_mat(xs_norm, xs_norm)

sims_pos_repeat = sims_pos.unsqueeze(dim=2).repeat(
1, 1, batch_size // num_classes_batch, 1
)
sims_pos_diff = sims_pos_repeat - sims_pos_repeat.permute(0, 1, 3, 2)

sims_pos_sigm = F.sigmoid(sims_pos_diff / self.temperature) * pos_mask.to(
sims_diff.device
)
sims_pos_ranks = torch.sum(sims_pos_sigm, dim=-1) + 1

g = batch_size // num_classes_batch
ap = torch.zeros(batch_size).to(embeddings.device)
for i in range(num_classes_batch):
for j in range(g):
pos_rank = sims_pos_ranks[i, j]
all_rank = sims_ranks[i * g + j, i * g : (i + 1) * g]
ap[i * g + j] = torch.sum(pos_rank / all_rank) / g

miner_weights = lmu.convert_to_weights(indices_tuple, labels, dtype=ap.dtype)
loss = (1 - ap) * miner_weights

return {
"ap_loss": {
"losses": loss,
"indices": c_f.torch_arange_from_size(loss),
"reduction_type": "element",
}
}
8 changes: 6 additions & 2 deletions src/pytorch_metric_learning/losses/subcenter_arcface_loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from copy import deepcopy

import numpy as np
import torch
Expand All @@ -13,9 +14,12 @@ class SubCenterArcFaceLoss(ArcFaceLoss):
"""

def __init__(self, *args, margin=28.6, scale=64, sub_centers=3, **kwargs):
num_classes, embedding_size = kwargs["num_classes"], kwargs["embedding_size"]
num_classes = deepcopy(kwargs["num_classes"])
embedding_size = deepcopy(kwargs["embedding_size"])
del kwargs["num_classes"]
del kwargs["embedding_size"]
super().__init__(
num_classes * sub_centers, embedding_size, margin=margin, scale=scale
num_classes=num_classes * sub_centers, embedding_size=embedding_size, margin=margin, scale=scale, **kwargs
)
self.sub_centers = sub_centers
self.num_classes = num_classes
Expand Down
32 changes: 0 additions & 32 deletions tests/losses/test_cross_batch_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ def test_loss(self):
batch_size = 32
for inner_loss in [ContrastiveLoss(), MultiSimilarityLoss()]:
inner_miner = MultiSimilarityMiner(0.3)
outer_miner = MultiSimilarityMiner(0.2)
self.loss = CrossBatchMemory(
loss=inner_loss,
embedding_size=self.embedding_size,
Expand Down Expand Up @@ -267,10 +266,6 @@ def test_loss(self):
labels = torch.randint(0, num_labels, (batch_size,)).to(TEST_DEVICE)
loss = self.loss(embeddings, labels)
loss_with_miner = self.loss_with_miner(embeddings, labels)
oa1, op, oa2, on = outer_miner(embeddings, labels)
loss_with_miner_and_input_indices = self.loss_with_miner2(
embeddings, labels, (oa1, op, oa2, on)
)
all_embeddings = torch.cat([all_embeddings, embeddings])
all_labels = torch.cat([all_labels, labels])

Expand Down Expand Up @@ -308,33 +303,6 @@ def test_loss(self):
torch.isclose(loss_with_miner, correct_loss_with_miner)
)

# loss with inner and outer miner
indices_tuple = inner_miner(
embeddings, labels, all_embeddings, all_labels
)
a1, p, a2, n = lmu.remove_self_comparisons(
indices_tuple,
self.loss_with_miner2.curr_batch_idx,
self.loss_with_miner2.memory_size,
)
a1 = torch.cat([oa1, a1])
p = torch.cat([op, p])
a2 = torch.cat([oa2, a2])
n = torch.cat([on, n])
correct_loss_with_miner_and_input_indice = inner_loss(
embeddings,
labels,
(a1, p, a2, n),
all_embeddings,
all_labels,
)
self.assertTrue(
torch.isclose(
loss_with_miner_and_input_indices,
correct_loss_with_miner_and_input_indice,
)
)

def test_queue(self):
for test_enqueue_mask in [False, True]:
for dtype in TEST_DTYPES:
Expand Down
Loading