Skip to content

Commit 8bc1cd1

Browse files
authored
Merge pull request #69 from BloodAxe/develop
Release of pytorch-toolbelt 0.5
2 parents 4a24e63 + b67d410 commit 8bc1cd1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+2129
-466
lines changed

.appveyor.yml

Lines changed: 0 additions & 19 deletions
This file was deleted.

.deepsource.toml

Lines changed: 0 additions & 13 deletions
This file was deleted.

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ var/
1919
.idea/
2020
.pytest_cache/
2121
/tests/tta_eval.csv
22+
/tests/tmp.onnx

README.md

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,29 @@
1-
# Pytorch-toolbelt
1+
# Important Update
2+
3+
![ukraine-flag](docs/480px-Flag_of_Ukraine.jpg)
4+
5+
On February 24th, 2022, Russia declared war and invaded peaceful Ukraine.
6+
After the annexation of Crimea and the occupation of the Donbas region, Putin's regime decided to destroy Ukrainian nationality.
7+
Ukrainians show fierce resistance and demonstrate to the entire world what it's like to fight for the nation's independence.
8+
9+
Ukraine's government launched a website to help russian mothers, wives & sisters find their beloved ones killed or captured in Ukraine - https://200rf.com & https://t.me/rf200_now (Telegram channel).
10+
Our goal is to inform those still in Russia & Belarus, so they refuse to assault Ukraine.
11+
12+
Help us get maximum exposure to what is happening in Ukraine, violence, and inhuman acts of terror that the "Russian World" has brought to Ukraine.
13+
This is a comprehensive Wiki on how you can help end this war: https://how-to-help-ukraine-now.super.site/
214

