Skip to content

Commit f3acfca

Browse files
authored
Merge pull request #57 from BloodAxe/develop
PyTorch Toolbelt 0.4.3
2 parents a04e28b + d8e2a30 commit f3acfca

24 files changed

+700
-536
lines changed

pytorch_toolbelt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from __future__ import absolute_import
22

3-
__version__ = "0.4.2"
3+
__version__ = "0.4.3"

pytorch_toolbelt/datasets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .common import *
22
from .classification import *
33
from .segmentation import *
4+
from .wrappers import *

pytorch_toolbelt/datasets/common.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
1+
import cv2
2+
13
__all__ = [
2-
"IGNORE_LABEL",
34
"INPUT_IMAGE_ID_KEY",
45
"INPUT_IMAGE_KEY",
56
"INPUT_INDEX_KEY",
6-
"INPUT_MASK_16_KEY",
7-
"INPUT_MASK_32_KEY",
8-
"INPUT_MASK_4_KEY",
9-
"INPUT_MASK_64_KEY",
10-
"INPUT_MASK_8_KEY",
117
"OUTPUT_EMBEDDINGS_KEY",
128
"OUTPUT_LOGITS_KEY",
139
"OUTPUT_MASK_16_KEY",
@@ -19,15 +15,19 @@
1915
"OUTPUT_MASK_KEY",
2016
"TARGET_CLASS_KEY",
2117
"TARGET_LABELS_KEY",
18+
"TARGET_MASK_16_KEY",
2219
"TARGET_MASK_2_KEY",
20+
"TARGET_MASK_32_KEY",
21+
"TARGET_MASK_4_KEY",
22+
"TARGET_MASK_64_KEY",
23+
"TARGET_MASK_8_KEY",
2324
"TARGET_MASK_KEY",
2425
"TARGET_MASK_WEIGHT_KEY",
25-
"UNLABELED_SAMPLE",
2626
"name_for_stride",
2727
"read_image_rgb",
2828
]
2929

30-
# Smaller masks for deep supervision
30+
3131
def name_for_stride(name, stride: int):
3232
return f"{name}_{stride}"
3333

@@ -36,18 +36,17 @@ def name_for_stride(name, stride: int):
3636
INPUT_IMAGE_KEY = "image"
3737
INPUT_IMAGE_ID_KEY = "image_id"
3838

39-
TARGET_MASK_KEY = "true_mask"
4039
TARGET_MASK_WEIGHT_KEY = "true_weights"
4140
TARGET_CLASS_KEY = "true_class"
4241
TARGET_LABELS_KEY = "true_labels"
4342

44-
43+
TARGET_MASK_KEY = "true_mask"
4544
TARGET_MASK_2_KEY = name_for_stride(TARGET_MASK_KEY, 2)
46-
INPUT_MASK_4_KEY = name_for_stride(TARGET_MASK_KEY, 4)
47-
INPUT_MASK_8_KEY = name_for_stride(TARGET_MASK_KEY, 8)
48-
INPUT_MASK_16_KEY = name_for_stride(TARGET_MASK_KEY, 16)
49-
INPUT_MASK_32_KEY = name_for_stride(TARGET_MASK_KEY, 32)
50-
INPUT_MASK_64_KEY = name_for_stride(TARGET_MASK_KEY, 64)
45+
TARGET_MASK_4_KEY = name_for_stride(TARGET_MASK_KEY, 4)
46+
TARGET_MASK_8_KEY = name_for_stride(TARGET_MASK_KEY, 8)
47+
TARGET_MASK_16_KEY = name_for_stride(TARGET_MASK_KEY, 16)
48+
TARGET_MASK_32_KEY = name_for_stride(TARGET_MASK_KEY, 32)
49+
TARGET_MASK_64_KEY = name_for_stride(TARGET_MASK_KEY, 64)
5150

5251
OUTPUT_MASK_KEY = "pred_mask"
5352
OUTPUT_MASK_2_KEY = name_for_stride(OUTPUT_MASK_KEY, 2)
@@ -60,9 +59,6 @@ def name_for_stride(name, stride: int):
6059
OUTPUT_LOGITS_KEY = "pred_logits"
6160
OUTPUT_EMBEDDINGS_KEY = "pred_embeddings"
6261

63-
UNLABELED_SAMPLE = 127
64-
IGNORE_LABEL = 255
65-
6662

6763
def read_image_rgb(fname: str):
6864
image = cv2.imread(fname)[..., ::-1]

pytorch_toolbelt/datasets/segmentation.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515
TARGET_MASK_WEIGHT_KEY,
1616
TARGET_MASK_KEY,
1717
name_for_stride,
18-
UNLABELED_SAMPLE,
1918
)
2019
from ..utils import fs, image_to_tensor
2120

22-
__all__ = ["mask_to_bce_target", "mask_to_ce_target", "SegmentationDataset", "compute_weight_mask"]
21+
__all__ = ["mask_to_bce_target", "mask_to_ce_target", "read_binary_mask", "SegmentationDataset", "compute_weight_mask"]
2322

2423

