5
5
__all__ = ["HRNetW18Encoder" , "HRNetW32Encoder" , "HRNetW48Encoder" , "TimmHRNetW18SmallV2Encoder" ]
6
6
7
7
8
- class HRNetW18Encoder (GenericTimmEncoder ):
9
- def __init__ (
10
- self , pretrained = True , use_incre_features : bool = True , layers = None , first_conv_stride_one : bool = False
11
- ):
12
- from timm .models import hrnet
13
-
14
- encoder = hrnet .hrnet_w18 (
15
- pretrained = pretrained ,
16
- feature_location = "incre" if use_incre_features else "" ,
17
- features_only = True ,
18
- out_indices = (0 , 1 , 2 , 3 , 4 ),
19
- )
8
+ class HRNetTimmEncoder (GenericTimmEncoder ):
9
+ def __init__ (self , encoder , first_conv_stride_one , layers ):
20
10
if first_conv_stride_one :
21
11
encoder .conv1 .stride = (1 , 1 )
22
12
23
13
super ().__init__ (encoder , layers )
24
-
25
14
if first_conv_stride_one :
26
- self ._output_strides = [ s // 2 for s in self ._output_strides ]
15
+ self ._output_strides = ( s // 2 for s in self ._output_strides )
27
16
28
17
def forward (self , x ):
29
18
y = self .encoder .forward (x )
@@ -34,29 +23,40 @@ def change_input_channels(self, input_channels: int, mode="auto", **kwargs):
34
23
return self
35
24
36
25
37
- class HRNetW32Encoder (GenericTimmEncoder ):
38
- def __init__ (self , pretrained = True , use_incre_features : bool = True , layers = None ):
26
+ class HRNetW18Encoder (HRNetTimmEncoder ):
27
+ def __init__ (
28
+ self , pretrained = True , use_incre_features : bool = True , layers = None , first_conv_stride_one : bool = False
29
+ ):
39
30
from timm .models import hrnet
40
31
41
- encoder = hrnet .hrnet_w32 (
32
+ encoder = hrnet .hrnet_w18 (
42
33
pretrained = pretrained ,
43
34
feature_location = "incre" if use_incre_features else "" ,
44
35
features_only = True ,
45
36
out_indices = (0 , 1 , 2 , 3 , 4 ),
46
37
)
47
- super ().__init__ (encoder , layers )
38
+ super ().__init__ (encoder , first_conv_stride_one = first_conv_stride_one , layers = layers )
48
39
49
- def forward (self , x ):
50
- y = self .encoder .forward (x )
51
- return _take (y , self ._layers )
52
40
53
- def change_input_channels (self , input_channels : int , mode = "auto" , ** kwargs ):
54
- self .encoder .conv1 = make_n_channel_input (self .encoder .conv1 , input_channels , mode , ** kwargs )
55
- return self
41
+ class HRNetW32Encoder (HRNetTimmEncoder ):
42
+ def __init__ (
43
+ elf , pretrained = True , use_incre_features : bool = True , layers = None , first_conv_stride_one : bool = False
44
+ ):
45
+ from timm .models import hrnet
56
46
47
+ encoder = hrnet .hrnet_w32 (
48
+ pretrained = pretrained ,
49
+ feature_location = "incre" if use_incre_features else "" ,
50
+ features_only = True ,
51
+ out_indices = (0 , 1 , 2 , 3 , 4 ),
52
+ )
53
+ super ().__init__ (encoder , first_conv_stride_one = first_conv_stride_one , layers = layers )
57
54
58
- class HRNetW48Encoder (GenericTimmEncoder ):
59
- def __init__ (self , pretrained = True , use_incre_features : bool = True , layers = None ):
55
+
56
+ class HRNetW48Encoder (HRNetTimmEncoder ):
57
+ def __init__ (
58
+ elf , pretrained = True , use_incre_features : bool = True , layers = None , first_conv_stride_one : bool = False
59
+ ):
60
60
from timm .models import hrnet
61
61
62
62
encoder = hrnet .hrnet_w48 (
@@ -65,19 +65,13 @@ def __init__(self, pretrained=True, use_incre_features: bool = True, layers=None
65
65
features_only = True ,
66
66
out_indices = (0 , 1 , 2 , 3 , 4 ),
67
67
)
68
- super ().__init__ (encoder , layers )
69
-
70
- def forward (self , x ):
71
- y = self .encoder .forward (x )
72
- return _take (y , self ._layers )
73
-
74
- def change_input_channels (self , input_channels : int , mode = "auto" , ** kwargs ):
75
- self .encoder .conv1 = make_n_channel_input (self .encoder .conv1 , input_channels , mode , ** kwargs )
76
- return self
68
+ super ().__init__ (encoder , first_conv_stride_one = first_conv_stride_one , layers = layers )
77
69
78
70
79
- class TimmHRNetW18SmallV2Encoder (GenericTimmEncoder ):
80
- def __init__ (self , pretrained = True , use_incre_features : bool = True , layers = None , activation = ACT_RELU ):
71
+ class TimmHRNetW18SmallV2Encoder (HRNetTimmEncoder ):
72
+ def __init__ (
73
+ self , elf , pretrained = True , use_incre_features : bool = True , layers = None , first_conv_stride_one : bool = False
74
+ ):
81
75
from timm .models import hrnet
82
76
83
77
encoder = hrnet .hrnet_w18_small_v2 (
@@ -86,8 +80,4 @@ def __init__(self, pretrained=True, use_incre_features: bool = True, layers=None
86
80
features_only = True ,
87
81
out_indices = (0 , 1 , 2 , 3 , 4 ),
88
82
)
89
- super ().__init__ (encoder , layers )
90
-
91
- def change_input_channels (self , input_channels : int , mode = "auto" , ** kwargs ):
92
- self .encoder .conv1 = make_n_channel_input (self .encoder .conv1 , input_channels , mode , ** kwargs )
93
- return self
83
+ super ().__init__ (encoder , first_conv_stride_one = first_conv_stride_one , layers = layers )
0 commit comments