Skip to content

Commit a3f52fd

Browse files
committed
Added 10-Crop TTA
1 parent d8f2d45 commit a3f52fd

File tree

5 files changed

+169
-10
lines changed

5 files changed

+169
-10
lines changed

pytorch_toolbelt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from __future__ import absolute_import
22

3-
__version__ = '0.0.4'
3+
__version__ = '0.0.5'

pytorch_toolbelt/inference/functional.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,20 @@ def torch_rot270(x: Tensor):
2525

2626

2727
def torch_flipud(x: Tensor):
28+
"""
29+
Flip image tensor vertically
30+
:param x:
31+
:return:
32+
"""
2833
return x.flip(2)
2934

3035

31-
def torch_fliplp(x: Tensor):
36+
def torch_fliplr(x: Tensor):
37+
"""
38+
Flip image tensor horizontally
39+
:param x:
40+
:return:
41+
"""
3242
return x.flip(3)
3343

3444

@@ -85,7 +95,7 @@ def pad_image_tensor(image_tensor: Tensor, pad_size: int = 32):
8595
return image_tensor, pad
8696

8797

88-
def unpad_tensor(image_tensor, pad):
98+
def unpad_image_tensor(image_tensor, pad):
8999
pad_left, pad_right, pad_top, pad_btm = pad
90100
rows, cols = image_tensor.size(2), image_tensor.size(3)
91101
return image_tensor[..., pad_top:rows - pad_btm, pad_left: cols - pad_right]

