diff --git a/docs/pages.jl b/docs/pages.jl index dd11e26e..a7830e11 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -2,8 +2,8 @@ pages = ["index.md", "Getting Started with BVP solving in Julia" => "tutorials/getting_started.md", - "Tutorials" => Any[ - "tutorials/continuation.md", "tutorials/solve_nlls_bvp.md", "tutorials/extremum.md"], + "Tutorials" => Any["tutorials/continuation.md", + "tutorials/solve_nlls_bvp.md", "tutorials/extremum.md"], "Basics" => Any["basics/bvp_problem.md", "basics/bvp_functions.md", "basics/solve.md", "basics/autodiff.md", "basics/error_control.md"], "Solver Summaries and Recommendations" => Any[ diff --git a/ext/BoundaryValueDiffEqODEInterfaceExt.jl b/ext/BoundaryValueDiffEqODEInterfaceExt.jl index e7e5fd86..b2198dbc 100644 --- a/ext/BoundaryValueDiffEqODEInterfaceExt.jl +++ b/ext/BoundaryValueDiffEqODEInterfaceExt.jl @@ -61,13 +61,17 @@ function SciMLBase.__solve(prob::BVProblem, alg::BVPM2; dt = 0.0, reltol = 1e-3, @closure (t, u, du) -> du .= vec(prob.f(reshape(u, u0_size), prob.p, t)) end bvp2m_bc = if SciMLBase.isinplace(prob) - @closure (ya, yb, bca, bcb) -> begin + @closure (ya, + yb, + bca, + bcb) -> begin prob.f.bc[1](reshape(bca, left_bc_size), reshape(ya, u0_size), prob.p) prob.f.bc[2](reshape(bcb, right_bc_size), reshape(yb, u0_size), prob.p) return nothing end else - @closure (ya, yb, bca, bcb) -> begin + @closure ( + ya, yb, bca, bcb) -> begin bca .= vec(prob.f.bc[1](reshape(ya, u0_size), prob.p)) bcb .= vec(prob.f.bc[2](reshape(yb, u0_size), prob.p)) return nothing @@ -140,7 +144,8 @@ function SciMLBase.__solve(prob::BVProblem, alg::BVPSOL; maxiters = 1000, end bvpsol_bc = if SciMLBase.isinplace(prob) - @closure (ya, yb, r) -> begin + @closure (ya, yb, + r) -> begin left_bc = reshape(@view(r[1:no_left_bc]), left_bc_size) right_bc = reshape(@view(r[(no_left_bc + 1):end]), right_bc_size) prob.f.bc[1](left_bc, reshape(ya, u0_size), prob.p) @@ -148,7 +153,9 @@ function SciMLBase.__solve(prob::BVProblem, alg::BVPSOL; maxiters = 1000, return nothing end else - @closure (ya, yb, r) -> begin + @closure (ya, + yb, + r) -> begin r[1:no_left_bc] .= vec(prob.f.bc[1](reshape(ya, u0_size), prob.p)) r[(no_left_bc + 1):end] .= vec(prob.f.bc[2](reshape(yb, u0_size), prob.p)) return nothing diff --git a/lib/BoundaryValueDiffEqAscher/src/adaptivity.jl b/lib/BoundaryValueDiffEqAscher/src/adaptivity.jl index dbdd4381..6108a12e 100644 --- a/lib/BoundaryValueDiffEqAscher/src/adaptivity.jl +++ b/lib/BoundaryValueDiffEqAscher/src/adaptivity.jl @@ -188,15 +188,13 @@ function error_estimate!(cache::AscherCache) # in valstr in case they prove to be needed later for an error estimate. x = mesh[i] + (mesh_dt[i]) * 2.0 / 3.0 @views approx(cache, x, valstr[i][3]) - error[i] .= wgterr .* - abs.(valstr[i][3] .- + error[i] .= wgterr .* abs.(valstr[i][3] .- (isodd(i) ? valstr[Int((i + 1) / 2)][2] : valstr[Int(i / 2)][4])) x = mesh[i] + (mesh_dt[i]) / 3.0 @views approx(cache, x, valstr[i][2]) error[i] .= error[i] .+ - wgterr .* - abs.(valstr[i][2] .- + wgterr .* abs.(valstr[i][2] .- (isodd(i) ? valstr[Int((i + 1) / 2)][1] : valstr[Int(i / 2)][3])) end return maximum(reduce(hcat, error), dims = 2) diff --git a/lib/BoundaryValueDiffEqAscher/src/ascher.jl b/lib/BoundaryValueDiffEqAscher/src/ascher.jl index 9d281206..4c51a956 100644 --- a/lib/BoundaryValueDiffEqAscher/src/ascher.jl +++ b/lib/BoundaryValueDiffEqAscher/src/ascher.jl @@ -102,7 +102,8 @@ function SciMLBase.__init( iip = isinplace(prob) - f, bc = if prob.u0 isa AbstractVector + f, + bc = if prob.u0 isa AbstractVector prob.f, prob.f.bc elseif iip vecf! = @closure (du, u, p, t) -> __vec_f!(du, u, p, t, prob.f, size(u0)) @@ -337,10 +338,11 @@ function __construct_nlproblem(cache::AscherCache{iip, T}) where {iip, T} end jac = if iip - @closure (J, u, p) -> __ascher_mpoint_jacobian!( - J, u, diffmode, jac_cache, loss, lz, cache.p) + @closure (J, u, + p) -> __ascher_mpoint_jacobian!(J, u, diffmode, jac_cache, loss, lz, cache.p) else - @closure (u, p) -> __ascher_mpoint_jacobian( + @closure (u, + p) -> __ascher_mpoint_jacobian( jac_prototype, u, diffmode, jac_cache, loss, cache.p) end diff --git a/lib/BoundaryValueDiffEqAscher/src/collocation.jl b/lib/BoundaryValueDiffEqAscher/src/collocation.jl index 758b8bbd..31780ee7 100644 --- a/lib/BoundaryValueDiffEqAscher/src/collocation.jl +++ b/lib/BoundaryValueDiffEqAscher/src/collocation.jl @@ -1,5 +1,6 @@ function Φ!(cache::AscherCache{iip, T}, z, res, pt::StandardBVProblem) where {iip, T} - (; f, mesh, mesh_dt, ncomp, ny, bc, k, p, zeta, residual, zval, yval, gval, delz, dmz, deldmz, g, w, v, ipvtg, ipvtw, TU) = cache + (; f, mesh, mesh_dt, ncomp, ny, bc, k, p, zeta, residual, zval, + yval, gval, delz, dmz, deldmz, g, w, v, ipvtg, ipvtw, TU) = cache (; acol, rho) = TU ncy = ncomp + ny n = length(mesh) - 1 @@ -156,7 +157,8 @@ function Φ!(cache::AscherCache{iip, T}, z, res, pt::StandardBVProblem) where {i end function Φ!(cache::AscherCache{iip, T}, z, res, pt::TwoPointBVProblem) where {iip, T} - (; f, mesh, mesh_dt, ncomp, ny, bc, k, p, zeta, bcresid_prototype, residual, zval, yval, gval, delz, dmz, deldmz, g, w, v, dmzo, ipvtg, ipvtw, TU) = cache + (; f, mesh, mesh_dt, ncomp, ny, bc, k, p, zeta, bcresid_prototype, residual, + zval, yval, gval, delz, dmz, deldmz, g, w, v, dmzo, ipvtg, ipvtw, TU) = cache (; acol, rho) = TU ncy = ncomp + ny n = length(mesh) - 1 @@ -319,7 +321,8 @@ end @inline __get_value(z) = isa(z, ForwardDiff.Dual) ? z.value : z function Φ(cache::AscherCache{iip, T}, z, pt::StandardBVProblem) where {iip, T} - (; f, mesh, mesh_dt, ncomp, ny, bc, k, p, zeta, residual, zval, yval, gval, delz, dmz, deldmz, g, w, v, dmzo, ipvtg, ipvtw, TU) = cache + (; f, mesh, mesh_dt, ncomp, ny, bc, k, p, zeta, residual, zval, yval, + gval, delz, dmz, deldmz, g, w, v, dmzo, ipvtg, ipvtw, TU) = cache (; acol, rho) = TU ncy = ncomp + ny n = length(mesh) - 1 @@ -476,7 +479,8 @@ function Φ(cache::AscherCache{iip, T}, z, pt::StandardBVProblem) where {iip, T} end function Φ(cache::AscherCache{iip, T}, z, pt::TwoPointBVProblem) where {iip, T} - (; f, mesh, mesh_dt, ncomp, ny, bc, k, p, zeta, residual, zval, yval, gval, delz, dmz, deldmz, g, w, v, dmzo, ipvtg, ipvtw, TU) = cache + (; f, mesh, mesh_dt, ncomp, ny, bc, k, p, zeta, residual, zval, yval, + gval, delz, dmz, deldmz, g, w, v, dmzo, ipvtg, ipvtw, TU) = cache (; acol, rho) = TU ncy = ncomp + ny n = length(mesh) - 1 diff --git a/lib/BoundaryValueDiffEqCore/src/types.jl b/lib/BoundaryValueDiffEqCore/src/types.jl index a2051f32..2a5c05cb 100644 --- a/lib/BoundaryValueDiffEqCore/src/types.jl +++ b/lib/BoundaryValueDiffEqCore/src/types.jl @@ -6,7 +6,8 @@ end @inline __materialize_jacobian_algorithm(_, alg::BVPJacobianAlgorithm) = alg -@inline __materialize_jacobian_algorithm(_, alg::ADTypes.AbstractADType) = BVPJacobianAlgorithm(alg) +@inline __materialize_jacobian_algorithm( + _, alg::ADTypes.AbstractADType) = BVPJacobianAlgorithm(alg) @inline __materialize_jacobian_algorithm(::Nothing, ::Nothing) = BVPJacobianAlgorithm() @inline function __materialize_jacobian_algorithm(nlsolve::N, ::Nothing) where {N} ad = hasfield(N, :jacobian_ad) ? nlsolve.jacobian_ad : missing diff --git a/lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl b/lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl index d20b7c37..824df0bb 100644 --- a/lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl +++ b/lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl @@ -100,6 +100,7 @@ include("sparse_jacobians.jl") @compile_workload begin @sync for prob in probs, alg in algs + Threads.@spawn solve(prob, alg; dt = 0.2) end end @@ -112,6 +113,7 @@ include("sparse_jacobians.jl") @compile_workload begin @sync for prob in probs, alg in algs + Threads.@spawn solve(prob, alg; dt = 0.2) end end @@ -124,6 +126,7 @@ include("sparse_jacobians.jl") @compile_workload begin @sync for prob in probs, alg in algs + Threads.@spawn solve(prob, alg; dt = 0.2) end end @@ -136,6 +139,7 @@ include("sparse_jacobians.jl") @compile_workload begin @sync for prob in probs, alg in algs + Threads.@spawn solve(prob, alg; dt = 0.2) end end @@ -194,6 +198,7 @@ include("sparse_jacobians.jl") @compile_workload begin @sync for prob in probs, alg in algs + Threads.@spawn solve(prob, alg; dt = 0.2, abstol = 1e-2) end end @@ -210,6 +215,7 @@ include("sparse_jacobians.jl") @compile_workload begin @sync for prob in probs, alg in algs + Threads.@spawn solve(prob, alg; dt = 0.2, abstol = 1e-2) end end @@ -226,6 +232,7 @@ include("sparse_jacobians.jl") @compile_workload begin @sync for prob in probs, alg in algs + Threads.@spawn solve(prob, alg; dt = 0.2, abstol = 1e-2) end end @@ -242,6 +249,7 @@ include("sparse_jacobians.jl") @compile_workload begin @sync for prob in probs, alg in algs + Threads.@spawn solve(prob, alg; dt = 0.2, abstol = 1e-2) end end @@ -258,6 +266,7 @@ include("sparse_jacobians.jl") @compile_workload begin @sync for prob in probs, alg in algs + Threads.@spawn solve(prob, alg; dt = 0.2, abstol = 1e-2) end end diff --git a/lib/BoundaryValueDiffEqFIRK/src/adaptivity.jl b/lib/BoundaryValueDiffEqFIRK/src/adaptivity.jl index 11870451..15cb25bc 100644 --- a/lib/BoundaryValueDiffEqFIRK/src/adaptivity.jl +++ b/lib/BoundaryValueDiffEqFIRK/src/adaptivity.jl @@ -226,8 +226,8 @@ function s_constraints(M, h) row_start = (i - 1) * M + 1 for k in 0:(M - 1) for j in 1:6 - A[row_start + k, j + k * 6] = j == 1.0 ? 0.0 : - (j - 1) * t[i + k * 6]^(j - 2) + A[row_start + k, + j + k * 6] = j == 1.0 ? 0.0 : (j - 1) * t[i + k * 6]^(j - 2) end end end diff --git a/lib/BoundaryValueDiffEqFIRK/src/algorithms.jl b/lib/BoundaryValueDiffEqFIRK/src/algorithms.jl index b3327ef6..c8156a86 100644 --- a/lib/BoundaryValueDiffEqFIRK/src/algorithms.jl +++ b/lib/BoundaryValueDiffEqFIRK/src/algorithms.jl @@ -95,8 +95,14 @@ for stage in (1, 2, 3, 5, 7) defect_threshold::T = 0.1 max_num_subintervals::Int = 3000 end - $(alg)(nlsolve::N, jac_alg::J; nested = false, nested_nlsolve_kwargs::NamedTuple = (;), defect_threshold::T = 0.1, max_num_subintervals::Int = 3000) where {N, J, T} = $(alg){ - N, J, T}(nlsolve, jac_alg, nested, nested_nlsolve_kwargs, + $(alg)(nlsolve::N, + jac_alg::J; + nested = false, + nested_nlsolve_kwargs::NamedTuple = (;), + defect_threshold::T = 0.1, + max_num_subintervals::Int = 3000) where {N, + J, + T} = $(alg){N, J, T}(nlsolve, jac_alg, nested, nested_nlsolve_kwargs, defect_threshold, max_num_subintervals) end end @@ -194,8 +200,14 @@ for stage in (2, 3, 4, 5) defect_threshold::T = 0.1 max_num_subintervals::Int = 3000 end - $(alg)(nlsolve::N, jac_alg::J; nested = false, nested_nlsolve_kwargs::NamedTuple = (;), defect_threshold::T = 0.1, max_num_subintervals::Int = 3000) where {N, J, T} = $(alg){ - N, J, T}(nlsolve, jac_alg, nested, nested_nlsolve_kwargs, + $(alg)(nlsolve::N, + jac_alg::J; + nested = false, + nested_nlsolve_kwargs::NamedTuple = (;), + defect_threshold::T = 0.1, + max_num_subintervals::Int = 3000) where {N, + J, + T} = $(alg){N, J, T}(nlsolve, jac_alg, nested, nested_nlsolve_kwargs, defect_threshold, max_num_subintervals) end end @@ -295,8 +307,14 @@ for stage in (2, 3, 4, 5) defect_threshold::T = 0.1 max_num_subintervals::Int = 3000 end - $(alg)(nlsolve::N, jac_alg::J; nested = false, nested_nlsolve_kwargs::NamedTuple = (;), defect_threshold::T = 0.1, max_num_subintervals::Int = 3000) where {N, J, T} = $(alg){ - N, J, T}(nlsolve, jac_alg, nested, nested_nlsolve_kwargs, + $(alg)(nlsolve::N, + jac_alg::J; + nested = false, + nested_nlsolve_kwargs::NamedTuple = (;), + defect_threshold::T = 0.1, + max_num_subintervals::Int = 3000) where {N, + J, + T} = $(alg){N, J, T}(nlsolve, jac_alg, nested, nested_nlsolve_kwargs, defect_threshold, max_num_subintervals) end end @@ -396,8 +414,14 @@ for stage in (2, 3, 4, 5) defect_threshold::T = 0.1 max_num_subintervals::Int = 3000 end - $(alg)(nlsolve::N, jac_alg::J; nested = false, nested_nlsolve_kwargs::NamedTuple = (;), defect_threshold::T = 0.1, max_num_subintervals::Int = 3000) where {N, J, T} = $(alg){ - N, J, T}(nlsolve, jac_alg, nested, nested_nlsolve_kwargs, + $(alg)(nlsolve::N, + jac_alg::J; + nested = false, + nested_nlsolve_kwargs::NamedTuple = (;), + defect_threshold::T = 0.1, + max_num_subintervals::Int = 3000) where {N, + J, + T} = $(alg){N, J, T}(nlsolve, jac_alg, nested, nested_nlsolve_kwargs, defect_threshold, max_num_subintervals) end end diff --git a/lib/BoundaryValueDiffEqFIRK/src/firk.jl b/lib/BoundaryValueDiffEqFIRK/src/firk.jl index 4f312cc7..aa153020 100644 --- a/lib/BoundaryValueDiffEqFIRK/src/firk.jl +++ b/lib/BoundaryValueDiffEqFIRK/src/firk.jl @@ -141,7 +141,8 @@ function init_nested( # Transform the functions to handle non-vector inputs bcresid_prototype = __vec(bcresid_prototype) - f, bc = if X isa AbstractVector + f, + bc = if X isa AbstractVector prob.f, prob.f.bc elseif iip vecf! = @closure (du, u, p, t) -> __vec_f!(du, u, p, t, prob.f, size(X)) @@ -149,10 +150,10 @@ function init_nested( @closure (r, u, p, t) -> __vec_bc!(r, u, p, t, prob.f.bc, resid₁_size, size(X)) else ( - @closure((r, u, p)->__vec_bc!( - r, u, p, first(prob.f.bc), resid₁_size[1], size(X))), - @closure((r, u, p)->__vec_bc!( - r, u, p, last(prob.f.bc), resid₁_size[2], size(X)))) + @closure((r, u, + p)->__vec_bc!(r, u, p, first(prob.f.bc), resid₁_size[1], size(X))), + @closure(( + r, u, p)->__vec_bc!(r, u, p, last(prob.f.bc), resid₁_size[2], size(X)))) end vecf!, vecbc! else @@ -237,7 +238,8 @@ function init_expanded( # Transform the functions to handle non-vector inputs bcresid_prototype = __vec(bcresid_prototype) - f, bc = if X isa AbstractVector + f, + bc = if X isa AbstractVector prob.f, prob.f.bc elseif iip vecf! = @closure (du, u, p, t) -> __vec_f!(du, u, p, t, prob.f, size(X)) @@ -245,10 +247,10 @@ function init_expanded( @closure (r, u, p, t) -> __vec_bc!(r, u, p, t, prob.f.bc, resid₁_size, size(X)) else ( - @closure((r, u, p)->__vec_bc!( - r, u, p, first(prob.f.bc)[1], resid₁_size[1], size(X))), - @closure ((r, u, p) -> __vec_bc!( - r, u, p, last(prob.f.bc)[2], resid₁_size[2], size(X)))) + @closure((r, u, + p)->__vec_bc!(r, u, p, first(prob.f.bc)[1], resid₁_size[1], size(X))), + @closure ((r, u, + p) -> __vec_bc!(r, u, p, last(prob.f.bc)[2], resid₁_size[2], size(X)))) end vecf!, vecbc! else @@ -307,8 +309,8 @@ function SciMLBase.solve!(cache::FIRKCacheExpand{iip, T}) where {iip, T} if adaptive while SciMLBase.successful_retcode(info) && defect_norm > abstol - sol_nlprob, info, defect_norm = __perform_firk_iteration( - cache, abstol, adaptive) + sol_nlprob, info, + defect_norm = __perform_firk_iteration(cache, abstol, adaptive) end end @@ -332,8 +334,8 @@ function SciMLBase.solve!(cache::FIRKCacheNested{iip, T}) where {iip, T} if adaptive while SciMLBase.successful_retcode(info) && defect_norm > abstol - sol_nlprob, info, defect_norm = __perform_firk_iteration( - cache, abstol, adaptive) + sol_nlprob, info, + defect_norm = __perform_firk_iteration(cache, abstol, adaptive) end end @@ -405,26 +407,33 @@ function __construct_nlproblem(cache::Union{FIRKCacheNested{iip}, FIRKCacheExpan trait = __cache_trait(jac_alg) loss_bc = if iip - @closure (du, u, p) -> __firk_loss_bc!( - du, u, p, pt, cache.bc, cache.y, cache.mesh, cache, trait) + @closure (du, + u, + p) -> __firk_loss_bc!(du, u, p, pt, cache.bc, cache.y, cache.mesh, cache, trait) else - @closure (u, p) -> __firk_loss_bc( - u, p, pt, cache.bc, cache.y, cache.mesh, cache, trait) + @closure ( + u, p) -> __firk_loss_bc(u, p, pt, cache.bc, cache.y, cache.mesh, cache, trait) end loss_collocation = if iip - @closure (du, u, p) -> __firk_loss_collocation!( + @closure (du, + u, + p) -> __firk_loss_collocation!( du, u, p, cache.y, cache.mesh, cache.residual, cache, trait) else - @closure (u, p) -> __firk_loss_collocation( + @closure (u, + p) -> __firk_loss_collocation( u, p, cache.y, cache.mesh, cache.residual, cache, trait) end loss = if iip - @closure (du, u, p) -> __firk_loss!(du, u, p, cache.y, pt, cache.bc, cache.residual, + @closure (du, + u, + p) -> __firk_loss!(du, u, p, cache.y, pt, cache.bc, cache.residual, cache.mesh, cache, eval_sol, trait) else - @closure (u, p) -> __firk_loss( + @closure (u, + p) -> __firk_loss( u, p, cache.y, pt, cache.bc, cache.mesh, cache, eval_sol, trait) end @@ -499,11 +508,14 @@ function __construct_nlproblem( end jac = if iip - @closure (J, u, p) -> __firk_mpoint_jacobian!( + @closure (J, + u, + p) -> __firk_mpoint_jacobian!( J, J_c, u, bc_diffmode, nonbc_diffmode, cache_bc, cache_collocation, loss_bc, loss_collocation, resid_bc, resid_collocation, L, cache.p) else - @closure (u, p) -> __firk_mpoint_jacobian( + @closure (u, + p) -> __firk_mpoint_jacobian( jac_prototype, J_c, u, bc_diffmode, nonbc_diffmode, cache_bc, cache_collocation, loss_bc, loss_collocation, L, cache.p) end @@ -560,10 +572,13 @@ function __construct_nlproblem( end jac = if iip - @closure (J, u, p) -> __firk_2point_jacobian!( + @closure (J, + u, + p) -> __firk_2point_jacobian!( J, u, jac_alg.diffmode, diffcache, loss, resid, cache.p) else - @closure (u, p) -> __firk_2point_jacobian( + @closure (u, + p) -> __firk_2point_jacobian( u, jac_prototype, jac_alg.diffmode, diffcache, loss, cache.p) end @@ -635,11 +650,14 @@ function __construct_nlproblem( end jac = if iip - @closure (J, u, p) -> __firk_mpoint_jacobian!( + @closure (J, + u, + p) -> __firk_mpoint_jacobian!( J, J_c, u, bc_diffmode, nonbc_diffmode, cache_bc, cache_collocation, loss_bc, loss_collocation, resid_bc, resid_collocation, L, cache.p) else - @closure (u, p) -> __firk_mpoint_jacobian( + @closure (u, + p) -> __firk_mpoint_jacobian( jac_prototype, J_c, u, bc_diffmode, nonbc_diffmode, cache_bc, cache_collocation, loss_bc, loss_collocation, L, cache.p) end @@ -687,10 +705,13 @@ function __construct_nlproblem( end jac = if iip - @closure (J, u, p) -> __firk_2point_jacobian!( + @closure (J, + u, + p) -> __firk_2point_jacobian!( J, u, jac_alg.diffmode, diffcache, loss, resid, cache.p) else - @closure (u, p) -> __firk_2point_jacobian( + @closure (u, + p) -> __firk_2point_jacobian( u, jac_prototype, jac_alg.diffmode, diffcache, loss, cache.p) end diff --git a/lib/BoundaryValueDiffEqFIRK/src/interpolation.jl b/lib/BoundaryValueDiffEqFIRK/src/interpolation.jl index b4eccdb3..faf52c0e 100644 --- a/lib/BoundaryValueDiffEqFIRK/src/interpolation.jl +++ b/lib/BoundaryValueDiffEqFIRK/src/interpolation.jl @@ -295,10 +295,10 @@ end dS_interpolate!(dz, τ, S_coeffs) end -@inline __build_interpolation(cache::FIRKCacheExpand, u::AbstractVector) = FIRKExpandInterpolation( - cache.mesh, u, cache) -@inline __build_interpolation(cache::FIRKCacheNested, u::AbstractVector) = FIRKNestedInterpolation( - cache.mesh, u, cache) +@inline __build_interpolation(cache::FIRKCacheExpand, + u::AbstractVector) = FIRKExpandInterpolation(cache.mesh, u, cache) +@inline __build_interpolation(cache::FIRKCacheNested, + u::AbstractVector) = FIRKNestedInterpolation(cache.mesh, u, cache) # Intermidiate solution for evaluating boundry conditions # basically simplified version of the interpolation for FIRK @@ -385,8 +385,8 @@ function s_constraints_interp(M, h) row_start = (i - 1) * M + 1 for k in 0:(M - 1) for j in 1:6 - A[row_start + k, j + k * 6] = j == 1.0 ? 0.0 : - (j - 1) * t[i + k * 6]^(j - 2) + A[row_start + k, + j + k * 6] = j == 1.0 ? 0.0 : (j - 1) * t[i + k * 6]^(j - 2) end end end diff --git a/lib/BoundaryValueDiffEqFIRK/test/expanded/ensemble_tests.jl b/lib/BoundaryValueDiffEqFIRK/test/expanded/ensemble_tests.jl index d9f5623f..47d4656d 100644 --- a/lib/BoundaryValueDiffEqFIRK/test/expanded/ensemble_tests.jl +++ b/lib/BoundaryValueDiffEqFIRK/test/expanded/ensemble_tests.jl @@ -33,8 +33,8 @@ end end - @testset "$(solver)" for solver in ( - LobattoIIIa2, LobattoIIIa3, LobattoIIIa4, LobattoIIIa5) + @testset "$(solver)" for solver in + (LobattoIIIa2, LobattoIIIa3, LobattoIIIa4, LobattoIIIa5) jac_algs = [BVPJacobianAlgorithm(), BVPJacobianAlgorithm( AutoSparse(AutoFiniteDiff()); bc_diffmode = AutoFiniteDiff(), diff --git a/lib/BoundaryValueDiffEqFIRK/test/expanded/firk_basic_tests.jl b/lib/BoundaryValueDiffEqFIRK/test/expanded/firk_basic_tests.jl index 05da6586..e319141d 100644 --- a/lib/BoundaryValueDiffEqFIRK/test/expanded/firk_basic_tests.jl +++ b/lib/BoundaryValueDiffEqFIRK/test/expanded/firk_basic_tests.jl @@ -67,11 +67,11 @@ odef1! = ODEFunction(f1!, analytic = (u0, p, t) -> [5 - t, -1]) odef1 = ODEFunction(f1, analytic = (u0, p, t) -> [5 - t, -1]) odef2! = ODEFunction(f2!, - analytic = (u0, p, t) -> [ - 5 * (cos(t) - cot(5) * sin(t)), 5 * (-cos(t) * cot(5) - sin(t))]) + analytic = ( + u0, p, t) -> [5 * (cos(t) - cot(5) * sin(t)), 5 * (-cos(t) * cot(5) - sin(t))]) odef2 = ODEFunction(f2, - analytic = (u0, p, t) -> [ - 5 * (cos(t) - cot(5) * sin(t)), 5 * (-cos(t) * cot(5) - sin(t))]) + analytic = ( + u0, p, t) -> [5 * (cos(t) - cot(5) * sin(t)), 5 * (-cos(t) * cot(5) - sin(t))]) bcresid_prototype = (Array{Float64}(undef, 1), Array{Float64}(undef, 1)) @@ -340,8 +340,8 @@ end @test sol(0.001; idxs = 2)≈-1.312035941 atol=testTol end - @testset "Derivtive Interpolation tests for RadauIIa$stage" for stage in ( - 2, 3, 5, 7) + @testset "Derivtive Interpolation tests for RadauIIa$stage" for stage in + (2, 3, 5, 7) @time sol = solve(prob_bvp_linear, radau_solver(Val(stage)); dt = 0.001) sol_analytic = prob_bvp_linear_analytic(nothing, λ, 0.04) dsol_analytic = prob_bvp_linear_analytic_derivative(nothing, λ, 0.04) @@ -356,8 +356,8 @@ end for (id, lobatto_solver) in zip(("a", "b", "c"), (lobattoIIIa_solver, lobattoIIIb_solver, lobattoIIIc_solver)) begin - @testset "Interpolation tests for LobattoIII$(id)$stage" for stage in ( - 3, 4, 5) + @testset "Interpolation tests for LobattoIII$(id)$stage" for stage in + (3, 4, 5) @time sol = solve( prob_bvp_linear, lobatto_solver(Val(stage)); dt = 0.001) @test sol(0.001)≈[0.998687464, -1.312035941] atol=testTol @@ -366,7 +366,8 @@ end @test sol(0.001; idxs = 2)≈-1.312035941 atol=testTol end - @testset "Derivative Interpolation tests for lobatto$(id)$stage" for stage in ( + @testset "Derivative Interpolation tests for lobatto$(id)$stage" for stage in + ( 3, 4, 5) @time sol = solve( prob_bvp_linear, lobatto_solver(Val(stage)); dt = 0.001) diff --git a/lib/BoundaryValueDiffEqFIRK/test/expanded/nlls_tests.jl b/lib/BoundaryValueDiffEqFIRK/test/expanded/nlls_tests.jl index 7485895c..732f1f94 100644 --- a/lib/BoundaryValueDiffEqFIRK/test/expanded/nlls_tests.jl +++ b/lib/BoundaryValueDiffEqFIRK/test/expanded/nlls_tests.jl @@ -5,7 +5,8 @@ using BoundaryValueDiffEqFIRK, LinearAlgebra SOLVERS = [firk() for firk in (RadauIIa5, LobattoIIIa4, LobattoIIIb4, LobattoIIIc4)] SOLVERS_NAMES = ["$solver" - for solver in ["RadauIIa5", "LobattoIIIa4", "LobattoIIIb4", "LobattoIIIc4"]] + for solver in + ["RadauIIa5", "LobattoIIIa4", "LobattoIIIb4", "LobattoIIIc4"]] ### Overconstrained BVP ### diff --git a/lib/BoundaryValueDiffEqFIRK/test/nested/ensemble_tests.jl b/lib/BoundaryValueDiffEqFIRK/test/nested/ensemble_tests.jl index b0459673..65ff128d 100644 --- a/lib/BoundaryValueDiffEqFIRK/test/nested/ensemble_tests.jl +++ b/lib/BoundaryValueDiffEqFIRK/test/nested/ensemble_tests.jl @@ -33,8 +33,8 @@ end end - @testset "$(solver)" for solver in ( - LobattoIIIa2, LobattoIIIa3, LobattoIIIa4, LobattoIIIa5) + @testset "$(solver)" for solver in + (LobattoIIIa2, LobattoIIIa3, LobattoIIIa4, LobattoIIIa5) jac_algs = [BVPJacobianAlgorithm(), BVPJacobianAlgorithm( AutoSparse(AutoFiniteDiff()); bc_diffmode = AutoFiniteDiff(), diff --git a/lib/BoundaryValueDiffEqFIRK/test/nested/firk_basic_tests.jl b/lib/BoundaryValueDiffEqFIRK/test/nested/firk_basic_tests.jl index 045097f0..a0904b04 100644 --- a/lib/BoundaryValueDiffEqFIRK/test/nested/firk_basic_tests.jl +++ b/lib/BoundaryValueDiffEqFIRK/test/nested/firk_basic_tests.jl @@ -66,11 +66,11 @@ odef1! = ODEFunction(f1!, analytic = (u0, p, t) -> [5 - t, -1]) odef1 = ODEFunction(f1, analytic = (u0, p, t) -> [5 - t, -1]) odef2! = ODEFunction(f2!, - analytic = (u0, p, t) -> [ - 5 * (cos(t) - cot(5) * sin(t)), 5 * (-cos(t) * cot(5) - sin(t))]) + analytic = ( + u0, p, t) -> [5 * (cos(t) - cot(5) * sin(t)), 5 * (-cos(t) * cot(5) - sin(t))]) odef2 = ODEFunction(f2, - analytic = (u0, p, t) -> [ - 5 * (cos(t) - cot(5) * sin(t)), 5 * (-cos(t) * cot(5) - sin(t))]) + analytic = ( + u0, p, t) -> [5 * (cos(t) - cot(5) * sin(t)), 5 * (-cos(t) * cot(5) - sin(t))]) bcresid_prototype = (Array{Float64}(undef, 1), Array{Float64}(undef, 1)) @@ -361,8 +361,8 @@ end @test sol(0.001; idxs = 2)≈-1.312035941 atol=testTol end - @testset "Derivtive Interpolation tests for RadauIIa$stage" for stage in ( - 2, 3, 5, 7) + @testset "Derivtive Interpolation tests for RadauIIa$stage" for stage in + (2, 3, 5, 7) @time sol = solve(prob_bvp_linear, radau_solver(Val(stage)); dt = 0.001) sol_analytic = prob_bvp_linear_analytic(nothing, λ, 0.04) dsol_analytic = prob_bvp_linear_analytic_derivative(nothing, λ, 0.04) @@ -377,8 +377,8 @@ end for (id, lobatto_solver) in zip(("a", "b", "c"), (lobattoIIIa_solver, lobattoIIIb_solver, lobattoIIIc_solver)) begin - @testset "Interpolation tests for LobattoIII$(id)$stage" for stage in ( - 3, 4, 5) + @testset "Interpolation tests for LobattoIII$(id)$stage" for stage in + (3, 4, 5) @time sol = solve( prob_bvp_linear, lobatto_solver(Val(stage)); dt = 0.001) @test sol(0.001)≈[0.998687464, -1.312035941] atol=testTol @@ -387,7 +387,8 @@ end @test sol(0.001; idxs = 2)≈-1.312035941 atol=testTol end - @testset "Derivative Interpolation tests for lobatto$(id)$stage" for stage in ( + @testset "Derivative Interpolation tests for lobatto$(id)$stage" for stage in + ( 3, 4, 5) @time sol = solve( prob_bvp_linear, lobatto_solver(Val(stage)); dt = 0.001) diff --git a/lib/BoundaryValueDiffEqFIRK/test/nested/nlls_tests.jl b/lib/BoundaryValueDiffEqFIRK/test/nested/nlls_tests.jl index e654ebba..37fc8833 100644 --- a/lib/BoundaryValueDiffEqFIRK/test/nested/nlls_tests.jl +++ b/lib/BoundaryValueDiffEqFIRK/test/nested/nlls_tests.jl @@ -7,7 +7,8 @@ SOLVERS = [firk(; nlsolve, nested_nlsolve = true) nlsolve in (NewtonRaphson(), GaussNewton(), TrustRegion())] SOLVERS_NAMES = ["$solver with $nlsolve" - for solver in ["RadauIIa5", "LobattoIIIa4", "LobattoIIIb4", "LobattoIIIc4"], + for solver in + ["RadauIIa5", "LobattoIIIa4", "LobattoIIIb4", "LobattoIIIc4"], nlsolve in ["NewtonRaphson", "GaussNewton", "TrustRegion"]] ### Overconstrained BVP ### diff --git a/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl b/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl index 4297e0e6..9d6bb6d7 100644 --- a/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl +++ b/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl @@ -98,6 +98,7 @@ include("sparse_jacobians.jl") @compile_workload begin @sync for prob in probs, alg in algs + Threads.@spawn solve(prob, alg; dt = 0.2) end end @@ -154,6 +155,7 @@ include("sparse_jacobians.jl") @compile_workload begin @sync for prob in probs, alg in algs + Threads.@spawn solve(prob, alg; dt = 0.2, abstol = 1e-2) end end diff --git a/lib/BoundaryValueDiffEqMIRK/src/adaptivity.jl b/lib/BoundaryValueDiffEqMIRK/src/adaptivity.jl index 86f130cb..eff1d50b 100644 --- a/lib/BoundaryValueDiffEqMIRK/src/adaptivity.jl +++ b/lib/BoundaryValueDiffEqMIRK/src/adaptivity.jl @@ -456,12 +456,13 @@ end @views function error_estimate!( cache::MIRKCache{iip, T}, controller::SequentialErrorControl, errors, sol, nlsolve_alg, abstol) where {iip, T} - defect_norm, info = error_estimate!( + defect_norm, + info = error_estimate!( cache::MIRKCache{iip, T}, controller.defect, errors, sol, nlsolve_alg, abstol) error_norm = defect_norm if defect_norm <= abstol - global_error_norm, info = error_estimate!( - cache::MIRKCache{iip, T}, controller.global_error, + global_error_norm, + info = error_estimate!(cache::MIRKCache{iip, T}, controller.global_error, controller.global_error.method, errors, sol, nlsolve_alg, abstol) error_norm = global_error_norm return error_norm, info @@ -475,10 +476,11 @@ function error_estimate!(cache::MIRKCache{iip, T}, controller::HybridErrorContro L = length(cache.mesh) - 1 defect = errors[:, 1:L] global_error = errors[:, (L + 1):end] - defect_norm, _ = error_estimate!( + defect_norm, + _ = error_estimate!( cache::MIRKCache{iip, T}, controller.defect, defect, sol, nlsolve_alg, abstol) - global_error_norm, _ = error_estimate!( - cache, controller.global_error, controller.global_error.method, + global_error_norm, + _ = error_estimate!(cache, controller.global_error, controller.global_error.method, global_error, sol, nlsolve_alg, abstol) error_norm = controller.DE * defect_norm + controller.GE * global_error_norm @@ -572,8 +574,8 @@ Here, the ki_interp is the stages in one subinterval. if iip f(k_interp.u[i][:, r], new_stages.u[i], p, mesh[i] + c_star[r] * mesh_dt[i]) else - k_interp.u[i][:, r] .= f( - new_stages.u[i], p, mesh[i] + c_star[r] * mesh_dt[i]) + k_interp.u[i][ + :, r] .= f(new_stages.u[i], p, mesh[i] + c_star[r] * mesh_dt[i]) end end end @@ -602,8 +604,8 @@ end if iip f(k_interp.u[i][:, r], new_stages.u[i], p, mesh[i] + c_star[r] * mesh_dt[i]) else - k_interp.u[i][:, r] .= f( - new_stages.u[i], p, mesh[i] + c_star[r] * mesh_dt[i]) + k_interp.u[i][ + :, r] .= f(new_stages.u[i], p, mesh[i] + c_star[r] * mesh_dt[i]) end end end diff --git a/lib/BoundaryValueDiffEqMIRK/src/interpolation.jl b/lib/BoundaryValueDiffEqMIRK/src/interpolation.jl index b17db952..c69e4e07 100644 --- a/lib/BoundaryValueDiffEqMIRK/src/interpolation.jl +++ b/lib/BoundaryValueDiffEqMIRK/src/interpolation.jl @@ -135,8 +135,8 @@ end return nothing end -@inline __build_interpolation(cache::MIRKCache, u::AbstractVector) = MIRKInterpolation( - cache.mesh, u, cache) +@inline __build_interpolation( + cache::MIRKCache, u::AbstractVector) = MIRKInterpolation(cache.mesh, u, cache) # Intermidiate solution for evaluating boundry conditions # basically simplified version of the interpolation for MIRK diff --git a/lib/BoundaryValueDiffEqMIRK/src/mirk.jl b/lib/BoundaryValueDiffEqMIRK/src/mirk.jl index 964ab542..6398b119 100644 --- a/lib/BoundaryValueDiffEqMIRK/src/mirk.jl +++ b/lib/BoundaryValueDiffEqMIRK/src/mirk.jl @@ -80,7 +80,8 @@ function SciMLBase.__init( # Transform the functions to handle non-vector inputs bcresid_prototype = __vec(bcresid_prototype) - f, bc = if X isa AbstractVector + f, + bc = if X isa AbstractVector prob.f, prob.f.bc elseif iip vecf! = @closure (du, u, p, t) -> __vec_f!(du, u, p, t, prob.f, size(X)) @@ -88,10 +89,10 @@ function SciMLBase.__init( @closure (r, u, p, t) -> __vec_bc!(r, u, p, t, prob.f.bc, resid₁_size, size(X)) else ( - @closure((r, u, p)->__vec_bc!( - r, u, p, first(prob.f.bc), resid₁_size[1], size(X))), - @closure((r, u, p)->__vec_bc!( - r, u, p, last(prob.f.bc), resid₁_size[2], size(X)))) + @closure((r, u, + p)->__vec_bc!(r, u, p, first(prob.f.bc), resid₁_size[1], size(X))), + @closure(( + r, u, p)->__vec_bc!(r, u, p, last(prob.f.bc), resid₁_size[2], size(X)))) end vecf!, vecbc! else @@ -138,13 +139,13 @@ function SciMLBase.solve!(cache::MIRKCache) # We do the first iteration outside the loop to preserve type-stability of the # `original` field of the solution - sol_nlprob, info, error_norm = __perform_mirk_iteration( - cache, abstol, adaptive, controller) + sol_nlprob, info, + error_norm = __perform_mirk_iteration(cache, abstol, adaptive, controller) if adaptive while SciMLBase.successful_retcode(info) && error_norm > abstol - sol_nlprob, info, error_norm = __perform_mirk_iteration( - cache, abstol, adaptive, controller) + sol_nlprob, info, + error_norm = __perform_mirk_iteration(cache, abstol, adaptive, controller) end end @@ -172,7 +173,8 @@ function __perform_mirk_iteration( info::ReturnCode.T = sol_nlprob.retcode if info == ReturnCode.Success # Nonlinear Solve was successful - error_norm, info = error_estimate!( + error_norm, + info = error_estimate!( cache, controller, cache.errors, sol_nlprob, nlsolve_alg, abstol) end @@ -216,26 +218,33 @@ function __construct_nlproblem( trait = __cache_trait(jac_alg) loss_bc = if iip - @closure (du, u, p) -> __mirk_loss_bc!( - du, u, p, pt, cache.bc, cache.y, cache.mesh, cache, trait) + @closure (du, + u, + p) -> __mirk_loss_bc!(du, u, p, pt, cache.bc, cache.y, cache.mesh, cache, trait) else - @closure (u, p) -> __mirk_loss_bc( - u, p, pt, cache.bc, cache.y, cache.mesh, cache, trait) + @closure ( + u, p) -> __mirk_loss_bc(u, p, pt, cache.bc, cache.y, cache.mesh, cache, trait) end loss_collocation = if iip - @closure (du, u, p) -> __mirk_loss_collocation!( + @closure (du, + u, + p) -> __mirk_loss_collocation!( du, u, p, cache.y, cache.mesh, cache.residual, cache, trait) else - @closure (u, p) -> __mirk_loss_collocation( + @closure (u, + p) -> __mirk_loss_collocation( u, p, cache.y, cache.mesh, cache.residual, cache, trait) end loss = if iip - @closure (du, u, p) -> __mirk_loss!(du, u, p, cache.y, pt, cache.bc, cache.residual, + @closure (du, + u, + p) -> __mirk_loss!(du, u, p, cache.y, pt, cache.bc, cache.residual, cache.mesh, cache, eval_sol, trait) else - @closure (u, p) -> __mirk_loss( + @closure (u, + p) -> __mirk_loss( u, p, cache.y, pt, cache.bc, cache.mesh, cache, eval_sol, trait) end @@ -408,11 +417,14 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo end jac = if iip - @closure (J, u, p) -> __mirk_mpoint_jacobian!( + @closure (J, + u, + p) -> __mirk_mpoint_jacobian!( J, J_c, u, bc_diffmode, nonbc_diffmode, cache_bc, cache_collocation, loss_bc, loss_collocation, resid_bc, resid_collocation, L, cache.p) else - @closure (u, p) -> __mirk_mpoint_jacobian( + @closure (u, + p) -> __mirk_mpoint_jacobian( jac_prototype, J_c, u, bc_diffmode, nonbc_diffmode, cache_bc, cache_collocation, loss_bc, loss_collocation, L, cache.p) end @@ -502,10 +514,12 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collo end jac = if iip - @closure (J, u, p) -> __mirk_2point_jacobian!( - J, u, jac_alg.diffmode, diffcache, loss, resid, p) + @closure (J, + u, + p) -> __mirk_2point_jacobian!(J, u, jac_alg.diffmode, diffcache, loss, resid, p) else - @closure (u, p) -> __mirk_2point_jacobian( + @closure (u, + p) -> __mirk_2point_jacobian( u, jac_prototype, jac_alg.diffmode, diffcache, loss, p) end diff --git a/lib/BoundaryValueDiffEqMIRK/test/mirk_basic_tests.jl b/lib/BoundaryValueDiffEqMIRK/test/mirk_basic_tests.jl index 349f61da..ff9a420b 100644 --- a/lib/BoundaryValueDiffEqMIRK/test/mirk_basic_tests.jl +++ b/lib/BoundaryValueDiffEqMIRK/test/mirk_basic_tests.jl @@ -50,11 +50,11 @@ odef1! = ODEFunction(f1!, analytic = (u0, p, t) -> [5 - t, -1]) odef1 = ODEFunction(f1, analytic = (u0, p, t) -> [5 - t, -1]) odef2! = ODEFunction(f2!, - analytic = (u0, p, t) -> [ - 5 * (cos(t) - cot(5) * sin(t)), 5 * (-cos(t) * cot(5) - sin(t))]) + analytic = ( + u0, p, t) -> [5 * (cos(t) - cot(5) * sin(t)), 5 * (-cos(t) * cot(5) - sin(t))]) odef2 = ODEFunction(f2, - analytic = (u0, p, t) -> [ - 5 * (cos(t) - cot(5) * sin(t)), 5 * (-cos(t) * cot(5) - sin(t))]) + analytic = ( + u0, p, t) -> [5 * (cos(t) - cot(5) * sin(t)), 5 * (-cos(t) * cot(5) - sin(t))]) bcresid_prototype = (Array{Float64}(undef, 1), Array{Float64}(undef, 1)) diff --git a/lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl b/lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl index c1c6b0dc..1ea3a157 100644 --- a/lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl +++ b/lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl @@ -58,19 +58,25 @@ function SciMLBase.__init(prob::SecondOrderBVProblem, alg::AbstractMIRKN; dt = 0 end resid_size = size(bcresid_prototype) - f, bc = if X isa AbstractVector + f, + bc = if X isa AbstractVector prob.f, prob.f.bc elseif iip vecf! = @closure (ddu, du, u, p, t) -> __vec_f!(ddu, du, u, p, t, prob.f, size(X)) vecbc! = if !(prob.problem_type isa TwoPointSecondOrderBVProblem) - @closure (r, du, u, p, t) -> __vec_so_bc!( - r, du, u, p, t, prob.f.bc, resid_size, size(X)) + @closure (r, du, u, p, + t) -> __vec_so_bc!(r, du, u, p, t, prob.f.bc, resid_size, size(X)) else ( - @closure((r, du, u, p)->__vec_so_bc!( + @closure((r, + du, + u, + p)->__vec_so_bc!( r, du, u, p, first(prob.f.bc), resid_size[1], size(X))), - @closure((r, du, u, p)->__vec_so_bc!( - r, du, u, p, last(prob.f.bc), resid_size[2], size(X)))) + @closure((r, + du, + u, + p)->__vec_so_bc!(r, du, u, p, last(prob.f.bc), resid_size[2], size(X)))) end vecf!, vecbc! else @@ -125,26 +131,30 @@ function __construct_nlproblem( __restructure_sol(y₀.u[(L + 1):end], cache.in_size), cache.mesh, cache) loss_bc = if iip - @closure (du, u, p) -> __mirkn_loss_bc!( - du, u, p, pt, cache.bc, cache.y, cache.mesh, cache) + @closure (du, u, + p) -> __mirkn_loss_bc!(du, u, p, pt, cache.bc, cache.y, cache.mesh, cache) else @closure (u, p) -> __mirkn_loss_bc(u, p, pt, cache.bc, cache.y, cache.mesh, cache) end loss_collocation = if iip - @closure (du, u, p) -> __mirkn_loss_collocation!( + @closure (du, + u, + p) -> __mirkn_loss_collocation!( du, u, p, cache.y, cache.mesh, cache.residual, cache) else - @closure (u, p) -> __mirkn_loss_collocation( - u, p, cache.y, cache.mesh, cache.residual, cache) + @closure (u, + p) -> __mirkn_loss_collocation(u, p, cache.y, cache.mesh, cache.residual, cache) end loss = if iip - @closure (du, u, p) -> __mirkn_loss!( - du, u, p, cache.y, pt, cache.bc, cache.residual, + @closure (du, + u, + p) -> __mirkn_loss!(du, u, p, cache.y, pt, cache.bc, cache.residual, cache.mesh, cache, eval_sol, eval_dsol) else - @closure (u, p) -> __mirkn_loss( + @closure (u, + p) -> __mirkn_loss( u, p, cache.y, pt, cache.bc, cache.mesh, cache, eval_sol, eval_dsol) end @@ -205,11 +215,14 @@ function __construct_nlproblem(cache::MIRKNCache{iip}, y, loss_bc::BC, loss_coll jac_prototype = vcat(J_bc, J_c) jac = if iip - @closure (J, u, p) -> __mirkn_mpoint_jacobian!( + @closure (J, + u, + p) -> __mirkn_mpoint_jacobian!( J, J_c, u, bc_diffmode, nonbc_diffmode, cache_bc, cache_collocation, loss_bc, loss_collocation, resid_bc, resid_collocation, L, cache.p) else - @closure (u, p) -> __mirkn_mpoint_jacobian( + @closure (u, + p) -> __mirkn_mpoint_jacobian( jac_prototype, J_c, u, bc_diffmode, nonbc_diffmode, cache_bc, cache_collocation, loss_bc, loss_collocation, L, cache.p) end @@ -249,10 +262,13 @@ function __construct_nlproblem(cache::MIRKNCache{iip}, y, loss_bc::BC, loss_coll end jac = if iip - @closure (J, u, p) -> __mirkn_2point_jacobian!( + @closure (J, + u, + p) -> __mirkn_2point_jacobian!( J, u, jac_alg.diffmode, diffcache, loss, resid, p) else - @closure (u, p) -> __mirkn_2point_jacobian( + @closure (u, + p) -> __mirkn_2point_jacobian( u, jac_prototype, jac_alg.diffmode, diffcache, loss, p) end diff --git a/lib/BoundaryValueDiffEqMIRKN/test/mirkn_basic_tests.jl b/lib/BoundaryValueDiffEqMIRKN/test/mirkn_basic_tests.jl index 3124bcc0..4a65f519 100644 --- a/lib/BoundaryValueDiffEqMIRKN/test/mirkn_basic_tests.jl +++ b/lib/BoundaryValueDiffEqMIRKN/test/mirkn_basic_tests.jl @@ -39,8 +39,8 @@ end function bc_b(du, u, p) return [u[1]] end -analytical_solution = (u0, p, t) -> [ - (exp(-t) - exp(t - 2)) / (1 - exp(-2)), (-exp(-t) - exp(t - 2)) / (1 - exp(-2))] +analytical_solution = (u0, p, + t) -> [(exp(-t) - exp(t - 2)) / (1 - exp(-2)), (-exp(-t) - exp(t - 2)) / (1 - exp(-2))] u0 = [1.0] tspan = (0.0, 1.0) testTol = 0.2 diff --git a/lib/BoundaryValueDiffEqShooting/src/algorithms.jl b/lib/BoundaryValueDiffEqShooting/src/algorithms.jl index 80ef8f8d..afcaa9d6 100644 --- a/lib/BoundaryValueDiffEqShooting/src/algorithms.jl +++ b/lib/BoundaryValueDiffEqShooting/src/algorithms.jl @@ -117,7 +117,7 @@ function MultipleShooting(; nshoots::Int, nshoots, grid_coarsening) end @inline MultipleShooting(nshoots::Int; kwargs...) = MultipleShooting(; nshoots, kwargs...) -@inline MultipleShooting(nshoots::Int, ode_alg; kwargs...) = MultipleShooting(; - nshoots, ode_alg, kwargs...) -@inline MultipleShooting(nshoots::Int, ode_alg, nlsolve; kwargs...) = MultipleShooting(; - nshoots, ode_alg, nlsolve, kwargs...) +@inline MultipleShooting( + nshoots::Int, ode_alg; kwargs...) = MultipleShooting(; nshoots, ode_alg, kwargs...) +@inline MultipleShooting(nshoots::Int, ode_alg, nlsolve; + kwargs...) = MultipleShooting(; nshoots, ode_alg, nlsolve, kwargs...) diff --git a/lib/BoundaryValueDiffEqShooting/src/multiple_shooting.jl b/lib/BoundaryValueDiffEqShooting/src/multiple_shooting.jl index 4e27fdf0..a62f636e 100644 --- a/lib/BoundaryValueDiffEqShooting/src/multiple_shooting.jl +++ b/lib/BoundaryValueDiffEqShooting/src/multiple_shooting.jl @@ -36,7 +36,12 @@ function SciMLBase.__solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwa internal_ode_kwargs = (; verbose, kwargs..., odesolve_kwargs..., save_end = true) - solve_internal_odes! = @closure (resid_nodes, us, p, cur_nshoot, nodes, odecache) -> __multiple_shooting_solve_internal_odes!( + solve_internal_odes! = @closure (resid_nodes, + us, + p, + cur_nshoot, + nodes, + odecache) -> __multiple_shooting_solve_internal_odes!( resid_nodes, us, cur_nshoot, odecache, nodes, u0_size, N, ensemblealg, tspan) # This gets all the nshoots except the final SingleShooting case @@ -96,7 +101,9 @@ function __solve_nlproblem!( resid_prototype = vcat( bcresid_prototype[1], similar(u_at_nodes, cur_nshoot * N), bcresid_prototype[2]) - loss_fn = @closure (du, u, p) -> __multiple_shooting_2point_loss!( + loss_fn = @closure (du, + u, + p) -> __multiple_shooting_2point_loss!( du, u, p, cur_nshoot, nodes, prob, solve_internal_odes!, resida_len, residb_len, N, bca, bcb, ode_cache_loss_fn) @@ -118,12 +125,15 @@ function __solve_nlproblem!( ensemblealg, prob, jac_cache, diffmode, alg.ode_alg, cur_nshoot, u0; internal_ode_kwargs...) - loss_fnₚ = @closure (du, u) -> __multiple_shooting_2point_loss!( + loss_fnₚ = @closure (du, + u) -> __multiple_shooting_2point_loss!( du, u, prob.p, cur_nshoot, nodes, prob, solve_internal_odes!, resida_len, residb_len, N, bca, bcb, ode_cache_jac_fn) jac_prototype = DI.jacobian(loss_fnₚ, resid_prototype, jac_cache, diffmode, u_at_nodes) - jac_fn = @closure (J, u, p) -> __multiple_shooting_2point_jacobian!( + jac_fn = @closure (J, + u, + p) -> __multiple_shooting_2point_jacobian!( J, u, p, jac_cache, loss_fnₚ, resid_prototype_cached, alg) loss_function! = NonlinearFunction{true}( @@ -148,7 +158,9 @@ function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_ resid_nodes = __maybe_allocate_diffcache( __resid_nodes, pickchunksize((cur_nshoot + 1) * N), alg.jac_alg.bc_diffmode) - loss_fn = @closure (du, u, p) -> __multiple_shooting_mpoint_loss!( + loss_fn = @closure (du, + u, + p) -> __multiple_shooting_mpoint_loss!( du, u, p, cur_nshoot, nodes, prob, solve_internal_odes!, resid_len, N, f, bc, u0_size, prob.tspan, alg.ode_alg, u0, ode_cache_loss_fn) @@ -182,9 +194,10 @@ function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_ alg.ode_alg, cur_nshoot, u0; internal_ode_kwargs...) # Define the functions now - ode_fn = @closure (du, u) -> solve_internal_odes!( - du, u, prob.p, cur_nshoot, nodes, ode_cache_ode_jac_fn) - bc_fn = @closure (du, u) -> __multiple_shooting_mpoint_loss_bc!( + ode_fn = @closure (du, + u) -> solve_internal_odes!(du, u, prob.p, cur_nshoot, nodes, ode_cache_ode_jac_fn) + bc_fn = @closure (du, + u) -> __multiple_shooting_mpoint_loss_bc!( du, u, prob.p, cur_nshoot, nodes, prob, solve_internal_odes!, N, f, bc, u0_size, prob.tspan, alg.ode_alg, u0, ode_cache_bc_jac_fn) @@ -194,7 +207,9 @@ function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_ bc_fn, similar(bcresid_prototype), bc_jac_cache, bc_diffmode, u_at_nodes) jac_prototype = vcat(jac_prototype_ode, jac_prototype_bc) - jac_fn = @closure (J, u, p) -> __multiple_shooting_mpoint_jacobian!( + jac_fn = @closure (J, + u, + p) -> __multiple_shooting_mpoint_jacobian!( J, u, p, similar(bcresid_prototype), resid_nodes, ode_jac_cache, bc_jac_cache, ode_fn, bc_fn, nonbc_diffmode, bc_diffmode, N, M, __cache_trait(alg.jac_alg)) diff --git a/lib/BoundaryValueDiffEqShooting/src/single_shooting.jl b/lib/BoundaryValueDiffEqShooting/src/single_shooting.jl index 37bc93b2..311f16a2 100644 --- a/lib/BoundaryValueDiffEqShooting/src/single_shooting.jl +++ b/lib/BoundaryValueDiffEqShooting/src/single_shooting.jl @@ -27,10 +27,13 @@ function SciMLBase.__solve(prob::BVProblem, alg_::Shooting; odesolve_kwargs = (; ode_cache_loss_fn = SciMLBase.__init(internal_prob, alg.ode_alg; ode_kwargs...) loss_fn = if iip - @closure (du, u, p) -> __single_shooting_loss!( + @closure (du, + u, + p) -> __single_shooting_loss!( du, u, p, ode_cache_loss_fn, bc, u0_size, prob.problem_type, resid_size) else - @closure (u, p) -> __single_shooting_loss( + @closure (u, + p) -> __single_shooting_loss( u, p, ode_cache_loss_fn, bc, u0_size, prob.problem_type) end @@ -47,7 +50,8 @@ function SciMLBase.__solve(prob::BVProblem, alg_::Shooting; odesolve_kwargs = (; diffmode, u0, alg.ode_alg; ode_kwargs...) loss_fnₚ = if iip - @closure (du, u) -> __single_shooting_loss!( + @closure (du, + u) -> __single_shooting_loss!( du, u, prob.p, ode_cache_jac_fn, bc, u0_size, prob.problem_type, resid_size) else @closure (u) -> __single_shooting_loss( @@ -61,10 +65,11 @@ function SciMLBase.__solve(prob::BVProblem, alg_::Shooting; odesolve_kwargs = (; end jac_fn = if iip - @closure (J, u, p) -> __single_shooting_jacobian!( - J, u, jac_cache, diffmode, loss_fnₚ, y_) + @closure ( + J, u, p) -> __single_shooting_jacobian!(J, u, jac_cache, diffmode, loss_fnₚ, y_) else - @closure (u, p) -> __single_shooting_jacobian( + @closure (u, + p) -> __single_shooting_jacobian( jac_prototype, u, jac_cache, diffmode, loss_fnₚ) end diff --git a/lib/BoundaryValueDiffEqShooting/test/basic_problems_tests.jl b/lib/BoundaryValueDiffEqShooting/test/basic_problems_tests.jl index be33c7b9..95b141d8 100644 --- a/lib/BoundaryValueDiffEqShooting/test/basic_problems_tests.jl +++ b/lib/BoundaryValueDiffEqShooting/test/basic_problems_tests.jl @@ -357,8 +357,8 @@ end alg_default = MultipleShooting( 10, AutoVern7(Rodas4P()); nlsolve = NewtonRaphson(), grid_coarsening = true) - for (prob, alg) in Iterators.product( - (prob_iip, prob_tp_iip), (alg_sp, alg_dense, alg_default)) + for (prob, alg) in + Iterators.product((prob_iip, prob_tp_iip), (alg_sp, alg_dense, alg_default)) sol = solve(prob, alg; abstol = 1e-6, reltol = 1e-6, maxiters = 1000, odesolve_kwargs = (; abstol = 1e-8, reltol = 1e-5)) diff --git a/lib/BoundaryValueDiffEqShooting/test/orbital_tests.jl b/lib/BoundaryValueDiffEqShooting/test/orbital_tests.jl index a806834f..d9f3b054 100644 --- a/lib/BoundaryValueDiffEqShooting/test/orbital_tests.jl +++ b/lib/BoundaryValueDiffEqShooting/test/orbital_tests.jl @@ -51,8 +51,8 @@ cur_bc_2point_b! = (resid, sol, p) -> bc!_generator_2p_b(resid, sol, init_val) bvp = BVProblem(orbital!, cur_bc!, y0, tspan; nlls = Val(false)) - for autodiff in ( - AutoForwardDiff(; chunksize = 6), AutoFiniteDiff(; fdtype = Val(:central)), + for autodiff in + (AutoForwardDiff(; chunksize = 6), AutoFiniteDiff(; fdtype = Val(:central)), AutoSparse(AutoForwardDiff(; chunksize = 6)), AutoFiniteDiff(; fdtype = Val(:forward)), AutoSparse(AutoFiniteDiff())) nlsolve = TrustRegion(; autodiff) diff --git a/test/misc/adaptivity_tests.jl b/test/misc/adaptivity_tests.jl index 33a0e11b..0bede435 100644 --- a/test/misc/adaptivity_tests.jl +++ b/test/misc/adaptivity_tests.jl @@ -19,20 +19,20 @@ end for stage in (2, 3, 4, 5) s = Symbol("LobattoIIIa$(stage)") - @eval lobattoIIIa_solver(::Val{$stage}, args...; kwargs...) = $(s)( - args...; kwargs...) + @eval lobattoIIIa_solver( + ::Val{$stage}, args...; kwargs...) = $(s)(args...; kwargs...) end for stage in (3, 4, 5) s = Symbol("LobattoIIIb$(stage)") - @eval lobattoIIIb_solver(::Val{$stage}, args...; kwargs...) = $(s)( - args...; kwargs...) + @eval lobattoIIIb_solver( + ::Val{$stage}, args...; kwargs...) = $(s)(args...; kwargs...) end for stage in (3, 4, 5) s = Symbol("LobattoIIIc$(stage)") - @eval lobattoIIIc_solver(::Val{$stage}, args...; kwargs...) = $(s)( - args...; kwargs...) + @eval lobattoIIIc_solver( + ::Val{$stage}, args...; kwargs...) = $(s)(args...; kwargs...) end for stage in (2, 3, 5, 7)