2524
def mask_to_bce_target(mask):
@@ -62,8 +61,21 @@ def _block_reduce_dominant_label(x: np.ndarray, axis):
6261

6362

6463
def read_binary_mask(mask_fname: str) -> np.ndarray:
65-
mask = cv2.imread(mask_fname, cv2.IMREAD_COLOR)
66-
return cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY, dst=mask)
64+
"""
65+
Read image as binary mask, all non-zero values are treated as positive labels and converted to 1
66+
Args:
67+
mask_fname: Image with mask
68+
69+
Returns:
70+
Numpy array with {0,1} values
71+
"""
72+
73+
mask = cv2.imread(mask_fname, cv2.IMREAD_GRAYSCALE)
74+
if mask is None:
75+
raise FileNotFoundError(f"Cannot find {mask_fname}")
76+
77+
cv2.threshold(mask, thresh=0, maxval=1, type=cv2.THRESH_BINARY, dst=mask)
78+
return mask
6779

6880

6981
class SegmentationDataset(Dataset):
@@ -81,11 +93,16 @@ def __init__(
8193
need_weight_mask=False,
8294
need_supervision_masks=False,
8395
make_mask_target_fn: Callable = mask_to_ce_target,
96+
image_ids: Optional[List[str]] = None,
8497
):
8598
if mask_filenames is not None and len(image_filenames) != len(mask_filenames):
8699
raise ValueError("Number of images does not corresponds to number of targets")
87100

88-
self.image_ids = [fs.id_from_fname(fname) for fname in image_filenames]
101+
if image_ids is None:
102+
self.image_ids = [fs.id_from_fname(fname) for fname in image_filenames]
103+
else:
104+
self.image_ids = image_ids
105+
89106
self.need_weight_mask = need_weight_mask
90107
self.need_supervision_masks = need_supervision_masks
91108