pytorch_toolbelt/inference/tta.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def fliplr_image2label(model: nn.Module, image: Tensor) -> Tensor:
2121
:param image:
2222
:return:
2323
"""
24-
output = model(image) + model(F.torch_fliplp(image))
24+
output = model(image) + model(F.torch_fliplr(image))
2525
one_over_2 = float(1.0 / 2.0)
2626
return output * one_over_2
2727

@@ -30,10 +30,10 @@ def fivecrop_image2label(model: nn.Module, image: Tensor, crop_size: Tuple) -> T
3030
"""Test-time augmentation for image classification that takes five crops out of input tensor (4 on corners and central)
3131
and averages predictions from them.
3232
33-
:param model:
34-
:param image:
35-
:param crop_size:
36-
:return:
33+
:param model: Classification model
34+
:param image: Input image tensor
35+
:param crop_size: Crop size. Must be smaller than image size
36+
:return: Averaged logits
3737
"""
3838
image_height, image_width = int(image.size(2)), int(image.size(3))
3939
crop_height, crop_width = crop_size
@@ -70,6 +70,55 @@ def fivecrop_image2label(model: nn.Module, image: Tensor, crop_size: Tuple) -> T
7070
return output * one_over_5
7171

7272

73+
def tencrop_image2label(model: nn.Module, image: Tensor, crop_size: Tuple) -> Tensor:
74+
"""Test-time augmentation for image classification that takes five crops out of input tensor (4 on corners and central)
75+
and averages predictions from them and from their horisontally-flipped versions (10-Crop TTA).
76+
77+
:param model: Classification model
78+
:param image: Input image tensor
79+
:param crop_size: Crop size. Must be smaller than image size
80+
:return: Averaged logits
81+
"""
82+
image_height, image_width = int(image.size(2)), int(image.size(3))
83+
crop_height, crop_width = crop_size
84+
85+
assert crop_height <= image_height
86+
assert crop_width <= image_width
87+
88+
bottom_crop_start = image_height - crop_height
89+
right_crop_start = image_width - crop_width
90+
crop_tl = image[..., :crop_height, :crop_width]
91+
crop_tr = image[..., :crop_height, right_crop_start:]
92+
crop_bl = image[..., bottom_crop_start:, :crop_width]
93+
crop_br = image[..., bottom_crop_start:, right_crop_start:]
94+
95+
assert crop_tl.size(2) == crop_height
96+
assert crop_tr.size(2) == crop_height
97+
assert crop_bl.size(2) == crop_height
98+
assert crop_br.size(2) == crop_height
99+
100+
assert crop_tl.size(3) == crop_width
101+
assert crop_tr.size(3) == crop_width
102+
assert crop_bl.size(3) == crop_width
103+
assert crop_br.size(3) == crop_width
104+
105+
center_crop_y = (image_height - crop_height) // 2
106+
center_crop_x = (image_width - crop_width) // 2
107+
108+
crop_cc = image[..., center_crop_y:center_crop_y + crop_height, center_crop_x:center_crop_x + crop_width]
109+
assert crop_cc.size(2) == crop_height
110+
assert crop_cc.size(3) == crop_width
111+
112+
output = model(crop_tl) + model(F.torch_fliplr(crop_tl)) + \
113+
model(crop_tr) + model(F.torch_fliplr(crop_tr)) + \
114+
model(crop_bl) + model(F.torch_fliplr(crop_bl)) + \
115+
model(crop_br) + model(F.torch_fliplr(crop_br)) + \
116+
model(crop_cc) + model(F.torch_fliplr(crop_cc))
117+
118+
one_over_10 = float(1.0 / 10.0)
119+
return output * one_over_10
120+
121+
73122
def fliplr_image2mask(model: nn.Module, image: Tensor) -> Tensor:
74123
"""Test-time augmentation for image segmentation that averages predictions
75124
for input image and vertically flipped one.
@@ -80,7 +129,7 @@ def fliplr_image2mask(model: nn.Module, image: Tensor) -> Tensor:
80129
:param image: Model input.
81130
:return: Arithmetically averaged predictions
82131
"""
83-
output = model(image) + F.torch_fliplp(model(F.torch_fliplp(image)))
132+
output = model(image) + F.torch_fliplr(model(F.torch_fliplr(image)))
84133
one_over_2 = float(1.0 / 2.0)
85134
return output * one_over_2
86135

@@ -129,7 +178,7 @@ def d4_image2mask(model: nn.Module, image: Tensor) -> Tensor:
129178

130179
for aug, deaug in zip([F.torch_none, F.torch_rot90, F.torch_rot180, F.torch_rot270], [F.torch_none, F.torch_rot270, F.torch_rot180, F.torch_rot90]):
131180
x = deaug(model(aug(image)))
132-
output = output + x
181+
output = output + F.torch_transpose(x)
133182

134183
one_over_8 = float(1.0 / 8.0)
135184
return output * one_over_8

pytorch_toolbelt/utils/torch_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,17 @@
99
from torch import nn
1010

1111

12+
def freeze_bn(module: nn.Module):
13+
"""Freezes BatchNorm
14+
"""
15+
if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
16+
module.track_running_stats = False
17+
18+
for m in module.modules():
19+
if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
20+
module.track_running_stats = False
21+
22+
1223
def logit(x: torch.Tensor, eps=1e-5):
1324
x = torch.clamp(x.float(), eps, 1.0 - eps)
1425
return torch.log(x / (1.0 - x))

tests/test_tta.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import torch
2+
import numpy as np
3+
from pytorch_toolbelt.inference import tta
4+
from pytorch_toolbelt.utils.torch_utils import to_numpy
5+
from torch import nn
6+
7+
8+
class NoOp(nn.Module):
9+
def __init__(self):
10+
super().__init__()
11+
12+
def forward(self, input):
13+
return input
14+
15+
16+
class SumAll(nn.Module):
17+
def __init__(self):
18+
super().__init__()
19+
20+
def forward(self, input):
21+
return input.sum(dim=[1, 2, 3])
22+
23+
24+
def test_d4_image2mask():
25+
input = torch.rand((4, 3, 224, 224))
26+
model = NoOp()
27+
28+
output = tta.d4_image2mask(model, input)
29+
np.testing.assert_allclose(to_numpy(output), to_numpy(input), atol=1e-6, rtol=1e-6)
30+
31+
32+
def test_fliplr_image2mask():
33+
input = torch.rand((4, 3, 224, 224))
34+
model = NoOp()
35+
36+
output = tta.fliplr_image2mask(model, input)
37+
np.testing.assert_allclose(to_numpy(output), to_numpy(input), atol=1e-6, rtol=1e-6)
38+
39+
40+
def test_d4_image2label():
41+
input = torch.tensor([[1, 2, 3, 4],
42+
[5, 6, 7, 8],
43+
[9, 0, 1, 2],
44+
[3, 4, 5, 6]]).unsqueeze(0).unsqueeze(0).float()
45+
model = SumAll()
46+
47+
output = tta.d4_image2label(model, input)
48+
expected = int(input.sum())
49+
50+
assert int(output) == expected
51+
52+
53+
def test_fliplr_image2label():
54+
input = torch.tensor([[1, 2, 3, 4],
55+
[5, 6, 7, 8],
56+
[9, 0, 1, 2],
57+
[3, 4, 5, 6]]).unsqueeze(0).unsqueeze(0).float()
58+
model = SumAll()
59+
60+
output = tta.fliplr_image2label(model, input)
61+
expected = int(input.sum())
62+
63+
assert int(output) == expected
64+
65+
66+
def test_fivecrop_image2label():
67+
input = torch.tensor([[1, 2, 3, 4],
68+
[5, 6, 7, 8],
69+
[9, 0, 1, 2],
70+
[3, 4, 5, 6]]).unsqueeze(0).unsqueeze(0).float()
71+
model = SumAll()
72+
73+
output = tta.fivecrop_image2label(model, input, (2, 2))
74+
expected = ((1 + 2 + 5 + 6) + (3 + 4 + 7 + 8) + (9 + 0 + 3 + 4) + (1 + 2 + 5 + 6) + (6 + 7 + 0 + 1)) / 5
75+
76+
assert int(output) == expected
77+
78+
79+
def test_tencrop_image2label():
80+
input = torch.tensor([[1, 2, 3, 4],
81+
[5, 6, 7, 8],
82+
[9, 0, 1, 2],
83+
[3, 4, 5, 6]]).unsqueeze(0).unsqueeze(0).float()
84+
model = SumAll()
85+
86+
output = tta.tencrop_image2label(model, input, (2, 2))
87+
expected = (2 * ((1 + 2 + 5 + 6) + (3 + 4 + 7 + 8) + (9 + 0 + 3 + 4) + (1 + 2 + 5 + 6) + (6 + 7 + 0 + 1))) / 10
88+
89+
assert int(output) == expected

0 commit comments

Comments
 (0)