Skip to content

Commit a04e28b

Browse files
authored
Merge pull request #55 from BloodAxe/develop
Pytorch-Toolbelt 0.4.2
2 parents 3eaeea6 + eec7f94 commit a04e28b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+2284
-875
lines changed

README.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,6 @@ logits = tta.fliplr_image2label(model, input)
160160
# Truly functional TTA for image segmentation using D4 augmentation:
161161
logits = tta.d4_image2mask(model, input)
162162

163-
# TTA using wrapper module:
164-
tta_model = tta.TTAWrapper(model, tta.fivecrop_image2label, crop_size=512)
165-
logits = tta_model(input)
166163
```
167164

168165
### Inference on huge images:

pytorch_toolbelt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from __future__ import absolute_import
22

3-
__version__ = "0.4.1"
3+
__version__ = "0.4.2"

pytorch_toolbelt/datasets/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .common import *
2+
from .classification import *
3+
from .segmentation import *
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from typing import Optional, List
2+
3+
import albumentations as A
4+
import torch
5+
from torch.utils.data import Dataset
6+
7+
from .common import read_image_rgb, INPUT_IMAGE_KEY, INPUT_IMAGE_ID_KEY, INPUT_INDEX_KEY, TARGET_CLASS_KEY
8+
from ..utils import fs, image_to_tensor
9+
10+
__all__ = ["ClassificationDataset", "label_to_tensor"]
11+
12+
13+
def label_to_tensor(x):
14+
return torch.tensor(x).long()
15+
16+
17+
class ClassificationDataset(Dataset):
18+
"""
19+
Dataset for image classification tasks
20+
"""
21+
22+
def __init__(
23+
self,
24+
image_filenames: List[str],
25+
labels: Optional[List[str]],
26+
transform: A.Compose,
27+
read_image_fn=read_image_rgb,
28+
make_target_fn=label_to_tensor,
29+
):
30+
if labels is not None and len(image_filenames) != len(labels):
31+
raise ValueError("Number of images does not corresponds to number of targets")
32+
33+
self.image_ids = [fs.id_from_fname(fname) for fname in image_filenames]
34+
self.labels = labels
35+
self.images = image_filenames
36+
self.read_image = read_image_fn
37+
self.transform = transform
38+
self.make_target = make_target_fn
39+
40+
def __len__(self):
41+
return len(self.images)
42+
43+
def __getitem__(self, index):
44+
image = self.read_image(self.images[index])
45+
data = self.transform(image=image)
46+
47+
image = data["image"]
48+
49+
sample = {
50+
INPUT_INDEX_KEY: index,
51+
INPUT_IMAGE_ID_KEY: self.image_ids[index],
52+
INPUT_IMAGE_KEY: image_to_tensor(image),
53+
}
54+
55+
if self.labels is not None:
56+
sample[TARGET_CLASS_KEY] = self.make_target(self.labels[index])
57+
return sample

pytorch_toolbelt/datasets/common.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
__all__ = [
2+
"IGNORE_LABEL",
3+
"INPUT_IMAGE_ID_KEY",
4+
"INPUT_IMAGE_KEY",
5+
"INPUT_INDEX_KEY",
6+
"INPUT_MASK_16_KEY",
7+
"INPUT_MASK_32_KEY",
8+
"INPUT_MASK_4_KEY",
9+
"INPUT_MASK_64_KEY",
10+
"INPUT_MASK_8_KEY",
11+
"OUTPUT_EMBEDDINGS_KEY",
12+
"OUTPUT_LOGITS_KEY",
13+
"OUTPUT_MASK_16_KEY",
14+
"OUTPUT_MASK_2_KEY",
15+
"OUTPUT_MASK_32_KEY",
16+
"OUTPUT_MASK_4_KEY",
17+
"OUTPUT_MASK_64_KEY",
18+
"OUTPUT_MASK_8_KEY",
19+
"OUTPUT_MASK_KEY",
20+
"TARGET_CLASS_KEY",
21+
"TARGET_LABELS_KEY",
22+
"TARGET_MASK_2_KEY",
23+
"TARGET_MASK_KEY",
24+
"TARGET_MASK_WEIGHT_KEY",
25+
"UNLABELED_SAMPLE",
26+
"name_for_stride",
27+
"read_image_rgb",
28+
]
29+
30+
# Smaller masks for deep supervision
31+
def name_for_stride(name, stride: int):
32+
return f"{name}_{stride}"
33+
34+
35+
INPUT_INDEX_KEY = "index"
36+
INPUT_IMAGE_KEY = "image"
37+
INPUT_IMAGE_ID_KEY = "image_id"
38+
39+
TARGET_MASK_KEY = "true_mask"
40+
TARGET_MASK_WEIGHT_KEY = "true_weights"
41+
TARGET_CLASS_KEY = "true_class"
42+
TARGET_LABELS_KEY = "true_labels"
43+
44+
45+
TARGET_MASK_2_KEY = name_for_stride(TARGET_MASK_KEY, 2)
46+
INPUT_MASK_4_KEY = name_for_stride(TARGET_MASK_KEY, 4)
47+
INPUT_MASK_8_KEY = name_for_stride(TARGET_MASK_KEY, 8)
48+
INPUT_MASK_16_KEY = name_for_stride(TARGET_MASK_KEY, 16)
49+
INPUT_MASK_32_KEY = name_for_stride(TARGET_MASK_KEY, 32)
50+
INPUT_MASK_64_KEY = name_for_stride(TARGET_MASK_KEY, 64)
51+
52+
OUTPUT_MASK_KEY = "pred_mask"
53+
OUTPUT_MASK_2_KEY = name_for_stride(OUTPUT_MASK_KEY, 2)
54+
OUTPUT_MASK_4_KEY = name_for_stride(OUTPUT_MASK_KEY, 4)
55+
OUTPUT_MASK_8_KEY = name_for_stride(OUTPUT_MASK_KEY, 8)
56+
OUTPUT_MASK_16_KEY = name_for_stride(OUTPUT_MASK_KEY, 16)
57+
OUTPUT_MASK_32_KEY = name_for_stride(OUTPUT_MASK_KEY, 32)
58+
OUTPUT_MASK_64_KEY = name_for_stride(OUTPUT_MASK_KEY, 64)
59+
60+
OUTPUT_LOGITS_KEY = "pred_logits"
61+
OUTPUT_EMBEDDINGS_KEY = "pred_embeddings"
62+
63+
UNLABELED_SAMPLE = 127
64+
IGNORE_LABEL = 255
65+
66+
67+
def read_image_rgb(fname: str):
68+
image = cv2.imread(fname)[..., ::-1]
69+
if image is None:
70+
raise IOError("Cannot read " + fname)
71+
return image
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
from functools import partial
2+
from typing import Optional, List, Callable
3+
4+
import albumentations as A
5+
import cv2
6+
import numpy as np
7+
from skimage.measure import block_reduce
8+
from torch.utils.data import Dataset
9+
10+
from .common import (
11+
read_image_rgb,
12+
INPUT_IMAGE_KEY,
13+
INPUT_IMAGE_ID_KEY,
14+
INPUT_INDEX_KEY,
15+
TARGET_MASK_WEIGHT_KEY,
16+
TARGET_MASK_KEY,
17+
name_for_stride,
18+
UNLABELED_SAMPLE,
19+
)
20+
from ..utils import fs, image_to_tensor
21+
22+
__all__ = ["mask_to_bce_target", "mask_to_ce_target", "SegmentationDataset", "compute_weight_mask"]
23+
24+
25+
def mask_to_bce_target(mask):
26+
return image_to_tensor(mask, dummy_channels_dim=True).float()
27+
28+
29+
def mask_to_ce_target(mask):
30+
return image_to_tensor(mask, dummy_channels_dim=False).long()
31+
32+
33+
def compute_weight_mask(mask: np.ndarray, edge_weight=4) -> np.ndarray:
34+
from skimage.morphology import binary_dilation, binary_erosion
35+
36+
binary_mask = mask > 0
37+
weight_mask = np.ones(mask.shape[:2]).astype(np.float32)
38+
39+
if binary_mask.any():
40+
dilated = binary_dilation(binary_mask, structure=np.ones((5, 5), dtype=np.bool))
41+
eroded = binary_erosion(binary_mask, structure=np.ones((5, 5), dtype=np.bool))
42+
43+
a = dilated & ~binary_mask
44+
b = binary_mask & ~eroded
45+
46+
weight_mask = (a | b).astype(np.float32) * edge_weight + 1
47+
weight_mask = cv2.GaussianBlur(weight_mask, ksize=(5, 5), sigmaX=5)
48+
return weight_mask
49+
50+
51+
def _block_reduce_dominant_label(x: np.ndarray, axis):
52+
try:
53+
# minlength is +1 to num classes because we must account for IGNORE_LABEL
54+
minlength = np.max(x) + 1
55+
bincount_fn = partial(np.bincount, minlength=minlength)
56+
counts = np.apply_along_axis(bincount_fn, -1, x.reshape((x.shape[0], x.shape[1], -1)))
57+
reduced = np.argmax(counts, axis=-1)
58+
return reduced
59+
except Exception as e:
60+
print(e)
61+
print("shape", x.shape, "axis", axis)
62+
63+
64+
def read_binary_mask(mask_fname: str) -> np.ndarray:
65+
mask = cv2.imread(mask_fname, cv2.IMREAD_COLOR)
66+
return cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY, dst=mask)
67+
68+
69+
class SegmentationDataset(Dataset):
70+
"""
71+
Dataset class suitable for segmentation tasks
72+
"""
73+
74+
def __init__(
75+
self,
76+
image_filenames: List[str],
77+
mask_filenames: Optional[List[str]],
78+
transform: A.Compose,
79+
read_image_fn: Callable = read_image_rgb,
80+
read_mask_fn: Callable = cv2.imread,
81+
need_weight_mask=False,
82+
need_supervision_masks=False,
83+
make_mask_target_fn: Callable = mask_to_ce_target,
84+
):
85+
if mask_filenames is not None and len(image_filenames) != len(mask_filenames):
86+
raise ValueError("Number of images does not corresponds to number of targets")
87+
88+
self.image_ids = [fs.id_from_fname(fname) for fname in image_filenames]
89+
self.need_weight_mask = need_weight_mask
90+
self.need_supervision_masks = need_supervision_masks
91+
92+
self.images = image_filenames
93+
self.masks = mask_filenames
94+
self.read_image = read_image_fn
95+
self.read_mask = read_mask_fn
96+
97+
self.transform = transform
98+
self.make_target = make_mask_target_fn
99+
100+
def __len__(self):
101+
return len(self.images)
102+
103+
def set_target(self, index: int, value: np.ndarray):
104+
mask_fname = self.masks[index]
105+
106+
value = (value * 255).astype(np.uint8)
107+
cv2.imwrite(mask_fname, value)
108+
109+
def __getitem__(self, index):
110+
image = self.read_image(self.images[index])
111+
112+
if self.masks is not None:
113+
mask = self.read_mask(self.masks[index])
114+
else:
115+
mask = np.ones((image.shape[0], image.shape[1], 1), dtype=np.uint8) * UNLABELED_SAMPLE
116+
117+
data = self.transform(image=image, mask=mask)
118+
119+
image = data["image"]
120+
mask = data["mask"]
121+
122+
sample = {
123+
INPUT_INDEX_KEY: index,
124+
INPUT_IMAGE_ID_KEY: self.image_ids[index],
125+
INPUT_IMAGE_KEY: image_to_tensor(image),
126+
TARGET_MASK_KEY: self.make_target(mask),
127+
}
128+
129+
if self.need_weight_mask:
130+
sample[TARGET_MASK_WEIGHT_KEY] = image_to_tensor(compute_weight_mask(mask)).float()
131+
132+
if self.need_supervision_masks:
133+
for i in range(1, 5):
134+
stride = 2 ** i
135+
mask = block_reduce(mask, (2, 2), partial(_block_reduce_dominant_label))
136+
sample[name_for_stride(TARGET_MASK_KEY, stride)] = self.make_target(mask)
137+
138+
return sample

0 commit comments

Comments
 (0)