|
32 | 32 | efficient_net_b5,
|
33 | 33 | efficient_net_b7,
|
34 | 34 | )
|
35 |
| -from pytorch_toolbelt.modules.backbone.inceptionv4 import InceptionV4 |
| 35 | +from pytorch_toolbelt.modules.backbone.inceptionv4 import InceptionV4, \ |
| 36 | + inceptionv4 |
36 | 37 | from pytorch_toolbelt.modules.backbone.mobilenetv3 import MobileNetV3
|
37 | 38 | from pytorch_toolbelt.modules.backbone.wider_resnet import WiderResNet, \
|
38 | 39 | WiderResNetA2
|
@@ -866,19 +867,20 @@ def __init__(self, layers=None, **kwargs):
|
866 | 867 |
|
867 | 868 |
|
868 | 869 | class InceptionV4Encoder(EncoderModule):
|
869 |
| - def __init__(self, inceptionv4: InceptionV4, layers=None, **kwargs): |
870 |
| - channels = [64, 384, 384, 1024, 1536] |
871 |
| - strides = [2, 4, 8, 16, 32] |
| 870 | + def __init__(self, pretrained=True, layers=None, **kwargs): |
| 871 | + backbone = inceptionv4(pretrained="imagenet" if pretrained else None) |
| 872 | + channels = [64, 192, 384, 1024, 1536] |
| 873 | + strides = [2, 4, 8, 16, 32] # Note output strides are approximate |
872 | 874 | if layers is None:
|
873 | 875 | layers = [1, 2, 3, 4]
|
874 |
| - features = inceptionv4.features |
| 876 | + features = backbone.features |
875 | 877 | super().__init__(channels, strides, layers)
|
876 | 878 |
|
877 |
| - self.layer0 = nn.Sequential(features[0:0 + 3]) |
878 |
| - self.layer1 = nn.Sequential(features[3:3 + 3]) |
879 |
| - self.layer2 = nn.Sequential(features[6:6 + 4]) |
880 |
| - self.layer3 = nn.Sequential(features[10:10 + 8]) |
881 |
| - self.layer4 = nn.Sequential(features[18:18 + 4]) |
| 879 | + self.layer0 = features[0:3] |
| 880 | + self.layer1 = features[3:5] |
| 881 | + self.layer2 = features[5:10] |
| 882 | + self.layer3 = features[10:18] |
| 883 | + self.layer4 = features[18:22] |
882 | 884 |
|
883 | 885 | self._output_strides = _take(strides, layers)
|
884 | 886 | self._output_filters = _take(channels, layers)
|
|
0 commit comments