Skip to content

Commit 7320185

Browse files
authored
Merge pull request #35 from BloodAxe/develop
Release 0.3.0
2 parents ec2bfbd + f7b83ef commit 7320185

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

86 files changed

+4780
-2342
lines changed

.appveyor.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ cache:
66
environment:
77

88
matrix:
9-
- PYTHON: 'C:\Python27-x64'
10-
- PYTHON: 'C:\Python35-x64'
119
- PYTHON: 'C:\Python36-x64'
1210
- PYTHON: 'C:\Python37-x64'
1311

CREDITS.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
This file contains links to repositories, source code of which may be partially used in this repository. Mind giving them kudos on GitHub!
2+
3+
1. https://github.com/Cadene/pretrained-models.pytorch
4+
1. https://blog.ceshine.net/post/pytorch-memory-swish/
5+
1. https://github.com/digantamisra98/Mish
6+
1. https://github.com/mapillary/inplace_abn
7+
1. https://github.com/PkuRainBow/OCNet.pytorch

black.toml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Example configuration for Black.
2+
3+
# NOTE: you have to use single-quoted strings in TOML for regular expressions.
4+
# It's the equivalent of r-strings in Python. Multiline strings are treated as
5+
# verbose regular expressions by Black. Use [ ] to denote a significant space
6+
# character.
7+
8+
[tool.black]
9+
line-length = 119
10+
target-version = ['py36', 'py37', 'py38']
11+
include = '\.pyi?$'
12+
exclude = '''
13+
/(
14+
\.eggs
15+
| \.git
16+
| \.hg
17+
| \.mypy_cache
18+
| \.tox
19+
| \.venv
20+
| _build
21+
| buck-out
22+
| build
23+
| dist
24+
)/
25+
'''

demo/demo_losses.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,14 @@ def main():
1616
# "dice_log": L.BinaryDiceLogLoss(),
1717
# "sdice": L.BinarySymmetricDiceLoss(),
1818
# "sdice_log": L.BinarySymmetricDiceLoss(log_loss=True),
19-
2019
"bce+lovasz": L.JointLoss(BCEWithLogitsLoss(), L.BinaryLovaszLoss()),
2120
# "lovasz": L.BinaryLovaszLoss(),
2221
# "bce+jaccard": L.JointLoss(BCEWithLogitsLoss(),
2322
# L.BinaryJaccardLoss(), 1, 0.5),
24-
2523
# "bce+log_jaccard": L.JointLoss(BCEWithLogitsLoss(),
2624
# L.BinaryJaccardLogLoss(), 1, 0.5),
27-
2825
# "bce+log_dice": L.JointLoss(BCEWithLogitsLoss(),
2926
# L.BinaryDiceLogLoss(), 1, 0.5)
30-
3127
# "reduced_focal": L.BinaryFocalLoss(reduced=True)
3228
}
3329

@@ -55,5 +51,5 @@ def main():
5551
f.show()
5652

5753

58-
if __name__ == '__main__':
54+
if __name__ == "__main__":
5955
main()

pytorch_toolbelt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from __future__ import absolute_import
22

3-
__version__ = "0.2.1"
3+
__version__ = "0.3.0"

