Skip to content

Commit 4be0954

Browse files
committed
Add options to disable stride on B4
1 parent e41b4fe commit 4be0954

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

pytorch_toolbelt/modules/encoders/timm/efficient_net.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,22 +188,33 @@ def change_input_channels(self, input_channels: int, mode="auto", **kwargs):
188188

189189

190190
class TimmB4Encoder(EncoderModule):
191-
def __init__(self, pretrained=True, layers=[1, 2, 3, 4], activation: str = ACT_SILU, no_stride=False):
191+
def __init__(
192+
self,
193+
pretrained=True,
194+
layers=[1, 2, 3, 4],
195+
activation: str = ACT_SILU,
196+
no_stride_s32=False,
197+
no_stride_s16=False,
198+
):
192199
from timm.models.efficientnet import tf_efficientnet_b4_ns
193200

194201
act_layer = get_activation_block(activation)
195202
encoder = tf_efficientnet_b4_ns(
196203
pretrained=pretrained, features_only=True, act_layer=act_layer, drop_path_rate=0.2
197204
)
198205
strides = [2, 4, 8, 16, 32]
199-
if no_stride:
200-
encoder.blocks[5][0].conv_dw.stride = (1, 1)
201-
encoder.blocks[5][0].conv_dw.dilation = (2, 2)
202206

207+
if no_stride_s16:
203208
encoder.blocks[3][0].conv_dw.stride = (1, 1)
204209
encoder.blocks[3][0].conv_dw.dilation = (2, 2)
205210
strides[3] = 8
206-
strides[4] = 8
211+
strides[4] = 16
212+
213+
if no_stride_s32:
214+
encoder.blocks[5][0].conv_dw.stride = (1, 1)
215+
encoder.blocks[5][0].conv_dw.dilation = (2, 2)
216+
strides[4] = strides[3]
217+
207218
super().__init__([24, 32, 56, 160, 448], strides, layers)
208219
self.encoder = encoder
209220

0 commit comments

Comments
 (0)