@@ -21,7 +21,7 @@ def fliplr_image2label(model: nn.Module, image: Tensor) -> Tensor:
21
21
:param image:
22
22
:return:
23
23
"""
24
- output = model (image ) + model (F .torch_fliplp (image ))
24
+ output = model (image ) + model (F .torch_fliplr (image ))
25
25
one_over_2 = float (1.0 / 2.0 )
26
26
return output * one_over_2
27
27
@@ -30,10 +30,10 @@ def fivecrop_image2label(model: nn.Module, image: Tensor, crop_size: Tuple) -> T
30
30
"""Test-time augmentation for image classification that takes five crops out of input tensor (4 on corners and central)
31
31
and averages predictions from them.
32
32
33
- :param model:
34
- :param image:
35
- :param crop_size:
36
- :return:
33
+ :param model: Classification model
34
+ :param image: Input image tensor
35
+ :param crop_size: Crop size. Must be smaller than image size
36
+ :return: Averaged logits
37
37
"""
38
38
image_height , image_width = int (image .size (2 )), int (image .size (3 ))
39
39
crop_height , crop_width = crop_size
@@ -70,6 +70,55 @@ def fivecrop_image2label(model: nn.Module, image: Tensor, crop_size: Tuple) -> T
70
70
return output * one_over_5
71
71
72
72
73
+ def tencrop_image2label (model : nn .Module , image : Tensor , crop_size : Tuple ) -> Tensor :
74
+ """Test-time augmentation for image classification that takes five crops out of input tensor (4 on corners and central)
75
+ and averages predictions from them and from their horisontally-flipped versions (10-Crop TTA).
76
+
77
+ :param model: Classification model
78
+ :param image: Input image tensor
79
+ :param crop_size: Crop size. Must be smaller than image size
80
+ :return: Averaged logits
81
+ """
82
+ image_height , image_width = int (image .size (2 )), int (image .size (3 ))
83
+ crop_height , crop_width = crop_size
84
+
85
+ assert crop_height <= image_height
86
+ assert crop_width <= image_width
87
+
88
+ bottom_crop_start = image_height - crop_height
89
+ right_crop_start = image_width - crop_width
90
+ crop_tl = image [..., :crop_height , :crop_width ]
91
+ crop_tr = image [..., :crop_height , right_crop_start :]
92
+ crop_bl = image [..., bottom_crop_start :, :crop_width ]
93
+ crop_br = image [..., bottom_crop_start :, right_crop_start :]
94
+
95
+ assert crop_tl .size (2 ) == crop_height
96
+ assert crop_tr .size (2 ) == crop_height
97
+ assert crop_bl .size (2 ) == crop_height
98
+ assert crop_br .size (2 ) == crop_height
99
+
100
+ assert crop_tl .size (3 ) == crop_width
101
+ assert crop_tr .size (3 ) == crop_width
102
+ assert crop_bl .size (3 ) == crop_width
103
+ assert crop_br .size (3 ) == crop_width
104
+
105
+ center_crop_y = (image_height - crop_height ) // 2
106
+ center_crop_x = (image_width - crop_width ) // 2
107
+
108
+ crop_cc = image [..., center_crop_y :center_crop_y + crop_height , center_crop_x :center_crop_x + crop_width ]
109
+ assert crop_cc .size (2 ) == crop_height
110
+ assert crop_cc .size (3 ) == crop_width
111
+
112
+ output = model (crop_tl ) + model (F .torch_fliplr (crop_tl )) + \
113
+ model (crop_tr ) + model (F .torch_fliplr (crop_tr )) + \
114
+ model (crop_bl ) + model (F .torch_fliplr (crop_bl )) + \
115
+ model (crop_br ) + model (F .torch_fliplr (crop_br )) + \
116
+ model (crop_cc ) + model (F .torch_fliplr (crop_cc ))
117
+
118
+ one_over_10 = float (1.0 / 10.0 )
119
+ return output * one_over_10
120
+
121
+
73
122
def fliplr_image2mask (model : nn .Module , image : Tensor ) -> Tensor :
74
123
"""Test-time augmentation for image segmentation that averages predictions
75
124
for input image and vertically flipped one.
@@ -80,7 +129,7 @@ def fliplr_image2mask(model: nn.Module, image: Tensor) -> Tensor:
80
129
:param image: Model input.
81
130
:return: Arithmetically averaged predictions
82
131
"""
83
- output = model (image ) + F .torch_fliplp (model (F .torch_fliplp (image )))
132
+ output = model (image ) + F .torch_fliplr (model (F .torch_fliplr (image )))
84
133
one_over_2 = float (1.0 / 2.0 )
85
134
return output * one_over_2
86
135
@@ -129,7 +178,7 @@ def d4_image2mask(model: nn.Module, image: Tensor) -> Tensor:
129
178
130
179
for aug , deaug in zip ([F .torch_none , F .torch_rot90 , F .torch_rot180 , F .torch_rot270 ], [F .torch_none , F .torch_rot270 , F .torch_rot180 , F .torch_rot90 ]):
131
180
x = deaug (model (aug (image )))
132
- output = output + x
181
+ output = output + F . torch_transpose ( x )
133
182
134
183
one_over_8 = float (1.0 / 8.0 )
135
184
return output * one_over_8
0 commit comments