Skip to content

Commit 5e506c7

Browse files
Merge pull request #666 from AstitvaAggarwal/develop
added additional loss against data for NNODE
2 parents 0c932e2 + 9495d96 commit 5e506c7

File tree

2 files changed

+143
-13
lines changed

2 files changed

+143
-13
lines changed

src/ode_solve.jl

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ abstract type NeuralPDEAlgorithm <: DiffEqBase.AbstractODEAlgorithm end
33
"""
44
```julia
55
NNODE(chain, opt=OptimizationPolyalgorithms.PolyOpt(), init_params = nothing;
6-
autodiff=false, batch=0, kwargs...)
6+
autodiff=false, batch=0,additional_loss=nothing,
7+
kwargs...)
78
```
89
910
Algorithm for solving ordinary differential equations using a neural network. This is a specialization
@@ -23,6 +24,19 @@ of the physics-informed neural network which is used as a solver for a standard
2324
which thus uses the random initialization provided by the neural network library.
2425
2526
## Keyword Arguments
27+
* `additional_loss`: A function additional_loss(phi, θ) where phi are the neural network trial solutions,
28+
θ are the weights of the neural network(s).
29+
30+
## Example
31+
32+
```julia
33+
ts=[t for t in 1:100]
34+
(u_, t_) = (analytical_func(ts), ts)
35+
function additional_loss(phi, θ)
36+
return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_)
37+
end
38+
alg = NeuralPDE.NNODE(chain, opt, additional_loss = additional_loss)
39+
```
2640
2741
* `autodiff`: The switch between automatic and numerical differentiation for
2842
the PDE operators. The reverse mode of the loss function is always
@@ -63,20 +77,23 @@ is an accurate interpolation (up to the neural network training result). In addi
6377
Lagaris, Isaac E., Aristidis Likas, and Dimitrios I. Fotiadis. "Artificial neural networks for solving
6478
ordinary and partial differential equations." IEEE Transactions on Neural Networks 9, no. 5 (1998): 987-1000.
6579
"""
66-
struct NNODE{C, O, P, B, K, S <: Union{Nothing, AbstractTrainingStrategy}} <:
80+
struct NNODE{C, O, P, B, K, AL <: Union{Nothing, Function},
81+
S <: Union{Nothing, AbstractTrainingStrategy}
82+
} <:
6783
NeuralPDEAlgorithm
6884
chain::C
6985
opt::O
7086
init_params::P
7187
autodiff::Bool
7288
batch::B
7389
strategy::S
90+
additional_loss::AL
7491
kwargs::K
7592
end
7693
function NNODE(chain, opt, init_params = nothing;
7794
strategy = nothing,
78-
autodiff = false, batch = nothing, kwargs...)
79-
NNODE(chain, opt, init_params, autodiff, batch, strategy, kwargs)
95+
autodiff = false, batch = nothing, additional_loss = nothing, kwargs...)
96+
NNODE(chain, opt, init_params, autodiff, batch, strategy, additional_loss, kwargs)
8097
end
8198

