Skip to content

Commit 8faed01

Browse files
authored
Merge pull request #73 from BloodAxe/feature/0.5.1
Feature/0.5.1
2 parents bdbdbdc + 399fa45 commit 8faed01

File tree

5 files changed

+78
-32
lines changed

5 files changed

+78
-32
lines changed

pytorch_toolbelt/inference/functional.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,27 @@
77
from ..utils.support import pytorch_toolbelt_deprecated
88

99
__all__ = [
10+
"geometric_mean",
11+
"harmonic_mean",
12+
"logodd_mean",
13+
"pad_image_tensor",
14+
"torch_fliplr",
15+
"torch_flipud",
1016
"torch_none",
17+
"torch_rot180",
18+
"torch_rot270",
1119
"torch_rot90",
12-
"torch_rot90_cw",
1320
"torch_rot90_ccw",
14-
"torch_transpose_rot90_cw",
15-
"torch_transpose_rot90_ccw",
1621
"torch_rot90_ccw_transpose",
22+
"torch_rot90_cw",
1723
"torch_rot90_cw_transpose",
18-
"torch_rot180",
19-
"torch_rot270",
20-
"torch_fliplr",
21-
"torch_flipud",
2224
"torch_transpose",
2325
"torch_transpose2",
2426
"torch_transpose_",
25-
"pad_image_tensor",
27+
"torch_transpose_rot90_ccw",
28+
"torch_transpose_rot90_cw",
2629
"unpad_image_tensor",
2730
"unpad_xyxy_bboxes",
28-
"geometric_mean",
29-
"harmonic_mean",
3031
]
3132

3233

@@ -240,3 +241,23 @@ def harmonic_mean(x: Tensor, dim: int, eps: float = 1e-6) -> Tensor:
240241
x = torch.mean(x, dim=dim)
241242
x = torch.reciprocal(x.clamp_min(eps))
242243
return x
244+
245+
246+
def logodd_mean(x: Tensor, dim: int, eps: float = 1e-6) -> Tensor:
247+
"""
248+
Compute log-odd mean along given dimension.
249+
logodd = log(p / (1 - p))
250+
251+
This implementation assume values are in range [0, 1] (Probabilities)
252+
Args:
253+
x: Input tensor of arbitrary shape
254+
dim: Dimension to reduce
255+
256+
Returns:
257+
Tensor
258+
"""
259+
x = x.clamp(min=eps, max=1.0 - eps)
260+
x = torch.log(x / (1 - x))
261+
x = torch.mean(x, dim=dim)
262+
x = torch.exp(x) / (1 + torch.exp(x))
263+
return x

pytorch_toolbelt/inference/tta.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ def _deaugment_averaging(x: Tensor, reduction: MaybeStrOrCallable) -> Tensor:
7979
x = F.geometric_mean(x, dim=0)
8080
elif reduction in {"hmean", "harmonic_mean"}:
8181
x = F.harmonic_mean(x, dim=0)
82+
elif reduction == "logodd":
83+
x = F.logodd_mean(x, dim=0)
8284
elif callable(reduction):
8385
x = reduction(x, dim=0)
8486
elif reduction in {None, "None", "none"}:

pytorch_toolbelt/optimization/functional.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def get_optimizable_parameters(model: nn.Module) -> Iterator[nn.Parameter]:
4646
return filter(lambda x: x.requires_grad, model.parameters())
4747

4848

49-
def freeze_model(module: nn.Module, freeze_parameters: Optional[bool] = True, freeze_bn: Optional[bool] = True):
49+
def freeze_model(
50+
module: nn.Module, freeze_parameters: Optional[bool] = True, freeze_bn: Optional[bool] = True
51+
) -> nn.Module:
5052
"""
5153
Change 'requires_grad' value for module and it's child modules and
5254
optionally freeze batchnorm modules.
@@ -70,3 +72,5 @@ def freeze_model(module: nn.Module, freeze_parameters: Optional[bool] = True, fr
7072
for m in module.modules():
7173
if isinstance(m, bn_types):
7274
module.track_running_stats = not freeze_bn
75+
76+
return module

pytorch_toolbelt/utils/fs.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"find_images_in_dir",
1717
"find_in_dir",
1818
"find_in_dir_glob",
19+
"find_subdirectories_in_dir",
1920
"has_ext",
2021
"has_image_ext",
2122
"id_from_fname",
@@ -45,6 +46,19 @@ def find_in_dir(dirname: str) -> List[str]:
4546
return [os.path.join(dirname, fname) for fname in sorted(os.listdir(dirname))]
4647

4748

49+
def find_subdirectories_in_dir(dirname: str) -> List[str]:
50+
"""
51+
Retrieve list of subdirectories (non-recursive) in the given directory.
52+
Args:
53+
dirname: Target directory name
54+
55+
Returns:
56+
Sorted list of absolute paths to directories
57+
"""
58+
all_entries = find_in_dir(dirname)
59+
return [entry for entry in all_entries if os.path.isdir(entry)]
60+
61+
4862
def find_in_dir_with_ext(dirname: str, extensions: Union[str, List[str]]) -> List[str]:
4963
return [os.path.join(dirname, fname) for fname in sorted(os.listdir(dirname)) if has_ext(fname, extensions)]
5064

pytorch_toolbelt/utils/visualization.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,31 @@ def plot_confusion_matrix(
2929
title: str = "Confusion matrix",
3030
cmap=None,
3131
fname=None,
32+
show_scores: bool = True,
3233
noshow: bool = False,
3334
backend: str = "Agg",
3435
format_string: Optional[str] = None,
3536
):
36-
"""Render the confusion matrix and return matplotlib's figure with it.
37+
"""
38+
Render the confusion matrix and return matplotlib's figure with it.
3739
Normalization can be applied by setting `normalize=True`.
40+
41+
Args:
42+
cm: Numpy array of (N,N) shape - confusion matrix array
43+
class_names: List of [N] names of the classes
44+
figsize:
45+
fontsize:
46+
normalize: Whether to apply normalization for each row of CM
47+
title: Title of the confusion matrix
48+
cmap:
49+
fname: Filename of the rendered confusion matrix
50+
show_scores: Show scores in each cell
51+
noshow:
52+
backend:
53+
format_string:
54+
55+
Returns:
56+
Matplotlib's figure
3857
"""
3958
import matplotlib
4059

@@ -64,26 +83,12 @@ def plot_confusion_matrix(
6483
if format_string is None:
6584
format_string = ".3f" if normalize else "d"
6685

67-
thresh = (cm.max() + cm.min()) / 2.0
68-
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
69-
if np.isfinite(cm[i, j]):
70-
plt.text(
71-
j,
72-
i,
73-
format(cm[i, j], format_string),
74-
horizontalalignment="center",
75-
fontsize=fontsize,
76-
color="white" if cm[i, j] > thresh else "black",
77-
)
78-
else:
79-
plt.text(
80-
j,
81-
i,
82-
"N/A",
83-
horizontalalignment="center",
84-
fontsize=fontsize,
85-
color="black",
86-
)
86+
if show_scores:
87+
thresh = (cm.max() + cm.min()) / 2.0
88+
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
89+
text = format(cm[i, j], format_string) if np.isfinite(cm[i, j]) else "N/A"
90+
color = "white" if cm[i, j] > thresh else "black"
91+
plt.text(j, i, text, horizontalalignment="center", fontsize=fontsize, color=color)
8792

8893
plt.ylabel("True label")
8994

0 commit comments

Comments
 (0)