Skip to content

Commit 90b8a40

Browse files
committed
energies: adding energy loss functions
This commit adds energy loss functions. These are loss functions analogous to those obtained from PDE equations or boundary equations, but the energy integrand can be given explicitly in symbolic form.
1 parent af4d4b7 commit 90b8a40

7 files changed

+420
-138
lines changed

src/adaptive_losses.jl

+46-18
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ end
1515
"""
1616
```julia
1717
NonAdaptiveLoss{T}(; pde_loss_weights = 1,
18+
energy_loss_weights = 1,
1819
bc_loss_weights = 1,
1920
additional_loss_weights = 1)
2021
```
@@ -25,30 +26,33 @@ change during optimization
2526
mutable struct NonAdaptiveLoss{T <: Real} <: AbstractAdaptiveLoss
2627
pde_loss_weights::Vector{T}
2728
bc_loss_weights::Vector{T}
29+
energy_loss_weights::Vector{T}
2830
additional_loss_weights::Vector{T}
2931
SciMLBase.@add_kwonly function NonAdaptiveLoss{T}(; pde_loss_weights = 1,
32+
energy_loss_weights = 1,
3033
bc_loss_weights = 1,
3134
additional_loss_weights = 1) where {
3235
T <:
3336
Real
3437
}
35-
new(vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T),
38+
new(vectorify(pde_loss_weights, T), vectorify(energy_loss_weights, T), vectorify(bc_loss_weights, T),
3639
vectorify(additional_loss_weights, T))
3740
end
3841
end
3942

4043
# default to Float64
41-
SciMLBase.@add_kwonly function NonAdaptiveLoss(; pde_loss_weights = 1, bc_loss_weights = 1,
44+
SciMLBase.@add_kwonly function NonAdaptiveLoss(; pde_loss_weights = 1, energy_loss_weights = 1, bc_loss_weights = 1,
4245
additional_loss_weights = 1)
4346
NonAdaptiveLoss{Float64}(; pde_loss_weights = pde_loss_weights,
4447
bc_loss_weights = bc_loss_weights,
48+
energy_loss_weights = energy_loss_weights,
4549
additional_loss_weights = additional_loss_weights)
4650
end
4751

4852
function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
4953
adaloss::NonAdaptiveLoss,
50-
pde_loss_functions, bc_loss_functions)
51-
function null_nonadaptive_loss(θ, pde_losses, bc_losses)
54+
pde_loss_functions, energy_loss_functions, bc_loss_functions)
55+
function null_nonadaptive_loss(θ, pde_loss, energy_loss, bc_losses)
5256
nothing
5357
end
5458
end
@@ -58,6 +62,7 @@ end
5862
GradientScaleAdaptiveLoss(reweight_every;
5963
weight_change_inertia = 0.9,
6064
pde_loss_weights = 1,
65+
energy_loss_weights = 1,
6166
bc_loss_weights = 1,
6267
additional_loss_weights = 1)
6368
```
@@ -90,61 +95,66 @@ mutable struct GradientScaleAdaptiveLoss{T <: Real} <: AbstractAdaptiveLoss
9095
reweight_every::Int64
9196
weight_change_inertia::T
9297
pde_loss_weights::Vector{T}
98+
energy_loss_weights::Vector{T}
9399
bc_loss_weights::Vector{T}
94100
additional_loss_weights::Vector{T}
95101
SciMLBase.@add_kwonly function GradientScaleAdaptiveLoss{T}(reweight_every;
96102
weight_change_inertia = 0.9,
97103
pde_loss_weights = 1,
104+
energy_loss_weights = 1,
98105
bc_loss_weights = 1,
99106
additional_loss_weights = 1) where {
100107
T <:
101108
Real
102109
}
103110
new(convert(Int64, reweight_every), convert(T, weight_change_inertia),
104-
vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T),
105-
vectorify(additional_loss_weights, T))
111+
vectorify(pde_loss_weights, T), vectorify(energy_loss_weights, T),
112+
vectorify(bc_loss_weights, T), vectorify(additional_loss_weights, T))
106113
end
107114
end
108115
# default to Float64
109116
SciMLBase.@add_kwonly function GradientScaleAdaptiveLoss(reweight_every;
110117
weight_change_inertia = 0.9,
111118
pde_loss_weights = 1,
119+
energy_loss_weights = 1,
112120
bc_loss_weights = 1,
113121
additional_loss_weights = 1)
114122
GradientScaleAdaptiveLoss{Float64}(reweight_every;
115123
weight_change_inertia = weight_change_inertia,
116124
pde_loss_weights = pde_loss_weights,
125+
energy_loss_weights = energy_loss_weights,
117126
bc_loss_weights = bc_loss_weights,
118127
additional_loss_weights = additional_loss_weights)
119128
end
120129

121130
function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
122131
adaloss::GradientScaleAdaptiveLoss,
123-
pde_loss_functions, bc_loss_functions)
132+
pde_loss_functions, energy_loss_functions, bc_loss_functions)
124133
weight_change_inertia = adaloss.weight_change_inertia
125134
iteration = pinnrep.iteration
126135
adaloss_T = eltype(adaloss.pde_loss_weights)
127136

128137
function run_loss_gradients_adaptive_loss(θ, pde_losses, bc_losses)
129138
if iteration[1] % adaloss.reweight_every == 0
130139
# the paper assumes a single pde loss function, so here we grab the maximum of the maximums of each pde loss function
131-
pde_grads_maxes = [maximum(abs.(Zygote.gradient(pde_loss_function, θ)[1]))
132-
for pde_loss_function in pde_loss_functions]
133-
pde_grads_max = maximum(pde_grads_maxes)
140+
# we treat energy loss functions the same as pde loss functions
141+
pde_energy_grads_maxes = [maximum(abs.(Zygote.gradient(pde_loss_function, θ)[1]))
142+
for pde_loss_function in vcat(pde_loss_functions, energy_loss_functions)]
143+
pde_energy_grads_max = maximum(pde_energy_grads_maxes)
134144
bc_grads_mean = [mean(abs.(Zygote.gradient(bc_loss_function, θ)[1]))
135145
for bc_loss_function in bc_loss_functions]
136146

137147
nonzero_divisor_eps = adaloss_T isa Float64 ? Float64(1e-11) :
138148
convert(adaloss_T, 1e-7)
139-
bc_loss_weights_proposed = pde_grads_max ./
149+
bc_loss_weights_proposed = pde_energy_grads_max ./
140150
(bc_grads_mean .+ nonzero_divisor_eps)
141151
adaloss.bc_loss_weights .= weight_change_inertia .*
142152
adaloss.bc_loss_weights .+
143153
(1 .- weight_change_inertia) .*
144154
bc_loss_weights_proposed
145-
logscalar(pinnrep.logger, pde_grads_max, "adaptive_loss/pde_grad_max",
155+
logscalar(pinnrep.logger, pde_energy_grads_max, "adaptive_loss/pde_energy_grad_max",
146156
iteration[1])
147-
logvector(pinnrep.logger, pde_grads_maxes, "adaptive_loss/pde_grad_maxes",
157+
logvector(pinnrep.logger, pde_energy_grads_maxes, "adaptive_loss/pde_energy_grad_maxes",
148158
iteration[1])
149159
logvector(pinnrep.logger, bc_grads_mean, "adaptive_loss/bc_grad_mean",
150160
iteration[1])
@@ -160,8 +170,10 @@ end
160170
```julia
161171
function MiniMaxAdaptiveLoss(reweight_every;
162172
pde_max_optimiser = Flux.ADAM(1e-4),
173+
energy_max_optimiser = Flux.ADAM(1e-4),
163174
bc_max_optimiser = Flux.ADAM(0.5),
164175
pde_loss_weights = 1,
176+
energy_loss_weights = 1,
165177
bc_loss_weights = 1,
166178
additional_loss_weights = 1)
167179
```
@@ -191,65 +203,81 @@ https://arxiv.org/abs/2009.04544
191203
"""
192204
mutable struct MiniMaxAdaptiveLoss{T <: Real,
193205
PDE_OPT <: Flux.Optimise.AbstractOptimiser,
206+
ENERGY_OPT <: Flux.Optimise.AbstractOptimiser,
194207
BC_OPT <: Flux.Optimise.AbstractOptimiser} <:
195208
AbstractAdaptiveLoss
196209
reweight_every::Int64
197210
pde_max_optimiser::PDE_OPT
211+
energy_max_optimiser::ENERGY_OPT
198212
bc_max_optimiser::BC_OPT
199213
pde_loss_weights::Vector{T}
214+
energy_loss_weights::Vector{T}
200215
bc_loss_weights::Vector{T}
201216
additional_loss_weights::Vector{T}
202217
SciMLBase.@add_kwonly function MiniMaxAdaptiveLoss{T,
203-
PDE_OPT, BC_OPT}(reweight_every;
218+
PDE_OPT, ENERGY_OPT, BC_OPT}(reweight_every;
204219
pde_max_optimiser = Flux.ADAM(1e-4),
220+
energy_max_optimiser = Flux.ADAM(1e-4),
205221
bc_max_optimiser = Flux.ADAM(0.5),
206222
pde_loss_weights = 1,
223+
energy_loss_weights = 1,
207224
bc_loss_weights = 1,
208225
additional_loss_weights = 1) where {
209226
T <:
210227
Real,
211228
PDE_OPT <:
212229
Flux.Optimise.AbstractOptimiser,
230+
ENERGY_OPT <:
231+
Flux.Optimise.AbstractOptimiser,
213232
BC_OPT <:
214233
Flux.Optimise.AbstractOptimiser
215234
}
216-
new(convert(Int64, reweight_every), convert(PDE_OPT, pde_max_optimiser),
235+
new(convert(Int64, reweight_every), convert(PDE_OPT, pde_max_optimiser), convert(ENERGY_OPT, energy_max_optimiser),
217236
convert(BC_OPT, bc_max_optimiser),
218237
vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T),
219-
vectorify(additional_loss_weights, T))
238+
vectorify(energy_loss_weights, T), vectorify(additional_loss_weights, T))
220239
end
221240
end
222241

