Skip to content

Commit 628ebda

Browse files
committed
support other value of ignore label
1 parent 297e7bf commit 628ebda

File tree

4 files changed

+17
-15
lines changed

4 files changed

+17
-15
lines changed

lib/base_dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(self, dataroot, annpath, trans_func=None, mode='train'):
2424
self.mode = mode
2525
self.trans_func = trans_func
2626

27+
self.lb_ignore = -100
2728
self.lb_map = None
2829

2930
with open(annpath, 'r') as fr:

lib/coco.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@
4848
class CocoStuff(BaseDataset):
4949

5050
def __init__(self, dataroot, annpath, trans_func=None, mode='train'):
51-
super(CocoStuff, self).__init__(dataroot, annpath, trans_func, mode)
51+
super(CocoStuff, self).__init__(
52+
dataroot, annpath, trans_func, mode)
5253
self.n_cats = 171 # 91 stuff, 91 thing, 11 of thing have no annos
53-
self.lb_ignore = 255
5454

5555
## label mapping, remove non-existing labels
5656
missing = [11, 25, 28, 29, 44, 65, 67, 68, 70, 82, 90]

lib/ohem_ce_loss.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,30 +10,30 @@
1010
# import ohem_cpp
1111
# class OhemCELoss(nn.Module):
1212
#
13-
# def __init__(self, thresh, ignore_lb=255):
13+
# def __init__(self, thresh, lb_ignore=255):
1414
# super(OhemCELoss, self).__init__()
1515
# self.score_thresh = thresh
16-
# self.ignore_lb = ignore_lb
17-
# self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='mean')
16+
# self.lb_ignore = lb_ignore
17+
# self.criteria = nn.CrossEntropyLoss(ignore_index=lb_ignore, reduction='mean')
1818
#
1919
# def forward(self, logits, labels):
20-
# n_min = labels[labels != self.ignore_lb].numel() // 16
20+
# n_min = labels[labels != self.lb_ignore].numel() // 16
2121
# labels = ohem_cpp.score_ohem_label(
22-
# logits, labels, self.ignore_lb, self.score_thresh, n_min).detach()
22+
# logits, labels, self.lb_ignore, self.score_thresh, n_min).detach()
2323
# loss = self.criteria(logits, labels)
2424
# return loss
2525

2626

2727
class OhemCELoss(nn.Module):
2828

29-
def __init__(self, thresh, ignore_lb=255):
29+
def __init__(self, thresh, lb_ignore=255):
3030
super(OhemCELoss, self).__init__()
3131
self.thresh = -torch.log(torch.tensor(thresh, requires_grad=False, dtype=torch.float)).cuda()
32-
self.ignore_lb = ignore_lb
33-
self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')
32+
self.lb_ignore = lb_ignore
33+
self.criteria = nn.CrossEntropyLoss(ignore_index=lb_ignore, reduction='none')
3434

3535
def forward(self, logits, labels):
36-
n_min = labels[labels != self.ignore_lb].numel() // 16
36+
n_min = labels[labels != self.lb_ignore].numel() // 16
3737
loss = self.criteria(logits, labels).view(-1)
3838
loss_hard = loss[loss > self.thresh]
3939
if loss_hard.numel() < n_min:

tools/train_amp.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def parse_args():
5252

5353

5454

55-
def set_model():
55+
def set_model(lb_ignore=255):
5656
logger = logging.getLogger()
5757
net = model_factory[cfg.model_type](cfg.n_cats)
5858
if not args.finetune_from is None:
@@ -61,8 +61,9 @@ def set_model():
6161
if cfg.use_sync_bn: net = nn.SyncBatchNorm.convert_sync_batchnorm(net)
6262
net.cuda()
6363
net.train()
64-
criteria_pre = OhemCELoss(0.7)
65-
criteria_aux = [OhemCELoss(0.7) for _ in range(cfg.num_aux_heads)]
64+
criteria_pre = OhemCELoss(0.7, lb_ignore)
65+
criteria_aux = [OhemCELoss(0.7, lb_ignore)
66+
for _ in range(cfg.num_aux_heads)]
6667
return net, criteria_pre, criteria_aux
6768

6869

@@ -126,7 +127,7 @@ def train():
126127
dl = get_data_loader(cfg, mode='train', distributed=is_dist)
127128

128129
## model
129-
net, criteria_pre, criteria_aux = set_model()
130+
net, criteria_pre, criteria_aux = set_model(dl.dataset.lb_ignore)
130131

131132
## optimizer
132133
optim = set_optimizer(net)

0 commit comments

Comments
 (0)