Skip to content

Commit 439e3d1

Browse files
authored
Merge pull request #89 from BloodAxe/develop
0.6.2
2 parents b5abf25 + 4fbc91e commit 439e3d1

File tree

15 files changed

+471
-43
lines changed

15 files changed

+471
-43
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ jobs:
88
strategy:
99
matrix:
1010
operating-system: [ubuntu-latest, windows-latest, macos-latest]
11-
python-version: ['3.6', '3.7', '3.8', '3.9', '3.10']
11+
python-version: ['3.7', '3.8', '3.9', '3.10']
1212
pytorch-toolbelt-version: [tests]
1313
fail-fast: false
1414
steps:
@@ -40,7 +40,7 @@ jobs:
4040
runs-on: ubuntu-latest
4141
strategy:
4242
matrix:
43-
python-version: ['3.6', '3.7', '3.8', '3.9', '3.10']
43+
python-version: ['3.7', '3.8', '3.9', '3.10']
4444
steps:
4545
- name: Checkout
4646
uses: actions/checkout@v2
@@ -53,4 +53,4 @@ jobs:
5353
- name: Install Black
5454
run: pip install black==22.8.0
5555
- name: Run Black
56-
run: black --config=black.toml --check .
56+
run: black --config=pyproject.toml --check .
File renamed without changes.

pytorch_toolbelt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from __future__ import absolute_import
22

3-
__version__ = "0.6.1"
3+
__version__ = "0.6.2"

pytorch_toolbelt/inference/functional.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
__all__ = [
1010
"geometric_mean",
1111
"harmonic_mean",
12+
"harmonic1p_mean",
1213
"logodd_mean",
14+
"log1p_mean",
1315
"pad_image_tensor",
1416
"torch_fliplr",
1517
"torch_flipud",
@@ -229,7 +231,7 @@ def geometric_mean(x: Tensor, dim: int) -> Tensor:
229231
def harmonic_mean(x: Tensor, dim: int, eps: float = 1e-6) -> Tensor:
230232
"""
231233
Compute harmonic mean along given dimension.
232-
This implementation assume values are in range (0...1) (Probabilities)
234+
233235
Args:
234236
x: Input tensor of arbitrary shape
235237
dim: Dimension to reduce
@@ -243,6 +245,23 @@ def harmonic_mean(x: Tensor, dim: int, eps: float = 1e-6) -> Tensor:
243245
return x
244246

245247

248+
def harmonic1p_mean(x: Tensor, dim: int) -> Tensor:
249+
"""
250+
Compute harmonic mean along given dimension.
251+
252+
Args:
253+
x: Input tensor of arbitrary shape
254+
dim: Dimension to reduce
255+
256+
Returns:
257+
Tensor
258+
"""
259+
x = torch.reciprocal(x + 1)
260+
x = torch.mean(x, dim=dim)
261+
x = torch.reciprocal(x) - 1
262+
return x
263+
264+
246265
def logodd_mean(x: Tensor, dim: int, eps: float = 1e-6) -> Tensor:
247266
"""
248267
Compute log-odd mean along given dimension.
@@ -261,3 +280,21 @@ def logodd_mean(x: Tensor, dim: int, eps: float = 1e-6) -> Tensor:
261280
x = torch.mean(x, dim=dim)
262281
x = torch.exp(x) / (1 + torch.exp(x))
263282
return x
283+
284+
285+
def log1p_mean(x: Tensor, dim: int) -> Tensor:
286+
"""
287+
Compute average log(x+1) and them compute exp.
288+
Requires all inputs to be non-negative
289+
290+
Args:
291+
x: Input tensor of arbitrary shape
292+
dim: Dimension to reduce
293+
294+
Returns:
295+
Tensor
296+
"""
297+
x = torch.log1p(x)
298+
x = torch.mean(x, dim=dim)
299+
x = torch.exp(x) - 1
300+
return x

pytorch_toolbelt/inference/tta.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,12 @@ def _deaugment_averaging(x: Tensor, reduction: MaybeStrOrCallable) -> Tensor:
7979
x = F.geometric_mean(x, dim=0)
8080
elif reduction in {"hmean", "harmonic_mean"}:
8181
x = F.harmonic_mean(x, dim=0)
82+
elif reduction in {"harmonic1p"}:
83+
x = F.harmonic1p_mean(x, dim=0)
8284
elif reduction == "logodd":
8385
x = F.logodd_mean(x, dim=0)
86+
elif reduction == "log1p":
87+
x = F.log1p_mean(x, dim=0)
8488
elif callable(reduction):
8589
x = reduction(x, dim=0)
8690
elif reduction in {None, "None", "none"}:

pytorch_toolbelt/losses/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@
1313
from .soft_ce import *
1414
from .soft_f1 import *
1515
from .wing_loss import *
16+
from .logcosh import *

pytorch_toolbelt/losses/functional.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"soft_jaccard_score",
1313
"soft_dice_score",
1414
"wing_loss",
15+
"log_cosh_loss",
1516
]
1617

