Skip to content

Commit f932d16

Browse files
authored
Merge pull request #46 from BloodAxe/feature/new-catalyst
Feature/new catalyst
2 parents 60d6157 + 1fd36d3 commit f932d16

File tree

14 files changed

+216
-231
lines changed

14 files changed

+216
-231
lines changed

pytorch_toolbelt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from __future__ import absolute_import
22

3-
__version__ = "0.3.3"
3+
__version__ = "0.4.0"

pytorch_toolbelt/inference/ensembling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def forward(self, *input, **kwargs): # skipcq: PYL-W0221
8686

8787
if self.average:
8888
for key in keys:
89-
output_0[key].mul_(1. / num_models)
89+
output_0[key].mul_(1.0 / num_models)
9090

9191
return output_0
9292

pytorch_toolbelt/losses/dice.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor:
8282
y_true = y_true.view(bs, num_classes, -1)
8383
y_pred = y_pred.view(bs, num_classes, -1)
8484

85-
scores = soft_dice_score(y_pred, y_true.type_as(y_pred), self.smooth, self.eps, dims=dims)
85+
scores = soft_dice_score(y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims)
8686

8787
if self.log_loss:
8888
loss = -torch.log(scores.clamp_min(self.eps))
8989
else:
90-
loss = 1 - scores
90+
loss = 1.0 - scores
9191

9292
# Dice loss is undefined for non-empty classes
9393
# So we zero contribution of channel that does not have true pixels

