Skip to content

Commit b635ecf

Browse files
authored
Merge pull request #47 from BloodAxe/develop
Release 0.4.0
2 parents 1054a4f + f932d16 commit b635ecf

31 files changed

+972
-893
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ from pytorch_toolbelt import losses as L
139139

140140
# Creates a loss function that is a weighted sum of focal loss
141141
# and lovasz loss with weigths 1.0 and 0.5 accordingly.
142-
loss = L.JointLoss(L.FocalLoss(), 1.0, L.LovaszLoss(), 0.5)
142+
loss = L.JointLoss(L.FocalLoss(), L.LovaszLoss(), 1.0, 0.5)
143143
```
144144

145145

pytorch_toolbelt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from __future__ import absolute_import
22

3-
__version__ = "0.3.2"
3+
__version__ = "0.4.0"

pytorch_toolbelt/inference/ensembling.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,39 +5,54 @@
55

66

77
class ApplySoftmaxTo(nn.Module):
8-
def __init__(self, model, output_key: Union[str, List[str]] = "logits", dim=1):
8+
def __init__(self, model: nn.Module, output_key: Union[str, List[str]] = "logits", dim=1, temperature=1):
9+
"""
10+
Apply softmax activation on given output(s) of the model
11+
:param model: Model to wrap
12+
:param output_key: string or list of strings, indicating to what outputs softmax activation should be applied.
13+
:param dim: Tensor dimension for softmax activation
14+
:param temperature: Temperature scaling coefficient. Values > 1 will make logits sharper.
15+
"""
916
super().__init__()
1017
output_key = output_key if isinstance(output_key, (list, tuple)) else [output_key]
1118
# By converting to set, we prevent double-activation by passing output_key=["logits", "logits"]
1219
self.output_keys = set(output_key)
1320
self.model = model
1421
self.dim = dim
22+
self.temperature = temperature
1523

16-
def forward(self, input):
17-
output = self.model(input)
24+
def forward(self, *input, **kwargs):
25+
output = self.model(*input, **kwargs)
1826
for key in self.output_keys:
19-
output[key] = output[key].softmax(dim=1)
27+
output[key] = output[key].mul(self.temperature).softmax(dim=1)
2028
return output
2129

2230

2331
class ApplySigmoidTo(nn.Module):
24-
def __init__(self, model, output_key: Union[str, List[str]] = "logits"):
32+
def __init__(self, model: nn.Module, output_key: Union[str, List[str]] = "logits", temperature=1):
33+
"""
34+
Apply sigmoid activation on given output(s) of the model
35+
:param model: Model to wrap
36+
:param output_key: string or list of strings, indicating to what outputs sigmoid activation should be applied.
37+
:param temperature: Temperature scaling coefficient. Values > 1 will make logits sharper.
38+
"""
2539
super().__init__()
2640
output_key = output_key if isinstance(output_key, (list, tuple)) else [output_key]
2741
# By converting to set, we prevent double-activation by passing output_key=["logits", "logits"]
2842
self.output_keys = set(output_key)
2943
self.model = model
44+
self.temperature = temperature
3045

31-
def forward(self, input): # skipcq: PYL-W0221
32-
output = self.model(input)
46+
def forward(self, *input, **kwargs): # skipcq: PYL-W0221
47+
output = self.model(*input, **kwargs)
3348
for key in self.output_keys:
34-
output[key] = output[key].sigmoid()
49+
output[key] = output[key].mul(self.temperature).sigmoid()
3550
return output
3651

3752

3853
class Ensembler(nn.Module):
3954
"""
40-
Computes sum of outputs for several models with arithmetic averaging (optional).
55+
Compute sum (or average) of outputs of several models.
4156
"""
4257

4358
def __init__(self, models: List[nn.Module], average=True, outputs=None):
@@ -53,8 +68,8 @@ def __init__(self, models: List[nn.Module], average=True, outputs=None):
5368
self.models = nn.ModuleList(models)
5469
self.average = average
5570

56-
def forward(self, x): # skipcq: PYL-W0221
57-
output_0 = self.models[0](x)
71+
def forward(self, *input, **kwargs): # skipcq: PYL-W0221
72+
output_0 = self.models[0](*input, **kwargs)
5873
num_models = len(self.models)
5974

6075
if self.outputs:
@@ -63,15 +78,15 @@ def forward(self, x): # skipcq: PYL-W0221
6378
keys = output_0.keys()
6479

6580
for index in range(1, num_models):
66-
output_i = self.models[index](x)
81+
output_i = self.models[index](*input, **kwargs)
6782

6883
# Sum outputs
6984
for key in keys:
70-
output_0[key] += output_i[key]
85+
output_0[key].add_(output_i[key])
7186

7287
if self.average:
7388
for key in keys:
74-
output_0[key] /= num_models
89+
output_0[key].mul_(1.0 / num_models)
7590

7691
return output_0
7792

@@ -86,6 +101,6 @@ def __init__(self, model: nn.Module, key: str):
86101
self.model = model
87102
self.target_key = key
88103

89-
def forward(self, input) -> Tensor:
90-
output = self.model(input)
104+
def forward(self, *input, **kwargs) -> Tensor:
105+
output = self.model(*input, **kwargs)
91106
return output[self.target_key]

pytorch_toolbelt/losses/dice.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor:
8282
y_true = y_true.view(bs, num_classes, -1)
8383
y_pred = y_pred.view(bs, num_classes, -1)
8484

85-
scores = soft_dice_score(y_pred, y_true.type_as(y_pred), self.smooth, self.eps, dims=dims)
85+
scores = soft_dice_score(y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims)
8686

8787
if self.log_loss:
8888
loss = -torch.log(scores.clamp_min(self.eps))
8989
else:
90-
loss = 1 - scores
90+
loss = 1.0 - scores
9191

9292
# Dice loss is undefined for non-empty classes
9393
# So we zero contribution of channel that does not have true pixels

pytorch_toolbelt/losses/functional.py

Lines changed: 88 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,25 @@
88

99

1010
def focal_loss_with_logits(
11-
input: torch.Tensor,
11+
output: torch.Tensor,
1212
target: torch.Tensor,
13-
gamma=2.0,
13+
gamma: float = 2.0,
1414
alpha: Optional[float] = 0.25,
15-
reduction="mean",
16-
normalized=False,
15+
reduction: str = "mean",
16+
normalized: bool = False,
1717
reduced_threshold: Optional[float] = None,
18+
eps: float = 1e-6,
1819
) -> torch.Tensor:
1920
"""Compute binary focal loss between target and output logits.
2021
2122
See :class:`~pytorch_toolbelt.losses.FocalLoss` for details.
2223
2324
Args:
24-
input: Tensor of arbitrary shape
25+
output: Tensor of arbitrary shape (predictions of the model)
2526
target: Tensor of the same shape as input
27+
gamma: Focal loss power factor
28+
alpha: Weight factor to balance positive and negative samples. Alpha must be in [0...1] range,
29+
high values will give more weight to positive class.
2630
reduction (string, optional): Specifies the reduction to apply to the output:
2731
'none' | 'mean' | 'sum' | 'batchwise_mean'. 'none': no reduction will be applied,
2832
'mean': the sum of the output will be divided by the number of
@@ -32,18 +36,18 @@ def focal_loss_with_logits(
3236
'batchwise_mean' computes mean loss per sample in batch. Default: 'mean'
3337
normalized (bool): Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf).
3438
reduced_threshold (float, optional): Compute reduced focal loss (https://arxiv.org/abs/1903.01347).
35-
References::
3639
40+
References:
3741
https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/loss/losses.py
3842
"""
39-
target = target.type(input.type())
43+
target = target.type(output.type())
4044

41-
logpt = F.binary_cross_entropy_with_logits(input, target, reduction="none")
45+
logpt = F.binary_cross_entropy_with_logits(output, target, reduction="none")
4246
pt = torch.exp(-logpt)
4347

4448
# compute the loss
4549
if reduced_threshold is None:
46-
focal_term = (1 - pt).pow(gamma)
50+
focal_term = (1.0 - pt).pow(gamma)
4751
else:
4852
focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma)
4953
focal_term[pt < reduced_threshold] = 1
@@ -54,7 +58,7 @@ def focal_loss_with_logits(
5458
loss *= alpha * target + (1 - alpha) * (1 - target)
5559

5660
if normalized:
57-
norm_factor = focal_term.sum() + 1e-5
61+
norm_factor = focal_term.sum().clamp_min(eps)
5862
loss /= norm_factor
5963

6064
if reduction == "mean":
@@ -72,19 +76,22 @@ def focal_loss_with_logits(
7276

7377

7478
# TODO: Mark as deprecated and emit warning
75-
def reduced_focal_loss(input: torch.Tensor, target: torch.Tensor, threshold=0.5, gamma=2.0, reduction="mean"):
79+
def reduced_focal_loss(output: torch.Tensor, target: torch.Tensor, threshold=0.5, gamma=2.0, reduction="mean"):
7680
return focal_loss_with_logits(
77-
input, target, alpha=None, gamma=gamma, reduction=reduction, reduced_threshold=threshold
81+
output, target, alpha=None, gamma=gamma, reduction=reduction, reduced_threshold=threshold
7882
)
7983

8084

81-
def soft_jaccard_score(y_pred: torch.Tensor, y_true: torch.Tensor, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor:
85+
def soft_jaccard_score(
86+
output: torch.Tensor, target: torch.Tensor, smooth: float = 0.0, eps: float = 1e-7, dims=None
87+
) -> torch.Tensor:
8288
"""
8389
84-
:param y_pred:
85-
:param y_true:
90+
:param output:
91+
:param target:
8692
:param smooth:
8793
:param eps:
94+
:param dims:
8895
:return:
8996
9097
Shape:
@@ -94,25 +101,27 @@ def soft_jaccard_score(y_pred: torch.Tensor, y_true: torch.Tensor, smooth=0.0, e
94101
- Output: scalar.
95102
96103
"""
97-
assert y_pred.size() == y_true.size()
104+
assert output.size() == target.size()
98105

99106
if dims is not None:
100-
intersection = torch.sum(y_pred * y_true, dim=dims)
101-
cardinality = torch.sum(y_pred + y_true, dim=dims)
107+
intersection = torch.sum(output * target, dim=dims)
108+
cardinality = torch.sum(output + target, dim=dims)
102109
else:
103-
intersection = torch.sum(y_pred * y_true)
104-
cardinality = torch.sum(y_pred + y_true)
110+
intersection = torch.sum(output * target)
111+
cardinality = torch.sum(output + target)
105112

106113
union = cardinality - intersection
107-
jaccard_score = (intersection + smooth) / (union.clamp_min(eps) + smooth)
114+
jaccard_score = (intersection + smooth) / (union + smooth).clamp_min(eps)
108115
return jaccard_score
109116

110117

111-
def soft_dice_score(y_pred: torch.Tensor, y_true: torch.Tensor, smooth=0, eps=1e-7, dims=None) -> torch.Tensor:
118+
def soft_dice_score(
119+
output: torch.Tensor, target: torch.Tensor, smooth: float = 0.0, eps: float = 1e-7, dims=None
120+
) -> torch.Tensor:
112121
"""
113122
114-
:param y_pred:
115-
:param y_true:
123+
:param output:
124+
:param target:
116125
:param smooth:
117126
:param eps:
118127
:return:
@@ -124,28 +133,28 @@ def soft_dice_score(y_pred: torch.Tensor, y_true: torch.Tensor, smooth=0, eps=1e
124133
- Output: scalar.
125134
126135
"""
127-
assert y_pred.size() == y_true.size()
136+
assert output.size() == target.size()
128137
if dims is not None:
129-
intersection = torch.sum(y_pred * y_true, dim=dims)
130-
cardinality = torch.sum(y_pred + y_true, dim=dims)
138+
intersection = torch.sum(output * target, dim=dims)
139+
cardinality = torch.sum(output + target, dim=dims)
131140
else:
132-
intersection = torch.sum(y_pred * y_true)
133-
cardinality = torch.sum(y_pred + y_true)
134-
dice_score = (2.0 * intersection + smooth) / (cardinality.clamp_min(eps) + smooth)
141+
intersection = torch.sum(output * target)
142+
cardinality = torch.sum(output + target)
143+
dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps)
135144
return dice_score
136145

137146

138-
def wing_loss(prediction: torch.Tensor, target: torch.Tensor, width=5, curvature=0.5, reduction="mean"):
147+
def wing_loss(output: torch.Tensor, target: torch.Tensor, width=5, curvature=0.5, reduction="mean"):
139148
"""
140149
https://arxiv.org/pdf/1711.06753.pdf
141-
:param prediction:
150+
:param output:
142151
:param target:
143152
:param width:
144153
:param curvature:
145154
:param reduction:
146155
:return:
147156
"""
148-
diff_abs = (target - prediction).abs()
157+
diff_abs = (target - output).abs()
149158
loss = diff_abs.clone()
150159

151160
idx_smaller = diff_abs < width
@@ -163,3 +172,49 @@ def wing_loss(prediction: torch.Tensor, target: torch.Tensor, width=5, curvature
163172
loss = loss.mean()
164173

165174
return loss
175+
176+
177+
def label_smoothed_nll_loss(
178+
lprobs: torch.Tensor, target: torch.Tensor, epsilon: float, ignore_index=None, reduction="mean", dim=-1
179+
) -> torch.Tensor:
180+
"""
181+
182+
Source: https://github.com/pytorch/fairseq/blob/master/fairseq/criterions/label_smoothed_cross_entropy.py
183+
184+
:param lprobs: Log-probabilities of predictions (e.g after log_softmax)
185+
:param target:
186+
:param epsilon:
187+
:param ignore_index:
188+
:param reduction:
189+
:return:
190+
"""
191+
if target.dim() == lprobs.dim() - 1:
192+
target = target.unsqueeze(dim)
193+
194+
if ignore_index is not None:
195+
pad_mask = target.eq(ignore_index)
196+
target = target.masked_fill(pad_mask, 0)
197+
nll_loss = -lprobs.gather(dim=dim, index=target)
198+
smooth_loss = -lprobs.sum(dim=dim, keepdim=True)
199+
200+
# nll_loss.masked_fill_(pad_mask, 0.0)
201+
# smooth_loss.masked_fill_(pad_mask, 0.0)
202+
nll_loss = nll_loss.masked_fill(pad_mask, 0.0)
203+
smooth_loss = smooth_loss.masked_fill(pad_mask, 0.0)
204+
else:
205+
nll_loss = -lprobs.gather(dim=dim, index=target)
206+
smooth_loss = -lprobs.sum(dim=dim, keepdim=True)
207+
208+
nll_loss = nll_loss.squeeze(dim)
209+
smooth_loss = smooth_loss.squeeze(dim)
210+
211+
if reduction == "sum":
212+
nll_loss = nll_loss.sum()
213+
smooth_loss = smooth_loss.sum()
214+
if reduction == "mean":
215+
nll_loss = nll_loss.mean()
216+
smooth_loss = smooth_loss.mean()
217+
218+
eps_i = epsilon / lprobs.size(dim)
219+
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
220+
return loss

pytorch_toolbelt/losses/jaccard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor:
8282
y_true = y_true.view(bs, num_classes, -1)
8383
y_pred = y_pred.view(bs, num_classes, -1)
8484

85-
scores = soft_jaccard_score(y_pred, y_true.type(y_pred.dtype), self.smooth, self.eps, dims=dims)
85+
scores = soft_jaccard_score(y_pred, y_true.type(y_pred.dtype), smooth=self.smooth, eps=self.eps, dims=dims)
8686

8787
if self.log_loss:
8888
loss = -torch.log(scores.clamp_min(self.eps))
8989
else:
90-
loss = 1 - scores
90+
loss = 1.0 - scores
9191

9292
# IoU loss is defined for non-empty classes
9393
# So we zero contribution of channel that does not have true pixels

pytorch_toolbelt/losses/joint_loss.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from torch import nn
12
from torch.nn.modules.loss import _Loss
23

34
__all__ = ["JointLoss", "WeightedLoss"]
@@ -18,7 +19,11 @@ def forward(self, *input):
1819

1920

2021
class JointLoss(_Loss):
21-
def __init__(self, first, second, first_weight=1.0, second_weight=1.0):
22+
"""
23+
Wrap two loss functions into one. This class computes a weighted sum of two losses.
24+
"""
25+
26+
def __init__(self, first: nn.Module, second: nn.Module, first_weight=1.0, second_weight=1.0):
2227
super().__init__()
2328
self.first = WeightedLoss(first, first_weight)
2429
self.second = WeightedLoss(second, second_weight)

0 commit comments

Comments
 (0)