pytorch_toolbelt/inference/functional.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,7 @@ def pad_image_tensor(image_tensor: Tensor, pad_size: int = 32):
6262
:return: Tuple of output tensor and pad params. Second argument can be used to reverse pad operation of model output
6363
"""
6464
rows, cols = image_tensor.size(2), image_tensor.size(3)
65-
if (
66-
isinstance(pad_size, Sized)
67-
and isinstance(pad_size, Iterable)
68-
and len(pad_size) == 2
69-
):
65+
if isinstance(pad_size, Sized) and isinstance(pad_size, Iterable) and len(pad_size) == 2:
7066
pad_height, pad_width = [int(val) for val in pad_size]
7167
elif isinstance(pad_size, int):
7268
pad_height = pad_width = pad_size
@@ -109,9 +105,7 @@ def unpad_image_tensor(image_tensor, pad):
109105

110106
def unpad_xyxy_bboxes(bboxes_tensor: torch.Tensor, pad, dim=-1):
111107
pad_left, pad_right, pad_top, pad_btm = pad
112-
pad = torch.tensor(
113-
[pad_left, pad_top, pad_left, pad_top], dtype=bboxes_tensor.dtype
114-
).to(bboxes_tensor.device)
108+
pad = torch.tensor([pad_left, pad_top, pad_left, pad_top], dtype=bboxes_tensor.dtype).to(bboxes_tensor.device)
115109

116110
if dim == -1:
117111
dim = len(bboxes_tensor.size()) - 1

pytorch_toolbelt/inference/tiles.py

Lines changed: 24 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""Implementation of tile-based inference allowing to predict huge images that does not fit into GPU memory entirely
22
in a sliding-window fashion and merging prediction mask back to full-resolution.
33
"""
4+
import math
45
from typing import List
56

6-
import numpy as np
77
import cv2
8-
import math
8+
import numpy as np
99
import torch
1010

1111

@@ -28,14 +28,18 @@ def compute_pyramid_patch_weight_loss(width, height) -> np.ndarray:
2828
Dc = np.zeros((width, height))
2929
De = np.zeros((width, height))
3030

31-
for i in range(width):
32-
for j in range(height):
33-
Dc[i, j] = np.sqrt(np.square(i - xc + 0.5) + np.square(j - yc + 0.5))
34-
De_l = np.sqrt(np.square(i - xl + 0.5) + np.square(j - j + 0.5))
35-
De_r = np.sqrt(np.square(i - xr + 0.5) + np.square(j - j + 0.5))
36-
De_b = np.sqrt(np.square(i - i + 0.5) + np.square(j - yb + 0.5))
37-
De_t = np.sqrt(np.square(i - i + 0.5) + np.square(j - yt + 0.5))
38-
De[i, j] = np.min([De_l, De_r, De_b, De_t])
31+
Dcx = np.square(np.arange(width) - xc + 0.5)
32+
Dcy = np.square(np.arange(height) - yc + 0.5)
33+
Dc = np.sqrt(Dcx[np.newaxis].transpose() + Dcy)
34+
35+
De_l = np.square(np.arange(width) - xl + 0.5) + np.square(0.5)
36+
De_r = np.square(np.arange(width) - xr + 0.5) + np.square(0.5)
37+
De_b = np.square(0.5) + np.square(np.arange(height) - yb + 0.5)
38+
De_t = np.square(0.5) + np.square(np.arange(height) - yt + 0.5)
39+
40+
De_x = np.sqrt(np.minimum(De_l, De_r))
41+
De_y = np.sqrt(np.minimum(De_b, De_t))
42+
De = np.minimum(De_x[np.newaxis].transpose(), De_y)
3943