223242
# default to Float64, ADAM, ADAM
224243
SciMLBase.@add_kwonly function MiniMaxAdaptiveLoss(reweight_every;
225244
pde_max_optimiser = Flux.ADAM(1e-4),
245+
energy_max_optimiser = Flux.ADAM(1e-4),
226246
bc_max_optimiser = Flux.ADAM(0.5),
227247
pde_loss_weights = 1,
248+
energy_loss_weights = 1,
228249
bc_loss_weights = 1,
229250
additional_loss_weights = 1)
230251
MiniMaxAdaptiveLoss{Float64, typeof(pde_max_optimiser),
231252
typeof(bc_max_optimiser)}(reweight_every;
232253
pde_max_optimiser = pde_max_optimiser,
254+
energy_max_optimiser = energy_max_optimiser,
233255
bc_max_optimiser = bc_max_optimiser,
234256
pde_loss_weights = pde_loss_weights,
257+
energy_loss_weights = energy_loss_weights,
235258
bc_loss_weights = bc_loss_weights,
236259
additional_loss_weights = additional_loss_weights)
237260
end
238261

239262
function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
240263
adaloss::MiniMaxAdaptiveLoss,
241-
pde_loss_functions, bc_loss_functions)
264+
pde_loss_functions, energy_loss_functions, bc_loss_functions)
242265
pde_max_optimiser = adaloss.pde_max_optimiser
266+
energy_max_optimiser = adaloss.energy_max_optimiser
243267
bc_max_optimiser = adaloss.bc_max_optimiser
244268
iteration = pinnrep.iteration
245269

246-
function run_minimax_adaptive_loss(θ, pde_losses, bc_losses)
270+
function run_minimax_adaptive_loss(θ, pde_losses, energy_losses, bc_losses)
247271
if iteration[1] % adaloss.reweight_every == 0
248272
Flux.Optimise.update!(pde_max_optimiser, adaloss.pde_loss_weights,
249273
-pde_losses)
274+
Flux.Optimise.update!(energy_max_optimiser, adaloss.energy_loss_weights,
275+
-energy_losses)
250276
Flux.Optimise.update!(bc_max_optimiser, adaloss.bc_loss_weights, -bc_losses)
251277
logvector(pinnrep.logger, adaloss.pde_loss_weights,
252278
"adaptive_loss/pde_loss_weights", iteration[1])
279+
logvector(pinnrep.logger, adaloss.energy_loss_weights,
280+
"adaptive_loss/energy_loss_weights", iteration[1])
253281
logvector(pinnrep.logger, adaloss.bc_loss_weights,
254282
"adaptive_loss/bc_loss_weights",
255283
iteration[1])

0 commit comments

Comments
 (0)