Skip to content

Commit 7898b2e

Browse files
authored
Merge pull request #14 from BloodAxe/develop
Release 0.1.0
2 parents 3ef7f37 + ba72200 commit 7898b2e

File tree

10 files changed

+808
-64
lines changed

10 files changed

+808
-64
lines changed

demo/demo_losses.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,33 +9,50 @@
99
def main():
1010
losses = {
1111
"bce": BCEWithLogitsLoss(),
12-
"focal": L.BinaryFocalLoss(),
13-
"jaccard": L.BinaryJaccardLoss(),
12+
# "focal": L.BinaryFocalLoss(),
13+
# "jaccard": L.BinaryJaccardLoss(),
1414
# "jaccard_log": L.BinaryJaccardLogLoss(),
15-
"lovasz": L.BinaryLovaszLoss(),
16-
# "bce+jaccard_log": L.BinaryJaccardLogLoss(),
17-
"reduced_focal": L.BinaryFocalLoss(reduced=True)
15+
# "dice": L.BinaryDiceLoss(),
16+
# "dice_log": L.BinaryDiceLogLoss(),
17+
# "sdice": L.BinarySymmetricDiceLoss(),
18+
# "sdice_log": L.BinarySymmetricDiceLoss(log_loss=True),
19+
20+
"bce+lovasz": L.JointLoss(BCEWithLogitsLoss(), L.BinaryLovaszLoss()),
21+
# "lovasz": L.BinaryLovaszLoss(),
22+
# "bce+jaccard": L.JointLoss(BCEWithLogitsLoss(),
23+
# L.BinaryJaccardLoss(), 1, 0.5),
24+
25+
# "bce+log_jaccard": L.JointLoss(BCEWithLogitsLoss(),
26+
# L.BinaryJaccardLogLoss(), 1, 0.5),
27+
28+
# "bce+log_dice": L.JointLoss(BCEWithLogitsLoss(),
29+
# L.BinaryDiceLogLoss(), 1, 0.5)
30+
31+
# "reduced_focal": L.BinaryFocalLoss(reduced=True)
1832
}
1933

20-
x_vec = torch.arange(-5, 5, 0.01)
34+
dx = 0.01
35+
x_vec = torch.arange(-5, 5, dx).view(-1, 1).expand((-1, 100))
2136

22-
plt.figure()
37+
f, ax = plt.subplots(3, figsize=(16, 16))
2338

2439
for name, loss in losses.items():
2540
x_arr = []
2641
y_arr = []
27-
target = torch.tensor(1.0)
42+
target = torch.tensor(1.0).view(1).expand((100))
2843

2944
for x in x_vec:
30-
y = loss(x, target)
45+
y = loss(x, target).item()
3146

32-
x_arr.append(float(x))
47+
x_arr.append(float(x[0]))
3348
y_arr.append(float(y))
3449

35-
plt.plot(x_arr, y_arr, label=name)
50+
ax[0].plot(x_arr, y_arr, label=name)
51+
ax[1].plot(x_arr, np.gradient(y_arr, dx))
52+
ax[2].plot(x_arr, np.gradient(np.gradient(y_arr, dx), dx))
3653

37-
plt.legend()
38-
plt.show()
54+
f.legend()
55+
f.show()
3956

4057

4158
if __name__ == '__main__':

pytorch_toolbelt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from __future__ import absolute_import
22

3-
__version__ = '0.0.9'
3+
__version__ = '0.1.0'

pytorch_toolbelt/inference/tta.py

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,20 @@
44
transformation written in PyTorch and respect gradients flow.
55
"""
66
from functools import partial
7-
from typing import Tuple
7+
from typing import Tuple, List
88

99
from torch import Tensor, nn
10+
1011
from . import functional as F
1112

12-
__all__ = ['d4_image2label', 'd4_image2mask', 'fivecrop_image2label', 'fliplr_image2mask',
13-
'fliplr_image2label', 'TTAWrapper']
13+
__all__ = ['d4_image2label',
14+
'd4_image2mask',
15+
'fivecrop_image2label',
16+
'tencrop_image2label',
17+
'fliplr_image2mask',
18+
'fliplr_image2label',
19+
'TTAWrapper',
20+
'MultiscaleTTAWrapper']
1421

1522

1623
def fliplr_image2label(model: nn.Module, image: Tensor) -> Tensor:
@@ -26,7 +33,8 @@ def fliplr_image2label(model: nn.Module, image: Tensor) -> Tensor:
2633
return output * one_over_2
2734

2835

29-
def fivecrop_image2label(model: nn.Module, image: Tensor, crop_size: Tuple) -> Tensor:
36+
def fivecrop_image2label(model: nn.Module, image: Tensor,
37+
crop_size: Tuple) -> Tensor:
3038
"""Test-time augmentation for image classification that takes five crops out of input tensor (4 on corners and central)
3139
and averages predictions from them.
3240
@@ -61,16 +69,19 @@ def fivecrop_image2label(model: nn.Module, image: Tensor, crop_size: Tuple) -> T
6169
center_crop_y = (image_height - crop_height) // 2
6270
center_crop_x = (image_width - crop_width) // 2
6371

