|
207 | 207 | end
|
208 | 208 | return nothing
|
209 | 209 | end
|
| 210 | + |
| 211 | +function initialize!(integrator, cache::ImplicitTaylor2ConstantCache) |
| 212 | + integrator.kshortsize = 2 |
| 213 | + integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) |
| 214 | + integrator.fsalfirst = integrator.f(integrator.uprev, integrator.p, integrator.t) # Pre-start fsal |
| 215 | + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) |
| 216 | + |
| 217 | + # Avoid undefined entries if k is an array of arrays |
| 218 | + integrator.fsallast = zero(integrator.fsalfirst) |
| 219 | + integrator.k[1] = integrator.fsalfirst |
| 220 | + integrator.k[2] = integrator.fsallast |
| 221 | +end |
| 222 | + |
| 223 | +@muladd function perform_step!(integrator, cache::ImplicitTaylor2ConstantCache, |
| 224 | + repeat_step = false) |
| 225 | + @unpack t, dt, uprev, u, f, p = integrator |
| 226 | + nlsolver = cache.nlsolver |
| 227 | + alg = unwrap_alg(integrator, true) |
| 228 | + markfirststage!(nlsolver) |
| 229 | + |
| 230 | + # initial guess |
| 231 | + if alg.extrapolant == :linear |
| 232 | + nlsolver.z = dt * integrator.fsalfirst |
| 233 | + else # :constant |
| 234 | + nlsolver.z = zero(u) |
| 235 | + end |
| 236 | + |
| 237 | + nlsolver.tmp = uprev |
| 238 | + nlsolver.γ = 1 |
| 239 | + z = nlsolve!(nlsolver, integrator, cache, repeat_step) |
| 240 | + nlsolvefail(nlsolver) && return |
| 241 | + u = nlsolver.tmp + z |
| 242 | + |
| 243 | + if integrator.opts.adaptive && integrator.success_iter > 0 |
| 244 | + # local truncation error (LTE) bound by dt^2/2*max|y''(t)| |
| 245 | + # use 2nd divided differences (DD) a la SPICE and Shampine |
| 246 | + |
| 247 | + # TODO: check numerical stability |
| 248 | + uprev2 = integrator.uprev2 |
| 249 | + tprev = integrator.tprev |
| 250 | + |
| 251 | + dt1 = dt * (t + dt - tprev) |
| 252 | + dt2 = (t - tprev) * (t + dt - tprev) |
| 253 | + c = 7 / 12 # default correction factor in SPICE (LTE overestimated by DD) |
| 254 | + r = c * dt^2 # by mean value theorem 2nd DD equals y''(s)/2 for some s |
| 255 | + |
| 256 | + tmp = r * |
| 257 | + integrator.opts.internalnorm.((u - uprev) / dt1 - (uprev - uprev2) / dt2, t) |
| 258 | + atmp = calculate_residuals(tmp, uprev, u, integrator.opts.abstol, |
| 259 | + integrator.opts.reltol, integrator.opts.internalnorm, t) |
| 260 | + integrator.EEst = integrator.opts.internalnorm(atmp, t) |
| 261 | + else |
| 262 | + integrator.EEst = 1 |
| 263 | + end |
| 264 | + |
| 265 | + integrator.fsallast = f(u, p, t + dt) |
| 266 | + |
| 267 | + if integrator.opts.adaptive && integrator.differential_vars !== nothing |
| 268 | + atmp = @. ifelse(!integrator.differential_vars, integrator.fsallast, false) ./ |
| 269 | + integrator.opts.abstol |
| 270 | + integrator.EEst += integrator.opts.internalnorm(atmp, t) |
| 271 | + end |
| 272 | + |
| 273 | + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) |
| 274 | + integrator.k[1] = integrator.fsalfirst |
| 275 | + integrator.k[2] = integrator.fsallast |
| 276 | + integrator.u = u |
| 277 | +end |
0 commit comments