Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions object_detection_fastai/models/RetinaNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class RetinaNet(nn.Module):
"Implements RetinaNet from https://arxiv.org/abs/1708.02002"

def __init__(self, encoder: nn.Module, n_classes, final_bias:float=0., n_conv:float=4,
chs=256, n_anchors=9, flatten=True, sizes=None):
chs=256, n_anchors=9, flatten=True, sizes=None, n_upsample_layers = 5):
super().__init__()
self.n_classes, self.flatten = n_classes, flatten
imsize = (256, 256)
Expand All @@ -30,8 +30,12 @@ def __init__(self, encoder: nn.Module, n_classes, final_bias:float=0., n_conv:f
self.c5top5 = conv2d(sfs_szs[-1][1], chs, ks=1, bias=True)
self.c5top6 = conv2d(sfs_szs[-1][1], chs, stride=2, bias=True)
self.p6top7 = nn.Sequential(nn.ReLU(), conv2d(chs, chs, stride=2, bias=True))
self.merges = nn.ModuleList([LateralUpsampleMerge(chs, szs[1], hook)
for szs, hook in zip(sfs_szs[-2:-4:-1], hooks[-2:-4:-1])])

last_index = self._get_last_upsample_index(n_upsample_layers)
szs_last_idx = self._get_size_change_layers_indexes(sfs_szs)
self.merges = nn.ModuleList([LateralUpsampleMerge(chs, sfs_szs[idx][1], hooks[idx])
for idx in szs_last_idx[:last_index:-1]])

self.smoothers = nn.ModuleList([conv2d(chs, chs, 3, bias=True) for _ in range(3)])
self.classifier = self._head_subnet(n_classes, n_anchors, final_bias, chs=chs, n_conv=n_conv)
self.box_regressor = self._head_subnet(4, n_anchors, 0., chs=chs, n_conv=n_conv)
Expand Down Expand Up @@ -67,7 +71,17 @@ def _conv2d_relu(self, ni:int, nf:int, ks:int=3, stride:int=1,
layers = [conv2d(ni, nf, ks=ks, stride=stride, padding=padding, bias=bias), nn.ReLU()]
if bn: layers.append(nn.BatchNorm2d(nf))
return nn.Sequential(*layers)

def _get_size_change_layers_indexes(self, sfs_szs):
t = torch.tensor(sfs_szs)
return (t[1:, 2] != t[:-1, 2]).nonzero().view(-1).numpy() #Checks if size was changed, if it was, gets index of the row, then reshape it

def _get_last_upsample_index(self, n_upsample_layers):
n_upsample_layers -= 3 #Substract upsample layers which are not in merge
n_upsample_layers *= -1 #Some operations to count in reverse
n_upsample_layers -= 1
return n_upsample_layers

def forward(self, x):
c5 = self.encoder(x)
p_states = [self.c5top5(c5.clone()), self.c5top6(c5)]
Expand Down