4044
alpha = (width * height) / np.sum(np.divide(De, np.add(Dc, De)))
4145
W = alpha * np.divide(De, np.add(Dc, De))
@@ -47,9 +51,7 @@ class ImageSlicer:
4751
Helper class to slice image into tiles and merge them back
4852
"""
4953

50-
def __init__(
51-
self, image_shape, tile_size, tile_step=0, image_margin=0, weight="mean"
52-
):
54+
def __init__(self, image_shape, tile_size, tile_step=0, image_margin=0, weight="mean"):
5355
"""
5456
5557
:param image_shape: Shape of the source image (H, W)
@@ -75,21 +77,14 @@ def __init__(
7577

7678
weights = {"mean": self._mean, "pyramid": self._pyramid}
7779

78-
self.weight = (
79-
weight
80-
if isinstance(weight, np.ndarray)
81-
else weights[weight](self.tile_size)
82-
)
80+
self.weight = weight if isinstance(weight, np.ndarray) else weights[weight](self.tile_size)
8381

8482
if self.tile_step[0] < 1 or self.tile_step[0] > self.tile_size[0]:
8583
raise ValueError()
8684
if self.tile_step[1] < 1 or self.tile_step[1] > self.tile_size[1]:
8785
raise ValueError()
8886

89-
overlap = [
90-
self.tile_size[0] - self.tile_step[0],
91-
self.tile_size[1] - self.tile_step[1],
92-
]
87+
overlap = [self.tile_size[0] - self.tile_step[0], self.tile_size[1] - self.tile_step[1]]
9388

9489
self.margin_left = 0
9590
self.margin_right = 0
@@ -111,14 +106,10 @@ def __init__(
111106
self.margin_bottom = extra_h - self.margin_top
112107

113108
else:
114-
if (self.image_width - overlap[1] + 2 * image_margin) % self.tile_step[
115-
1
116-
] != 0:
109+
if (self.image_width - overlap[1] + 2 * image_margin) % self.tile_step[1] != 0:
117110
raise ValueError()
118111

119-
if (self.image_height - overlap[0] + 2 * image_margin) % self.tile_step[
120-
0
121-
] != 0:
112+
if (self.image_height - overlap[0] + 2 * image_margin) % self.tile_step[0] != 0:
122113
raise ValueError()
123114

124115
self.margin_left = image_margin
@@ -130,32 +121,13 @@ def __init__(
130121
bbox_crops = []
131122

132123
for y in range(
133-
0,
134-
self.image_height
135-
+ self.margin_top
136-
+ self.margin_bottom
137-
- self.tile_size[0]
138-
+ 1,
139-
self.tile_step[0],
124+
0, self.image_height + self.margin_top + self.margin_bottom - self.tile_size[0] + 1, self.tile_step[0]
140125
):
141126
for x in range(
142-
0,
143-
self.image_width
144-
+ self.margin_left
145-
+ self.margin_right
146-
- self.tile_size[1]
147-
+ 1,
148-
self.tile_step[1],
127+
0, self.image_width + self.margin_left + self.margin_right - self.tile_size[1] + 1, self.tile_step[1]
149128
):
150129
crops.append((x, y, self.tile_size[1], self.tile_size[0]))
151-
bbox_crops.append(
152-
(
153-
x - self.margin_left,
154-
y - self.margin_top,
155-
self.tile_size[1],
156-
self.tile_size[0],
157-
)
158-
)
130+
bbox_crops.append((x - self.margin_left, y - self.margin_top, self.tile_size[1], self.tile_size[0]))
159131

160132
self.crops = np.array(crops)
161133
self.bbox_crops = np.array(bbox_crops)
@@ -189,9 +161,7 @@ def split(self, image, border_type=cv2.BORDER_CONSTANT, value=0):
189161

190162
return tiles
191163

192-
def cut_patch(
193-
self, image: np.ndarray, slice_index, border_type=cv2.BORDER_CONSTANT, value=0
194-
):
164+
def cut_patch(self, image: np.ndarray, slice_index, border_type=cv2.BORDER_CONSTANT, value=0):
195165
assert image.shape[0] == self.image_height
196166
assert image.shape[1] == self.image_width
197167

@@ -298,9 +268,7 @@ def integrate_batch(self, batch: torch.Tensor, crop_coords):
298268
:param crop_coords: Corresponding tile crops w.r.t to original image
299269
"""
300270
if len(batch) != len(crop_coords):
301-
raise ValueError(
302-
"Number of images in batch does not correspond to number of coordinates"
303-
)
271+
raise ValueError("Number of images in batch does not correspond to number of coordinates")
304272

305273
for tile, (x, y, tile_width, tile_height) in zip(batch, crop_coords):
306274
self.image[:, y : y + tile_height, x : x + tile_width] += tile * self.weight

pytorch_toolbelt/inference/tta.py

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -71,21 +71,11 @@ def fivecrop_image2label(model: nn.Module, image: Tensor, crop_size: Tuple) -> T
7171
center_crop_y = (image_height - crop_height) // 2
7272
center_crop_x = (image_width - crop_width) // 2
7373

74-
crop_cc = image[
75-
...,
76-
center_crop_y : center_crop_y + crop_height,
77-
center_crop_x : center_crop_x + crop_width,
78-
]
74+
crop_cc = image[..., center_crop_y : center_crop_y + crop_height, center_crop_x : center_crop_x + crop_width]
7975
assert crop_cc.size(2) == crop_height
8076
assert crop_cc.size(3) == crop_width
8177

82-
output = (
83-
model(crop_tl)
84-
+ model(crop_tr)
85-
+ model(crop_bl)
86-
+ model(crop_br)
87-
+ model(crop_cc)
88-
)
78+
output = model(crop_tl) + model(crop_tr) + model(crop_bl) + model(crop_br) + model(crop_cc)
8979
one_over_5 = float(1.0 / 5.0)
9080
return output * one_over_5
9181

@@ -125,11 +115,7 @@ def tencrop_image2label(model: nn.Module, image: Tensor, crop_size: Tuple) -> Te
125115
center_crop_y = (image_height - crop_height) // 2
126116
center_crop_x = (image_width - crop_width) // 2
127117

128-
crop_cc = image[
129-
...,
130-
center_crop_y : center_crop_y + crop_height,
131-
center_crop_x : center_crop_x + crop_width,
132-
]
118+
crop_cc = image[..., center_crop_y : center_crop_y + crop_height, center_crop_x : center_crop_x + crop_width]
133119
assert crop_cc.size(2) == crop_height
134120
assert crop_cc.size(3) == crop_width
135121

@@ -202,11 +188,10 @@ def d4_image2mask(model: nn.Module, image: Tensor) -> Tensor:
202188
output = model(image)
203189

204190
for aug, deaug in zip(
205-
[F.torch_rot90, F.torch_rot180, F.torch_rot270],
206-
[F.torch_rot270, F.torch_rot180, F.torch_rot90],
191+
[F.torch_rot90, F.torch_rot180, F.torch_rot270], [F.torch_rot270, F.torch_rot180, F.torch_rot90]
207192
):
208193
x = deaug(model(aug(image)))
209-
output = output + x
194+
output += x
210195

211196
image = F.torch_transpose(image)
212197

@@ -215,10 +200,11 @@ def d4_image2mask(model: nn.Module, image: Tensor) -> Tensor:
215200
[F.torch_none, F.torch_rot270, F.torch_rot180, F.torch_rot90],
216201
):
217202
x = deaug(model(aug(image)))
218-
output = output + F.torch_transpose(x)
203+
output += F.torch_transpose(x)
219204

220205
one_over_8 = float(1.0 / 8.0)
221-
return output * one_over_8
206+
output *= one_over_8
207+
return output
222208

223209

224210
class TTAWrapper(nn.Module):
@@ -258,13 +244,9 @@ def forward(self, input: Tensor) -> Tensor:
258244

259245
for scale in self.scale_levels:
260246
dst_size = int(h * scale), int(w * scale)
261-
input_scaled = interpolate(
262-
input, dst_size, mode="bilinear", align_corners=True
263-
)
247+
input_scaled = interpolate(input, dst_size, mode="bilinear", align_corners=False)
264248
output_scaled = self.model(input_scaled)
265-
output_scaled = interpolate(
266-
output_scaled, out_size, mode="bilinear", align_corners=True
267-
)
249+
output_scaled = interpolate(output_scaled, out_size, mode="bilinear", align_corners=False)
268250
output += output_scaled
269251

270252
return output / (1 + len(self.scale_levels))

pytorch_toolbelt/losses/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import absolute_import
22

3+
from .dice import *
34
from .focal import *
45
from .jaccard import *
5-
from .dice import *
6-
from .lovasz import *
76
from .joint_loss import *
7+
from .lovasz import *
8+
from .soft_bce import *
9+
from .soft_ce import *
810
from .wing_loss import *

0 commit comments

Comments
 (0)