8299
"""
@@ -236,7 +253,7 @@ function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector,
236253
end
237254

238255
"""
239-
Representation of the loss function, paramtric on the training strategy `strategy`
256+
Representation of the loss function, parametric on the training strategy `strategy`
240257
"""
241258
function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tspan, p,
242259
batch)
@@ -250,38 +267,36 @@ function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tsp
250267
sol.u
251268
end
252269

253-
# Default this to ForwardDiff until Integrals.jl autodiff is sorted out
254-
OptimizationFunction(loss, Optimization.AutoForwardDiff())
270+
return loss
255271
end
256272

257273
function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p, batch)
258274
ts = tspan[1]:(strategy.dx):tspan[2]
259275

260276
# sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken
261-
262277
function loss(θ, _)
263278
if batch
264279
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p))
265280
else
266281
sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p) for t in ts])
267282
end
268283
end
269-
optf = OptimizationFunction(loss, Optimization.AutoZygote())
284+
return loss
270285
end
271286

272287
function generate_loss(strategy::StochasticTraining, phi, f, autodiff::Bool, tspan, p,
273288
batch)
289+
# sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken
274290
function loss(θ, _)
275291
ts = adapt(parameterless_type(θ),
276292
[(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)])
277-
278293
if batch
279294
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p))
280295
else
281296
sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p) for t in ts])
282297
end
283298
end
284-
optf = OptimizationFunction(loss, Optimization.AutoZygote())
299+
return loss
285300
end
286301

287302
function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p,
@@ -312,7 +327,7 @@ function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Boo
312327
sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p) for t in ts])
313328
end
314329
end
315-
optf = OptimizationFunction(loss, Optimization.AutoZygote())
330+
return loss
316331
end
317332

318333
function generate_loss(strategy::QuasiRandomTraining, phi, f, autodiff::Bool, tspan)
@@ -407,7 +422,27 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
407422
alg.batch
408423
end
409424

410-
optf = generate_loss(strategy, phi, f, autodiff::Bool, tspan, p, batch)
425+
inner_f = generate_loss(strategy, phi, f, autodiff, tspan, p, batch)
426+
additional_loss = alg.additional_loss
427+
428+
# Creates OptimizationFunction Object from total_loss
429+
function total_loss(θ, _)
430+
L2_loss = inner_f(θ, phi)
431+
if !(additional_loss isa Nothing)
432+
return additional_loss(phi, θ) + L2_loss
433+
end
434+
L2_loss
435+
end
436+
437+
# Choice of Optimization Algo for Training Strategies
438+
opt_algo = if strategy isa QuadratureTraining
439+
Optimization.AutoForwardDiff()
440+
else
441+
Optimization.AutoZygote()
442+
end
443+
444+
# Creates OptimizationFunction Object from total_loss
445+
optf = OptimizationFunction(total_loss, opt_algo)
411446

412447
iteration = 0
413448
callback = function (p, l)

test/NNODE_tests.jl

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ sol = solve(prob, NeuralPDE.NNODE(luxchain, opt; batch = true), verbose = true,
207207
abstol = 1.0f-8, dt = 1 / 5.0f0)
208208
@test sol.errors[:l2] < 0.5
209209

210+
# WeightedIntervalTraining(Lux Chain)
210211
function f(u, p, t)
211212
[p[1] * u[1] - p[2] * u[1] * u[2], -p[3] * u[2] + p[4] * u[1] * u[2]]
212213
end
@@ -228,3 +229,97 @@ alg = NeuralPDE.NNODE(chain, opt, autodiff = false,
228229
sol = solve(prob_oop, alg, verbose = true, maxiters = 100000, saveat = 0.01)
229230

230231
@test abs(mean(sol) - mean(true_sol)) < 0.2
232+
233+
# Checking if additional_loss feature works for NNODE
234+
linear = (u, p, t) -> cos(2pi * t)
235+
linear_analytic = (u, p, t) -> (1 / (2pi)) * sin(2pi * t)
236+
tspan = (0.0f0, 1.0f0)
237+
dt = (tspan[2] - tspan[1]) / 99
238+
ts = collect(tspan[1]:dt:tspan[2])
239+
prob = ODEProblem(ODEFunction(linear, analytic = linear_analytic), 0.0f0, (0.0f0, 1.0f0))
240+
opt = OptimizationOptimisers.Adam(0.1, (0.9, 0.95))
241+
242+
# Analytical solution
243+
u_analytical(x) = (1 / (2pi)) .* sin.(2pi .* x)
244+
245+
# GridTraining (Flux Chain)
246+
chain = Flux.Chain(Dense(1, 5, σ), Dense(5, 1))
247+
248+
(u_, t_) = (u_analytical(ts), ts)
249+
function additional_loss(phi, θ)
250+
return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_)
251+
end
252+
253+
alg1 = NeuralPDE.NNODE(chain, opt, strategy = GridTraining(0.01),
254+
additional_loss = additional_loss)
255+
256+
sol1 = solve(prob, alg1, verbose = true, abstol = 1.0f-8, maxiters = 500)
257+
@test sol1.errors[:l2] < 0.5
258+
259+
# GridTraining (Lux Chain)
260+
luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1))
261+
262+
(u_, t_) = (u_analytical(ts), ts)
263+
function additional_loss(phi, θ)
264+
return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_)
265+
end
266+
267+
alg1 = NeuralPDE.NNODE(luxchain, opt, strategy = GridTraining(0.01),
268+
additional_loss = additional_loss)
269+
270+
sol1 = solve(prob, alg1, verbose = true, abstol = 1.0f-8, maxiters = 500)
271+
@test sol1.errors[:l2] < 0.5
272+
273+
# QuadratureTraining (Flux Chain)
274+
chain = Flux.Chain(Dense(1, 5, σ), Dense(5, 1))
275+
276+
(u_, t_) = (u_analytical(ts), ts)
277+
function additional_loss(phi, θ)
278+
return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_)
279+
end
280+
281+
alg1 = NeuralPDE.NNODE(chain, opt, additional_loss = additional_loss)
282+
283+
sol1 = solve(prob, alg1, verbose = true, abstol = 1.0f-10, maxiters = 200)
284+
@test sol1.errors[:l2] < 0.5
285+
286+
# QuadratureTraining (Lux Chain)
287+
luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1))
288+
289+
(u_, t_) = (u_analytical(ts), ts)
290+
function additional_loss(phi, θ)
291+
return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_)
292+
end
293+
294+
alg1 = NeuralPDE.NNODE(luxchain, opt, additional_loss = additional_loss)
295+
296+
sol1 = solve(prob, alg1, verbose = true, abstol = 1.0f-10, maxiters = 200)
297+
@test sol1.errors[:l2] < 0.5
298+
299+
# StochasticTraining(Flux Chain)
300+
chain = Flux.Chain(Dense(1, 5, σ), Dense(5, 1))
301+
302+
(u_, t_) = (u_analytical(ts), ts)
303+
function additional_loss(phi, θ)
304+
return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_)
305+
end
306+
307+
alg1 = NeuralPDE.NNODE(chain, opt, strategy = StochasticTraining(1000),
308+
additional_loss = additional_loss)
309+
310+
sol1 = solve(prob, alg1, verbose = true, abstol = 1.0f-8, maxiters = 500)
311+
@test sol1.errors[:l2] < 0.5
312+
313+
# StochasticTraining (Lux Chain)
314+
luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1))
315+
316+
(u_, t_) = (u_analytical(ts), ts)
317+
function additional_loss(phi, θ)
318+
return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_)
319+
end
320+
321+
alg1 = NeuralPDE.NNODE(luxchain, opt, strategy = StochasticTraining(1000),
322+
additional_loss = additional_loss)
323+
324+
sol1 = solve(prob, alg1, verbose = true, abstol = 1.0f-8, maxiters = 500)
325+
@test sol1.errors[:l2] < 0.5

0 commit comments

Comments
 (0)