Skip to content

Commit 46f189f

Browse files
authored
Merge pull request #76 from BloodAxe/develop
0.5.2 Release
2 parents ee72463 + f1114e8 commit 46f189f

File tree

4 files changed

+59
-13
lines changed

4 files changed

+59
-13
lines changed

pytorch_toolbelt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from __future__ import absolute_import
22

3-
__version__ = "0.5.1"
3+
__version__ = "0.5.2"

pytorch_toolbelt/inference/ensembling.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,22 @@ class ApplySoftmaxTo(nn.Module):
1313
dim: int
1414

1515
def __init__(
16-
self, model: nn.Module, output_key: Union[str, Iterable[str]] = "logits", dim: int = 1, temperature: float = 1
16+
self,
17+
model: nn.Module,
18+
output_key: Union[str, int, Iterable[str]] = "logits",
19+
dim: int = 1,
20+
temperature: float = 1,
1721
):
1822
"""
1923
Apply softmax activation on given output(s) of the model
2024
:param model: Model to wrap
21-
:param output_key: string or list of strings, indicating to what outputs softmax activation should be applied.
25+
:param output_key: string, index or list of strings, indicating to what outputs softmax activation should be applied.
2226
:param dim: Tensor dimension for softmax activation
2327
:param temperature: Temperature scaling coefficient. Values > 1 will make logits sharper.
2428
"""
2529
super().__init__()
2630
# By converting to set, we prevent double-activation by passing output_key=["logits", "logits"]
27-
output_key = tuple(set(output_key)) if isinstance(output_key, Iterable) else tuple([output_key])
31+
output_key = (output_key,) if isinstance(output_key, (str, int)) else tuple(set(output_key))
2832
self.output_keys = output_key
2933
self.model = model
3034
self.dim = dim
@@ -41,16 +45,16 @@ class ApplySigmoidTo(nn.Module):
4145
output_keys: Tuple
4246
temperature: float
4347

44-
def __init__(self, model: nn.Module, output_key: Union[str, Iterable[str]] = "logits", temperature=1):
48+
def __init__(self, model: nn.Module, output_key: Union[str, int, Iterable[str]] = "logits", temperature=1):
4549
"""
4650
Apply sigmoid activation on given output(s) of the model
4751
:param model: Model to wrap
48-
:param output_key: string or list of strings, indicating to what outputs sigmoid activation should be applied.
52+
:param output_key: string index, or list of strings, indicating to what outputs sigmoid activation should be applied.
4953
:param temperature: Temperature scaling coefficient. Values > 1 will make logits sharper.
5054
"""
5155
super().__init__()
5256
# By converting to set, we prevent double-activation by passing output_key=["logits", "logits"]
53-
output_key = tuple(set(output_key)) if isinstance(output_key, Iterable) else tuple([output_key])
57+
output_key = (output_key,) if isinstance(output_key, (str, int)) else tuple(set(output_key))
5458
self.output_keys = output_key
5559
self.model = model
5660
self.temperature = temperature

pytorch_toolbelt/utils/fs.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
"auto_file",
1515
"change_extension",
1616
"find_images_in_dir",
17+
"find_images_in_dir_recursive",
1718
"find_in_dir",
1819
"find_in_dir_glob",
20+
"find_in_dir_with_ext",
1921
"find_subdirectories_in_dir",
2022
"has_ext",
2123
"has_image_ext",
@@ -64,7 +66,15 @@ def find_in_dir_with_ext(dirname: str, extensions: Union[str, List[str]]) -> Lis
6466

6567

6668
def find_images_in_dir(dirname: str) -> List[str]:
67-
return [fname for fname in find_in_dir(dirname) if has_image_ext(fname)]
69+
return [fname for fname in find_in_dir(dirname) if has_image_ext(fname) and os.path.isfile(fname)]
70+
71+
72+
def find_images_in_dir_recursive(dirname: str) -> List[str]:
73+
return [
74+
fname
75+
for fname in glob.glob(os.path.join(dirname, "**"), recursive=True)
76+
if has_image_ext(fname) and os.path.isfile(fname)
77+
]
6878

6979

7080
def find_in_dir_glob(dirname: str, recursive=False):
@@ -76,13 +86,17 @@ def id_from_fname(fname: str) -> str:
7686
return os.path.splitext(os.path.basename(fname))[0]
7787

7888

79-
def change_extension(fname: Union[str, Path], new_ext: str) -> str:
80-
if type(fname) == str:
89+
def change_extension(fname: Union[str, Path], new_ext: str) -> Union[str, Path]:
90+
if isinstance(fname, str):
8191
return os.path.splitext(fname)[0] + new_ext
82-
else:
92+
elif isinstance(fname, Path):
8393
if new_ext[0] != ".":
8494
new_ext = "." + new_ext
8595
return fname.with_suffix(new_ext)
96+
else:
97+
raise RuntimeError(
98+
f"Received input argument `fname` for unsupported type {type(fname)}. Argument must be string or Path."
99+
)
86100

87101

88102
def auto_file(filename: str, where: str = ".") -> str:

pytorch_toolbelt/utils/torch_utils.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"to_tensor",
3535
"transfer_weights",
3636
"move_to_device_non_blocking",
37+
"describe_outputs",
3738
]
3839

3940

@@ -102,8 +103,11 @@ def count_parameters(
102103
parameters = {"total": total, "trainable": trainable}
103104

104105
for key in keys:
105-
if hasattr(model, key) and model.__getattr__(key) is not None:
106-
parameters[key] = int(sum(p.numel() for p in model.__getattr__(key).parameters()))
106+
try:
107+
if hasattr(model, key) and model.__getattr__(key) is not None:
108+
parameters[key] = int(sum(p.numel() for p in model.__getattr__(key).parameters()))
109+
except AttributeError:
110+
pass
107111

108112
if human_friendly:
109113
for key in parameters.keys():
@@ -289,3 +293,27 @@ def move_to_device_non_blocking(x: Tensor, device: torch.device) -> Tensor:
289293

290294

291295
resize_as = resize_like
296+
297+
298+
def describe_outputs(outputs: Union[Tensor, Dict[str, Tensor], Iterable[Tensor]]) -> Union[List[Dict], Dict[str, Any]]:
299+
"""
300+
Describe outputs and return shape, mean & std for each tensor in list or dict (Supports nested tensors)
301+
302+
Args:
303+
outputs: Input (Usually model outputs)
304+
Returns:
305+
Same structure but each item represents tensor shape, mean & std
306+
"""
307+
if torch.is_tensor(outputs):
308+
desc = dict(size=tuple(outputs.size()), mean=outputs.mean().item(), std=outputs.std().item())
309+
elif isinstance(outputs, collections.Mapping):
310+
desc = {}
311+
for key, value in outputs.items():
312+
desc[key] = describe_outputs(value)
313+
elif isinstance(outputs, collections.Iterable):
314+
desc = []
315+
for index, output in enumerate(outputs):
316+
desc.append(describe_outputs(output))
317+
else:
318+
raise NotImplementedError
319+
return desc

0 commit comments

Comments
 (0)