Skip to content

Commit 1054a4f

Browse files
authored
Merge pull request #43 from BloodAxe/develop
Pytorch-toolbelt 0.3.2
2 parents b8796b8 + 11693ca commit 1054a4f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+2467
-1408
lines changed

.deepsource.toml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
version = 1
2+
3+
test_patterns = [
4+
"tests/**",
5+
"test_*.py"
6+
]
7+
8+
[[analyzers]]
9+
name = "python"
10+
enabled = true
11+
12+
[analyzers.meta]
13+
runtime_version = "3.x.x"

README.md

Lines changed: 117 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
[![Build Status](https://travis-ci.org/BloodAxe/pytorch-toolbelt.svg?branch=develop)](https://travis-ci.org/BloodAxe/pytorch-toolbelt)
44
[![Documentation Status](https://readthedocs.org/projects/pytorch-toolbelt/badge/?version=latest)](https://pytorch-toolbelt.readthedocs.io/en/latest/?badge=latest)
5-
5+
[![DeepSource](https://static.deepsource.io/deepsource-badge-light-mini.svg)](https://deepsource.io/gh/BloodAxe/pytorch-toolbelt/?ref=repository-badge)
66

77
A `pytorch-toolbelt` is a Python library with a set of bells and whistles for PyTorch for fast R&D prototyping and Kaggle farming:
88

@@ -25,57 +25,135 @@ During 2018 I achieved a [Kaggle Master](https://www.kaggle.com/bloodaxe) badge
2525
Very often I found myself re-using most of the old pipelines over and over again.
2626
At some point it crystallized into this repository.
2727

28-
This lib is not meant to replace catalyst / ignite / fast.ai. Instead it's designed to complement them.
28+
This lib is not meant to replace catalyst / ignite / fast.ai high-level frameworks. Instead it's designed to complement them.
2929

3030
# Installation
3131

3232
`pip install pytorch_toolbelt`
3333

34-
# Showcase
34+
# How do I ...
35+
36+
## Model creation
3537

36-
## Encoder-decoder models construction
38+
### Create Encoder-Decoder U-Net model
3739

40+
Below a code snippet that creates vanilla U-Net model for binary segmentation.
41+
By design, both encoder and decoder produces a list of tensors, from fine (high-resolution, indexed `0`) to coarse (low-resolution) feature maps.
42+
Access to all intermediate feature maps is beneficial if you want to apply deep supervision losses on them or encoder-decoder of object detection task,
43+
where access to intermediate feature maps is necessary.
44+
3845
```python
46+
from torch import nn
3947
from pytorch_toolbelt.modules import encoders as E
4048
from pytorch_toolbelt.modules import decoders as D
4149

42-
class FPNSegmentationModel(nn.Module):
43-
def __init__(self, encoder:E.EncoderModule, num_classes, fpn_features=128):
44-
self.encoder = encoder
45-
self.decoder = D.FPNDecoder(encoder.output_filters, fpn_features=fpn_features)
46-
self.fuse = D.FPNFuse()
47-
input_channels = sum(self.decoder.output_filters)
48-
self.logits = nn.Conv2d(input_channels, num_classes,kernel_size=1)
49-
50-
def forward(self, input):
51-
features = self.encoder(input)
52-
features = self.decoder(features)
53-
features = self.fuse(features)
54-
logits = self.logits(features)
55-
return logits
56-
57-
def fpn_resnext50(num_classes):
58-
encoder = E.SEResNeXt50Encoder()
59-
return FPNSegmentationModel(encoder, num_classes)
60-
61-
def fpn_mobilenet(num_classes):
62-
encoder = E.MobilenetV2Encoder()
63-
return FPNSegmentationModel(encoder, num_classes)
50+
class UNet(nn.Module):
51+
def __init__(self, input_channels, num_classes):
52+
super().__init__()
53+
self.encoder = E.UnetEncoder(in_channels=input_channels, out_channels=32, growth_factor=2)
54+
self.decoder = D.UNetDecoder(self.encoder.channels, decoder_features=32)
55+
self.logits = nn.Conv2d(self.decoder.channels[0], num_classes, kernel_size=1)
56+
57+
def forward(self, x):
58+
x = self.encoder(x)
59+
x = self.decoder(x)
60+
return self.logits(x[0])
6461
```
6562

66-
## Compose multiple losses
63+
### Create Encoder-Decoder FPN model with pretrained encoder
64+
65+
Similarly to previous example, you can change decoder to FPN with contatenation.
66+
67+
```python
68+
from torch import nn
69+
from pytorch_toolbelt.modules import encoders as E
70+
from pytorch_toolbelt.modules import decoders as D
71+
72+
class SEResNeXt50FPN(nn.Module):
73+
def __init__(self, num_classes, fpn_channels):
74+
super().__init__()
75+
self.encoder = E.SEResNeXt50Encoder()
76+
self.decoder = D.FPNCatDecoder(self.encoder.channels, fpn_channels)
77+
self.logits = nn.Conv2d(self.decoder.channels[0], num_classes, kernel_size=1)
78+
79+
def forward(self, x):
80+
x = self.encoder(x)
81+
x = self.decoder(x)
82+
return self.logits(x[0])
83+
```
84+
85+
### Change number of input channels for the Encoder
86+
87+
All encoders from `pytorch_toolbelt` supports changing number of input channels. Simply call `encoder.change_input_channels(num_channels)` and first convolution layer will be changed.
88+
Whenever possible, existing weights of convolutional layer will be re-used (in case new number of channels is greater than default, new weight tensor will be padded with randomly-initialized weigths).
89+
Class method returns `self`, so this call can be chained.
90+
91+
92+
```python
93+
from pytorch_toolbelt.modules import encoders as E
94+
95+
encoder = E.SEResnet101Encoder()
96+
encoder = encoder.change_input_channels(6)
97+
```
98+
99+
100+
## Misc
101+
102+
103+
## Count number of parameters in encoder/decoder and other modules
104+
105+
When designing a model and optimizing number of features in neural network, I found it's quite useful to print number of parameters in high-level blocks (like `encoder` and `decoder`).
106+
Here is how to do it with `pytorch_toolbelt`:
107+
108+
109+
```python
110+
from torch import nn
111+
from pytorch_toolbelt.modules import encoders as E
112+
from pytorch_toolbelt.modules import decoders as D
113+
from pytorch_toolbelt.utils import count_parameters
114+
115+
class SEResNeXt50FPN(nn.Module):
116+
def __init__(self, num_classes, fpn_channels):
117+
super().__init__()
118+
self.encoder = E.SEResNeXt50Encoder()
119+
self.decoder = D.FPNCatDecoder(self.encoder.channels, fpn_channels)
120+
self.logits = nn.Conv2d(self.decoder.channels[0], num_classes, kernel_size=1)
121+
122+
def forward(self, x):
123+
x = self.encoder(x)
124+
x = self.decoder(x)
125+
return self.logits(x[0])
126+
127+
net = SEResNeXt50FPN(1, 128)
128+
print(count_parameters(net))
129+
# Prints {'total': 34232561, 'trainable': 34232561, 'encoder': 25510896, 'decoder': 8721536, 'logits': 129}
130+
131+
```
132+
133+
### Compose multiple losses
134+
135+
There are multiple ways to combine multiple losses, and high-level DL frameworks like Catalyst offers way more flexible way to achieve this, but here's 100%-pure PyTorch implementation of mine:
67136

68137
```python
69138
from pytorch_toolbelt import losses as L
70139

140+
# Creates a loss function that is a weighted sum of focal loss
141+
# and lovasz loss with weigths 1.0 and 0.5 accordingly.
71142
loss = L.JointLoss(L.FocalLoss(), 1.0, L.LovaszLoss(), 0.5)
72143
```
73144

74-
## Test-time augmentation
145+
146+
## TTA / Inferencing
147+
148+
### Apply Test-time augmentation (TTA) for the model
149+
150+
Test-time augmetnation (TTA) can be used in both training and testing phases.
75151

76152
```python
77153
from pytorch_toolbelt.inference import tta
78154

155+
model = UNet()
156+
79157
# Truly functional TTA for image classification using horizontal flips:
80158
logits = tta.fliplr_image2label(model, input)
81159

@@ -87,11 +165,19 @@ tta_model = tta.TTAWrapper(model, tta.fivecrop_image2label, crop_size=512)
87165
logits = tta_model(input)
88166
```
89167

90-
## Inference on huge images:
168+
### Inference on huge images:
169+
170+
Quite often, there is a need to perform image segmentation for enormously big image (5000px and more). There are a few problems with such a big pixel arrays:
171+
1. There are size limitations on maximum size of CUDA tensors (Concrete numbers depends on driver and GPU version)
172+
2. Heavy CNNs architectures may eat up all available GPU memory with ease when inferencing relatively small 1024x1024 images, leaving no room to bigger image resolution.
173+
174+
One of the solutions is to slice input image into tiles (optionally overlapping) and feed each through model and concatenate the results back.
175+
In this way you can guarantee upper limit of GPU ram usage, while keeping ability to process arbitrary-sized images on GPU.
176+
91177

92178
```python
93179
import numpy as np
94-
import torch
180+
from torch.utils.data import DataLoader
95181
import cv2
96182

97183
from pytorch_toolbelt.inference.tiles import ImageSlicer, CudaTileMerger
@@ -102,7 +188,7 @@ image = cv2.imread('really_huge_image.jpg')
102188
model = get_model(...)
103189

104190
# Cut large image into overlapping tiles
105-
tiler = ImageSlicer(image.shape, tile_size=(512, 512), tile_step=(256, 256), weight='pyramid')
191+
tiler = ImageSlicer(image.shape, tile_size=(512, 512), tile_step=(256, 256))
106192

107193
# HCW -> CHW. Optionally, do normalization here
108194
tiles = [tensor_from_rgb_image(tile) for tile in tiler.split(image)]

pytorch_toolbelt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from __future__ import absolute_import
22

3-
__version__ = "0.3.1"
3+
__version__ = "0.3.2"
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from torch import nn, Tensor
2+
from typing import List, Union
3+
4+
__all__ = ["ApplySoftmaxTo", "ApplySigmoidTo", "Ensembler", "PickModelOutput"]
5+
6+
7+
class ApplySoftmaxTo(nn.Module):
8+
def __init__(self, model, output_key: Union[str, List[str]] = "logits", dim=1):
9+
super().__init__()
10+
output_key = output_key if isinstance(output_key, (list, tuple)) else [output_key]
11+
# By converting to set, we prevent double-activation by passing output_key=["logits", "logits"]
12+
self.output_keys = set(output_key)
13+
self.model = model
14+
self.dim = dim
15+
16+
def forward(self, input):
17+
output = self.model(input)
18+
for key in self.output_keys:
19+
output[key] = output[key].softmax(dim=1)
20+
return output
21+
22+
23+
class ApplySigmoidTo(nn.Module):
24+
def __init__(self, model, output_key: Union[str, List[str]] = "logits"):
25+
super().__init__()
26+
output_key = output_key if isinstance(output_key, (list, tuple)) else [output_key]
27+
# By converting to set, we prevent double-activation by passing output_key=["logits", "logits"]
28+
self.output_keys = set(output_key)
29+
self.model = model
30+
31+
def forward(self, input): # skipcq: PYL-W0221
32+
output = self.model(input)
33+
for key in self.output_keys:
34+
output[key] = output[key].sigmoid()
35+
return output
36+
37+
38+
class Ensembler(nn.Module):
39+
"""
40+
Computes sum of outputs for several models with arithmetic averaging (optional).
41+
"""
42+
43+
def __init__(self, models: List[nn.Module], average=True, outputs=None):
44+
"""
45+
46+
:param models:
47+
:param average:
48+
:param outputs: Name of model outputs to average and return from Ensembler.
49+
If None, all outputs from the first model will be used.
50+
"""
51+
super().__init__()
52+
self.outputs = outputs
53+
self.models = nn.ModuleList(models)
54+
self.average = average
55+
56+
def forward(self, x): # skipcq: PYL-W0221
57+
output_0 = self.models[0](x)
58+
num_models = len(self.models)
59+
60+
if self.outputs:
61+
keys = self.outputs
62+
else:
63+
keys = output_0.keys()
64+
65+
for index in range(1, num_models):
66+
output_i = self.models[index](x)
67+
68+
# Sum outputs
69+
for key in keys:
70+
output_0[key] += output_i[key]
71+
72+
if self.average:
73+
for key in keys:
74+
output_0[key] /= num_models
75+
76+
return output_0
77+
78+
79+
class PickModelOutput(nn.Module):
80+
"""
81+
Assuming you have a model that outputs a dictionary, this module returns only a given element by it's key
82+
"""
83+
84+
def __init__(self, model: nn.Module, key: str):
85+
super().__init__()
86+
self.model = model
87+
self.target_key = key
88+
89+
def forward(self, input) -> Tensor:
90+
output = self.model(input)
91+
return output[self.target_key]

0 commit comments

Comments
 (0)