Skip to content

Commit 6b9f4ec

Browse files
bors[bot]sambitdashgxyd
authored
Merge #1427
1427: (Complete) Implementation of label smoothing with crossentropy r=CarloLucibello a=gxyd Trying to complete the PR #1025 from @sambitdash (Thanks Sambit). Closes #1016 A few changes compared to the original code - Throwing an error when `label_smoothing` isn't between 0 and 1. - Label smoothing is applied as a dispatch. Co-authored-by: Sambit Kumar Dash <sambitdash@gmail.com> Co-authored-by: Gaurav Dhingra <gauravdhingra.gxyd@gmail.com>
2 parents a7e055b + 014af2f commit 6b9f4ec

File tree

4 files changed

+111
-10
lines changed

4 files changed

+111
-10
lines changed

docs/src/models/losses.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Flux.Losses.mae
2828
Flux.Losses.mse
2929
Flux.Losses.msle
3030
Flux.Losses.huber_loss
31+
Flux.Losses.label_smoothing
3132
Flux.Losses.crossentropy
3233
Flux.Losses.logitcrossentropy
3334
Flux.Losses.binarycrossentropy

src/losses/Losses.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using NNlib: logsoftmax, logσ
99
import Base.Broadcast: broadcasted
1010

1111
export mse, mae, msle,
12+
label_smoothing,
1213
crossentropy, logitcrossentropy,
1314
# binarycrossentropy, logitbinarycrossentropy # export only after end deprecation
1415
kldivergence,

src/losses/functions.jl

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,53 @@ function huber_loss(ŷ, y; agg=mean, δ=ofeltype(ŷ, 1))
4646
agg(((abs_error.^2) .* temp) .* x .+ δ*(abs_error .- x*δ) .* (1 .- temp))
4747
end
4848