64-
crop_cc = image[..., center_crop_y:center_crop_y + crop_height, center_crop_x:center_crop_x + crop_width]
72+
crop_cc = image[..., center_crop_y:center_crop_y + crop_height,
73+
center_crop_x:center_crop_x + crop_width]
6574
assert crop_cc.size(2) == crop_height
6675
assert crop_cc.size(3) == crop_width
6776

68-
output = model(crop_tl) + model(crop_tr) + model(crop_bl) + model(crop_br) + model(crop_cc)
77+
output = model(crop_tl) + model(crop_tr) + model(crop_bl) + model(
78+
crop_br) + model(crop_cc)
6979
one_over_5 = float(1.0 / 5.0)
7080
return output * one_over_5
7181

7282

73-
def tencrop_image2label(model: nn.Module, image: Tensor, crop_size: Tuple) -> Tensor:
83+
def tencrop_image2label(model: nn.Module, image: Tensor,
84+
crop_size: Tuple) -> Tensor:
7485
"""Test-time augmentation for image classification that takes five crops out of input tensor (4 on corners and central)
7586
and averages predictions from them and from their horisontally-flipped versions (10-Crop TTA).
7687
@@ -105,7 +116,8 @@ def tencrop_image2label(model: nn.Module, image: Tensor, crop_size: Tuple) -> Te
105116
center_crop_y = (image_height - crop_height) // 2
106117
center_crop_x = (image_width - crop_width) // 2
107118

108-
crop_cc = image[..., center_crop_y:center_crop_y + crop_height, center_crop_x:center_crop_x + crop_width]
119+
crop_cc = image[..., center_crop_y:center_crop_y + crop_height,
120+
center_crop_x:center_crop_x + crop_width]
109121
assert crop_cc.size(2) == crop_height
110122
assert crop_cc.size(3) == crop_width
111123

@@ -170,13 +182,16 @@ def d4_image2mask(model: nn.Module, image: Tensor) -> Tensor:
170182
"""
171183
output = model(image)
172184

173-
for aug, deaug in zip([F.torch_rot90, F.torch_rot180, F.torch_rot270], [F.torch_rot270, F.torch_rot180, F.torch_rot90]):
185+
for aug, deaug in zip([F.torch_rot90, F.torch_rot180, F.torch_rot270],
186+
[F.torch_rot270, F.torch_rot180, F.torch_rot90]):
174187
x = deaug(model(aug(image)))
175188
output = output + x
176189

177190
image = F.torch_transpose(image)
178191

179-
for aug, deaug in zip([F.torch_none, F.torch_rot90, F.torch_rot180, F.torch_rot270], [F.torch_none, F.torch_rot270, F.torch_rot180, F.torch_rot90]):
192+
for aug, deaug in zip(
193+
[F.torch_none, F.torch_rot90, F.torch_rot180, F.torch_rot270],
194+
[F.torch_none, F.torch_rot270, F.torch_rot180, F.torch_rot90]):
180195
x = deaug(model(aug(image)))
181196
output = output + F.torch_transpose(x)
182197

@@ -185,10 +200,47 @@ def d4_image2mask(model: nn.Module, image: Tensor) -> Tensor:
185200

186201

187202
class TTAWrapper(nn.Module):
188-
def __init__(self, model, tta_function, **kwargs):
203+
def __init__(self, model: nn.Module, tta_function, **kwargs):
189204
super().__init__()
190205
self.model = model
191206
self.tta = partial(tta_function, **kwargs)
192207

193208
def forward(self, *input):
194209
return self.tta(self.model, *input)
210+
211+
212+
class MultiscaleTTAWrapper(nn.Module):
213+
"""
214+
Multiscale TTA wrapper module
215+
"""
216+
217+
def __init__(self, model: nn.Module, scale_levels: List[float]):
218+
"""
219+
Initialize multi-scale TTA wrapper
220+
221+
:param model: Base model for inference
222+
:param scale_levels: List of additional scale levels,
223+
e.g: [0.5, 0.75, 1.25]
224+
"""
225+
super().__init__()
226+
assert len(scale_levels)
227+
self.model = model
228+
self.scale_levels = scale_levels
229+
230+
def forward(self, input: Tensor) -> Tensor:
231+
h = input.size(2)
232+
w = input.size(3)
233+
234+
out_size = h, w
235+
output = self.model(input)
236+
237+
for scale in self.scale_levels:
238+
dst_size = int(h * scale), int(w * scale)
239+
input_scaled = F.interpolate(input, dst_size, mode='bilinear',
240+
align_corners=True)
241+
output_scaled = self.model(input_scaled)
242+
output_scaled = F.interpolate(output_scaled, out_size,
243+
mode='bilinear', align_corners=True)
244+
output += output_scaled
245+
246+
return output / (1 + len(self.scale_levels))

