Skip to content

Commit 5bad076

Browse files
authored
Merge pull request #17 from BloodAxe/develop
Release 0.1.2
2 parents 281b143 + d086e02 commit 5bad076

File tree

4 files changed

+18
-47
lines changed

4 files changed

+18
-47
lines changed

README.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,17 @@
66

77
A `pytorch-toolbelt` is a Python library with a set of bells and whistles for PyTorch for fast R&D prototyping and Kaggle farming:
88

9+
## What's inside
10+
911
* Easy model building using flexible encoder-decoder architecture.
10-
* Modules: CoordConv, SCSE, Hypercolumn, Depthwise separable convolution and more
12+
* Modules: CoordConv, SCSE, Hypercolumn, Depthwise separable convolution and more.
1113
* GPU-friendly test-time augmentation TTA for segmentation and classification
1214
* GPU-friendly inference on huge (5000x5000) images
1315
* Every-day common routines (fix/restore random seed, filesystem utils, metrics)
14-
* Fancy losses: Focal, Lovasz, Jaccard and Dice losses, Wing Loss
16+
* Losses: BinaryFocalLoss, Focal, ReducedFocal, Lovasz, Jaccard and Dice losses, Wing Loss and more.
17+
* Extras for [Catalyst](https://github.com/catalyst-team/catalyst) library (Visualization of batch predictions, additional metrics)
18+
19+
Showcase: [Catalyst, Albumentations, Pytorch Toolbelt example: Semantic Segmentation @ CamVid](https://colab.research.google.com/drive/1OUPJYU7TzH5Vz1si6FBkooackuIlzaGr#scrollTo=GUWuiO5K3aUm)
1520

1621
# Why
1722

@@ -119,4 +124,5 @@ merged_mask = tiler.crop_to_orignal_size(merged_mask)
119124

120125
## Advanced examples
121126

122-
1. [Inria Sattelite Segmentation](https://github.com/BloodAxe/Catalyst-Inria-Segmentation-Example)
127+
1. [Inria Sattelite Segmentation](https://github.com/BloodAxe/Catalyst-Inria-Segmentation-Example)
128+
1. [CamVid Semantic Segmentation](https://github.com/BloodAxe/Catalyst-CamVid-Segmentation-Example)

pytorch_toolbelt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from __future__ import absolute_import
22

3-
__version__ = '0.1.1'
3+
__version__ = '0.1.2'

pytorch_toolbelt/inference/tta.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Tuple, List
88

99
from torch import Tensor, nn
10+
from torch.nn.functional import interpolate
1011

1112
from . import functional as F
1213

@@ -236,11 +237,11 @@ def forward(self, input: Tensor) -> Tensor:
236237

237238
for scale in self.scale_levels:
238239
dst_size = int(h * scale), int(w * scale)
239-
input_scaled = F.interpolate(input, dst_size, mode='bilinear',
240-
align_corners=True)
240+
input_scaled = interpolate(input, dst_size, mode='bilinear',
241+
align_corners=True)
241242
output_scaled = self.model(input_scaled)
242-
output_scaled = F.interpolate(output_scaled, out_size,
243-
mode='bilinear', align_corners=True)
243+
output_scaled = interpolate(output_scaled, out_size,
244+
mode='bilinear', align_corners=True)
244245
output += output_scaled
245246

246247
return output / (1 + len(self.scale_levels))

pytorch_toolbelt/losses/focal.py

Lines changed: 3 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
from functools import partial
22

3-
import torch
4-
from torch import nn
3+
from torch.nn.modules.loss import _Loss
54

65
from .functional import sigmoid_focal_loss, reduced_focal_loss
7-
from torch.nn.modules.loss import _Loss
86

97
__all__ = ['BinaryFocalLoss', 'FocalLoss']
108

@@ -57,52 +55,18 @@ def forward(self, label_input, label_target):
5755
"""
5856
num_classes = label_input.size(1)
5957
loss = 0
60-
label_target = label_target.view(-1)
61-
label_input = label_input.view(-1, num_classes)
6258

6359
# Filter anchors with -1 label from loss computation
6460
if self.ignore is not None:
6561
not_ignored = label_target != self.ignore
6662

6763
for cls in range(num_classes):
68-
cls_label_target = (label_target == (cls + 0)).long()
69-
cls_label_input = label_input[..., cls]
64+
cls_label_target = (label_target == cls).long()
65+
cls_label_input = label_input[:, cls, ...]
7066

7167
if self.ignore is not None:
7268
cls_label_target = cls_label_target[not_ignored]
7369
cls_label_input = cls_label_input[not_ignored]
7470

7571
loss += sigmoid_focal_loss(cls_label_input, cls_label_target, gamma=self.gamma, alpha=self.alpha)
7672
return loss
77-
78-
# Needs testing
79-
# class SoftmaxFocalLoss(nn.Module):
80-
# def __init__(self, gamma=2, eps=1e-7):
81-
# super(SoftmaxFocalLoss, self).__init__()
82-
# self.gamma = gamma
83-
# self.eps = eps
84-
#
85-
# @staticmethod
86-
# def _one_hot(index, classes):
87-
# size = index.size() + (classes,)
88-
# view = index.size() + (1,)
89-
#
90-
# mask = torch.Tensor(*size).fill_(0)
91-
# index = index.view(*view)
92-
# ones = 1.
93-
#
94-
# if isinstance(index, Variable):
95-
# ones = Variable(torch.Tensor(index.size()).fill_(1))
96-
# mask = Variable(mask, volatile=index.volatile)
97-
#
98-
# return mask.scatter_(1, index, ones)
99-
#
100-
# def forward(self, input, target):
101-
# y = one_hot(target, input.size(-1))
102-
# logit = F.softmax(input, dim=-1)
103-
# logit = logit.clamp(self.eps, 1. - self.eps)
104-
#
105-
# loss = -1 * y * torch.log(logit) # cross entropy
106-
# loss = loss * (1 - logit) ** self.gamma # focal loss
107-
#
108-
# return loss.sum()

0 commit comments

Comments
 (0)