Skip to content

Commit ee72463

Browse files
authored
Merge pull request #74 from BloodAxe/develop
Release 0.5.1
2 parents 8bc1cd1 + 8faed01 commit ee72463

File tree

23 files changed

+581
-75
lines changed

23 files changed

+581
-75
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ jobs:
2525
matrix.operating-system == 'ubuntu-latest' ||
2626
matrix.operating-system == 'windows-latest'
2727
run: >
28-
pip install torch==1.8.1+cpu torchvision==0.9.1+cpu
28+
pip install torch==1.10.1+cpu torchvision==0.11.2+cpu
2929
-f https://download.pytorch.org/whl/torch_stable.html
3030
- name: Install PyTorch on MacOS
3131
if: matrix.operating-system == 'macos-latest'
32-
run: pip install torch==1.8.1 torchvision==0.9.1
32+
run: pip install torch==1.10.1 torchvision==0.11.2
3333
- name: Install dependencies
3434
run: pip install .[${{ matrix.pytorch-toolbelt-version }}]
3535
- name: Install linters
@@ -59,6 +59,6 @@ jobs:
5959
- name: Update pip
6060
run: python -m pip install --upgrade pip
6161
- name: Install Black
62-
run: pip install black==20.8b1
62+
run: pip install black==22.3.0
6363
- name: Run Black
6464
run: black --config=black.toml --check .

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ var/
2020
.pytest_cache/
2121
/tests/tta_eval.csv
2222
/tests/tmp.onnx
23+
/tests/test_plot_confusion_matrix.png

notebooks/tiled_inference.ipynb

Lines changed: 368 additions & 0 deletions
Large diffs are not rendered by default.

pytorch_toolbelt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from __future__ import absolute_import
22

3-
__version__ = "0.5.0"
3+
__version__ = "0.5.1"

pytorch_toolbelt/datasets/segmentation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def __getitem__(self, index):
139139

140140
if self.need_supervision_masks:
141141
for i in range(1, 6):
142-
stride = 2 ** i
142+
stride = 2**i
143143
mask = block_reduce(mask, (2, 2), partial(_block_reduce_dominant_label))
144144
sample[name_for_stride(TARGET_MASK_KEY, stride)] = self.make_target(mask)
145145

pytorch_toolbelt/inference/functional.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,27 @@
77
from ..utils.support import pytorch_toolbelt_deprecated
88

99
__all__ = [
10+
"geometric_mean",
11+
"harmonic_mean",
12+
"logodd_mean",
13+
"pad_image_tensor",
14+
"torch_fliplr",
15+
"torch_flipud",
1016
"torch_none",
17+
"torch_rot180",
18+
"torch_rot270",
1119
"torch_rot90",
12-
"torch_rot90_cw",
1320
"torch_rot90_ccw",
14-
"torch_transpose_rot90_cw",
15-
"torch_transpose_rot90_ccw",
1621
"torch_rot90_ccw_transpose",
22+
"torch_rot90_cw",
1723
"torch_rot90_cw_transpose",
18-
"torch_rot180",
19-
"torch_rot270",
20-
"torch_fliplr",
21-
"torch_flipud",
2224
"torch_transpose",
2325
"torch_transpose2",
2426
"torch_transpose_",
25-
"pad_image_tensor",
27+
"torch_transpose_rot90_ccw",
28+
"torch_transpose_rot90_cw",
2629
"unpad_image_tensor",
2730
"unpad_xyxy_bboxes",
28-
"geometric_mean",
29-
"harmonic_mean",
3031
]
3132

3233

@@ -240,3 +241,23 @@ def harmonic_mean(x: Tensor, dim: int, eps: float = 1e-6) -> Tensor:
240241
x = torch.mean(x, dim=dim)
241242
x = torch.reciprocal(x.clamp_min(eps))
242243
return x
244+
245+
246+
def logodd_mean(x: Tensor, dim: int, eps: float = 1e-6) -> Tensor:
247+
"""
248+
Compute log-odd mean along given dimension.
249+
logodd = log(p / (1 - p))
250+
251+
This implementation assume values are in range [0, 1] (Probabilities)
252+
Args:
253+
x: Input tensor of arbitrary shape
254+
dim: Dimension to reduce
255+
256+
Returns:
257+
Tensor
258+
"""
259+
x = x.clamp(min=eps, max=1.0 - eps)
260+
x = torch.log(x / (1 - x))
261+
x = torch.mean(x, dim=dim)
262+
x = torch.exp(x) / (1 + torch.exp(x))
263+
return x

