Skip to content

Commit 0f26891

Browse files
committed
Fix loss, comments
1 parent 3417710 commit 0f26891

File tree

3 files changed

+11
-13
lines changed

3 files changed

+11
-13
lines changed

pytorch_toolbelt/modules/backbone/wider_resnet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self,
1818
groups=1,
1919
norm_act=ABN,
2020
dropout=None):
21-
"""Configurable identity-mapping residual block
21+
"""Identity-mapping residual block
2222
Parameters
2323
----------
2424
in_channels : int
@@ -97,6 +97,7 @@ def __init__(self,
9797
norm_act=ABN,
9898
classes=0):
9999
"""Wider ResNet with pre-activation (identity mapping) blocks
100+
100101
Parameters
101102
----------
102103
structure : list of int
@@ -168,8 +169,9 @@ def __init__(self,
168169
norm_act=ABN,
169170
classes=0,
170171
dilation=False):
171-
"""Wider ResNet with pre-activation (identity mapping) blocks
172+
"""Wider ResNet with pre-activation (identity mapping) blocks.
172173
This variant uses down-sampling by max-pooling in the first two blocks and by strided convolution in the others.
174+
173175
Parameters
174176
----------
175177
structure : list of int

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[flake8]
22
max-line-length = 179
33
exclude =.git,__pycache__,docs/source/conf.py,build,dist
4-
ignore = I101,I201,F401,F403,S001,D100,D101,D102,D103,D104,D105,D106,D107,D200,D205,D400,W504,E127
4+
ignore = I101,I201,F401,F403,S001,D100,D101,D102,D103,D104,D105,D106,D107,D200,D205,D400,D413,W504,E127

tests/test_losses.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,13 @@ def test_sigmoid_focal_loss():
1515

1616

1717
def test_reduced_focal_loss():
18-
input_bad = torch.Tensor([-1]).float()
19-
input_good = torch.Tensor([10]).float()
20-
target = torch.Tensor([1])
21-
22-
focal_loss_val = F.sigmoid_focal_loss(input_good, target)
23-
reduced_loss_val = F.reduced_focal_loss(input_good, target)
24-
assert reduced_loss_val < focal_loss_val
18+
input_good = torch.Tensor([10, -10, 10]).float()
19+
input_bad = torch.Tensor([-1, 2, 0]).float()
20+
target = torch.Tensor([1, 0, 1])
2521

26-
focal_loss_val = F.sigmoid_focal_loss(input_bad, target)
27-
reduced_loss_val = F.reduced_focal_loss(input_bad, target)
28-
assert reduced_loss_val == focal_loss_val
22+
loss_good = F.reduced_focal_loss(input_good, target)
23+
loss_bad = F.reduced_focal_loss(input_bad, target)
24+
assert loss_good < loss_bad
2925

3026

3127
def test_soft_jaccard_score():

0 commit comments

Comments
 (0)