Skip to content

Commit 3ef7f37

Browse files
authored
Merge pull request #12 from BloodAxe/develop
0.0.9
2 parents 41ea208 + 3c179e6 commit 3ef7f37

File tree

11 files changed

+169
-84
lines changed

11 files changed

+169
-84
lines changed

pytorch_toolbelt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from __future__ import absolute_import
22

3-
__version__ = '0.0.8'
3+
__version__ = '0.0.9'
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from functools import partial
2+
3+
from torch import nn
4+
from torch.nn import functional as F
5+
6+
7+
def swish(x):
8+
return x * x.sigmoid()
9+
10+
11+
def hard_sigmoid(x, inplace=False):
12+
return F.relu6(x + 3, inplace) / 6
13+
14+
15+
def hard_swish(x, inplace=False):
16+
return x * hard_sigmoid(x, inplace)
17+
18+
19+
class HardSigmoid(nn.Module):
20+
def __init__(self, inplace=False):
21+
super(HardSigmoid, self).__init__()
22+
self.inplace = inplace
23+
24+
def forward(self, x):
25+
return hard_sigmoid(x, inplace=self.inplace)
26+
27+
28+
class Swish(nn.Module):
29+
def __init__(self, inplace=False):
30+
super(Swish, self).__init__()
31+
32+
def forward(self, x):
33+
return swish(x)
34+
35+
36+
class HardSwish(nn.Module):
37+
def __init__(self, inplace=False):
38+
super(HardSwish, self).__init__()
39+
self.inplace = inplace
40+
41+
def forward(self, x):
42+
return hard_swish(x, inplace=self.inplace)
43+
44+
45+
def get_activation_module(activation_name: str, **kwargs) -> nn.Module:
46+
if activation_name.lower() == 'relu':
47+
return partial(nn.ReLU, **kwargs)
48+
49+
if activation_name.lower() == 'relu6':
50+
return partial(nn.ReLU6, **kwargs)
51+
52+
if activation_name.lower() == 'leaky_relu':
53+
return partial(nn.LeakyReLU, **kwargs)
54+
55+
if activation_name.lower() == 'elu':
56+
return partial(nn.ELU, **kwargs)
57+
58+
if activation_name.lower() == 'selu':
59+
return partial(nn.SELU, **kwargs)
60+
61+
if activation_name.lower() == 'celu':
62+
return partial(nn.CELU, **kwargs)
63+
64+
if activation_name.lower() == 'glu':
65+
return partial(nn.GLU, **kwargs)
66+
67+
if activation_name.lower() == 'prelu':
68+
return partial(nn.PReLU, **kwargs)
69+
70+
if activation_name.lower() == 'hard_sigmoid':
71+
return partial(HardSigmoid, **kwargs)
72+
73+
if activation_name.lower() == 'hard_swish':
74+
return partial(HardSwish, **kwargs)
75+
76+
raise ValueError(f'Activation \'{activation_name}\' is not supported')

pytorch_toolbelt/modules/backbone/mobilenet.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,29 @@
1+
from __future__ import absolute_import
2+
13
import torch.nn as nn
24
import math
35

6+
from ..activations import get_activation_module
7+
48

5-
def conv_bn(inp, oup, stride):
9+
def conv_bn(inp, oup, stride, activation: nn.Module):
610
return nn.Sequential(
711
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
812
nn.BatchNorm2d(oup),
9-
nn.ReLU6(inplace=True)
13+
activation(inplace=True)
1014
)
1115

1216

13-
def conv_1x1_bn(inp, oup):
17+
def conv_1x1_bn(inp, oup, activation: nn.Module):
1418
return nn.Sequential(
1519
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
1620
nn.BatchNorm2d(oup),
17-
nn.ReLU6(inplace=True)
21+
activation(inplace=True)
1822
)
1923

2024

2125
class InvertedResidual(nn.Module):
22-
def __init__(self, inp, oup, stride, expand_ratio):
26+
def __init__(self, inp, oup, stride, expand_ratio, activation: nn.Module):
2327
super(InvertedResidual, self).__init__()
2428
self.stride = stride
2529
assert stride in [1, 2]
@@ -32,7 +36,7 @@ def __init__(self, inp, oup, stride, expand_ratio):
3236
# dw
3337
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
3438
nn.BatchNorm2d(hidden_dim),
35-
nn.ReLU6(inplace=True),
39+
activation(inplace=True),
3640
# pw-linear
3741
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
3842
nn.BatchNorm2d(oup),
@@ -42,11 +46,11 @@ def __init__(self, inp, oup, stride, expand_ratio):
4246
# pw
4347
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
4448
nn.BatchNorm2d(hidden_dim),
45-
nn.ReLU6(inplace=True),
49+
activation(inplace=True),
4650
# dw
4751
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
4852
nn.BatchNorm2d(hidden_dim),
49-
nn.ReLU6(inplace=True),
53+
activation(inplace=True),
5054
# pw-linear
5155
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
5256
nn.BatchNorm2d(oup),
@@ -60,8 +64,11 @@ def forward(self, x):
6064