@@ -100,39 +117,31 @@ def __init__(
100117
def __len__(self):
101118
return len(self.images)
102119

103-
def set_target(self, index: int, value: np.ndarray):
104-
mask_fname = self.masks[index]
105-
106-
value = (value * 255).astype(np.uint8)
107-
cv2.imwrite(mask_fname, value)
108-
109120
def __getitem__(self, index):
110121
image = self.read_image(self.images[index])
111-
122+
data = {"image": image}
112123
if self.masks is not None:
113-
mask = self.read_mask(self.masks[index])
114-
else:
115-
mask = np.ones((image.shape[0], image.shape[1], 1), dtype=np.uint8) * UNLABELED_SAMPLE
124+
data["mask"] = self.read_mask(self.masks[index])
116125

117-
data = self.transform(image=image, mask=mask)
126+
data = self.transform(**data)
118127

119128
image = data["image"]
120-
mask = data["mask"]
121-
122129
sample = {
123130
INPUT_INDEX_KEY: index,
124131
INPUT_IMAGE_ID_KEY: self.image_ids[index],
125132
INPUT_IMAGE_KEY: image_to_tensor(image),
126-
TARGET_MASK_KEY: self.make_target(mask),
127133
}
128134

129-
if self.need_weight_mask:
130-
sample[TARGET_MASK_WEIGHT_KEY] = image_to_tensor(compute_weight_mask(mask)).float()
131-
132-
if self.need_supervision_masks:
133-
for i in range(1, 5):
134-
stride = 2 ** i
135-
mask = block_reduce(mask, (2, 2), partial(_block_reduce_dominant_label))
136-
sample[name_for_stride(TARGET_MASK_KEY, stride)] = self.make_target(mask)
135+
if self.masks is not None:
136+
mask = data["mask"]
137+
sample[TARGET_MASK_KEY] = self.make_target(mask)
138+
if self.need_weight_mask:
139+
sample[TARGET_MASK_WEIGHT_KEY] = image_to_tensor(compute_weight_mask(mask)).float()
140+
141+
if self.need_supervision_masks:
142+
for i in range(1, 6):
143+
stride = 2 ** i
144+
mask = block_reduce(mask, (2, 2), partial(_block_reduce_dominant_label))
145+
sample[name_for_stride(TARGET_MASK_KEY, stride)] = self.make_target(mask)
137146

138147
return sample

pytorch_toolbelt/datasets/wrappers.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import random
2+
from typing import Any
3+
4+
from torch.utils.data import Dataset
5+
import numpy as np
6+
7+
__all__ = ["RandomSubsetDataset", "RandomSubsetWithMaskDataset"]
8+
9+
10+
class RandomSubsetDataset(Dataset):
11+
"""
12+
Wrapper to get desired number of samples from underlying dataset
13+
"""
14+
15+
def __init__(self, dataset, num_samples: int):
16+
self.dataset = dataset
17+
self.num_samples = num_samples
18+
19+
def __len__(self) -> int:
20+
return self.num_samples
21+
22+
def __getitem__(self, _) -> Any:
23+
index = random.randrange(len(self.dataset))
24+
return self.dataset[index]
25+
26+
27+
class RandomSubsetWithMaskDataset(Dataset):
28+
"""
29+
Wrapper to get desired number of samples from underlying dataset only considering
30+
samples P for which mask[P] equals True
31+
"""
32+
33+
def __init__(self, dataset: Dataset, mask: np.ndarray, num_samples: int):
34+
if (
35+
not isinstance(mask, np.ndarray)
36+
or mask.dtype != np.bool
37+
or len(mask.shape) != 1
38+
or len(mask) != len(dataset)
39+
):
40+
raise ValueError("Mask must be boolean 1-D numpy array")
41+
42+
if not mask.any():
43+
raise ValueError("Mask must have at least one positive value")
44+
45+
self.dataset = dataset
46+
self.mask = mask
47+
self.num_samples = num_samples
48+
self.indexes = np.flatnonzero(self.mask)
49+
50+
def __len__(self) -> int:
51+
return self.num_samples
52+
53+
def __getitem__(self, _) -> Any:
54+
index = random.choice(self.indexes)
55+
return self.dataset[index]

pytorch_toolbelt/inference/ensembling.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import torch
12
from torch import nn, Tensor
2-
from typing import List, Union
3+
from typing import List, Union, Iterable, Optional
34

45
__all__ = ["ApplySoftmaxTo", "ApplySigmoidTo", "Ensembler", "PickModelOutput"]
56

7+
from pytorch_toolbelt.inference.tta import _deaugment_averaging
8+
69

710
class ApplySoftmaxTo(nn.Module):
811
def __init__(self, model: nn.Module, output_key: Union[str, List[str]] = "logits", dim=1, temperature=1):
@@ -55,40 +58,35 @@ class Ensembler(nn.Module):
5558
Compute sum (or average) of outputs of several models.
5659
"""
5760

58-
def __init__(self, models: List[nn.Module], average=True, outputs=None):
61+
def __init__(self, models: List[nn.Module], reduction: str = "mean", outputs: Optional[Iterable[str]] = None):
5962
"""
6063
6164
:param models:
62-
:param average:
65+
:param reduction: Reduction key ('mean', 'sum', 'gmean', 'hmean' or None)
6366
:param outputs: Name of model outputs to average and return from Ensembler.
6467
If None, all outputs from the first model will be used.
6568
"""
6669
super().__init__()
6770
self.outputs = outputs
6871
self.models = nn.ModuleList(models)
69-
self.average = average
72+
self.reduction = reduction
7073

7174
def forward(self, *input, **kwargs): # skipcq: PYL-W0221
72-
output_0 = self.models[0](*input, **kwargs)
73-
num_models = len(self.models)
75+
outputs = [model(*input, **kwargs) for model in self.models]
7476

7577
if self.outputs:
7678
keys = self.outputs
7779
else:
78-
keys = output_0.keys()
79-
80-
for index in range(1, num_models):
81-
output_i = self.models[index](*input, **kwargs)
82-
83-
# Sum outputs
84-
for key in keys:
85-
output_0[key].add_(output_i[key])
80+
keys = outputs[0].keys()
8681

87-
if self.average:
88-
for key in keys:
89-
output_0[key].mul_(1.0 / num_models)
82+
averaged_output = {}
83+
for key in keys:
84+
predictions = [output[key] for output in outputs]
85+
predictions = torch.stack(predictions)
86+
predictions = _deaugment_averaging(predictions, self.reduction)
87+
averaged_output[key] = predictions
9088

91-
return output_0
89+
return averaged_output
9290

9391

9492
class PickModelOutput(nn.Module):

pytorch_toolbelt/inference/functional.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections import Sized, Iterable
1+
from collections.abc import Sized, Iterable
22
from typing import Union, Tuple
33

44
import torch
@@ -25,6 +25,8 @@
2525
"pad_image_tensor",
2626
"unpad_image_tensor",
2727
"unpad_xyxy_bboxes",
28+
"geometric_mean",
29+
"harmonic_mean",
2830
]
2931

3032

@@ -205,3 +207,34 @@ def unpad_xyxy_bboxes(bboxes_tensor: torch.Tensor, pad, dim=-1):
205207
pad = pad.unsqueeze(dim)
206208

207209
return bboxes_tensor - pad
210+
211+
212+
def geometric_mean(x: Tensor, dim: int) -> Tensor:
213+
"""
214+
Compute geometric mean along given dimension.
215+
This implementation assume values are in range (0...1) (Probabilities)
216+
Args:
217+
x: Input tensor of arbitrary shape
218+
dim: Dimension to reduce
219+
220+
Returns:
221+
Tensor
222+
"""
223+
return x.log().mean(dim=dim).exp()
224+
225+
226+
def harmonic_mean(x: Tensor, dim: int, eps: float = 1e-6) -> Tensor:
227+
"""
228+
Compute harmonic mean along given dimension.
229+
This implementation assume values are in range (0...1) (Probabilities)
230+
Args:
231+
x: Input tensor of arbitrary shape
232+
dim: Dimension to reduce
233+
234+
Returns:
235+
Tensor
236+
"""
237+
x = torch.reciprocal(x.clamp_min(eps))
238+
x = torch.mean(x, dim=dim)
239+
x = torch.reciprocal(x.clamp_min(eps))
240+
return x

0 commit comments

Comments
 (0)