@@ -3,7 +3,8 @@ abstract type NeuralPDEAlgorithm <: DiffEqBase.AbstractODEAlgorithm end
3
3
"""
4
4
```julia
5
5
NNODE(chain, opt=OptimizationPolyalgorithms.PolyOpt(), init_params = nothing;
6
- autodiff=false, batch=0, kwargs...)
6
+ autodiff=false, batch=0,additional_loss=nothing,
7
+ kwargs...)
7
8
```
8
9
9
10
Algorithm for solving ordinary differential equations using a neural network. This is a specialization
@@ -23,6 +24,19 @@ of the physics-informed neural network which is used as a solver for a standard
23
24
which thus uses the random initialization provided by the neural network library.
24
25
25
26
## Keyword Arguments
27
+ * `additional_loss`: A function additional_loss(phi, θ) where phi are the neural network trial solutions,
28
+ θ are the weights of the neural network(s).
29
+
30
+ ## Example
31
+
32
+ ```julia
33
+ ts=[t for t in 1:100]
34
+ (u_, t_) = (analytical_func(ts), ts)
35
+ function additional_loss(phi, θ)
36
+ return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_)
37
+ end
38
+ alg = NeuralPDE.NNODE(chain, opt, additional_loss = additional_loss)
39
+ ```
26
40
27
41
* `autodiff`: The switch between automatic and numerical differentiation for
28
42
the PDE operators. The reverse mode of the loss function is always
@@ -63,20 +77,23 @@ is an accurate interpolation (up to the neural network training result). In addi
63
77
Lagaris, Isaac E., Aristidis Likas, and Dimitrios I. Fotiadis. "Artificial neural networks for solving
64
78
ordinary and partial differential equations." IEEE Transactions on Neural Networks 9, no. 5 (1998): 987-1000.
65
79
"""
66
- struct NNODE{C, O, P, B, K, S <: Union{Nothing, AbstractTrainingStrategy} } < :
80
+ struct NNODE{C, O, P, B, K, AL <: Union{Nothing, Function} ,
81
+ S <: Union{Nothing, AbstractTrainingStrategy}
82
+ } < :
67
83
NeuralPDEAlgorithm
68
84
chain:: C
69
85
opt:: O
70
86
init_params:: P
71
87
autodiff:: Bool
72
88
batch:: B
73
89
strategy:: S
90
+ additional_loss:: AL
74
91
kwargs:: K
75
92
end
76
93
function NNODE (chain, opt, init_params = nothing ;
77
94
strategy = nothing ,
78
- autodiff = false , batch = nothing , kwargs... )
79
- NNODE (chain, opt, init_params, autodiff, batch, strategy, kwargs)
95
+ autodiff = false , batch = nothing , additional_loss = nothing , kwargs... )
96
+ NNODE (chain, opt, init_params, autodiff, batch, strategy, additional_loss, kwargs)
80
97
end
81
98
82
99
"""
@@ -236,7 +253,7 @@ function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector,
236
253
end
237
254
238
255
"""
239
- Representation of the loss function, paramtric on the training strategy `strategy`
256
+ Representation of the loss function, parametric on the training strategy `strategy`
240
257
"""
241
258
function generate_loss (strategy:: QuadratureTraining , phi, f, autodiff:: Bool , tspan, p,
242
259
batch)
@@ -250,38 +267,36 @@ function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tsp
250
267
sol. u
251
268
end
252
269
253
- # Default this to ForwardDiff until Integrals.jl autodiff is sorted out
254
- OptimizationFunction (loss, Optimization. AutoForwardDiff ())
270
+ return loss
255
271
end
256
272
257
273
function generate_loss (strategy:: GridTraining , phi, f, autodiff:: Bool , tspan, p, batch)
258
274
ts = tspan[1 ]: (strategy. dx): tspan[2 ]
259
275
260
276
# sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken
261
-
262
277
function loss (θ, _)
263
278
if batch
264
279
sum (abs2, inner_loss (phi, f, autodiff, ts, θ, p))
265
280
else
266
281
sum (abs2, [inner_loss (phi, f, autodiff, t, θ, p) for t in ts])
267
282
end
268
283
end
269
- optf = OptimizationFunction ( loss, Optimization . AutoZygote ())
284
+ return loss
270
285
end
271
286
272
287
function generate_loss (strategy:: StochasticTraining , phi, f, autodiff:: Bool , tspan, p,
273
288
batch)
289
+ # sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken
274
290
function loss (θ, _)
275
291
ts = adapt (parameterless_type (θ),
276
292
[(tspan[2 ] - tspan[1 ]) * rand () + tspan[1 ] for i in 1 : (strategy. points)])
277
-
278
293
if batch
279
294
sum (abs2, inner_loss (phi, f, autodiff, ts, θ, p))
280
295
else
281
296
sum (abs2, [inner_loss (phi, f, autodiff, t, θ, p) for t in ts])
282
297
end
283
298
end
284
- optf = OptimizationFunction ( loss, Optimization . AutoZygote ())
299
+ return loss
285
300
end
286
301
287
302
function generate_loss (strategy:: WeightedIntervalTraining , phi, f, autodiff:: Bool , tspan, p,
@@ -312,7 +327,7 @@ function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Boo
312
327
sum (abs2, [inner_loss (phi, f, autodiff, t, θ, p) for t in ts])
313
328
end
314
329
end
315
- optf = OptimizationFunction ( loss, Optimization . AutoZygote ())
330
+ return loss
316
331
end
317
332
318
333
function generate_loss (strategy:: QuasiRandomTraining , phi, f, autodiff:: Bool , tspan)
@@ -407,7 +422,27 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
407
422
alg. batch
408
423
end
409
424
410
- optf = generate_loss (strategy, phi, f, autodiff:: Bool , tspan, p, batch)
425
+ inner_f = generate_loss (strategy, phi, f, autodiff, tspan, p, batch)
426
+ additional_loss = alg. additional_loss
427
+
428
+ # Creates OptimizationFunction Object from total_loss
429
+ function total_loss (θ, _)
430
+ L2_loss = inner_f (θ, phi)
431
+ if ! (additional_loss isa Nothing)
432
+ return additional_loss (phi, θ) + L2_loss
433
+ end
434
+ L2_loss
435
+ end
436
+
437
+ # Choice of Optimization Algo for Training Strategies
438
+ opt_algo = if strategy isa QuadratureTraining
439
+ Optimization. AutoForwardDiff ()
440
+ else
441
+ Optimization. AutoZygote ()
442
+ end
443
+
444
+ # Creates OptimizationFunction Object from total_loss
445
+ optf = OptimizationFunction (total_loss, opt_algo)
411
446
412
447
iteration = 0
413
448
callback = function (p, l)
0 commit comments