Skip to content

Commit b8796b8

Browse files
authored
Merge pull request #39 from BloodAxe/develop
Release 0.3.1
2 parents 7320185 + b6ffa43 commit b8796b8

File tree

9 files changed

+291
-52
lines changed

9 files changed

+291
-52
lines changed

pytorch_toolbelt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from __future__ import absolute_import
22

3-
__version__ = "0.3.0"
3+
__version__ = "0.3.1"

pytorch_toolbelt/inference/tta.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
def fliplr_image2label(model: nn.Module, image: Tensor) -> Tensor:
2727
"""Test-time augmentation for image classification that averages predictions
28-
for input image and vertically flipped one.
28+
for input image and horizontally flipped one.
2929
3030
:param model:
3131
:param image:
@@ -176,7 +176,7 @@ def d4_image2label(model: nn.Module, image: Tensor) -> Tensor:
176176

177177

178178
def d4_image2mask(model: nn.Module, image: Tensor) -> Tensor:
179-
"""Test-time augmentation for image classification that averages predictions
179+
"""Test-time augmentation for image segmentation that averages predictions
180180
of all D4 augmentations applied to input image.
181181
182182
For segmentation we need to reverse the augmentation after making a prediction
@@ -222,7 +222,7 @@ class MultiscaleTTAWrapper(nn.Module):
222222
Multiscale TTA wrapper module
223223
"""
224224

225-
def __init__(self, model: nn.Module, scale_levels: List[float]):
225+
def __init__(self, model: nn.Module, scale_levels: List[float] = None, size_offsets: List[int] = None):
226226
"""
227227
Initialize multi-scale TTA wrapper
228228
@@ -231,9 +231,11 @@ def __init__(self, model: nn.Module, scale_levels: List[float]):
231231
e.g: [0.5, 0.75, 1.25]
232232
"""
233233
super().__init__()
234-
assert len(scale_levels)
234+
assert scale_levels or size_offsets, "Either scale_levels or size_offsets must be set"
235+
assert not (scale_levels and size_offsets), "Either scale_levels or size_offsets must be set"
235236
self.model = model
236237
self.scale_levels = scale_levels
238+
self.size_offsets = size_offsets
237239

238240
def forward(self, input: Tensor) -> Tensor:
239241
h = input.size(2)
@@ -242,11 +244,21 @@ def forward(self, input: Tensor) -> Tensor:
242244
out_size = h, w
243245
output = self.model(input)
244246

245-
for scale in self.scale_levels:
246-
dst_size = int(h * scale), int(w * scale)
247-
input_scaled = interpolate(input, dst_size, mode="bilinear", align_corners=False)
248-
output_scaled = self.model(input_scaled)
249-
output_scaled = interpolate(output_scaled, out_size, mode="bilinear", align_corners=False)
250-
output += output_scaled
251-
252-
return output / (1 + len(self.scale_levels))
247+
if self.scale_levels:
248+
for scale in self.scale_levels:
249+
dst_size = int(h * scale), int(w * scale)
250+
input_scaled = interpolate(input, dst_size, mode="bilinear", align_corners=False)
251+
output_scaled = self.model(input_scaled)
252+
output_scaled = interpolate(output_scaled, out_size, mode="bilinear", align_corners=False)
253+
output += output_scaled
254+
output /= 1.0 + len(self.scale_levels)
255+
elif self.size_offsets:
256+
for offset in self.size_offsets:
257+
dst_size = int(h + offset), int(w + offset)
258+
input_scaled = interpolate(input, dst_size, mode="bilinear", align_corners=False)
259+
output_scaled = self.model(input_scaled)
260+
output_scaled = interpolate(output_scaled, out_size, mode="bilinear", align_corners=False)
261+
output += output_scaled
262+
output /= 1.0 + len(self.size_offsets)
263+
264+
return output

pytorch_toolbelt/losses/soft_ce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
class SoftCrossEntropyLoss(nn.Module):
9-
def __init__(self, smooth_factor=1e-4, ignore_index=None):
9+
def __init__(self, smooth_factor=1e-4, ignore_index: int = -100):
1010
super().__init__()
1111
self.smooth_factor = smooth_factor
1212
self.ignore_index = ignore_index

pytorch_toolbelt/modules/decoders/hrnet.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,29 @@
11
from collections import OrderedDict
22
from typing import List
33

4+
import torch
5+
import torch.nn.functional as F
46
from torch import nn, Tensor
57

6-
from .common import DecoderModule
8+
from .common import SegmentationDecoderModule
79

8-
__all__ = ["HRNetDecoder"]
10+
__all__ = ["HRNetSegmentationDecoder"]
911

1012