3-
[![Documentation Status](https://readthedocs.org/projects/pytorch-toolbelt/badge/?version=latest)](https://pytorch-toolbelt.readthedocs.io/en/latest/?badge=latest)
15+
Official channels
16+
* [Official account of the Parliament of Ukraine](https://t.me/verkhovnaradaofukraine)
17+
* [Ministry of Defence](https://www.facebook.com/MinistryofDefence.UA)
18+
* [Office of the president](https://www.facebook.com/president.gov.ua)
19+
* [Cabinet of Ministers of Ukraine](https://www.facebook.com/KabminUA)
20+
* [Center of strategic communications](https://www.facebook.com/StratcomCentreUA)
21+
* [Minister of Foreign Affairs of Ukraine](https://twitter.com/DmytroKuleba)
22+
23+
Glory to Ukraine!
24+
25+
26+
# Pytorch-toolbelt
427

528
A `pytorch-toolbelt` is a Python library with a set of bells and whistles for PyTorch for fast R&D prototyping and Kaggle farming:
629

docs/480px-Flag_of_Ukraine.jpg

18.4 KB
Loading

pytorch_toolbelt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from __future__ import absolute_import
22

3-
__version__ = "0.4.4"
3+
__version__ = "0.5.0"

pytorch_toolbelt/datasets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from .classification import *
33
from .segmentation import *
44
from .wrappers import *
5+
from .mean_std import *

pytorch_toolbelt/datasets/mean_std.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import numpy as np
2+
from typing import Optional, Tuple
3+
4+
__all__ = ["DatasetMeanStdCalculator"]
5+
6+
7+
class DatasetMeanStdCalculator:
8+
__slots__ = ["global_mean", "global_var", "n_items", "num_channels", "global_max", "global_min"]
9+
10+
"""
11+
Class to calculate running mean and std of the dataset. It helps when whole dataset does not fit entirely in RAM.
12+
"""
13+
14+
def __init__(self, num_channels: int = 3):
15+
"""
16+
Create a new instance of DatasetMeanStdCalculator
17+
18+
Args:
19+
num_channels: Number of channels in the image. Default value is 3
20+
"""
21+
super(DatasetMeanStdCalculator, self).__init__()
22+
self.num_channels = num_channels
23+
self.global_mean = None
24+
self.global_var = None
25+
self.global_max = None
26+
self.global_min = None
27+
self.n_items = 0
28+
self.reset()
29+
30+
def reset(self):
31+
self.global_mean = np.zeros(self.num_channels, dtype=np.float64)
32+
self.global_var = np.zeros(self.num_channels, dtype=np.float64)
33+
self.global_max = np.ones_like(self.global_mean) * float("-inf")
34+
self.global_min = np.ones_like(self.global_mean) * float("+inf")
35+
self.n_items = 0
36+
37+
def accumulate(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> None:
38+
"""
39+
Compute mean and std of a single image and integrates it into global statistics
40+
Args:
41+
image: Input image (Must be HWC, with number of channels C equal to self.num_channels)
42+
mask: Optional mask to include only certain parts of image from statistics computation.
43+
Only non-zero elements will be included,
44+
"""
45+
if len(image.shape) == 2:
46+
image = np.expand_dims(image, axis=-1)
47+
48+
if self.num_channels != image.shape[2]:
49+
raise RuntimeError(f"Number of channels in image must be {self.num_channels}, got {image.shape[2]}.")
50+
image = image.reshape((-1, self.num_channels))
51+
52+
if mask is not None:
53+
mask = mask.reshape((mask.shape[0] * mask.shape[1], 1))
54+
image = image[mask]
55+
56+
# In case the whole image is masked out, we exclude it entirely
57+
if len(image) == 0:
58+
return
59+
60+
mean = np.mean(image, axis=0)
61+
std = np.std(image, axis=0)
62+
63+
self.global_mean += np.squeeze(mean)
64+
self.global_var += np.squeeze(std) ** 2
65+
self.global_max = np.maximum(self.global_max, np.max(image, axis=0))
66+
self.global_min = np.minimum(self.global_min, np.min(image, axis=0))
67+
self.n_items += 1
68+
69+
def compute(self) -> Tuple[np.ndarray, np.ndarray]:
70+
"""
71+
Compute dataset-level mean & std
72+
73+
Returns:
74+
Tuple of global [mean, std] per channel
75+
"""
76+
return self.global_mean / self.n_items, np.sqrt(self.global_var / self.n_items)

pytorch_toolbelt/inference/ensembling.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
import torch
22
from torch import nn, Tensor
3-
from typing import List, Union, Iterable, Optional, Dict
3+
from typing import List, Union, Iterable, Optional, Dict, Tuple
44

55
__all__ = ["ApplySoftmaxTo", "ApplySigmoidTo", "Ensembler", "PickModelOutput", "SelectByIndex"]
66

77
from pytorch_toolbelt.inference.tta import _deaugment_averaging
88

99

1010
class ApplySoftmaxTo(nn.Module):
11-
def __init__(self, model: nn.Module, output_key: Union[str, List[str]] = "logits", dim=1, temperature=1):
11+
output_keys: Tuple
12+
temperature: float
13+
dim: int
14+
15+
def __init__(
16+
self, model: nn.Module, output_key: Union[str, Iterable[str]] = "logits", dim: int = 1, temperature: float = 1
17+
):
1218
"""
1319
Apply softmax activation on given output(s) of the model
1420
:param model: Model to wrap
@@ -17,39 +23,42 @@ def __init__(self, model: nn.Module, output_key: Union[str, List[str]] = "logits
1723
:param temperature: Temperature scaling coefficient. Values > 1 will make logits sharper.
1824
"""
1925
super().__init__()
20-
output_key = output_key if isinstance(output_key, (list, tuple)) else [output_key]
2126
# By converting to set, we prevent double-activation by passing output_key=["logits", "logits"]
22-
self.output_keys = set(output_key)
27+
output_key = tuple(set(output_key)) if isinstance(output_key, Iterable) else tuple([output_key])
28+
self.output_keys = output_key
2329
self.model = model
2430
self.dim = dim
2531
self.temperature = temperature
2632

2733
def forward(self, *input, **kwargs):
2834
output = self.model(*input, **kwargs)
2935
for key in self.output_keys:
30-
output[key] = output[key].mul(self.temperature).softmax(dim=1)
36+
output[key] = output[key].mul(self.temperature).softmax(dim=self.dim)
3137
return output
3238

3339

3440
class ApplySigmoidTo(nn.Module):
35-
def __init__(self, model: nn.Module, output_key: Union[str, List[str]] = "logits", temperature=1):
41+
output_keys: Tuple
42+
temperature: float
43+
44+
def __init__(self, model: nn.Module, output_key: Union[str, Iterable[str]] = "logits", temperature=1):
3645
"""
3746
Apply sigmoid activation on given output(s) of the model
3847
:param model: Model to wrap
3948
:param output_key: string or list of strings, indicating to what outputs sigmoid activation should be applied.
4049
:param temperature: Temperature scaling coefficient. Values > 1 will make logits sharper.
4150
"""
4251
super().__init__()
43-
output_key = output_key if isinstance(output_key, (list, tuple)) else [output_key]
4452
# By converting to set, we prevent double-activation by passing output_key=["logits", "logits"]
45-
self.output_keys = set(output_key)
53+
output_key = tuple(set(output_key)) if isinstance(output_key, Iterable) else tuple([output_key])
54+
self.output_keys = output_key
4655
self.model = model
4756
self.temperature = temperature
4857

4958
def forward(self, *input, **kwargs): # skipcq: PYL-W0221
5059
output = self.model(*input, **kwargs)
5160
for key in self.output_keys:
52-
output[key] = output[key].mul(self.temperature).sigmoid()
61+
output[key] = output[key].mul(self.temperature).sigmoid_()
5362
return output
5463

5564

pytorch_toolbelt/inference/functional.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,9 @@ def torch_transpose2(x: Tensor):
140140
return x.transpose(3, 2)
141141

142142

143-
def pad_image_tensor(image_tensor: Tensor, pad_size: Union[int, Tuple[int, int]] = 32):
143+
def pad_image_tensor(
144+
image_tensor: Tensor, pad_size: Union[int, Tuple[int, int]] = 32
145+
) -> Tuple[Tensor, Tuple[int, int, int, int]]:
144146
"""Pad input tensor to make it's height and width dividable by @pad_size
145147
146148
:param image_tensor: 4D image tensor of shape NCHW

pytorch_toolbelt/inference/tiles.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,14 +334,24 @@ def integrate_batch(self, batch: torch.Tensor, crop_coords):
334334
if len(batch) != len(crop_coords):
335335
raise ValueError("Number of images in batch does not correspond to number of coordinates")
336336

337-
batch = batch.to(device=self.image.device)
337+
if batch.device != self.image.device:
338+
batch = batch.to(device=self.image.device)
339+
338340
for tile, (x, y, tile_width, tile_height) in zip(batch, crop_coords):
339341
self.image[:, y : y + tile_height, x : x + tile_width] += tile * self.weight
340342
self.norm_mask[:, y : y + tile_height, x : x + tile_width] += self.weight
341343

344+
@property
345+
def device(self) -> torch.device:
346+
return self.image.device
347+
342348
def merge(self) -> torch.Tensor:
343349
return self.image / self.norm_mask
344350

351+
def merge_(self) -> torch.Tensor:
352+
self.image /= self.norm_mask
353+
return self.image
354+
345355

346356
@pytorch_toolbelt_deprecated("This class is deprecated and will be removed in 0.5.0. Please use TileMerger instead.")
347357
class CudaTileMerger(TileMerger):

0 commit comments

Comments
 (0)