49+
"""
50+
label_smoothing(y::Union{Number, AbstractArray}, α; dims::Int=1)
51+
52+
Returns smoothed labels, meaning the confidence on label values are relaxed.
53+
54+
When `y` is given as one-hot vector or batch of one-hot, its calculated as
55+
56+
y .* (1 - α) .+ α / size(y, dims)
57+
58+
when `y` is given as a number or batch of numbers for binary classification,
59+
its calculated as
60+
61+
y .* (1 - α) .+ α / 2
62+
63+
in which case the labels are squeezed towards `0.5`.
64+
65+
α is a number in interval (0, 1) called the smoothing factor. Higher the
66+
value of α larger the smoothing of `y`.
67+
68+
`dims` denotes the one-hot dimension, unless `dims=0` which denotes the application
69+
of label smoothing to binary distributions encoded in a single number.
70+
71+
Usage example:
72+
73+
sf = 0.1
74+
y = onehotbatch([1, 1, 1, 0, 0], 0:1)
75+
y_smoothed = label_smoothing(ya, 2sf)
76+
y_sim = y .* (1-2sf) .+ sf
77+
y_dis = copy(y_sim)
78+
y_dis[1,:], y_dis[2,:] = y_dis[2,:], y_dis[1,:]
79+
@assert crossentropy(y_sim, y) < crossentropy(y_sim, y_smoothed)
80+
@assert crossentropy(y_dis, y) > crossentropy(y_dis, y_smoothed)
81+
"""
82+
function label_smoothing(y::Union{AbstractArray,Number}, α::Number; dims::Int=1)
83+
if !(0 < α < 1)
84+
throw(ArgumentError("α must be between 0 and 1"))
85+
end
86+
if dims == 0
87+
y_smoothed = y .* (1 - α) .+ α*1//2
88+
elseif dims == 1
89+
y_smoothed = y .* (1 - α) .+ α* 1 // size(y, 1)
90+
else
91+
throw(ArgumentError("`dims` should be either 0 or 1"))
92+
end
93+
return y_smoothed
94+
end
95+
4996
"""
5097
crossentropy(ŷ, y; dims=1, ϵ=eps(ŷ), agg=mean)
5198
@@ -54,16 +101,20 @@ calculated as
54101
55102
agg(-sum(y .* log.(ŷ .+ ϵ); dims=dims))
56103
57-
Cross entropy is tipically used as a loss in multi-class classification,
104+
Cross entropy is typically used as a loss in multi-class classification,
58105
in which case the labels `y` are given in a one-hot format.
59106
`dims` specifies the dimension (or the dimensions) containing the class probabilities.
60107
The prediction `ŷ` is supposed to sum to one across `dims`,
61108
as would be the case with the output of a [`softmax`](@ref) operation.
62109
110+
Use [`label_smoothing`](@ref) to smooth the true labels as preprocessing before
111+
computing the loss.
112+
63113
Use of [`logitcrossentropy`](@ref) is recomended over `crossentropy` for
64114
numerical stability.
65115
66-
See also: [`Flux.logitcrossentropy`](@ref), [`Flux.binarycrossentropy`](@ref), [`Flux.logitbinarycrossentropy`](@ref)
116+
See also: [`logitcrossentropy`](@ref), [`binarycrossentropy`](@ref), [`logitbinarycrossentropy`](@ref),
117+
[`label_smoothing`](@ref)
67118
"""
68119
function crossentropy(ŷ, y; dims=1, agg=mean, ϵ=epseltype(ŷ))
69120
agg(.-sum(xlogy.(y, ŷ .+ ϵ); dims=dims))
@@ -72,15 +123,19 @@ end
72123
"""
73124
logitcrossentropy(ŷ, y; dims=1, agg=mean)
74125
75-
Return the crossentropy computed after a [`Flux.logsoftmax`](@ref) operation;
126+
Return the crossentropy computed after a [`logsoftmax`](@ref) operation;
76127
calculated as
77128
78129
agg(.-sum(y .* logsoftmax(ŷ; dims=dims); dims=dims))
79130
131+
Use [`label_smoothing`](@ref) to smooth the true labels as preprocessing before
132+
computing the loss.
133+
80134
`logitcrossentropy(ŷ, y)` is mathematically equivalent to
81-
[`Flux.Losses.crossentropy(softmax(ŷ), y)`](@ref) but it is more numerically stable.
135+
[`crossentropy(softmax(ŷ), y)`](@ref) but it is more numerically stable.
136+
82137
83-
See also: [`Flux.Losses.crossentropy`](@ref), [`Flux.Losses.binarycrossentropy`](@ref), [`Flux.Losses.logitbinarycrossentropy`](@ref)
138+
See also: [`crossentropy`](@ref), [`binarycrossentropy`](@ref), [`logitbinarycrossentropy`](@ref), [`label_smoothing`](@ref)
84139
"""
85140
function logitcrossentropy(ŷ, y; dims=1, agg=mean)
86141
agg(.-sum(y .* logsoftmax(ŷ; dims=dims); dims=dims))
@@ -97,9 +152,13 @@ The `ϵ` term provides numerical stability.
97152
98153
Typically, the prediction `ŷ` is given by the output of a [`sigmoid`](@ref) activation.
99154
155+
Use [`label_smoothing`](@ref) to smooth the `y` value as preprocessing before
156+
computing the loss.
157+
100158
Use of `logitbinarycrossentropy` is recomended over `binarycrossentropy` for numerical stability.
101159
102-
See also: [`Flux.Losses.crossentropy`](@ref), [`Flux.Losses.logitcrossentropy`](@ref), [`Flux.Losses.logitbinarycrossentropy`](@ref)
160+
See also: [`crossentropy`](@ref), [`logitcrossentropy`](@ref), [`logitbinarycrossentropy`](@ref),
161+
[`label_smoothing`](@ref)
103162
"""
104163
function binarycrossentropy(ŷ, y; agg=mean, ϵ=epseltype(ŷ))
105164
agg(@.(-xlogy(y, ŷ+ϵ) - xlogy(1-y, 1-+ϵ)))
@@ -111,10 +170,12 @@ end
111170
logitbinarycrossentropy(ŷ, y; agg=mean)
112171
113172
Mathematically equivalent to
114-
[`Flux.binarycrossentropy(σ(ŷ), y)`](@ref) but is more numerically stable.
173+
[`binarycrossentropy(σ(ŷ), y)`](@ref) but is more numerically stable.
174+
175+
Use [`label_smoothing`](@ref) to smooth the `y` value as preprocessing before
176+
computing the loss.
115177
116-
See also: [`Flux.Losses.crossentropy`](@ref), [`Flux.Losses.logitcrossentropy`](@ref), [`Flux.Losses.binarycrossentropy`](@ref)
117-
```
178+
See also: [`crossentropy`](@ref), [`logitcrossentropy`](@ref), [`binarycrossentropy`](@ref), [`label_smoothing`](@ref)
118179
"""
119180
function logitbinarycrossentropy(ŷ, y; agg=mean)
120181
agg(@.((1-y)*- logσ(ŷ)))

test/losses.jl

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Test
22
using Flux: onehotbatch, σ
33

4-
using Flux.Losses: mse, crossentropy, logitcrossentropy, binarycrossentropy, logitbinarycrossentropy
4+
using Flux.Losses: mse, label_smoothing, crossentropy, logitcrossentropy, binarycrossentropy, logitbinarycrossentropy
55
using Flux.Losses: xlogx, xlogy
66