11-
class HRNetDecoder(DecoderModule):
12-
def __init__(self, feature_maps: List[int], output_channels: int, dropout=0.0):
13+
class HRNetSegmentationDecoder(SegmentationDecoderModule):
14+
def __init__(
15+
self,
16+
feature_maps: List[int],
17+
output_channels: int,
18+
dropout=0.0,
19+
interpolation_mode="nearest",
20+
align_corners=None,
21+
):
1322
super().__init__()
23+
self.interpolation_mode = interpolation_mode
24+
self.align_corners = align_corners
1425

15-
features = feature_maps[-1]
16-
26+
features = sum(feature_maps)
1727
self.embedding = nn.Sequential(
1828
OrderedDict(
1929
[
@@ -37,5 +47,15 @@ def __init__(self, feature_maps: List[int], output_channels: int, dropout=0.0):
3747
)
3848

3949
def forward(self, features: List[Tensor]):
40-
embedding = self.embedding(features[-1])
50+
x_size = features[0].size()[2:]
51+
52+
resized_feature_maps = [features[0]]
53+
for feature_map in features[1:]:
54+
feature_map = F.interpolate(
55+
feature_map, size=x_size, mode=self.interpolation_mode, align_corners=self.align_corners
56+
)
57+
resized_feature_maps.append(feature_map)
58+
59+
feature_map = torch.cat(resized_feature_maps, dim=1)
60+
embedding = self.embedding(feature_map)
4161
return self.logits(embedding)

pytorch_toolbelt/modules/encoders/unet.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,15 @@ def __init__(self, input_channels=3, features=32, num_layers=4, growth_factor=2,
1616
super().__init__(feature_maps, strides, layers=list(range(num_layers)))
1717

1818
input_filters = input_channels
19-
output_filters = feature_maps[0]
2019
self.num_layers = num_layers
2120
for layer in range(num_layers):
22-
block = UnetEncoderBlock(input_filters, output_filters, abn_block=abn_block)
23-
21+
block = UnetEncoderBlock(input_filters, feature_maps[layer], abn_block=abn_block)
22+
input_filters = feature_maps[layer]
2423
self.add_module(f"layer{layer}", block)
2524

2625
@property
2726
def encoder_layers(self):
28-
return [self[f"layer{layer}"] for layer in range(self.num_layers)]
27+
return [self.__getattr__(f"layer{layer}") for layer in range(self.num_layers)]
2928

3029
def change_input_channels(self, input_channels: int, mode="auto"):
3130
self.layer0.conv1 = make_n_channel_input(self.layer0.conv1, input_channels, mode)

pytorch_toolbelt/utils/catalyst/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from .metrics import *
55
from .opl import *
66
from .visualization import *
7+
from .utils import *

pytorch_toolbelt/utils/catalyst/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def binary_dice_iou_score(
227227
if mode == "dice":
228228
score = (2.0 * intersection) / (cardinality + eps)
229229
else:
230-
score = intersection / (cardinality + eps)
230+
score = intersection / (cardinality - intersection + eps)
231231

232232
has_targets = torch.sum(y_true) > 0
233233
has_predicted = torch.sum(y_pred) > 0
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from typing import Dict
2+
3+
import torch
4+
5+
__all__ = ["clean_checkpoint", "report_checkpoint"]
6+
7+
8+
def clean_checkpoint(src_fname, dst_fname):
9+
"""
10+
Remove optimizer, scheduler and criterion states from checkpoint
11+
:param src_fname: Source checkpoint filename
12+
:param dst_fname: Target checkpoint filename (can be same)
13+
"""
14+
checkpoint = torch.load(src_fname, map_location="cpu")
15+
16+
keys = ["criterion_state_dict", "optimizer_state_dict", "scheduler_state_dict"]
17+
18+
for key in keys:
19+
if key in checkpoint:
20+
del checkpoint[key]
21+
22+
torch.save(checkpoint, dst_fname)
23+
24+
25+
def report_checkpoint(checkpoint: Dict):
26+
"""
27+
Print checkpoint metrics and epoch number
28+
:param checkpoint:
29+
"""
30+
print("Epoch :", checkpoint["epoch"])
31+
32+
skip_fields = [
33+
"_base/lr",
34+
"_base/momentum",
35+
"_timers/data_time",
36+
"_timers/model_time",
37+
"_timers/batch_time",
38+
"_timers/_fps",
39+
]
40+
print(
41+
"Metrics (Train):", [(k, v) for k, v, in checkpoint["epoch_metrics"]["train"].items() if k not in skip_fields]
42+
)
43+
print(
44+
"Metrics (Valid):", [(k, v) for k, v, in checkpoint["epoch_metrics"]["valid"].items() if k not in skip_fields]
45+
)

0 commit comments

Comments
 (0)