pytorch_toolbelt/modules/abn.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,20 @@
11
import torch
22
import torch.nn as nn
33
import torch.nn.functional as functional
4+
from pytorch_toolbelt.modules.activations import ACT_LEAKY_RELU, ACT_NONE, \
5+
ACT_HARD_SIGMOID, ACT_HARD_SWISH, ACT_SWISH, ACT_SELU, ACT_ELU, ACT_RELU6, \
6+
ACT_RELU, hard_swish, hard_sigmoid, swish
47

5-
__all__ = ['ACT_RELU', 'ACT_ELU', 'ACT_SELU', 'ACT_NONE', 'ACT_LEAKY_RELU', 'ABN']
6-
7-
# Activation names
8-
ACT_RELU = "relu"
9-
ACT_LEAKY_RELU = "leaky_relu"
10-
ACT_ELU = "elu"
11-
ACT_NONE = "none"
12-
ACT_SELU = "selu"
8+
__all__ = ['ABN']
139

1410

1511
class ABN(nn.Module):
1612
"""Activated Batch Normalization
1713
This gathers a `BatchNorm2d` and an activation function in a single module
1814
"""
1915

20-
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
16+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
17+
activation="leaky_relu", slope=0.01):
2118
"""Create an Activated Batch Normalization module
2219
Parameters
2320
----------
@@ -52,26 +49,39 @@ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation
5249
self.reset_parameters()
5350

5451
def reset_parameters(self):
55-
nn.init.constant_(self.running_mean, 0)
56-
nn.init.constant_(self.running_var, 1)
52+
nn.init.zeros_(self.running_mean)
53+
nn.init.ones_(self.running_var)
5754
if self.affine:
58-
nn.init.constant_(self.weight, 1)
59-
nn.init.constant_(self.bias, 0)
55+
nn.init.ones_(self.weight)
56+
nn.init.zeros_(self.bias)
6057

6158
def forward(self, x):
62-
x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
59+
x = functional.batch_norm(x,
60+
self.running_mean, self.running_var,
61+
self.weight, self.bias,
6362
self.training, self.momentum, self.eps)
6463

6564
if self.activation == ACT_RELU:
6665
return functional.relu(x, inplace=True)
66+
elif self.activation == ACT_RELU6:
67+
return functional.relu6(x, inplace=True)
6768
elif self.activation == ACT_LEAKY_RELU:
68-
return functional.leaky_relu(x, negative_slope=self.slope, inplace=True)
69+
return functional.leaky_relu(x, negative_slope=self.slope,
70+
inplace=True)
6971
elif self.activation == ACT_ELU:
7072
return functional.elu(x, inplace=True)
7173
elif self.activation == ACT_SELU:
7274
return functional.selu(x, inplace=True)
73-
else:
75+
elif self.activation == ACT_SWISH:
76+
return swish(x)
77+
elif self.activation == ACT_HARD_SWISH:
78+
return hard_swish(x, inplace=True)
79+
elif self.activation == ACT_HARD_SIGMOID:
80+
return hard_sigmoid(x, inplace=True)
81+
elif self.activation == ACT_NONE:
7482
return x
83+
else:
84+
raise KeyError(self.activation)
7585

7686
def __repr__(self):
7787
rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \

pytorch_toolbelt/modules/activations.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,24 @@
33
from torch import nn
44
from torch.nn import functional as F
55

6+
__all__ = ['ACT_ELU',
7+
'ACT_HARD_SIGMOID', 'ACT_HARD_SWISH', 'ACT_LEAKY_RELU', 'ACT_NONE',
8+
'ACT_RELU', 'ACT_RELU6', 'ACT_SELU', 'ACT_SWISH',
9+
'swish', 'hard_sigmoid', 'hard_swish', 'HardSigmoid', 'HardSwish',
10+
'Swish', 'get_activation_module'
11+
]
12+
13+
# Activation names
14+
ACT_RELU = "relu"
15+
ACT_RELU6 = "relu6"
16+
ACT_LEAKY_RELU = "leaky_relu"
17+
ACT_ELU = "elu"
18+
ACT_NONE = "none"
19+
ACT_SELU = "selu"
20+
ACT_SWISH = "swish"
21+
ACT_HARD_SWISH = "hard_swish"
22+
ACT_HARD_SIGMOID = "hard_sigmoid"
23+
624

725
def swish(x):
826
return x * x.sigmoid()
@@ -70,6 +88,9 @@ def get_activation_module(activation_name: str, **kwargs) -> nn.Module:
7088
if activation_name.lower() == 'hard_sigmoid':
7189
return partial(HardSigmoid, **kwargs)
7290

91+
if activation_name.lower() == 'swish':
92+
return partial(Swish, **kwargs)
93+
7394
if activation_name.lower() == 'hard_swish':
7495
return partial(HardSwish, **kwargs)
7596

0 commit comments

Comments
 (0)