4
4
transformation written in PyTorch and respect gradients flow.
5
5
"""
6
6
from functools import partial
7
- from typing import Tuple
7
+ from typing import Tuple , List
8
8
9
9
from torch import Tensor , nn
10
+
10
11
from . import functional as F
11
12
12
- __all__ = ['d4_image2label' , 'd4_image2mask' , 'fivecrop_image2label' , 'fliplr_image2mask' ,
13
- 'fliplr_image2label' , 'TTAWrapper' ]
13
+ __all__ = ['d4_image2label' ,
14
+ 'd4_image2mask' ,
15
+ 'fivecrop_image2label' ,
16
+ 'tencrop_image2label' ,
17
+ 'fliplr_image2mask' ,
18
+ 'fliplr_image2label' ,
19
+ 'TTAWrapper' ,
20
+ 'MultiscaleTTAWrapper' ]
14
21
15
22
16
23
def fliplr_image2label (model : nn .Module , image : Tensor ) -> Tensor :
@@ -26,7 +33,8 @@ def fliplr_image2label(model: nn.Module, image: Tensor) -> Tensor:
26
33
return output * one_over_2
27
34
28
35
29
- def fivecrop_image2label (model : nn .Module , image : Tensor , crop_size : Tuple ) -> Tensor :
36
+ def fivecrop_image2label (model : nn .Module , image : Tensor ,
37
+ crop_size : Tuple ) -> Tensor :
30
38
"""Test-time augmentation for image classification that takes five crops out of input tensor (4 on corners and central)
31
39
and averages predictions from them.
32
40
@@ -61,16 +69,19 @@ def fivecrop_image2label(model: nn.Module, image: Tensor, crop_size: Tuple) -> T
61
69
center_crop_y = (image_height - crop_height ) // 2
62
70
center_crop_x = (image_width - crop_width ) // 2
63
71
64
- crop_cc = image [..., center_crop_y :center_crop_y + crop_height , center_crop_x :center_crop_x + crop_width ]
72
+ crop_cc = image [..., center_crop_y :center_crop_y + crop_height ,
73
+ center_crop_x :center_crop_x + crop_width ]
65
74
assert crop_cc .size (2 ) == crop_height
66
75
assert crop_cc .size (3 ) == crop_width
67
76
68
- output = model (crop_tl ) + model (crop_tr ) + model (crop_bl ) + model (crop_br ) + model (crop_cc )
77
+ output = model (crop_tl ) + model (crop_tr ) + model (crop_bl ) + model (
78
+ crop_br ) + model (crop_cc )
69
79
one_over_5 = float (1.0 / 5.0 )
70
80
return output * one_over_5
71
81
72
82
73
- def tencrop_image2label (model : nn .Module , image : Tensor , crop_size : Tuple ) -> Tensor :
83
+ def tencrop_image2label (model : nn .Module , image : Tensor ,
84
+ crop_size : Tuple ) -> Tensor :
74
85
"""Test-time augmentation for image classification that takes five crops out of input tensor (4 on corners and central)
75
86
and averages predictions from them and from their horisontally-flipped versions (10-Crop TTA).
76
87
@@ -105,7 +116,8 @@ def tencrop_image2label(model: nn.Module, image: Tensor, crop_size: Tuple) -> Te
105
116
center_crop_y = (image_height - crop_height ) // 2
106
117
center_crop_x = (image_width - crop_width ) // 2
107
118
108
- crop_cc = image [..., center_crop_y :center_crop_y + crop_height , center_crop_x :center_crop_x + crop_width ]
119
+ crop_cc = image [..., center_crop_y :center_crop_y + crop_height ,
120
+ center_crop_x :center_crop_x + crop_width ]
109
121
assert crop_cc .size (2 ) == crop_height
110
122
assert crop_cc .size (3 ) == crop_width
111
123
@@ -170,13 +182,16 @@ def d4_image2mask(model: nn.Module, image: Tensor) -> Tensor:
170
182
"""
171
183
output = model (image )
172
184
173
- for aug , deaug in zip ([F .torch_rot90 , F .torch_rot180 , F .torch_rot270 ], [F .torch_rot270 , F .torch_rot180 , F .torch_rot90 ]):
185
+ for aug , deaug in zip ([F .torch_rot90 , F .torch_rot180 , F .torch_rot270 ],
186
+ [F .torch_rot270 , F .torch_rot180 , F .torch_rot90 ]):
174
187
x = deaug (model (aug (image )))
175
188
output = output + x
176
189
177
190
image = F .torch_transpose (image )
178
191
179
- for aug , deaug in zip ([F .torch_none , F .torch_rot90 , F .torch_rot180 , F .torch_rot270 ], [F .torch_none , F .torch_rot270 , F .torch_rot180 , F .torch_rot90 ]):
192
+ for aug , deaug in zip (
193
+ [F .torch_none , F .torch_rot90 , F .torch_rot180 , F .torch_rot270 ],
194
+ [F .torch_none , F .torch_rot270 , F .torch_rot180 , F .torch_rot90 ]):
180
195
x = deaug (model (aug (image )))
181
196
output = output + F .torch_transpose (x )
182
197
@@ -185,10 +200,47 @@ def d4_image2mask(model: nn.Module, image: Tensor) -> Tensor:
185
200
186
201
187
202
class TTAWrapper (nn .Module ):
188
- def __init__ (self , model , tta_function , ** kwargs ):
203
+ def __init__ (self , model : nn . Module , tta_function , ** kwargs ):
189
204
super ().__init__ ()
190
205
self .model = model
191
206
self .tta = partial (tta_function , ** kwargs )
192
207
193
208
def forward (self , * input ):
194
209
return self .tta (self .model , * input )
210
+
211
+
212
+ class MultiscaleTTAWrapper (nn .Module ):
213
+ """
214
+ Multiscale TTA wrapper module
215
+ """
216
+
217
+ def __init__ (self , model : nn .Module , scale_levels : List [float ]):
218
+ """
219
+ Initialize multi-scale TTA wrapper
220
+
221
+ :param model: Base model for inference
222
+ :param scale_levels: List of additional scale levels,
223
+ e.g: [0.5, 0.75, 1.25]
224
+ """
225
+ super ().__init__ ()
226
+ assert len (scale_levels )
227
+ self .model = model
228
+ self .scale_levels = scale_levels
229
+
230
+ def forward (self , input : Tensor ) -> Tensor :
231
+ h = input .size (2 )
232
+ w = input .size (3 )
233
+
234
+ out_size = h , w
235
+ output = self .model (input )
236
+
237
+ for scale in self .scale_levels :
238
+ dst_size = int (h * scale ), int (w * scale )
239
+ input_scaled = F .interpolate (input , dst_size , mode = 'bilinear' ,
240
+ align_corners = True )
241
+ output_scaled = self .model (input_scaled )
242
+ output_scaled = F .interpolate (output_scaled , out_size ,
243
+ mode = 'bilinear' , align_corners = True )
244
+ output += output_scaled
245
+
246
+ return output / (1 + len (self .scale_levels ))
0 commit comments