@@ -188,22 +188,33 @@ def change_input_channels(self, input_channels: int, mode="auto", **kwargs):
188
188
189
189
190
190
class TimmB4Encoder (EncoderModule ):
191
- def __init__ (self , pretrained = True , layers = [1 , 2 , 3 , 4 ], activation : str = ACT_SILU , no_stride = False ):
191
+ def __init__ (
192
+ self ,
193
+ pretrained = True ,
194
+ layers = [1 , 2 , 3 , 4 ],
195
+ activation : str = ACT_SILU ,
196
+ no_stride_s32 = False ,
197
+ no_stride_s16 = False ,
198
+ ):
192
199
from timm .models .efficientnet import tf_efficientnet_b4_ns
193
200
194
201
act_layer = get_activation_block (activation )
195
202
encoder = tf_efficientnet_b4_ns (
196
203
pretrained = pretrained , features_only = True , act_layer = act_layer , drop_path_rate = 0.2
197
204
)
198
205
strides = [2 , 4 , 8 , 16 , 32 ]
199
- if no_stride :
200
- encoder .blocks [5 ][0 ].conv_dw .stride = (1 , 1 )
201
- encoder .blocks [5 ][0 ].conv_dw .dilation = (2 , 2 )
202
206
207
+ if no_stride_s16 :
203
208
encoder .blocks [3 ][0 ].conv_dw .stride = (1 , 1 )
204
209
encoder .blocks [3 ][0 ].conv_dw .dilation = (2 , 2 )
205
210
strides [3 ] = 8
206
- strides [4 ] = 8
211
+ strides [4 ] = 16
212
+
213
+ if no_stride_s32 :
214
+ encoder .blocks [5 ][0 ].conv_dw .stride = (1 , 1 )
215
+ encoder .blocks [5 ][0 ].conv_dw .dilation = (2 , 2 )
216
+ strides [4 ] = strides [3 ]
217
+
207
218
super ().__init__ ([24 , 32 , 56 , 160 , 448 ], strides , layers )
208
219
self .encoder = encoder
209
220
0 commit comments