pytorch_toolbelt/losses/functional.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
def focal_loss_with_logits(
11-
input: torch.Tensor,
11+
output: torch.Tensor,
1212
target: torch.Tensor,
1313
gamma: float = 2.0,
1414
alpha: Optional[float] = 0.25,
@@ -22,7 +22,7 @@ def focal_loss_with_logits(
2222
See :class:`~pytorch_toolbelt.losses.FocalLoss` for details.
2323
2424
Args:
25-
input: Tensor of arbitrary shape
25+
output: Tensor of arbitrary shape (predictions of the model)
2626
target: Tensor of the same shape as input
2727
gamma: Focal loss power factor
2828
alpha: Weight factor to balance positive and negative samples. Alpha must be in [0...1] range,
@@ -40,9 +40,9 @@ def focal_loss_with_logits(
4040
References:
4141
https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/loss/losses.py
4242
"""
43-
target = target.type(input.type())
43+
target = target.type(output.type())
4444

45-
logpt = F.binary_cross_entropy_with_logits(input, target, reduction="none")
45+
logpt = F.binary_cross_entropy_with_logits(output, target, reduction="none")
4646
pt = torch.exp(-logpt)
4747

4848
# compute the loss
@@ -76,19 +76,22 @@ def focal_loss_with_logits(
7676

7777

7878
# TODO: Mark as deprecated and emit warning
79-
def reduced_focal_loss(input: torch.Tensor, target: torch.Tensor, threshold=0.5, gamma=2.0, reduction="mean"):
79+
def reduced_focal_loss(output: torch.Tensor, target: torch.Tensor, threshold=0.5, gamma=2.0, reduction="mean"):
8080
return focal_loss_with_logits(
81-
input, target, alpha=None, gamma=gamma, reduction=reduction, reduced_threshold=threshold
81+
output, target, alpha=None, gamma=gamma, reduction=reduction, reduced_threshold=threshold
8282
)
8383

8484

85-
def soft_jaccard_score(y_pred: torch.Tensor, y_true: torch.Tensor, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor:
85+
def soft_jaccard_score(
86+
output: torch.Tensor, target: torch.Tensor, smooth: float = 0.0, eps: float = 1e-7, dims=None
87+
) -> torch.Tensor:
8688
"""
8789
88-
:param y_pred:
89-
:param y_true:
90+
:param output:
91+
:param target:
9092
:param smooth:
9193
:param eps:
94+
:param dims:
9295
:return:
9396
9497
Shape:
@@ -98,25 +101,27 @@ def soft_jaccard_score(y_pred: torch.Tensor, y_true: torch.Tensor, smooth=0.0, e
98101
- Output: scalar.
99102
100103
"""
101-
assert y_pred.size() == y_true.size()
104+
assert output.size() == target.size()
102105

103106
if dims is not None:
104-
intersection = torch.sum(y_pred * y_true, dim=dims)
105-
cardinality = torch.sum(y_pred + y_true, dim=dims)
107+
intersection = torch.sum(output * target, dim=dims)
108+
cardinality = torch.sum(output + target, dim=dims)
106109
else:
107-
intersection = torch.sum(y_pred * y_true)
108-
cardinality = torch.sum(y_pred + y_true)
110+
intersection = torch.sum(output * target)
111+
cardinality = torch.sum(output + target)
109112

110113
union = cardinality - intersection
111-
jaccard_score = (intersection + smooth) / (union.clamp_min(eps) + smooth)
114+
jaccard_score = (intersection + smooth) / (union + smooth).clamp_min(eps)
112115
return jaccard_score
113116

114117

115-
def soft_dice_score(y_pred: torch.Tensor, y_true: torch.Tensor, smooth=0, eps=1e-7, dims=None) -> torch.Tensor:
118+
def soft_dice_score(
119+
output: torch.Tensor, target: torch.Tensor, smooth: float = 0.0, eps: float = 1e-7, dims=None
120+
) -> torch.Tensor:
116121
"""
117122
118-
:param y_pred:
119-
:param y_true:
123+
:param output:
124+
:param target:
120125
:param smooth:
121126
:param eps:
122127
:return:
@@ -128,28 +133,28 @@ def soft_dice_score(y_pred: torch.Tensor, y_true: torch.Tensor, smooth=0, eps=1e
128133
- Output: scalar.
129134
130135
"""
131-
assert y_pred.size() == y_true.size()
136+
assert output.size() == target.size()
132137
if dims is not None:
133-
intersection = torch.sum(y_pred * y_true, dim=dims)
134-
cardinality = torch.sum(y_pred + y_true, dim=dims)
138+
intersection = torch.sum(output * target, dim=dims)
139+
cardinality = torch.sum(output + target, dim=dims)
135140
else:
136-
intersection = torch.sum(y_pred * y_true)
137-
cardinality = torch.sum(y_pred + y_true)
138-
dice_score = (2.0 * intersection + smooth) / (cardinality.clamp_min(eps) + smooth)
141+
intersection = torch.sum(output * target)
142+
cardinality = torch.sum(output + target)
143+
dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps)
139144
return dice_score
140145

141146

142-
def wing_loss(prediction: torch.Tensor, target: torch.Tensor, width=5, curvature=0.5, reduction="mean"):
147+
def wing_loss(output: torch.Tensor, target: torch.Tensor, width=5, curvature=0.5, reduction="mean"):
143148
"""
144149
https://arxiv.org/pdf/1711.06753.pdf
145-
:param prediction:
150+
:param output:
146151
:param target:
147152
:param width:
148153
:param curvature:
149154
:param reduction:
150155
:return:
151156
"""
152-
diff_abs = (target - prediction).abs()
157+
diff_abs = (target - output).abs()
153158
loss = diff_abs.clone()
154159

155160
idx_smaller = diff_abs < width
@@ -180,7 +185,7 @@ def label_smoothed_nll_loss(
180185
:param target:
181186
:param epsilon:
182187
:param ignore_index:
183-
:param reduce:
188+
:param reduction:
184189
:return:
185190
"""
186191
if target.dim() == lprobs.dim() - 1:

pytorch_toolbelt/losses/jaccard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor:
8282
y_true = y_true.view(bs, num_classes, -1)
8383
y_pred = y_pred.view(bs, num_classes, -1)
8484

85-
scores = soft_jaccard_score(y_pred, y_true.type(y_pred.dtype), self.smooth, self.eps, dims=dims)
85+
scores = soft_jaccard_score(y_pred, y_true.type(y_pred.dtype), smooth=self.smooth, eps=self.eps, dims=dims)
8686

8787
if self.log_loss:
8888
loss = -torch.log(scores.clamp_min(self.eps))
8989
else:
90-
loss = 1 - scores
90+
loss = 1.0 - scores
9191

9292
# IoU loss is defined for non-empty classes
9393
# So we zero contribution of channel that does not have true pixels

pytorch_toolbelt/modules/activations.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ def mish_jit_bwd(x, grad_output):
129129

130130

131131
class MishFunction(torch.autograd.Function):
132-
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
132+
"""
133+
Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
133134
A memory efficient, jit scripted variant of Mish
134135
"""
135136

pytorch_toolbelt/utils/catalyst/criterions.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import math
22

33
import torch
4-
from catalyst.dl import CriterionCallback, RunnerState
5-
from catalyst.dl.callbacks.criterion import _add_loss_to_state
4+
from catalyst.dl import IRunner, CriterionCallback
65
from torch import nn
76
from torch.nn import functional as F
87

@@ -67,22 +66,22 @@ def __init__(
6766
self.p = p
6867
self.multiplier = None
6968

70-
def on_loader_start(self, state: RunnerState):
71-
self.is_needed = not self.on_train_only or state.loader_name.startswith("train")
69+
def on_loader_start(self, runner: IRunner):
70+
self.is_needed = not self.on_train_only or runner.loader_name.startswith("train")
7271
if self.is_needed:
73-
state.metrics.epoch_values[state.loader_name][f"l{self.p}_weight_decay"] = self.multiplier
72+
runner.metrics.epoch_values[runner.loader_name][f"l{self.p}_weight_decay"] = self.multiplier
7473

75-
def on_epoch_start(self, state: RunnerState):
76-
training_progress = float(state.epoch) / float(state.num_epochs)
74+
def on_epoch_start(self, runner: IRunner):
75+
training_progress = float(runner.epoch) / float(runner.num_epochs)
7776
self.multiplier = get_multiplier(training_progress, self.schedule, self.start_wd, self.end_wd)
7877

79-
def on_batch_end(self, state: RunnerState):
78+
def on_batch_end(self, runner: IRunner):
8079
if not self.is_needed:
8180
return
8281

8382
lp_reg = 0
8483

85-
for module in state.model.children():
84+
for module in runner.model.children():
8685
if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm1d, nn.BatchNorm3d)):
8786
continue
8887

@@ -93,8 +92,7 @@ def on_batch_end(self, state: RunnerState):
9392
if param.requires_grad:
9493
lp_reg = param.norm(self.p) * self.multiplier + lp_reg
9594

96-
state.metrics.add_batch_value(metrics_dict={self.prefix: lp_reg.item()})
97-
_add_loss_to_state(self.prefix, state, lp_reg)
95+
runner.batch_metrics.update(**{self.prefix: lp_reg.item()})
9896

9997

10098
class TSACriterionCallback(CriterionCallback):
@@ -117,7 +115,7 @@ def __init__(
117115
prefix: str = "loss",
118116
criterion_key: str = None,
119117
multiplier: float = 1.0,
120-
unsupervised_label=-100,
118+
ignore_index=-100,
121119
):
122120
super().__init__(
123121
input_key=input_key,
@@ -129,7 +127,7 @@ def __init__(
129127
self.num_epochs = num_epochs
130128
self.num_classes = num_classes
131129
self.tsa_threshold = None
132-
self.unsupervised_label = unsupervised_label
130+
self.ignore_index = ignore_index
133131

134132
def get_tsa_threshold(self, current_epoch, schedule, start, end) -> float:
135133
training_progress = float(current_epoch) / float(self.num_epochs)
@@ -148,16 +146,16 @@ def get_tsa_threshold(self, current_epoch, schedule, start, end) -> float:
148146
raise KeyError(f"Unsupported schedule name {schedule}")
149147
return threshold * (end - start) + start
150148

151-
def on_epoch_start(self, state: RunnerState):
152-
if state.loader_name == "train":
153-
self.tsa_threshold = self.get_tsa_threshold(state.epoch, "exp_schedule", 1.0 / self.num_classes, 1.0)
154-
state.metrics.epoch_values["train"]["tsa_threshold"] = self.tsa_threshold
149+
def on_epoch_start(self, runner: IRunner):
150+
if runner.loader_name == "train":
151+
self.tsa_threshold = self.get_tsa_threshold(runner.epoch, "exp_schedule", 1.0 / self.num_classes, 1.0)
152+
runner.epoch_metrics["train"]["tsa_threshold"] = self.tsa_threshold
155153

156-
def _compute_loss(self, state: RunnerState, criterion):
154+
def _compute_loss(self, runner: IRunner, criterion):
157155

158-
logits = state.output[self.output_key]
159-
targets = state.input[self.input_key]
160-
supervised_mask = targets != self.unsupervised_label # Mask indicating labeled samples
156+
logits = runner.output[self.output_key]
157+
targets = runner.input[self.input_key]
158+
supervised_mask = targets != self.ignore_index # Mask indicating labeled samples
161159

162160
targets = targets[supervised_mask]
163161
logits = logits[supervised_mask]

pytorch_toolbelt/utils/catalyst/loss_adapter.py

Lines changed: 7 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,15 @@
1+
from typing import Dict
2+
13
import torch
24
from catalyst.dl import (
35
CriterionCallback,
4-
RunnerState,
5-
OptimizerCallback,
6-
CheckpointCallback,
7-
SchedulerCallback,
8-
SupervisedExperiment,
9-
Callback,
6+
IRunner,
107
)
11-
from catalyst.dl.callbacks import VerboseLogger, ConsoleLogger, TensorboardLogger, RaiseExceptionCallback
128
from torch import nn, Tensor
13-
from typing import Dict
149

1510
__all__ = [
1611
"TrainOnlyCriterionCallback",
1712
"PassthroughCriterionCallback",
18-
"ParallelLossSupervisedExperiment",
1913
"LossModule",
2014
"LossWrapper",
2115
]
@@ -26,7 +20,7 @@ class TrainOnlyCriterionCallback(CriterionCallback):
2620
Computes loss only on training stage
2721
"""
2822

29-
def _compute_loss_value(self, state: RunnerState, criterion):
23+
def _compute_loss_value(self, state: IRunner, criterion):
3024
predictions = self._get_output(state.output, self.output_key)
3125
targets = self._get_input(state.input, self.input_key)
3226

@@ -36,7 +30,7 @@ def _compute_loss_value(self, state: RunnerState, criterion):
3630
loss = criterion(predictions, targets)
3731
return loss
3832

39-
def _compute_loss_key_value(self, state: RunnerState, criterion):
33+
def _compute_loss_key_value(self, state: IRunner, criterion):
4034
output = self._get_output(state.output, self.output_key)
4135
input = self._get_input(state.input, self.input_key)
4236

@@ -47,44 +41,6 @@ def _compute_loss_key_value(self, state: RunnerState, criterion):
4741
return loss
4842

4943

50-
class ParallelLossSupervisedExperiment(SupervisedExperiment):
51-
"""
52-
Custom experiment class. To use in conjunction with LossWrapper.
53-
"""
54-
55-
def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]":
56-
"""
57-
Override of ``BaseExperiment.get_callbacks`` method.
58-
Will add several of callbacks by default in case they missed.
59-
60-
Args:
61-
stage (str): name of stage. It should start with `infer` if you
62-
don't need default callbacks, as they required only for
63-
training stages.
64-
Returns:
65-
List[Callback]: list of callbacks for experiment
66-
"""
67-
callbacks = self._callbacks
68-
default_callbacks = []
69-
if self._verbose:
70-
default_callbacks.append(("verbose", VerboseLogger))
71-
if not stage.startswith("infer"):
72-
# default_callbacks.append(("_criterion", CriterionCallback)) # Commented
73-
default_callbacks.append(("_optimizer", OptimizerCallback))
74-
if self._scheduler is not None:
75-
default_callbacks.append(("_scheduler", SchedulerCallback))
76-
default_callbacks.append(("_saver", CheckpointCallback))
77-
default_callbacks.append(("console", ConsoleLogger))
78-
default_callbacks.append(("tensorboard", TensorboardLogger))
79-
default_callbacks.append(("exception", RaiseExceptionCallback))
80-
81-
for callback_name, callback_fn in default_callbacks:
82-
is_already_present = any(isinstance(x, callback_fn) for x in callbacks.values())
83-
if not is_already_present:
84-
callbacks[callback_name] = callback_fn()
85-
return callbacks
86-
87-
8844
class PassthroughCriterionCallback(CriterionCallback):
8945
"""
9046
Returns one of model's outputs as loss values
@@ -93,11 +49,11 @@ class PassthroughCriterionCallback(CriterionCallback):
9349
def __init__(self, output_key, multiplier=1.0):
9450
super().__init__(output_key=output_key, prefix=output_key, multiplier=multiplier)
9551

96-
def _compute_loss_value(self, state: RunnerState, criterion):
52+
def _compute_loss_value(self, state: IRunner, criterion):
9753
loss = self._get_output(state.output, self.output_key)
9854
return loss.mean()
9955

100-
def _compute_loss_key_value(self, state: RunnerState, criterion):
56+
def _compute_loss_key_value(self, state: IRunner, criterion):
10157
loss = self._get_output(state.output, self.output_key)
10258
return loss.mean()
10359

0 commit comments

Comments
 (0)