pytorch_toolbelt/inference/tiles.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class ImageSlicer:
6060
Helper class to slice image into tiles and merge them back
6161
"""
6262

63-
def __init__(self, image_shape, tile_size, tile_step=0, image_margin=0, weight="mean"):
63+
def __init__(self, image_shape: Tuple[int, int], tile_size, tile_step=0, image_margin=0, weight="mean"):
6464
"""
6565
6666
:param image_shape: Shape of the source image (H, W)
@@ -122,12 +122,6 @@ def __init__(self, image_shape, tile_size, tile_step=0, image_margin=0, weight="
122122
else:
123123
margin_left = margin_right = margin_top = margin_bottom = image_margin
124124

125-
if (self.image_width + margin_left + margin_right) % self.tile_size[1] != 0:
126-
raise ValueError()
127-
128-
if (self.image_height + margin_top + margin_bottom) % self.tile_size[0] != 0:
129-
raise ValueError()
130-
131125
self.margin_left = margin_left
132126
self.margin_right = margin_right
133127
self.margin_top = margin_top
@@ -337,6 +331,10 @@ def integrate_batch(self, batch: torch.Tensor, crop_coords):
337331
if batch.device != self.image.device:
338332
batch = batch.to(device=self.image.device)
339333

334+
# Ensure that input batch dtype match the target dtyle of the accumulator
335+
if batch.dtype != self.image.dtype:
336+
batch = batch.type_as(self.image)
337+
340338
for tile, (x, y, tile_width, tile_height) in zip(batch, crop_coords):
341339
self.image[:, y : y + tile_height, x : x + tile_width] += tile * self.weight
342340
self.norm_mask[:, y : y + tile_height, x : x + tile_width] += self.weight

pytorch_toolbelt/inference/tta.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ def _deaugment_averaging(x: Tensor, reduction: MaybeStrOrCallable) -> Tensor:
7979
x = F.geometric_mean(x, dim=0)
8080
elif reduction in {"hmean", "harmonic_mean"}:
8181
x = F.harmonic_mean(x, dim=0)
82+
elif reduction == "logodd":
83+
x = F.logodd_mean(x, dim=0)
8284
elif callable(reduction):
8385
x = reduction(x, dim=0)
8486
elif reduction in {None, "None", "none"}:

pytorch_toolbelt/modules/activations.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,15 @@ def instantiate_activation_block(activation_name: str, **kwargs) -> nn.Module:
264264

265265
act_params = {}
266266

267-
if "inplace" in kwargs and activation_name in {ACT_RELU, ACT_RELU6, ACT_LEAKY_RELU, ACT_SELU, ACT_CELU, ACT_ELU}:
267+
if "inplace" in kwargs and activation_name in {
268+
ACT_RELU,
269+
ACT_RELU6,
270+
ACT_LEAKY_RELU,
271+
ACT_SELU,
272+
ACT_SILU,
273+
ACT_CELU,
274+
ACT_ELU,
275+
}:
268276
act_params["inplace"] = kwargs["inplace"]
269277

270278
if "slope" in kwargs and activation_name in {ACT_LEAKY_RELU}:

pytorch_toolbelt/modules/decoders/unet_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels
100100
super().__init__()
101101

102102
if not isinstance(decoder_features, list):
103-
decoder_features = [decoder_features * (2 ** i) for i in range(len(feature_maps))]
103+
decoder_features = [decoder_features * (2**i) for i in range(len(feature_maps))]
104104

105105
blocks = []
106106
for block_index, in_enc_features in enumerate(feature_maps[:-1]):

pytorch_toolbelt/modules/dropblock.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _compute_block_mask(self, mask):
6868
return block_mask, keeped
6969

7070
def _compute_gamma(self, x):
71-
return self.drop_prob / (self.block_size ** 2)
71+
return self.drop_prob / (self.block_size**2)
7272

7373

7474
class DropBlock3D(DropBlock2D):
@@ -131,7 +131,7 @@ def _compute_block_mask(self, mask):
131131
return block_mask
132132

133133
def _compute_gamma(self, x):
134-
return self.drop_prob / (self.block_size ** 3)
134+
return self.drop_prob / (self.block_size**3)
135135

136136

137137
class DropBlockScheduled(nn.Module):

pytorch_toolbelt/modules/encoders/swin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, at
9494
self.window_size = window_size # Wh, Ww
9595
self.num_heads = num_heads
9696
head_dim = dim // num_heads
97-
self.scale = qk_scale or head_dim ** -0.5
97+
self.scale = qk_scale or head_dim**-0.5
9898

9999
# define a parameter table of relative position bias
100100
self.relative_position_bias_table = nn.Parameter(
@@ -587,7 +587,7 @@ def __init__(
587587
self.layers = nn.ModuleList()
588588
for i_layer in range(self.num_layers):
589589
layer = BasicLayer(
590-
dim=int(embed_dim * 2 ** i_layer),
590+
dim=int(embed_dim * 2**i_layer),
591591
depth=depths[i_layer],
592592
num_heads=num_heads[i_layer],
593593
window_size=window_size,
@@ -604,7 +604,7 @@ def __init__(
604604
)
605605
self.layers.append(layer)
606606

607-
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
607+
num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
608608
self.num_features = num_features
609609

610610
# add a norm layer for each output

pytorch_toolbelt/modules/encoders/unet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ def __init__(
2727
if pool_block is None:
2828
pool_block = partial(nn.MaxPool2d, kernel_size=2, stride=2)
2929

30-
feature_maps = [out_channels * (growth_factor ** i) for i in range(num_layers)]
31-
strides = [2 ** i for i in range(num_layers)]
30+
feature_maps = [out_channels * (growth_factor**i) for i in range(num_layers)]
31+
strides = [2**i for i in range(num_layers)]
3232
super().__init__(feature_maps, strides, layers=list(range(num_layers)))
3333

3434
input_filters = in_channels

pytorch_toolbelt/modules/ocnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def forward(self, x): # skipcq: PYL-W0221
6262
key = self.f_key(x).view(batch_size, self.key_channels, -1)
6363

6464
sim_map = torch.matmul(query, key)
65-
sim_map = (self.key_channels ** -0.5) * sim_map
65+
sim_map = (self.key_channels**-0.5) * sim_map
6666
sim_map = F.softmax(sim_map, dim=-1)
6767

6868
context = torch.matmul(sim_map, value)
@@ -300,7 +300,7 @@ def forward(self, x):
300300
key_local = key_local.contiguous().view(batch_size, self.key_channels, -1)
301301

302302
sim_map = torch.matmul(query_local, key_local)
303-
sim_map = (self.key_channels ** -0.5) * sim_map
303+
sim_map = (self.key_channels**-0.5) * sim_map
304304
sim_map = F.softmax(sim_map, dim=-1)
305305

306306
context_local = torch.matmul(sim_map, value_local)

pytorch_toolbelt/modules/upsample.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,14 @@ def icnr_init(tensor: torch.Tensor, upscale_factor=2, initializer=nn.init.kaimin
5050
.. _Checkerboard artifact free sub-pixel convolution:
5151
https://arxiv.org/abs/1707.02937
5252
"""
53-
new_shape = [int(tensor.shape[0] / (upscale_factor ** 2))] + list(tensor.shape[1:])
53+
new_shape = [int(tensor.shape[0] / (upscale_factor**2))] + list(tensor.shape[1:])
5454
subkernel = torch.zeros(new_shape)
5555
subkernel = initializer(subkernel)
5656
subkernel = subkernel.transpose(0, 1)
5757

5858
subkernel = subkernel.contiguous().view(subkernel.shape[0], subkernel.shape[1], -1)
5959

60-
kernel = subkernel.repeat(1, 1, upscale_factor ** 2)
60+
kernel = subkernel.repeat(1, 1, upscale_factor**2)
6161

6262
transposed_shape = [tensor.shape[1]] + [tensor.shape[0]] + list(tensor.shape[2:])
6363
kernel = kernel.contiguous().view(transposed_shape)
@@ -77,7 +77,7 @@ class DepthToSpaceUpsample2d(nn.Module):
7777

7878
def __init__(self, in_channels: int, out_channels: int, scale_factor: int = 2):
7979
super().__init__()
80-
n = 2 ** scale_factor
80+
n = 2**scale_factor
8181
self.conv = nn.Conv2d(in_channels, out_channels * n, kernel_size=3, padding=1, bias=False)
8282
self.out_channels = out_channels
8383
self.shuffle = nn.PixelShuffle(upscale_factor=scale_factor)

pytorch_toolbelt/optimization/functional.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def get_optimizable_parameters(model: nn.Module) -> Iterator[nn.Parameter]:
4646
return filter(lambda x: x.requires_grad, model.parameters())
4747

4848

49-
def freeze_model(module: nn.Module, freeze_parameters: Optional[bool] = True, freeze_bn: Optional[bool] = True):
49+
def freeze_model(
50+
module: nn.Module, freeze_parameters: Optional[bool] = True, freeze_bn: Optional[bool] = True
51+
) -> nn.Module:
5052
"""
5153
Change 'requires_grad' value for module and it's child modules and
5254
optionally freeze batchnorm modules.
@@ -70,3 +72,5 @@ def freeze_model(module: nn.Module, freeze_parameters: Optional[bool] = True, fr
7072
for m in module.modules():
7173
if isinstance(m, bn_types):
7274
module.track_running_stats = not freeze_bn
75+
76+
return module

pytorch_toolbelt/optimization/lr_schedules.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def get_lr(self):
8080
def compute_lr(base_lr):
8181
return (
8282
self.eta_min
83-
+ (base_lr * self.gamma ** self.last_epoch - self.eta_min)
83+
+ (base_lr * self.gamma**self.last_epoch - self.eta_min)
8484
* (1 + math.cos(math.pi * self.last_epoch / self.T_max))
8585
/ 2
8686
)
@@ -110,7 +110,7 @@ def get_lr(self):
110110

111111
return [
112112
self.eta_min
113-
+ (base_lr * self.gamma ** self.last_epoch - self.eta_min)
113+
+ (base_lr * self.gamma**self.last_epoch - self.eta_min)
114114
* (1 + math.cos(math.pi * self.T_cur / self.T_i))
115115
/ 2
116116
for base_lr in self.base_lrs

pytorch_toolbelt/utils/catalyst/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def _f1_from_confusion_matrix(
234234
true_sum = np.array([true_sum.sum()])
235235

236236
# Finally, we have all our sufficient statistics. Divide! #
237-
beta2 = beta ** 2
237+
beta2 = beta**2
238238

239239
# Divide, and on zero-division, set scores and/or warn according to
240240
# zero_division:

pytorch_toolbelt/utils/fs.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"find_images_in_dir",
1717
"find_in_dir",
1818
"find_in_dir_glob",
19+
"find_subdirectories_in_dir",
1920
"has_ext",
2021
"has_image_ext",
2122
"id_from_fname",
@@ -45,6 +46,19 @@ def find_in_dir(dirname: str) -> List[str]:
4546
return [os.path.join(dirname, fname) for fname in sorted(os.listdir(dirname))]
4647

4748

49+
def find_subdirectories_in_dir(dirname: str) -> List[str]:
50+
"""
51+
Retrieve list of subdirectories (non-recursive) in the given directory.
52+
Args:
53+
dirname: Target directory name
54+
55+
Returns:
56+
Sorted list of absolute paths to directories
57+
"""
58+
all_entries = find_in_dir(dirname)
59+
return [entry for entry in all_entries if os.path.isdir(entry)]
60+
61+
4862
def find_in_dir_with_ext(dirname: str, extensions: Union[str, List[str]]) -> List[str]:
4963
return [os.path.join(dirname, fname) for fname in sorted(os.listdir(dirname)) if has_ext(fname, extensions)]
5064

@@ -107,7 +121,7 @@ def read_rgb_image(fname: Union[str, Path]) -> np.ndarray:
107121
if type(fname) != str:
108122
fname = str(fname)
109123

110-
image = cv2.imread(fname, cv2.IMREAD_UNCHANGED)
124+
image = cv2.imread(fname, cv2.IMREAD_COLOR)
111125
if image is None:
112126
raise IOError(f'Cannot read image "{fname}"')
113127

0 commit comments

Comments
 (0)