Skip to content

Commit 71a8c2e

Browse files
committed
Fix output stride in layer1
1 parent 052ed6b commit 71a8c2e

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

pytorch_toolbelt/modules/encoders.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
efficient_net_b5,
3333
efficient_net_b7,
3434
)
35-
from pytorch_toolbelt.modules.backbone.inceptionv4 import InceptionV4
35+
from pytorch_toolbelt.modules.backbone.inceptionv4 import InceptionV4, \
36+
inceptionv4
3637
from pytorch_toolbelt.modules.backbone.mobilenetv3 import MobileNetV3
3738
from pytorch_toolbelt.modules.backbone.wider_resnet import WiderResNet, \
3839
WiderResNetA2
@@ -866,19 +867,20 @@ def __init__(self, layers=None, **kwargs):
866867

867868

868869
class InceptionV4Encoder(EncoderModule):
869-
def __init__(self, inceptionv4: InceptionV4, layers=None, **kwargs):
870-
channels = [64, 384, 384, 1024, 1536]
871-
strides = [2, 4, 8, 16, 32]
870+
def __init__(self, pretrained=True, layers=None, **kwargs):
871+
backbone = inceptionv4(pretrained="imagenet" if pretrained else None)
872+
channels = [64, 192, 384, 1024, 1536]
873+
strides = [2, 4, 8, 16, 32] # Note output strides are approximate
872874
if layers is None:
873875
layers = [1, 2, 3, 4]
874-
features = inceptionv4.features
876+
features = backbone.features
875877
super().__init__(channels, strides, layers)
876878

877-
self.layer0 = nn.Sequential(features[0:0 + 3])
878-
self.layer1 = nn.Sequential(features[3:3 + 3])
879-
self.layer2 = nn.Sequential(features[6:6 + 4])
880-
self.layer3 = nn.Sequential(features[10:10 + 8])
881-
self.layer4 = nn.Sequential(features[18:18 + 4])
879+
self.layer0 = features[0:3]
880+
self.layer1 = features[3:5]
881+
self.layer2 = features[5:10]
882+
self.layer3 = features[10:18]
883+
self.layer4 = features[18:22]
882884

883885
self._output_strides = _take(strides, layers)
884886
self._output_filters = _take(channels, layers)

0 commit comments

Comments
 (0)