Skip to content

Commit e41b4fe

Browse files
committed
Improving ensembling
1 parent 1fc6e38 commit e41b4fe

File tree

2 files changed

+29
-53
lines changed

2 files changed

+29
-53
lines changed

pytorch_toolbelt/inference/ensembling.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import torch
12
from torch import nn, Tensor
2-
from typing import List, Union
3+
from typing import List, Union, Iterable, Optional
34

45
__all__ = ["ApplySoftmaxTo", "ApplySigmoidTo", "Ensembler", "PickModelOutput"]
56

7+
from pytorch_toolbelt.inference.tta import _deaugment_averaging
8+
69

710
class ApplySoftmaxTo(nn.Module):
811
def __init__(self, model: nn.Module, output_key: Union[str, List[str]] = "logits", dim=1, temperature=1):
@@ -55,40 +58,35 @@ class Ensembler(nn.Module):
5558
Compute sum (or average) of outputs of several models.
5659
"""
5760

58-
def __init__(self, models: List[nn.Module], average=True, outputs=None):
61+
def __init__(self, models: List[nn.Module], reduction: str = "mean", outputs: Optional[Iterable[str]] = None):
5962
"""
6063
6164
:param models:
62-
:param average:
65+
:param reduction: Reduction key ('mean', 'sum', 'gmean', 'hmean' or None)
6366
:param outputs: Name of model outputs to average and return from Ensembler.
6467
If None, all outputs from the first model will be used.
6568
"""
6669
super().__init__()
6770
self.outputs = outputs
6871
self.models = nn.ModuleList(models)
69-
self.average = average
72+
self.reduction = reduction
7073

7174
def forward(self, *input, **kwargs): # skipcq: PYL-W0221
72-
output_0 = self.models[0](*input, **kwargs)
73-
num_models = len(self.models)
75+
outputs = [model(*input, **kwargs) for model in self.models]
7476

7577
if self.outputs:
7678
keys = self.outputs
7779
else:
78-
keys = output_0.keys()
79-
80-
for index in range(1, num_models):
81-
output_i = self.models[index](*input, **kwargs)
82-
83-
# Sum outputs
84-
for key in keys:
85-
output_0[key].add_(output_i[key])
80+
keys = outputs[0].keys()
8681

87-
if self.average:
88-
for key in keys:
89-
output_0[key].mul_(1.0 / num_models)
82+
averaged_output = {}
83+
for key in keys:
84+
predictions = [output[key] for output in outputs]
85+
predictions = torch.stack(predictions)
86+
predictions = _deaugment_averaging(predictions, self.reduction)
87+
averaged_output[key] = predictions
9088

91-
return output_0
89+
return averaged_output
9290

9391

9492
class PickModelOutput(nn.Module):

pytorch_toolbelt/inference/tta.py

Lines changed: 13 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _deaugment_averaging(x: Tensor, reduction: MaybeStrOrCallable) -> Tensor:
4848
Helper method to average predictions of TTA-ed model.
4949
This function assumes TTA dimension is 0, e.g [T, B, C, Ci, Cj, ..]
5050
Args:
51-
x:
51+
x: Input tensor of shape [T, B, ... ]
5252
reduction: Reduction mode ("sum", "mean", "gmean", "hmean", function, None)
5353
5454
Returns:
@@ -64,6 +64,11 @@ def _deaugment_averaging(x: Tensor, reduction: MaybeStrOrCallable) -> Tensor:
6464
x = F.harmonic_mean(x, dim=0)
6565
elif callable(reduction):
6666
x = reduction(x, dim=0)
67+
elif reduction in {None, "None", "none"}:
68+
pass
69+
else:
70+
raise KeyError(f"Unsupported reduction mode {reduction}")
71+
6772
return x
6873

6974

@@ -94,10 +99,7 @@ def fivecrop_image_augment(image: Tensor, crop_size: Tuple[int, int]) -> Tensor:
9499
center_crop_x = (image_width - crop_width) // 2
95100
crop_cc = image[..., center_crop_y : center_crop_y + crop_height, center_crop_x : center_crop_x + crop_width]
96101

97-
return torch.cat(
98-
[crop_tl, crop_tr, crop_bl, crop_br, crop_cc],
99-
dim=0,
100-
)
102+
return torch.cat([crop_tl, crop_tr, crop_bl, crop_br, crop_cc], dim=0,)
101103

102104

103105
def fivecrop_label_deaugment(logits: Tensor, reduction: MaybeStrOrCallable = "mean") -> Tensor:
@@ -275,15 +277,7 @@ def d2_image_augment(image: Tensor) -> Tensor:
275277
- Vertically-flipped tensor
276278
277279
"""
278-
return torch.cat(
279-
[
280-
image,
281-
F.torch_rot180(image),
282-
F.torch_fliplr(image),
283-
F.torch_flipud(image),
284-
],
285-
dim=0,
286-
)
280+
return torch.cat([image, F.torch_rot180(image), F.torch_fliplr(image), F.torch_flipud(image),], dim=0,)
287281