77
# group here all losses, used in tests
@@ -56,31 +56,69 @@ end
5656

5757
# Now onehot y's
5858
y = onehotbatch([1, 1, 0, 0], 0:1)
59+
y_smoothed = label_smoothing(y, 0.1)
5960
ŷ = [.1 .9; .9 .1; .9 .1; .1 .9]'
6061
v = log(.1 / .9)
6162
logŷ = [v 0.0; 0.0 v; 0.0 v; v 0.0]'
6263
lossvalue = 1.203972804325936
64+
lossvalue_smoothed = 1.2039728043259348
65+
yl = onehotbatch([1], 0:1)
66+
sf = 0.1
67+
yls = [sf (1-sf)]' # Effective y after label smoothing
68+
ylp = [0.9 0.1]'
69+
logylp = [0.0 v]'
70+
71+
# Construct `sim`ilar and `dis`imilar versions of the dataset so we can test effect of smoothing
72+
# smoothing should decrease loss on disimilar and increase the loss on similar, compared to
73+
# the loss without smoothing
74+
ya = onehotbatch([1, 1, 1, 0, 0], 0:1)
75+
ya_smoothed = label_smoothing(ya, 2sf)
76+
y_same = Float32.(ya)
77+
y_sim = y_same .* (1-2*sf) .+ sf
78+
y_dis = copy(y_sim)
79+
y_dis[1,:], y_dis[2,:] = y_dis[2,:], y_dis[1,:]
6380

6481
@testset "crossentropy" begin
6582
@test crossentropy([0.1,0.0,0.9], [0.1,0.0,0.9]) crossentropy([0.1,0.9], [0.1,0.9])
6683
@test crossentropy(ŷ, y) lossvalue
84+
@test crossentropy(ŷ, y_smoothed) lossvalue_smoothed
85+
@test crossentropy(ylp, label_smoothing(yl, 2sf)) -sum(yls.*log.(ylp))
86+
@test crossentropy(ylp, yl) -sum(yl.*log.(ylp))
87+
@test iszero(crossentropy(y_same, ya, ϵ=0))
88+
@test iszero(crossentropy(ya, ya, ϵ=0))
89+
@test crossentropy(y_sim, ya) < crossentropy(y_sim, ya_smoothed)
90+
@test crossentropy(y_dis, ya) > crossentropy(y_dis, ya_smoothed)
6791
end
6892

6993
@testset "logitcrossentropy" begin
7094
@test logitcrossentropy(logŷ, y) lossvalue
95+
@test logitcrossentropy(logylp, yl) -sum(yl.*logsoftmax(logylp))
96+
@test logitcrossentropy(logylp, label_smoothing(yl, 2sf)) -sum(yls.*logsoftmax(logylp))
7197
end
7298

7399
logŷ, y = randn(3), rand(3)
100+
yls = y.*(1-2sf).+sf
74101

75102
@testset "binarycrossentropy" begin
103+
@test binarycrossentropy.(σ.(logŷ), label_smoothing(y, 2sf; dims=0); ϵ=0) -yls.*log.(σ.(logŷ)) - (1 .- yls).*log.(1 .- σ.(logŷ))
76104
@test binarycrossentropy(σ.(logŷ), y; ϵ=0) mean(-y.*log.(σ.(logŷ)) - (1 .- y).*log.(1 .- σ.(logŷ)))
77105
@test binarycrossentropy(σ.(logŷ), y) mean(-y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 .- y).*log.(1 .- σ.(logŷ) .+ eps.(σ.(logŷ))))
78106
end
79107

80108
@testset "logitbinarycrossentropy" begin
109+
@test logitbinarycrossentropy.(logŷ, label_smoothing(y, 0.2)) binarycrossentropy.(σ.(logŷ), label_smoothing(y, 0.2); ϵ=0)
81110
@test logitbinarycrossentropy(logŷ, y) binarycrossentropy(σ.(logŷ), y; ϵ=0)
82111
end
83112

113+
y = onehotbatch([1], 0:1)
114+
yls = [0.1 0.9]'
115+
@testset "label_smoothing" begin
116+
@test label_smoothing(y, 0.2) == yls
117+
@test label_smoothing(y, 0.2; dims=0) == label_smoothing.(y, 0.2; dims=0)
118+
@test_throws ArgumentError label_smoothing([0., 0., 1., 0.], 1.2)
119+
@test_throws ArgumentError label_smoothing([0., 0., 1., 0.], 0.)
120+
end
121+
84122
y = [1 2 3]
85123
ŷ = [4.0 5.0 6.0]
86124

0 commit comments

Comments
 (0)