6165

6266
class MobileNetV2(nn.Module):
63-
def __init__(self, n_class=1000, input_size=224, width_mult=1.):
67+
def __init__(self, n_class=1000, input_size=224, width_mult=1., activation='relu6'):
6468
super(MobileNetV2, self).__init__()
69+
70+
act = get_activation_module(activation)
71+
6572
block = InvertedResidual
6673
input_channel = 32
6774
last_channel = 1280
@@ -80,7 +87,7 @@ def __init__(self, n_class=1000, input_size=224, width_mult=1.):
8087
assert input_size % 32 == 0
8188
input_channel = int(input_channel * width_mult)
8289
self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
83-
self.layer0 = conv_bn(3, input_channel, 2)
90+
self.layer0 = conv_bn(3, input_channel, 2, act)
8491

8592
# building inverted residual blocks
8693
for layer_index, (t, c, n, s) in enumerate(interverted_residual_setting):
@@ -89,16 +96,16 @@ def __init__(self, n_class=1000, input_size=224, width_mult=1.):
8996
blocks = []
9097
for i in range(n):
9198
if i == 0:
92-
blocks.append(block(input_channel, output_channel, s, expand_ratio=t))
99+
blocks.append(block(input_channel, output_channel, s, expand_ratio=t, activation=act))
93100
else:
94-
blocks.append(block(input_channel, output_channel, 1, expand_ratio=t))
101+
blocks.append(block(input_channel, output_channel, 1, expand_ratio=t, activation=act))
95102

96103
input_channel = output_channel
97104

98105
self.add_module(f'layer{layer_index + 1}', nn.Sequential(*blocks))
99106

100107
# building last several layers
101-
self.final_layer = conv_1x1_bn(input_channel, self.last_channel)
108+
self.final_layer = conv_1x1_bn(input_channel, self.last_channel, activation=act)
102109

