88
99
1010def focal_loss_with_logits (
11- input : torch .Tensor ,
11+ output : torch .Tensor ,
1212 target : torch .Tensor ,
1313 gamma : float = 2.0 ,
1414 alpha : Optional [float ] = 0.25 ,
@@ -22,7 +22,7 @@ def focal_loss_with_logits(
2222 See :class:`~pytorch_toolbelt.losses.FocalLoss` for details.
2323
2424 Args:
25- input : Tensor of arbitrary shape
25+ output : Tensor of arbitrary shape (predictions of the model)
2626 target: Tensor of the same shape as input
2727 gamma: Focal loss power factor
2828 alpha: Weight factor to balance positive and negative samples. Alpha must be in [0...1] range,
@@ -40,9 +40,9 @@ def focal_loss_with_logits(
4040 References:
4141 https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/loss/losses.py
4242 """
43- target = target .type (input .type ())
43+ target = target .type (output .type ())
4444
45- logpt = F .binary_cross_entropy_with_logits (input , target , reduction = "none" )
45+ logpt = F .binary_cross_entropy_with_logits (output , target , reduction = "none" )
4646 pt = torch .exp (- logpt )
4747
4848 # compute the loss
@@ -76,19 +76,22 @@ def focal_loss_with_logits(
7676
7777
7878# TODO: Mark as deprecated and emit warning
79- def reduced_focal_loss (input : torch .Tensor , target : torch .Tensor , threshold = 0.5 , gamma = 2.0 , reduction = "mean" ):
79+ def reduced_focal_loss (output : torch .Tensor , target : torch .Tensor , threshold = 0.5 , gamma = 2.0 , reduction = "mean" ):
8080 return focal_loss_with_logits (
81- input , target , alpha = None , gamma = gamma , reduction = reduction , reduced_threshold = threshold
81+ output , target , alpha = None , gamma = gamma , reduction = reduction , reduced_threshold = threshold
8282 )
8383
8484
85- def soft_jaccard_score (y_pred : torch .Tensor , y_true : torch .Tensor , smooth = 0.0 , eps = 1e-7 , dims = None ) -> torch .Tensor :
85+ def soft_jaccard_score (
86+ output : torch .Tensor , target : torch .Tensor , smooth : float = 0.0 , eps : float = 1e-7 , dims = None
87+ ) -> torch .Tensor :
8688 """
8789
88- :param y_pred :
89- :param y_true :
90+ :param output :
91+ :param target :
9092 :param smooth:
9193 :param eps:
94+ :param dims:
9295 :return:
9396
9497 Shape:
@@ -98,25 +101,27 @@ def soft_jaccard_score(y_pred: torch.Tensor, y_true: torch.Tensor, smooth=0.0, e
98101 - Output: scalar.
99102
100103 """
101- assert y_pred .size () == y_true .size ()
104+ assert output .size () == target .size ()
102105
103106 if dims is not None :
104- intersection = torch .sum (y_pred * y_true , dim = dims )
105- cardinality = torch .sum (y_pred + y_true , dim = dims )
107+ intersection = torch .sum (output * target , dim = dims )
108+ cardinality = torch .sum (output + target , dim = dims )
106109 else :
107- intersection = torch .sum (y_pred * y_true )
108- cardinality = torch .sum (y_pred + y_true )
110+ intersection = torch .sum (output * target )
111+ cardinality = torch .sum (output + target )
109112
110113 union = cardinality - intersection
111- jaccard_score = (intersection + smooth ) / (union . clamp_min ( eps ) + smooth )
114+ jaccard_score = (intersection + smooth ) / (union + smooth ). clamp_min ( eps )
112115 return jaccard_score
113116
114117
115- def soft_dice_score (y_pred : torch .Tensor , y_true : torch .Tensor , smooth = 0 , eps = 1e-7 , dims = None ) -> torch .Tensor :
118+ def soft_dice_score (
119+ output : torch .Tensor , target : torch .Tensor , smooth : float = 0.0 , eps : float = 1e-7 , dims = None
120+ ) -> torch .Tensor :
116121 """
117122
118- :param y_pred :
119- :param y_true :
123+ :param output :
124+ :param target :
120125 :param smooth:
121126 :param eps:
122127 :return:
@@ -128,28 +133,28 @@ def soft_dice_score(y_pred: torch.Tensor, y_true: torch.Tensor, smooth=0, eps=1e
128133 - Output: scalar.
129134
130135 """
131- assert y_pred .size () == y_true .size ()
136+ assert output .size () == target .size ()
132137 if dims is not None :
133- intersection = torch .sum (y_pred * y_true , dim = dims )
134- cardinality = torch .sum (y_pred + y_true , dim = dims )
138+ intersection = torch .sum (output * target , dim = dims )
139+ cardinality = torch .sum (output + target , dim = dims )
135140 else :
136- intersection = torch .sum (y_pred * y_true )
137- cardinality = torch .sum (y_pred + y_true )
138- dice_score = (2.0 * intersection + smooth ) / (cardinality . clamp_min ( eps ) + smooth )
141+ intersection = torch .sum (output * target )
142+ cardinality = torch .sum (output + target )
143+ dice_score = (2.0 * intersection + smooth ) / (cardinality + smooth ). clamp_min ( eps )
139144 return dice_score
140145
141146
142- def wing_loss (prediction : torch .Tensor , target : torch .Tensor , width = 5 , curvature = 0.5 , reduction = "mean" ):
147+ def wing_loss (output : torch .Tensor , target : torch .Tensor , width = 5 , curvature = 0.5 , reduction = "mean" ):
143148 """
144149 https://arxiv.org/pdf/1711.06753.pdf
145- :param prediction :
150+ :param output :
146151 :param target:
147152 :param width:
148153 :param curvature:
149154 :param reduction:
150155 :return:
151156 """
152- diff_abs = (target - prediction ).abs ()
157+ diff_abs = (target - output ).abs ()
153158 loss = diff_abs .clone ()
154159
155160 idx_smaller = diff_abs < width
@@ -180,7 +185,7 @@ def label_smoothed_nll_loss(
180185 :param target:
181186 :param epsilon:
182187 :param ignore_index:
183- :param reduce :
188+ :param reduction :
184189 :return:
185190 """
186191 if target .dim () == lprobs .dim () - 1 :
0 commit comments