288282

289283
def d2_image_deaugment(image: Tensor, reduction: MaybeStrOrCallable = "mean") -> Tensor:
@@ -302,12 +296,7 @@ def d2_image_deaugment(image: Tensor, reduction: MaybeStrOrCallable = "mean") ->
302296
b1, b2, b3, b4 = torch.chunk(image, 4)
303297

304298
image: Tensor = torch.stack(
305-
[
306-
b1,
307-
F.torch_rot180(b2),
308-
F.torch_fliplr(b3),
309-
F.torch_flipud(b4),
310-
]
299+
[b1, F.torch_rot180(b2), F.torch_fliplr(b3), F.torch_flipud(b4),]
311300
)
312301

313302
return _deaugment_averaging(image, reduction=reduction)
@@ -440,10 +429,7 @@ def flips_image_augment(image: Tensor) -> Tensor:
440429
return torch.cat([image, F.torch_fliplr(image), F.torch_flipud(image)], dim=0)
441430

442431

443-
def flips_image_deaugment(
444-
image: Tensor,
445-
reduction: MaybeStrOrCallable = "mean",
446-
) -> Tensor:
432+
def flips_image_deaugment(image: Tensor, reduction: MaybeStrOrCallable = "mean",) -> Tensor:
447433
"""
448434
Deaugment input tensor (output of the model) assuming the input was flip-augmented image (See flips_augment).
449435
Args:
@@ -464,10 +450,7 @@ def flips_image_deaugment(
464450
return _deaugment_averaging(image, reduction=reduction)
465451

466452

467-
def fliplr_labels_deaugment(
468-
logits: Tensor,
469-
reduction: MaybeStrOrCallable = "mean",
470-
) -> Tensor:
453+
def fliplr_labels_deaugment(logits: Tensor, reduction: MaybeStrOrCallable = "mean",) -> Tensor:
471454
"""
472455
Deaugment input tensor (output of the model) assuming the input was fliplr-augmented image (See fliplr_image_augment).
473456
Args:
@@ -485,10 +468,7 @@ def fliplr_labels_deaugment(
485468
return _deaugment_averaging(logits, reduction=reduction)
486469

487470

488-
def flips_labels_deaugment(
489-
logits: Tensor,
490-
reduction: MaybeStrOrCallable = "mean",
491-
) -> Tensor:
471+
def flips_labels_deaugment(logits: Tensor, reduction: MaybeStrOrCallable = "mean",) -> Tensor:
492472
"""
493473
Deaugment input tensor (output of the model) assuming the input was flip-augmented image (See flips_image_augment).
494474
Args:
@@ -543,9 +523,7 @@ def ms_image_augment(
543523

544524

545525
def ms_labels_deaugment(
546-
logits: List[Tensor],
547-
size_offsets: List[Union[int, Tuple[int, int]]],
548-
reduction: MaybeStrOrCallable = "mean",
526+
logits: List[Tensor], size_offsets: List[Union[int, Tuple[int, int]]], reduction: MaybeStrOrCallable = "mean",
549527
):
550528
"""
551529
Deaugment logits

0 commit comments

Comments
 (0)