1718

@@ -298,3 +299,22 @@ def label_smoothed_nll_loss(
298299
eps_i = epsilon / lprobs.size(dim)
299300
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
300301
return loss
302+
303+
304+
def log_cosh_loss(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
305+
"""
306+
Numerically stable log-cosh implementation.
307+
Reference: https://datascience.stackexchange.com/questions/96271/logcoshloss-on-pytorch
308+
309+
Args:
310+
y_pred:
311+
y_true:
312+
313+
Returns:
314+
315+
"""
316+
317+
def _log_cosh(x: torch.Tensor) -> torch.Tensor:
318+
return x + torch.nn.functional.softplus(-2.0 * x) - math.log(2.0)
319+
320+
return torch.mean(_log_cosh(y_pred - y_true))

pytorch_toolbelt/losses/logcosh.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import torch
2+
from pytorch_toolbelt.losses.functional import log_cosh_loss
3+
from torch import nn
4+
5+
__all__ = ["LogCoshLoss"]
6+
7+
8+
class LogCoshLoss(nn.Module):
9+
def __init__(self):
10+
super().__init__()
11+
12+
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
13+
return log_cosh_loss(y_pred, y_true)

pytorch_toolbelt/modules/encoders/timm/common.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@
1111

1212

1313
class GenericTimmEncoder(EncoderModule):
14-
def __init__(self, timm_encoder: Union[nn.Module, str], layers: List[int] = None, pretrained=True):
14+
def __init__(self, timm_encoder: Union[nn.Module, str], layers: List[int] = None, pretrained=True, **kwargs):
1515
strides = []
1616
channels = []
1717
default_layers = []
1818
if isinstance(timm_encoder, str):
1919
import timm.models.factory
2020

21-
timm_encoder = timm.models.factory.create_model(timm_encoder, features_only=True, pretrained=pretrained)
21+
timm_encoder = timm.models.factory.create_model(
22+
timm_encoder, features_only=True, pretrained=pretrained, **kwargs
23+
)
2224

2325
for i, fi in enumerate(timm_encoder.feature_info):
2426
strides.append(fi["reduction"])
@@ -61,7 +63,6 @@ def make_n_channel_input_std_conv(conv: nn.Module, in_channels: int, mode="auto"
6163
dilation=kwargs.get("dilation", conv.dilation),
6264
groups=kwargs.get("groups", conv.groups),
6365
bias=kwargs.get("bias", conv.bias is not None),
64-
eps=kwargs.get("eps", conv.eps),
6566
)
6667

6768
w = conv.weight

pytorch_toolbelt/modules/upsample.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import Optional, List
23

34
import torch
@@ -97,20 +98,26 @@ class BilinearAdditiveUpsample2d(nn.Module):
9798
https://arxiv.org/abs/1707.05847
9899
"""
99100

100-
def __init__(self, in_channels: int, scale_factor: int = 2, n: int = 4):
101+
def __init__(self, in_channels: int, scale_factor: int = 2, n=None):
101102
super().__init__()
102-
if in_channels % n != 0:
103-
raise ValueError(f"Number of input channels ({in_channels})must be divisable by n ({n})")
103+
if n is not None:
104+
warnings.warn(
105+
"Argument n has been deprecated and will be removed in new release. It is computed automatically and not required to be specified explicitly"
106+
)
107+
108+
self.n = 2**scale_factor
109+
110+
if in_channels % self.n != 0:
111+
raise ValueError(f"Number of input channels ({in_channels})must be divisable by n ({self.n})")
104112

105113
self.in_channels = in_channels
106-
self.out_channels = in_channels // n
114+
self.out_channels = in_channels // self.n
107115
self.upsample = nn.UpsamplingBilinear2d(scale_factor=scale_factor)
108-
self.n = n
109116

110117
def forward(self, x: Tensor) -> Tensor: # skipcq: PYL-W0221
111118
x = self.upsample(x)
112119
n, c, h, w = x.size()
113-
x = x.reshape(n, c // self.n, self.n, h, w).mean(2)
120+
x = x.reshape(n, self.out_channels, self.n, h, w).mean(2)
114121
return x
115122

116123

@@ -135,7 +142,7 @@ def __init__(self, in_channels, scale_factor=2, n=4):
135142
self.conv = nn.ConvTranspose2d(
136143
in_channels, in_channels // n, kernel_size=3, padding=1, stride=scale_factor, output_padding=1
137144
)
138-
self.residual = BilinearAdditiveUpsample2d(in_channels, scale_factor=scale_factor, n=n)
145+
self.residual = BilinearAdditiveUpsample2d(in_channels, scale_factor=scale_factor)
139146
self.init_weights()
140147

141148
def forward(self, x: Tensor) -> Tensor: # skipcq: PYL-W0221

0 commit comments

Comments
 (0)