Skip to content

Commit 45fafad

Browse files
committed
Update HRNet to add first_conv_stride_one parameter in all variants
1 parent 2ec2d52 commit 45fafad

File tree

1 file changed

+32
-42
lines changed
  • pytorch_toolbelt/modules/encoders/timm

1 file changed

+32
-42
lines changed

pytorch_toolbelt/modules/encoders/timm/hrnet.py

Lines changed: 32 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,14 @@
55
__all__ = ["HRNetW18Encoder", "HRNetW32Encoder", "HRNetW48Encoder", "TimmHRNetW18SmallV2Encoder"]
66

77

8-
class HRNetW18Encoder(GenericTimmEncoder):
9-
def __init__(
10-
self, pretrained=True, use_incre_features: bool = True, layers=None, first_conv_stride_one: bool = False
11-
):
12-
from timm.models import hrnet
13-
14-
encoder = hrnet.hrnet_w18(
15-
pretrained=pretrained,
16-
feature_location="incre" if use_incre_features else "",
17-
features_only=True,
18-
out_indices=(0, 1, 2, 3, 4),
19-
)
8+
class HRNetTimmEncoder(GenericTimmEncoder):
9+
def __init__(self, encoder, first_conv_stride_one, layers):
2010
if first_conv_stride_one:
2111
encoder.conv1.stride = (1, 1)
2212

2313
super().__init__(encoder, layers)
24-
2514
if first_conv_stride_one:
26-
self._output_strides = [s // 2 for s in self._output_strides]
15+
self._output_strides = (s // 2 for s in self._output_strides)
2716

2817
def forward(self, x):
2918
y = self.encoder.forward(x)
@@ -34,29 +23,40 @@ def change_input_channels(self, input_channels: int, mode="auto", **kwargs):
3423
return self
3524

3625

37-
class HRNetW32Encoder(GenericTimmEncoder):
38-
def __init__(self, pretrained=True, use_incre_features: bool = True, layers=None):
26+
class HRNetW18Encoder(HRNetTimmEncoder):
27+
def __init__(
28+
self, pretrained=True, use_incre_features: bool = True, layers=None, first_conv_stride_one: bool = False
29+
):
3930
from timm.models import hrnet
4031

41-
encoder = hrnet.hrnet_w32(
32+
encoder = hrnet.hrnet_w18(
4233
pretrained=pretrained,
4334
feature_location="incre" if use_incre_features else "",
4435
features_only=True,
4536
out_indices=(0, 1, 2, 3, 4),
4637
)
47-
super().__init__(encoder, layers)
38+
super().__init__(encoder, first_conv_stride_one=first_conv_stride_one, layers=layers)
4839

49-
def forward(self, x):
50-
y = self.encoder.forward(x)
51-
return _take(y, self._layers)
5240

53-
def change_input_channels(self, input_channels: int, mode="auto", **kwargs):
54-
self.encoder.conv1 = make_n_channel_input(self.encoder.conv1, input_channels, mode, **kwargs)
55-
return self
41+
class HRNetW32Encoder(HRNetTimmEncoder):
42+
def __init__(
43+
elf, pretrained=True, use_incre_features: bool = True, layers=None, first_conv_stride_one: bool = False
44+
):
45+
from timm.models import hrnet
5646

47+
encoder = hrnet.hrnet_w32(
48+
pretrained=pretrained,
49+
feature_location="incre" if use_incre_features else "",
50+
features_only=True,
51+
out_indices=(0, 1, 2, 3, 4),
52+
)
53+
super().__init__(encoder, first_conv_stride_one=first_conv_stride_one, layers=layers)
5754

58-
class HRNetW48Encoder(GenericTimmEncoder):
59-
def __init__(self, pretrained=True, use_incre_features: bool = True, layers=None):
55+
56+
class HRNetW48Encoder(HRNetTimmEncoder):
57+
def __init__(
58+
elf, pretrained=True, use_incre_features: bool = True, layers=None, first_conv_stride_one: bool = False
59+
):
6060
from timm.models import hrnet
6161

6262
encoder = hrnet.hrnet_w48(
@@ -65,19 +65,13 @@ def __init__(self, pretrained=True, use_incre_features: bool = True, layers=None
6565
features_only=True,
6666
out_indices=(0, 1, 2, 3, 4),
6767
)
68-
super().__init__(encoder, layers)
69-
70-
def forward(self, x):
71-
y = self.encoder.forward(x)
72-
return _take(y, self._layers)
73-
74-
def change_input_channels(self, input_channels: int, mode="auto", **kwargs):
75-
self.encoder.conv1 = make_n_channel_input(self.encoder.conv1, input_channels, mode, **kwargs)
76-
return self
68+
super().__init__(encoder, first_conv_stride_one=first_conv_stride_one, layers=layers)
7769

7870

79-
class TimmHRNetW18SmallV2Encoder(GenericTimmEncoder):
80-
def __init__(self, pretrained=True, use_incre_features: bool = True, layers=None, activation=ACT_RELU):
71+
class TimmHRNetW18SmallV2Encoder(HRNetTimmEncoder):
72+
def __init__(
73+
self, elf, pretrained=True, use_incre_features: bool = True, layers=None, first_conv_stride_one: bool = False
74+
):
8175
from timm.models import hrnet
8276

8377
encoder = hrnet.hrnet_w18_small_v2(
@@ -86,8 +80,4 @@ def __init__(self, pretrained=True, use_incre_features: bool = True, layers=None
8680
features_only=True,
8781
out_indices=(0, 1, 2, 3, 4),
8882
)
89-
super().__init__(encoder, layers)
90-
91-
def change_input_channels(self, input_channels: int, mode="auto", **kwargs):
92-
self.encoder.conv1 = make_n_channel_input(self.encoder.conv1, input_channels, mode, **kwargs)
93-
return self
83+
super().__init__(encoder, first_conv_stride_one=first_conv_stride_one, layers=layers)

0 commit comments

Comments
 (0)