From 9df099e1b9f9bf579e0bada1284f4edbcaf70438 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Tue, 28 Jan 2025 16:14:42 -0500 Subject: [PATCH] prepare for switching to Linsolve Interface --- .../src/newton.jl | 11 +- lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl | 12 +- .../src/generic_rosenbrock.jl | 14 +-- .../src/rosenbrock_caches.jl | 45 ++------ .../src/rosenbrock_perform_step.jl | 103 ++++-------------- 5 files changed, 39 insertions(+), 146 deletions(-) diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl index 994ecea696..99e95035ea 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl @@ -220,14 +220,9 @@ end reltol = eps(eltype(dz)) end - if is_always_new(nlsolver) || (iter == 1 && new_W) - linres = dolinsolve(integrator, linsolve; A = W, b = _vec(b), linu = _vec(dz), - reltol = reltol) - else - linres = dolinsolve( - integrator, linsolve; A = nothing, b = _vec(b), linu = _vec(dz), - reltol = reltol) - end + make_new_W = is_always_new(nlsolver) || (iter == 1 && new_W) + linres = dolinsolve(integrator, linsolve; A = make_new_W ? W : nothing, b = _vec(b), + linu = _vec(dz), reltol) if !SciMLBase.successful_retcode(linres.retcode) && linres.retcode != SciMLBase.ReturnCode.Default return convert(eltype(atmp,),Inf) diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl index 3bed6a453a..bf11eb0d49 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl @@ -191,14 +191,10 @@ function build_nlsolver( end jac_config = build_jac_config(alg, nf, uf, du1, uprev, u, ztmp, dz) end - linprob = LinearProblem(W, _vec(k); u0 = _vec(dz)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., - weight, dz) - linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, - Pl = Pl, Pr = Pr, - assumptions = LinearSolve.OperatorAssumptions(true)) + linprob = LinearProblem(W, _vec(k), (isdae ? du1 : nothing,u,p,t); u0 = _vec(dz)) + linsolve = init(linprob, + wrapprecs(alg.linsolve, W, weight), + (isdae ? du1 : nothing,u,p,t); alias_A = true, alias_b = true) tType = typeof(t) invγdt = inv(oneunit(t) * one(uTolType)) diff --git a/lib/OrdinaryDiffEqRosenbrock/src/generic_rosenbrock.jl b/lib/OrdinaryDiffEqRosenbrock/src/generic_rosenbrock.jl index 1b9ff7a936..26d3c5faca 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/generic_rosenbrock.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/generic_rosenbrock.jl @@ -247,10 +247,8 @@ function gen_algcache(cacheexpr::Expr,constcachename::Symbol,algname::Symbol,tab tf = TimeGradientWrapper(f,uprev,p) uf = UJacobianWrapper(f,t,p) linsolve_tmp = zero(rate_prototype) - linprob = LinearProblem(W,_vec(linsolve_tmp); u0=_vec(tmp)) - linsolve = init(linprob,alg.linsolve,alias_A=true,alias_b=true, - Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))), - Pr = Diagonal(_vec(weight))) + linprob = LinearProblem(W,_vec(linsolve_tmp), (nothing, u, p, t); u0=_vec(tmp)) + linsolve = init(linprob,alg.linsolve,alias_A=true,alias_b=true) grad_config = build_grad_config(alg,f,tf,du1,t) jac_config = build_jac_config(alg,f,uf,du1,uprev,u,tmp,du2) $cachename($(valsyms...)) @@ -1036,7 +1034,7 @@ references = """ """, "Rodas3", references = """ -- Sandu, Verwer, Van Loon, Carmichael, Potra, Dabdub, Seinfeld, Benchmarking stiff ode solvers for atmospheric chemistry problems-I. +- Sandu, Verwer, Van Loon, Carmichael, Potra, Dabdub, Seinfeld, Benchmarking stiff ode solvers for atmospheric chemistry problems-I. implicit vs explicit, Atmospheric Environment, 31(19), 3151-3166, 1997. """, with_step_limiter=true) Rodas3 @@ -1096,9 +1094,9 @@ lower if not corrected). """, "Rodas4P", references = """ -- Steinebach, G., Rentrop, P., An adaptive method of lines approach for modelling flow and transport in rivers. +- Steinebach, G., Rentrop, P., An adaptive method of lines approach for modelling flow and transport in rivers. Adaptive method of lines , Wouver, A. Vande, Sauces, Ph., Schiesser, W.E. (ed.),S. 181-205,Chapman & Hall/CRC, 2001, -- Steinebach, G., Oder-reduction of ROW-methods for DAEs and method of lines applications. +- Steinebach, G., Oder-reduction of ROW-methods for DAEs and method of lines applications. Preprint-Nr. 1741, FB Mathematik, TH Darmstadt, 1995. """, with_step_limiter=true) Rodas4P @@ -1111,7 +1109,7 @@ of Roadas4P and in case of inexact Jacobians a second order W method. """, "Rodas4P2", references = """ -- Steinebach G., Improvement of Rosenbrock-Wanner Method RODASP, In: Reis T., Grundel S., Schöps S. (eds) +- Steinebach G., Improvement of Rosenbrock-Wanner Method RODASP, In: Reis T., Grundel S., Schöps S. (eds) Progress in Differential-Algebraic Equations II. Differential-Algebraic Equations Forum. Springer, Cham., 165-184, 2020. """, with_step_limiter=true) Rodas4P2 diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl index 33fc5dcd2b..90de7e143a 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl @@ -150,12 +150,8 @@ function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits}, uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) - linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) + linprob = LinearProblem(W, _vec(linsolve_tmp), (nothing,u,p,t); u0 = _vec(tmp)) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, - Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) grad_config = build_grad_config(alg, f, tf, du1, t) @@ -195,13 +191,8 @@ function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits}, tf = TimeGradientWrapper(f, uprev, p) uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) - linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) + linprob = LinearProblem(W, _vec(linsolve_tmp), (nothing,u,p,t); u0 = _vec(tmp)) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, - Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) grad_config = build_grad_config(alg, f, tf, du1, t) jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) @@ -344,12 +335,8 @@ function alg_cache(alg::ROS3P, u, rate_prototype, ::Type{uEltypeNoUnits}, tf = TimeGradientWrapper(f, uprev, p) uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) - linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) + linprob = LinearProblem(W, _vec(linsolve_tmp), (nothing, u, p, t); u0 = _vec(tmp)) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, - Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) grad_config = build_grad_config(alg, f, tf, du1, t) jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) @@ -430,12 +417,8 @@ function alg_cache(alg::Rodas3, u, rate_prototype, ::Type{uEltypeNoUnits}, tf = TimeGradientWrapper(f, uprev, p) uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) - linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) + linprob = LinearProblem(W, _vec(linsolve_tmp), (nothing, u, p, t); u0 = _vec(tmp)) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, - Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) grad_config = build_grad_config(alg, f, tf, du1, t) jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) @@ -623,12 +606,8 @@ function alg_cache(alg::Rodas23W, u, rate_prototype, ::Type{uEltypeNoUnits}, tf = TimeGradientWrapper(f, uprev, p) uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) - linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) + linprob = LinearProblem(W, _vec(linsolve_tmp), (nothing, u, p, t); u0 = _vec(tmp)) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, - Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) grad_config = build_grad_config(alg, f, tf, du1, t) jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) @@ -667,12 +646,8 @@ function alg_cache(alg::Rodas3P, u, rate_prototype, ::Type{uEltypeNoUnits}, tf = TimeGradientWrapper(f, uprev, p) uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) - linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) + linprob = LinearProblem(W, _vec(linsolve_tmp), (nothing, u, p, t); u0 = _vec(tmp)) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, - Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) grad_config = build_grad_config(alg, f, tf, du1, t) jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2) @@ -771,14 +746,8 @@ function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5 tf = TimeGradientWrapper(f, uprev, p) uf = UJacobianWrapper(f, t, p) linsolve_tmp = zero(rate_prototype) - linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp)) - - Pl, Pr = wrapprecs( - alg.precs(W, nothing, u, p, t, nothing, nothing, nothing, - nothing)..., weight, tmp) - + linprob = LinearProblem(W, _vec(linsolve_tmp), (nothing, u, p, t); u0 = _vec(tmp)) linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true, - Pl = Pl, Pr = Pr, assumptions = LinearSolve.OperatorAssumptions(true)) grad_config = build_grad_config(alg, f, tf, du1, t) diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl index 1414a93f92..20159c4f61 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl @@ -49,21 +49,11 @@ end integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t) - if repeat_step - linres = dolinsolve( - integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtγ)) - else - linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtγ)) - end + linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp)) - vecu = _vec(linres.u) veck₁ = _vec(k₁) - @.. veck₁ = vecu * neginvdtγ + @.. veck₁ = linres.u * neginvdtγ integrator.stats.nsolve += 1 @.. u = uprev + dto2 * k₁ @@ -80,10 +70,9 @@ end @.. linsolve_tmp = f₁ - tmp linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - vecu = _vec(linres.u) veck₂ = _vec(k₂) - @.. veck₂ = vecu * neginvdtγ + veck₁ + @.. veck₂ = linres.u * neginvdtγ + veck₁ integrator.stats.nsolve += 1 @.. u = uprev + dt * k₂ @@ -105,9 +94,8 @@ end end linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - vecu = _vec(linres.u) veck3 = _vec(k₃) - @.. veck3 = vecu * neginvdtγ + @.. veck3 = linres.u * neginvdtγ integrator.stats.nsolve += 1 @@ -161,21 +149,11 @@ end integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t) - if repeat_step - linres = dolinsolve( - integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtγ)) - else - linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtγ)) - end + linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp)) - vecu = _vec(linres.u) veck₁ = _vec(k₁) - @.. veck₁ = vecu * neginvdtγ + @.. veck₁ = linres.u * neginvdtγ integrator.stats.nsolve += 1 @.. broadcast=false u=uprev + dto2 * k₁ @@ -192,10 +170,9 @@ end @.. broadcast=false linsolve_tmp=f₁ - tmp linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - vecu = _vec(linres.u) veck₂ = _vec(k₂) - @.. veck₂ = vecu * neginvdtγ + veck₁ + @.. veck₂ = linres.u * neginvdtγ + veck₁ integrator.stats.nsolve += 1 @.. tmp = uprev + dt * k₂ @@ -213,10 +190,9 @@ end end linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - vecu = _vec(linres.u) veck3 = _vec(k₃) - @.. veck3 = vecu * neginvdtγ + @.. veck3 = linres.u * neginvdtγ integrator.stats.nsolve += 1 @.. broadcast=false u=uprev + dto6 * (k₁ + 4k₂ + k₃) @@ -521,21 +497,11 @@ end integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t) - if repeat_step - linres = dolinsolve( - integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtgamma)) - else - linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtgamma)) - end + linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp)) - vecu = _vec(linres.u) veck1 = _vec(k1) - @.. broadcast=false veck1=-vecu + @.. broadcast=false veck1=-linres.u integrator.stats.nsolve += 1 @.. broadcast=false u=uprev + a21 * k1 @@ -552,10 +518,9 @@ end end linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - vecu = _vec(linres.u) veck2 = _vec(k2) - @.. broadcast=false veck2=-vecu + @.. broadcast=false veck2=-linres.u integrator.stats.nsolve += 1 @@ -573,10 +538,9 @@ end end linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) - vecu = _vec(linres.u) veck3 = _vec(k3) - @.. broadcast=false veck3=-vecu + @.. broadcast=false veck3=-linres.u integrator.stats.nsolve += 1 @@ -716,21 +680,10 @@ end integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t) - if repeat_step - linres = dolinsolve( - integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtgamma)) - else - linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp), - du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtgamma)) - end - - vecu = _vec(linres.u) + linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp)) veck1 = _vec(k1) - @.. broadcast=false veck1=-vecu + @.. broadcast=false veck1=-linres.u integrator.stats.nsolve += 1 #= @@ -751,7 +704,7 @@ end linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) veck2 = _vec(k2) - @.. broadcast=false veck2=-vecu + @.. broadcast=false veck2=-linres.u integrator.stats.nsolve += 1 @.. broadcast=false u=uprev + a31 * k1 + a32 * k2 @@ -769,7 +722,7 @@ end linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) veck3 = _vec(k3) - @.. broadcast=false veck3=-vecu + @.. broadcast=false veck3=-linres.u integrator.stats.nsolve += 1 @.. broadcast=false u=uprev + a41 * k1 + a42 * k2 + a43 * k3 stage_limiter!(u, integrator, p, t + dt) @@ -787,7 +740,7 @@ end linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp)) veck4 = _vec(k4) - @.. broadcast=false veck4=-vecu + @.. broadcast=false veck4=-linres.u integrator.stats.nsolve += 1 @.. broadcast=false u=uprev + b1 * k1 + b2 * k2 + b3 * k3 + b4 * k4 @@ -1024,16 +977,7 @@ end integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t) - if repeat_step - linres = dolinsolve( - integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp), - du = cache.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtgamma)) - else - linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp), - du = cache.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtgamma)) - end + linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp)) @.. broadcast=false $(_vec(k1))=-linres.u @@ -1339,16 +1283,7 @@ end integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t) - if repeat_step - linres = dolinsolve( - integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp), - du = cache.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtgamma)) - else - linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp), - du = cache.fsalfirst, u = u, p = p, t = t, weight = weight, - solverdata = (; gamma = dtgamma)) - end + linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp)) @.. $(_vec(ks[1])) = -linres.u integrator.stats.nsolve += 1