103110
# building classifier
104111
self.classifier = nn.Sequential(

pytorch_toolbelt/modules/backbone/mobilenetv3.py

Lines changed: 24 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -6,43 +6,15 @@
66
import torch.nn as nn
77
import torch.nn.functional as F
88

9-
from pytorch_toolbelt.modules.dropblock import DropBlockScheduled, DropBlock2D
10-
from pytorch_toolbelt.modules import Identity
11-
12-
13-
def swish(x):
14-
return x * x.sigmoid()
15-
16-
17-
def hard_sigmoid(x, inplace=False):
18-
return F.relu6(x + 3, inplace) / 6
19-
20-
21-
def hard_swish(x, inplace=False):
22-
return x * hard_sigmoid(x, inplace)
23-
24-
25-
class HardSigmoid(nn.Module):
26-
def __init__(self, inplace=False):
27-
super(HardSigmoid, self).__init__()
28-
self.inplace = inplace
29-
30-
def forward(self, x):
31-
return hard_sigmoid(x, inplace=self.inplace)
32-
33-
34-
class HardSwish(nn.Module):
35-
def __init__(self, inplace=False):
36-
super(HardSwish, self).__init__()
37-
self.inplace = inplace
38-
39-
def forward(self, x):
40-
return hard_swish(x, inplace=self.inplace)
9+
# from pytorch_toolbelt.modules.dropblock import DropBlockScheduled, DropBlock2D
10+
from pytorch_toolbelt.modules.activations import HardSwish, HardSigmoid
11+
from pytorch_toolbelt.modules.identity import Identity
4112

4213

4314
def _make_divisible(v, divisor, min_value=None):
4415
"""
4516
Ensure that all layers have a channel number that is divisible by 8
17+
4618
It can be seen here:
4719
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
4820
:param v:
@@ -59,9 +31,9 @@ def _make_divisible(v, divisor, min_value=None):
5931
return new_v
6032

6133

62-
# https://github.com/jonnedtc/Squeeze-Excitation-PyTorch/blob/master/networks.py
6334
class SqEx(nn.Module):
64-
"""Squeeze-Excitation block, implemented in ONNX & CoreML friendly way
35+
"""Squeeze-Excitation block. Implemented in ONNX & CoreML friendly way.
36+
Original implementation: https://github.com/jonnedtc/Squeeze-Excitation-PyTorch/blob/master/networks.py
6537
"""
6638

6739
def __init__(self, n_features, reduction=4):
@@ -89,24 +61,26 @@ def __init__(self, inplanes, outplanes, expplanes, k=3, stride=1, drop_prob=0, n
8961
super(LinearBottleneck, self).__init__()
9062
self.conv1 = nn.Conv2d(inplanes, expplanes, kernel_size=1, bias=False)
9163
self.bn1 = nn.BatchNorm2d(expplanes)
92-
self.db1 = DropBlockScheduled(DropBlock2D(drop_prob=drop_prob, block_size=7), start_value=0.,
93-
stop_value=drop_prob, nr_steps=num_steps, start_step=start_step)
94-
# TODO: first doesn't have act?
64+
self.db1 = nn.Dropout2d(drop_prob)
65+
# self.db1 = DropBlockScheduled(DropBlock2D(drop_prob=drop_prob, block_size=7), start_value=0.,
66+
# stop_value=drop_prob, nr_steps=num_steps, start_step=start_step)
67+
self.act1 = activation(**act_params) # first does have act according to MobileNetV2
9568

9669
self.conv2 = nn.Conv2d(expplanes, expplanes, kernel_size=k, stride=stride, padding=k // 2, bias=False,
9770
groups=expplanes)
9871
self.bn2 = nn.BatchNorm2d(expplanes)
99-
self.db2 = DropBlockScheduled(DropBlock2D(drop_prob=drop_prob, block_size=7), start_value=0.,
100-
stop_value=drop_prob, nr_steps=num_steps, start_step=start_step)
72+
self.db2 = nn.Dropout2d(drop_prob)
73+
# self.db2 = DropBlockScheduled(DropBlock2D(drop_prob=drop_prob, block_size=7), start_value=0.,
74+
# stop_value=drop_prob, nr_steps=num_steps, start_step=start_step)
10175
self.act2 = activation(**act_params)
10276

10377
self.se = SqEx(expplanes) if SE else Identity()
10478

10579
self.conv3 = nn.Conv2d(expplanes, outplanes, kernel_size=1, bias=False)
10680
self.bn3 = nn.BatchNorm2d(outplanes)
107-
self.db3 = DropBlockScheduled(DropBlock2D(drop_prob=drop_prob, block_size=7), start_value=0.,
108-
stop_value=drop_prob, nr_steps=num_steps, start_step=start_step)
109-
self.act3 = activation(**act_params)
81+
self.db3 = nn.Dropout2d(drop_prob)
82+
# self.db3 = DropBlockScheduled(DropBlock2D(drop_prob=drop_prob, block_size=7), start_value=0.,
83+
# stop_value=drop_prob, nr_steps=num_steps, start_step=start_step)
11084

11185
self.stride = stride
11286
self.expplanes = expplanes
@@ -119,6 +93,7 @@ def forward(self, x):
11993
out = self.conv1(x)
12094
out = self.bn1(out)
12195
out = self.db1(out)
96+
out = self.act1(out)
12297

12398
out = self.conv2(out)
12499
out = self.bn2(out)
@@ -130,10 +105,9 @@ def forward(self, x):
130105
out = self.conv3(out)
131106
out = self.bn3(out)
132107
out = self.db3(out)
133-
out = self.act3(out)
134108

135109
if self.stride == 1 and self.inplanes == self.outplanes: # TODO: or add 1x1?
136-
out = out + residual # No inplace if there is in-place activation before
110+
out += residual # No inplace if there is in-place activation before
137111

138112
return out
139113

@@ -187,7 +161,6 @@ def __init__(self, inplanes, num_classes, expplanes1, expplanes2):
187161
self.avgpool = nn.AdaptiveAvgPool2d(1)
188162

189163
self.conv2 = nn.Conv2d(expplanes1, expplanes2, kernel_size=1, stride=1, bias=False)
190-
self.bn2 = nn.BatchNorm2d(expplanes2)
191164
self.act2 = HardSwish(inplace=True)
192165

193166
self.dropout = nn.Dropout(p=0.2, inplace=True)
@@ -207,7 +180,6 @@ def forward(self, x):
207180
out = self.avgpool(out)
208181

209182
out = self.conv2(out)
210-
out = self.bn2(out)
211183
out = self.act2(out)
212184

213185
# flatten for input to fully-connected layer
@@ -246,16 +218,16 @@ def __init__(self, num_classes=1000, scale=1., in_channels=3, drop_prob=0.0, num
246218
[80, 184, 80, 1, 3, drop_prob, False, HardSwish], # -> 14x14
247219
[80, 480, 112, 1, 3, drop_prob, True, HardSwish], # -> 14x14
248220
[112, 672, 112, 1, 3, drop_prob, True, HardSwish], # -> 14x14
249-
[112, 672, 160, 1, 5, drop_prob, True, HardSwish], # -> 14x14
250-
[160, 672, 160, 2, 5, drop_prob, True, HardSwish], # -> 7x7 #TODO
221+
[112, 672, 160, 2, 5, drop_prob, True, HardSwish], # -> 7x7
222+
[160, 672, 160, 1, 5, drop_prob, True, HardSwish], # -> 7x7
251223
[160, 960, 160, 1, 5, drop_prob, True, HardSwish], # -> 7x7
252224
]
253225
self.bottlenecks_setting_small = [
254226
# in, exp, out, s, k, dp, se, act
255-
[16, 64, 24, 2, 3, 0, True, nn.ReLU], # -> 56x56 #TODO
256-
[24, 72, 24, 2, 3, 0, False, nn.ReLU], # -> 28x28
257-
[24, 88, 40, 1, 3, 0, False, nn.ReLU], # -> 28x28
258-
[40, 96, 40, 2, 5, 0, True, HardSwish], # -> 14x14 #TODO
227+
[16, 64, 16, 2, 3, 0, True, nn.ReLU], # -> 56x56
228+
[16, 72, 24, 2, 3, 0, False, nn.ReLU], # -> 28x28
229+
[24, 88, 24, 1, 3, 0, False, nn.ReLU], # -> 28x28
230+
[24, 96, 40, 2, 5, 0, True, HardSwish], # -> 14x14
259231
[40, 240, 40, 1, 5, drop_prob, True, HardSwish], # -> 14x14
260232
[40, 240, 40, 1, 5, drop_prob, True, HardSwish], # -> 14x14
261233
[40, 120, 48, 1, 5, drop_prob, True, HardSwish], # -> 14x14
@@ -290,7 +262,6 @@ def __init__(self, num_classes=1000, scale=1., in_channels=3, drop_prob=0.0, num
290262

291263
def _make_bottlenecks(self):
292264
layers = []
293-
294265
modules = OrderedDict()
295266
stage_name = "Bottleneck"
296267

pytorch_toolbelt/modules/dropblock.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,13 @@ def forward(self, x):
4141
mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).to(x)
4242

4343
# compute block mask
44-
block_mask = self._compute_block_mask(mask)
44+
block_mask, keeped = self._compute_block_mask(mask)
4545

4646
# apply block mask
4747
out = x * block_mask[:, None, :, :]
4848

4949
# scale output
50-
out = out * block_mask.numel() / block_mask.sum()
51-
50+
out = out * (block_mask.numel() / keeped).to(out)
5251
return out
5352

5453
def _compute_block_mask(self, mask):
@@ -60,9 +59,10 @@ def _compute_block_mask(self, mask):
6059
if self.block_size % 2 == 0:
6160
block_mask = block_mask[:, :, :-1, :-1]
6261

62+
keeped = block_mask.numel() - block_mask.sum().to(torch.float32) # prevent overflow in float16
6363
block_mask = 1 - block_mask.squeeze(1)
6464

65-
return block_mask
65+
return block_mask, keeped
6666

6767
def _compute_gamma(self, x):
6868
return self.drop_prob / (self.block_size ** 2)
@@ -146,7 +146,7 @@ def forward(self, x):
146146

147147
def step(self):
148148
idx = self.i.item()
149-
if idx > self.start_step and idx < self.start_step + self.nr_steps:
149+
if self.start_step < idx < self.start_step + self.nr_steps:
150150
self.dropblock.drop_prob += self.step_size
151151

152152
self.i += 1

pytorch_toolbelt/modules/encoders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,9 @@ def encoder_layers(self):
280280

281281

282282
class MobilenetV2Encoder(EncoderModule):
283-
def __init__(self, layers=[2, 3, 5, 7]):
283+
def __init__(self, layers=[2, 3, 5, 7], activation='relu6'):
284284
super().__init__([32, 16, 24, 32, 64, 96, 160, 320], [2, 2, 4, 8, 16, 16, 32, 32], layers)
285-
encoder = MobileNetV2()
285+
encoder = MobileNetV2(activation=activation)
286286

287287
self.layer0 = encoder.layer0
288288
self.layer1 = encoder.layer1

0 commit comments

Comments
 (0)