Skip to content

Commit 89c8be3

Browse files
committed
Merge branch 'develop'
2 parents 2b535b7 + 2db8629 commit 89c8be3

File tree

8 files changed

+155
-29
lines changed

8 files changed

+155
-29
lines changed

README.md

Lines changed: 100 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,111 @@ A `pytorch-toolbelt` is a Python library with a set of bells and whistles for Py
88

99
* Easy model building using flexible encoder-decoder architecture.
1010
* Modules: CoordConv, SCSE, Hypercolumn, Depthwise separable convolution and more
11-
* GPU-friendly test-time augmentation
11+
* GPU-friendly test-time augmentation TTA for segmentation and classification
1212
* GPU-friendly inference on huge (5000x5000) images
13-
* Every-day common routines (fix/restore random seed, filesystem utils)
14-
* Fancy losses: Focal, Lovasz, Jaccard and Dice losses
13+
* Every-day common routines (fix/restore random seed, filesystem utils, metrics)
14+
* Fancy losses: Focal, Lovasz, Jaccard and Dice losses, Wing Loss
1515

16-
# Quick start
16+
# Why
1717

18-
`TODO: Implement`
18+
Honest answer is "I needed a convenient way to re-use code for my Kaggle career".
19+
During 2018 I achieved a [Kaggle Master](https://www.kaggle.com/bloodaxe) badge and this been a long path.
20+
Very often I found myself re-using most of the old pipelines over and over again.
21+
At some point it crystallized into this repository.
22+
23+
This lib is not meant to replace catalyst / ignite / fast.ai. Instead it's designed to complement them.
1924

2025
# Installation
2126

22-
`TODO: Implement`
27+
`pip install pytorch_toolbelt`
28+
29+
# Showcase
30+
31+
## Encoder-decoder models construction
32+
33+
```python
34+
from pytorch_toolbelt.modules import encoders as E
35+
from pytorch_toolbelt.modules import decoders as D
36+
37+
class FPNSegmentationModel(nn.Module):
38+
def __init__(self, encoder:E.EncoderModule, num_classes, fpn_features=128):
39+
self.encoder = encoder
40+
self.decoder = D.FPNDecoder(encoder.output_filters, fpn_features=fpn_features)
41+
self.fuse = D.FPNFuse()
42+
input_channels = sum(self.decoder.output_filters)
43+
self.logits = nn.Conv2d(input_channels, num_classes,kernel_size=1)
44+
45+
def forward(self, input):
46+
features = self.encoder(input)
47+
features = self.decoder(features)
48+
features = self.fuse(features)
49+
logits = self.logits(features)
50+
return logits
51+
52+
def fpn_resnext50(num_classes):
53+
encoder = E.SEResNeXt50Encoder()
54+
return FPNSegmentationModel(encoder, num_classes)
55+
56+
def fpn_mobilenet(num_classes):
57+
encoder = E.MobilenetV2Encoder()
58+
return FPNSegmentationModel(encoder, num_classes)
59+
```
60+
61+
## Compose multiple losses
62+
63+
```python
64+
from pytorch_toolbelt import losses as L
65+
66+
loss = L.JointLoss(L.FocalLoss(), 1.0, L.LovaszLoss(), 0.5)
67+
```
68+
69+
## Test-time augmentation
70+
71+
```python
72+
from pytorch_toolbelt.inference import tta
73+
74+
# Truly functional TTA for image classification using horizontal flips:
75+
logits = tta.fliplr_image2label(model, input)
76+
77+
# Truly functional TTA for image segmentation using D4 augmentation:
78+
logits = tta.d4_image2mask(model, input)
79+
80+
# TTA using wrapper module:
81+
tta_model = tta.TTAWrapper(model, tta.fivecrop_image2label, crop_size=512)
82+
logits = tta_model(input)
83+
```
84+
85+
## Inference on huge images:
86+
87+
```python
88+
import numpy as np
89+
import torch
90+
import cv2
91+
92+
from pytorch_toolbelt.inference.tiles import ImageSlicer, CudaTileMerger
93+
from pytorch_toolbelt.utils.torch_utils import tensor_from_rgb_image, to_numpy
94+
95+
96+
image = cv2.imread('really_huge_image.jpg')
97+
model = get_model(...)
98+
99+
# Cut large image into overlapping tiles
100+
tiler = ImageSlicer(image.shape, tile_size=(512, 512), tile_step=(256, 256), weight='pyramid')
101+
102+
# HCW -> CHW. Optionally, do normalization here
103+
tiles = [tensor_from_rgb_image(tile) for tile in tiler.split(image)]
104+
105+
# Allocate a CUDA buffer for holding entire mask
106+
merger = CudaTileMerger(tiler.target_shape, 1, tiler.weight)
107+
108+
# Run predictions for tiles and accumulate them
109+
for tiles_batch, coords_batch in DataLoader(list(zip(tiles, tiler.crops)), batch_size=8, pin_memory=True):
110+
tiles_batch = tiles_batch.float().cuda()
111+
pred_batch = model(tiles_batch)
23112

24-
# Documentation
113+
merger.integrate_batch(pred_batch, coords_batch)
25114

26-
`TODO: Implement`
115+
# Normalize accumulated mask and convert back to numpy
116+
merged_mask = np.moveaxis(to_numpy(merger.merge()), 0, -1).astype(np.uint8)
117+
merged_mask = tiler.crop_to_orignal_size(merged_mask)
118+
```

examples/segmentation-inria/models/factory.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from tqdm import tqdm
1313

1414
from pytorch_toolbelt.inference.tiles import CudaTileMerger, ImageSlicer
15-
from pytorch_toolbelt.inference.tta import tta_fliplr_image2mask, tta_d4_image2mask
15+
from pytorch_toolbelt.inference.tta import fliplr_image2mask, d4_image2mask
1616
from pytorch_toolbelt.losses.focal import BinaryFocalLoss
1717
from pytorch_toolbelt.losses.jaccard import BinaryJaccardLogLoss
1818
from pytorch_toolbelt.losses.lovasz import BinaryLovaszLoss
@@ -94,7 +94,7 @@ def __init__(self, model):
9494
self.model = model
9595

9696
def forward(self, x):
97-
return tta_d4_image2mask(self.model, x)
97+
return d4_image2mask(self.model, x)
9898

9999

100100
class TTAWrapperD4(nn.Module):
@@ -103,7 +103,7 @@ def __init__(self, model):
103103
self.model = model
104104

105105
def forward(self, x):
106-
return tta_fliplr_image2mask(self.model, x)
106+
return fliplr_image2mask(self.model, x)
107107

108108

109109
def predict(model: nn.Module, image: np.ndarray, image_size, tta=None, normalize=A.Normalize(), batch_size=1, activation='sigmoid') -> np.ndarray:

pytorch_toolbelt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from __future__ import absolute_import
22

3-
__version__ = '0.0.3'
3+
__version__ = '0.0.4'

pytorch_toolbelt/inference/tiles.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,16 @@ def __init__(self, image_shape, tile_size, tile_step=0, image_margin=0, weight='
119119
self.margin_top = image_margin
120120
self.margin_bottom = image_margin
121121

122-
self.crops = []
123-
self.bbox_crops = []
122+
crops = []
123+
bbox_crops = []
124124

125125
for y in range(0, self.image_height + self.margin_top + self.margin_bottom - self.tile_size[0] + 1, self.tile_step[0]):
126126
for x in range(0, self.image_width + self.margin_left + self.margin_right - self.tile_size[1] + 1, self.tile_step[1]):
127-
self.crops.append((x, y, self.tile_size[1], self.tile_size[0]))
128-
self.bbox_crops.append((x - self.margin_left, y - self.margin_top, self.tile_size[1], self.tile_size[0]))
127+
crops.append((x, y, self.tile_size[1], self.tile_size[0]))
128+
bbox_crops.append((x - self.margin_left, y - self.margin_top, self.tile_size[1], self.tile_size[0]))
129+
130+
self.crops = np.array(crops)
131+
self.bbox_crops = np.array(bbox_crops)
129132

130133
def split(self, image, border_type=cv2.BORDER_CONSTANT, value=0):
131134
assert image.shape[0] == self.image_height

pytorch_toolbelt/inference/tta.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,17 @@
33
Despite this is called test-time augmentation, these method can be used at training time as well since all
44
transformation written in PyTorch and respect gradients flow.
55
"""
6+
from functools import partial
67
from typing import Tuple
78

89
from torch import Tensor, nn
910
from . import functional as F
1011

12+
__all__ = ['d4_image2label', 'd4_image2mask', 'fivecrop_image2label', 'fliplr_image2mask',
13+
'fliplr_image2label', 'TTAWrapper']
1114

12-
def tta_fliplr_image2label(model: nn.Module, image: Tensor) -> Tensor:
15+
16+
def fliplr_image2label(model: nn.Module, image: Tensor) -> Tensor:
1317
"""Test-time augmentation for image classification that averages predictions
1418
for input image and vertically flipped one.
1519
@@ -22,7 +26,7 @@ def tta_fliplr_image2label(model: nn.Module, image: Tensor) -> Tensor:
2226
return output * one_over_2
2327

2428

25-
def tta_fivecrop_image2label(model: nn.Module, image: Tensor, crop_size: Tuple) -> Tensor:
29+
def fivecrop_image2label(model: nn.Module, image: Tensor, crop_size: Tuple) -> Tensor:
2630
"""Test-time augmentation for image classification that takes five crops out of input tensor (4 on corners and central)
2731
and averages predictions from them.
2832
@@ -66,7 +70,7 @@ def tta_fivecrop_image2label(model: nn.Module, image: Tensor, crop_size: Tuple)
6670
return output * one_over_5
6771

6872

69-
def tta_fliplr_image2mask(model: nn.Module, image: Tensor) -> Tensor:
73+
def fliplr_image2mask(model: nn.Module, image: Tensor) -> Tensor:
7074
"""Test-time augmentation for image segmentation that averages predictions
7175
for input image and vertically flipped one.
7276
@@ -81,7 +85,7 @@ def tta_fliplr_image2mask(model: nn.Module, image: Tensor) -> Tensor:
8185
return output * one_over_2
8286

8387

84-
def tta_d4_image2label(model: nn.Module, image: Tensor) -> Tensor:
88+
def d4_image2label(model: nn.Module, image: Tensor) -> Tensor:
8589
"""Test-time augmentation for image classification that averages predictions
8690
of all D4 augmentations applied to input image.
8791
@@ -105,7 +109,7 @@ def tta_d4_image2label(model: nn.Module, image: Tensor) -> Tensor:
105109
return output * one_over_8
106110

107111

108-
def tta_d4_image2mask(model: nn.Module, image: Tensor) -> Tensor:
112+
def d4_image2mask(model: nn.Module, image: Tensor) -> Tensor:
109113
"""Test-time augmentation for image classification that averages predictions
110114
of all D4 augmentations applied to input image.
111115
@@ -129,3 +133,13 @@ def tta_d4_image2mask(model: nn.Module, image: Tensor) -> Tensor:
129133

130134
one_over_8 = float(1.0 / 8.0)
131135
return output * one_over_8
136+
137+
138+
class TTAWrapper(nn.Module):
139+
def __init__(self, model, tta_function, **kwargs):
140+
super().__init__()
141+
self.model = model
142+
self.tta = partial(tta_function, **kwargs)
143+
144+
def forward(self, *input):
145+
return self.tta(self.model, *input)

pytorch_toolbelt/modules/decoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def forward(self, features):
5454

5555
class FPNDecoder(DecoderModule):
5656
def __init__(self, features,
57-
prediction_block: nn.Module,
57+
prediction_block=FPNPredictionBlock,
5858
bottleneck=FPNBottleneckBlock,
5959
fpn_features=128,
6060
prediction_features=128,

pytorch_toolbelt/optimization/functional.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
def get_lr_decay_parameters(parameters, learning_rate, groups: dict):
2-
custom_lr_parameters = dict((group_name, {'params': [], 'lr': learning_rate * lr_factor}) for (group_name, lr_factor) in groups.items())
2+
custom_lr_parameters = dict((group_name, {'params': [], 'lr': learning_rate * lr_factor})
3+
for (group_name, lr_factor) in groups.items())
34
custom_lr_parameters['default'] = {'params': [], 'lr': learning_rate}
45

56
for parameter_name, parameter in parameters:

tests/test_tiles.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import numpy as np
22
import torch
3+
from torch import nn
4+
from torch.utils.data import DataLoader
35

46
from pytorch_toolbelt.inference.tiles import ImageSlicer, CudaTileMerger
57
from pytorch_toolbelt.utils.torch_utils import tensor_from_rgb_image, rgb_image_from_tensor, to_numpy
@@ -24,19 +26,33 @@ def test_tiles_split_merge_2():
2426
np.testing.assert_equal(merged, image)
2527

2628

29+
@torch.no_grad()
2730
def test_tiles_split_merge_cuda():
2831
if not torch.cuda.is_available():
2932
return
33+
34+
class MaxChannelIntensity(nn.Module):
35+
def __init__(self):
36+
super().__init__()
37+
38+
def forward(self, input):
39+
max_channel, _ = torch.max(input, dim=1, keepdim=True)
40+
return max_channel
41+
3042
image = np.random.random((5000, 5000, 3)).astype(np.uint8)
3143
tiler = ImageSlicer(image.shape, tile_size=(512, 512), tile_step=(256, 256), weight='pyramid')
32-
tiles = tiler.split(image)
44+
tiles = [tensor_from_rgb_image(tile) for tile in tiler.split(image)]
3345

34-
merger = CudaTileMerger(tiler.target_shape, 3, tiler.weight)
35-
for tile, coord in zip(tiles, tiler.crops):
36-
batch = tensor_from_rgb_image(tile).unsqueeze(0).float().cuda()
37-
merger.integrate_batch(batch, [coord])
46+
model = MaxChannelIntensity().eval().cuda()
47+
48+
merger = CudaTileMerger(tiler.target_shape, 1, tiler.weight)
49+
for tiles_batch, coords_batch in DataLoader(list(zip(tiles, tiler.crops)), batch_size=8, pin_memory=True):
50+
tiles_batch = tiles_batch.float().cuda()
51+
pred_batch = model(tiles_batch)
52+
53+
merger.integrate_batch(pred_batch, coords_batch)
3854

3955
merged = np.moveaxis(to_numpy(merger.merge()), 0, -1).astype(np.uint8)
4056
merged = tiler.crop_to_orignal_size(merged)
4157

42-
np.testing.assert_equal(merged, image)
58+
np.testing.assert_equal(merged, image.max(axis=2, keepdims=True))

0 commit comments

Comments
 (0)