Skip to content

Commit 9dbb259

Browse files
Merge pull request #921 from SciML/sb/fix_CI
Fix CI
2 parents d8ff418 + d82581c commit 9dbb259

10 files changed

+49
-26
lines changed

src/NeuralPDE.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ using Reexport: @reexport
2929
using RuntimeGeneratedFunctions: RuntimeGeneratedFunctions, @RuntimeGeneratedFunction
3030
using SciMLBase: SciMLBase, BatchIntegralFunction, IntegralProblem, NoiseProblem,
3131
OptimizationFunction, OptimizationProblem, ReturnCode, discretize,
32-
isinplace, solve, symbolic_discretize
32+
isinplace, solve, symbolic_discretize, ODEProblem, ODESolution
3333
using Statistics: Statistics, mean
3434
using QuasiMonteCarlo: QuasiMonteCarlo, LatinHypercubeSample
3535
using WeightInitializers: glorot_uniform, zeros32

src/PDE_BPINN.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ function LogDensityProblems.logdensity(ltd::PDELogTargetDensity, θ)
2020
priorlogpdf(ltd, θ) + L2LossData(ltd, θ)
2121
else
2222
return ltd.full_loglikelihood(setparameters(ltd, θ), ltd.allstd) +
23-
priorlogpdf(ltd, θ) + L2LossData(ltd, θ) + ltd.L2_loss2(setparameters(ltd, θ), ltd.phynewstd)
23+
priorlogpdf(ltd, θ) + L2LossData(ltd, θ) +
24+
ltd.L2_loss2(setparameters(ltd, θ), ltd.phynewstd)
2425
end
2526
end
2627

@@ -57,11 +58,11 @@ function get_lossy(pinnrep, dataset, Dict_differentials)
5758
# each sub vector has dataset's indvar coord's datafree_colloc_loss_function, n_subvectors = n_rows_dataset(or n_indvar_coords_dataset)
5859
# zip each colloc equation with args for each build_loss call per equation vector
5960
data_colloc_loss_functions = [[build_loss_function(pinnrep, eq, pde_indvar)
60-
for (eq, pde_indvar, integration_indvar) in zip(
61-
colloc_equation,
62-
pinnrep.pde_indvars,
63-
pinnrep.pde_integration_vars)]
64-
for colloc_equation in colloc_equations]
61+
for (eq, pde_indvar, integration_indvar) in zip(
62+
colloc_equation,
63+
pinnrep.pde_indvars,
64+
pinnrep.pde_integration_vars)]
65+
for colloc_equation in colloc_equations]
6566

6667
return data_colloc_loss_functions
6768
end

src/ode_solve.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,12 @@ function (f::NNODEInterpolation)(t::Vector, idxs, ::Type{Val{0}}, p, continuity)
281281
return DiffEqArray([out[idxs, i] for i in axes(out, 2)], t)
282282
end
283283

284+
function (sol::ODESolution{T, N, U, U2, D, T2, R, D2, P, A})(
285+
t::AbstractVector{<:Number}, ::Type{deriv}, idxs::Nothing,
286+
continuity) where {T, N, U, U2, D, T2, R, D2, P, A <: NNODE, deriv}
287+
sol.interp(t, idxs, deriv, sol.prob.p, continuity)
288+
end
289+
284290
SciMLBase.interp_summary(::NNODEInterpolation) = "Trained neural network interpolation"
285291
SciMLBase.allowscomplex(::NNODE) = true
286292

src/pino_ode_solve.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,10 @@ p,t = rand(1, 50, 10), rand(1, 50, 10)
260260
interp(p, t)
261261
```
262262
"""
263-
function (f::PINOODEInterpolation)(p::AbstractArray, t::AbstractArray)
263+
(f::PINOODEInterpolation)(p, t) = f(t, nothing, Val{0}, p, nothing)
264+
265+
function (f::PINOODEInterpolation)(
266+
t::AbstractArray, ::Nothing, ::Type{Val{0}}, p::AbstractArray, continuity)
264267
if f.phi.model isa DeepONet
265268
f.phi((p, t), f.θ)
266269
elseif f.phi.model isa Chain
@@ -273,7 +276,8 @@ function (f::PINOODEInterpolation)(p::AbstractArray, t::AbstractArray)
273276
end
274277
end
275278

276-
function (f::PINOODEInterpolation)(p::AbstractArray, t::Number)
279+
function (f::PINOODEInterpolation)(
280+
t::Number, ::Nothing, ::Type{Val{0}}, p::AbstractArray, continuity)
277281
if f.phi.model isa DeepONet
278282
t_ = [t]
279283
f.phi((p, t_), f.θ)
@@ -288,8 +292,16 @@ end
288292
SciMLBase.interp_summary(::PINOODEInterpolation) = "Trained neural network interpolation"
289293
SciMLBase.allowscomplex(::PINOODE) = true
290294

291-
function (sol::SciMLBase.AbstractODESolution)(t::Union{Number, AbstractArray})
292-
sol.interp(sol.prob.p, t)
295+
function (sol::ODESolution{T, N, U, U2, D, T2, R, D2, P, A})(
296+
t::AbstractArray, ::Type{deriv}, idxs::Nothing,
297+
continuity) where {T, N, U, U2, D, T2, R, D2, P, A <: PINOODE, deriv}
298+
sol.interp(t, idxs, deriv, sol.prob.p, continuity)
299+
end
300+
301+
function (sol::ODESolution{T, N, U, U2, D, T2, R, D2, P, A})(
302+
t::AbstractVector{<:Number}, ::Type{deriv}, idxs::Nothing,
303+
continuity) where {T, N, U, U2, D, T2, R, D2, P, A <: PINOODE, deriv}
304+
sol.interp(t, idxs, deriv, sol.prob.p, continuity)
293305
end
294306

295307
function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,

src/training_strategies.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,11 @@ function merge_strategy_with_loglikelihood_function(pinnrep::PINNRepresentation,
9797
return pde_loss_functions, bc_loss_functions
9898
end
9999

100-
function get_points_loss_functions(loss_function, train_set, eltypeθ, strategy::GridTraining;
100+
function get_points_loss_functions(
101+
loss_function, train_set, eltypeθ, strategy::GridTraining;
101102
τ = nothing)
102-
# loss_function length is number of all points loss is being evaluated upon
103-
# train sets rows are for each indvar, cols are coordinates (row_1,row_2,..row_n) at which loss evaluated
103+
# loss_function length is number of all points loss is being evaluated upon
104+
# train sets rows are for each indvar, cols are coordinates (row_1,row_2,..row_n) at which loss evaluated
104105
function loss(θ, std)
105106
logpdf(
106107
MvNormal(loss_function(train_set, θ)[1, :],

test/BPINN_PDE_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,4 +526,4 @@ end
526526
α = 1
527527
@test abs(param_new - α) < 0.2 * α
528528
@test abs(param_new - α) < abs(param_old - α)
529-
end
529+
end

test/BPINN_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,4 +426,4 @@ end
426426
mean(abs, u[2, :] .- pmean(sol_pestim2.ensemblesol[2]))
427427

428428
@test Loss_1 > Loss_2
429-
end
429+
end

test/NNODE_tests.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ end
142142
end
143143

144144
@testitem "Training Strategy: Others" tags=[:nnode] begin
145-
using OrdinaryDiffEq, Random, Lux, Optimisers
145+
using OrdinaryDiffEq, Random, Lux, Optimisers, Integrals
146146

147147
Random.seed!(100)
148148

@@ -165,13 +165,14 @@ end
165165
@testset "$(nameof(typeof(strategy)))" for strategy in [
166166
GridTraining(0.01),
167167
StochasticTraining(1000),
168-
QuadratureTraining(reltol = 1e-3, abstol = 1e-6, maxiters = 50, batch = 100)
168+
QuadratureTraining(reltol = 1e-3, abstol = 1e-6, maxiters = 50,
169+
batch = 100, quadrature_alg = QuadGKJL())
169170
]
170171
alg = NNODE(luxchain, opt; additional_loss, strategy)
171172
@test begin
172173
sol = solve(prob, alg; verbose = false, maxiters = 500, abstol = 1e-6)
173174
sol.errors[:l2] < 0.5
174-
end broken=(strategy isa QuadratureTraining)
175+
end
175176
end
176177
end
177178

test/NNPDE_tests.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,10 @@ end
187187
phi = discretization.phi
188188

189189
xs, ys = [infimum(d.domain):0.01:supremum(d.domain) for d in domains]
190-
analytic_sol_func(x, y) = (sinpi(x) * sinpi(y)) / (2pi^2)
190+
analytic = (x, y) -> (sinpi(x) * sinpi(y)) / (2pi^2)
191191

192192
u_predict = [first(phi([x, y], res.u)) for x in xs for y in ys]
193-
u_real = [analytic_sol_func(x, y) for x in xs for y in ys]
193+
u_real = [analytic(x, y) for x in xs for y in ys]
194194

195195
@test u_predictu_real atol=2.0
196196
end
@@ -394,6 +394,8 @@ end
394394
@testitem "PDE VI: PDE with mixed derivative" tags=[:nnpde1] setup=[NNPDE1TestSetup] begin
395395
using Lux, Random, Optimisers, DomainSets, Cubature, QuasiMonteCarlo, Integrals
396396
import ModelingToolkit: Interval, infimum, supremum
397+
using OptimizationOptimJL: BFGS
398+
using LineSearches: BackTracking
397399

398400
@parameters x y
399401
@variables u(..)
@@ -414,15 +416,15 @@ end
414416
# Space and time domains
415417
domains = [x Interval(0.0, 1.0), y Interval(0.0, 1.0)]
416418

417-
strategy = StochasticTraining(1024)
418-
inner = 20
419-
chain = Chain(Dense(2, inner, tanh), Dense(inner, inner, tanh), Dense(inner, 1))
419+
strategy = StochasticTraining(2048)
420+
inner = 32
421+
chain = Chain(Dense(2, inner, sigmoid), Dense(inner, inner, sigmoid), Dense(inner, 1))
420422

421423
discretization = PhysicsInformedNN(chain, strategy)
422424
@named pde_system = PDESystem(eq, bcs, domains, [x, y], [u(x, y)])
423425

424426
prob = discretize(pde_system, discretization)
425-
res = solve(prob, Adam(0.01); maxiters = 5000, callback)
427+
res = solve(prob, BFGS(); maxiters = 500, callback)
426428
phi = discretization.phi
427429

428430
analytic_sol_func(x, y) = x + x * y + y^2 / 2

test/PINO_ode_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ end
231231
Chain(Dense(1 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast),
232232
Dense(10 => 10, Lux.tanh_fast)))
233233

234-
u = rand(2, 50)
234+
u = rand(3, 50)
235235
v = rand(1, 40, 1)
236236
θ, st = Lux.setup(Random.default_rng(), deeponet)
237237
c = deeponet((u, v), θ, st)[1]

0 commit comments

Comments
 (0)