diff --git a/src/callbacks.jl b/src/callbacks.jl index f512de88f..773322f83 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -62,7 +62,7 @@ has_continuous_callback(cb::VectorContinuousCallback) = true has_continuous_callback(cb::CallbackSet) = !isempty(cb.continuous_callbacks) has_continuous_callback(cb::Nothing) = false -isforward(integrator::DEIntegrator) = isone(integrator.tdir) +rightfloat(t, tdir) = isone(tdir) ? nextfloat(t) : prevfloat(t) # Callback handling @@ -178,232 +178,197 @@ end return ex end -@inline function determine_event_occurrence( - integrator, callback::VectorContinuousCallback, - counter - ) - event_occurred = false +@inline function find_callback_time( + integrator, callback::VectorContinuousCallback, + callback_idx +) if callback.interp_points != 0 addsteps!(integrator) end - ts = range(integrator.tprev, stop = integrator.t, length = callback.interp_points) - - #= - # Faster but can be inaccurate - if callback.interp_points > 1 - dt = (integrator.t - integrator.tprev) / (callback.interp_points-1) - else - dt = integrator.dt - end - ts = integrator.tprev:dt:integrator.t - =# - - interp_index = 0 - # Check if the event occurred - previous_condition = @views(integrator.callback_cache.previous_condition[1:(callback.len)]) - - if callback.idxs === nothing - callback.condition( - previous_condition, integrator.uprev, integrator.tprev, - integrator - ) - else - callback.condition( - previous_condition, integrator.uprev[callback.idxs], - integrator.tprev, integrator - ) - end - integrator.sol.stats.ncondition += 1 - - ivec = integrator.vector_event_last_time - prev_sign = @view(integrator.callback_cache.prev_sign[1:(callback.len)]) - next_sign = @view(integrator.callback_cache.next_sign[1:(callback.len)]) - - if integrator.event_last_time == counter && - minimum( - ODE_DEFAULT_NORM( - ArrayInterface.allowed_getindex( - previous_condition, - ivec - ), integrator.t - ) - ) <= - 100ODE_DEFAULT_NORM(integrator.last_event_error, integrator.t) - - # If there was a previous event, utilize the derivative at the start to - # chose the previous sign. If the derivative is positive at tprev, then - # we treat `prev_sign` as negative, and if the derivative is negative then we - # treat `prev_sign` as positive, regardless of the positivity/negativity - # of the true value due to it being =0 sans floating point issues. + # Compute previous sign + bottom_sign = @view(integrator.callback_cache.prev_sign[1:(callback.len)]) + bottom_t = integrator.tprev + prev_condition = get_condition(integrator, callback, integrator.tprev) + @. bottom_sign = sign(prev_condition) - # Only due this if the discontinuity did not move it far away from an event - # Since near even we use direction instead of location to reset + if integrator.event_last_time == callback_idx + nudged_idx = integrator.vector_event_last_time + # If there was a previous event, nudge tprev on the right + # side of the root (if necessary) to avoid repeat detection if callback.interp_points == 0 addsteps!(integrator) end # Evaluate condition slightly in future - abst = integrator.tprev + integrator.dt * callback.repeat_nudge - tmp_condition = get_condition(integrator, callback, abst) - @. prev_sign = sign(previous_condition) - prev_sign[ivec] = sign(tmp_condition[ivec]) + nudged_t = nudge_tprev(integrator, callback, prev_condition[nudged_idx]) + tmp_condition = get_condition(integrator, callback, nudged_t) + + bottom_sign[nudged_idx] = sign(tmp_condition[nudged_idx]) else - @. prev_sign = sign(previous_condition) + nudged_idx = -1 + nudged_t = bottom_t end - prev_sign_index = 1 - abst = integrator.t - next_condition = get_condition(integrator, callback, abst) - @. next_sign = sign(next_condition) - - event_idx = findall_events!( - next_sign, callback.affect!, callback.affect_neg!, - prev_sign - ) - if sum(event_idx) != 0 - event_occurred = true - interp_index = callback.interp_points - end + # Check if an event occured + event_occurred, event_idx, top_t, top_sign = + check_event_occurence(integrator, callback, bottom_sign) - if callback.interp_points != 0 && !isdiscrete(integrator.alg) && - sum(event_idx) != length(event_idx) # Use the interpolants for safety checking - fallback = true - for i in 2:length(ts) - abst = ts[i] - copyto!(next_sign, get_condition(integrator, callback, abst)) - _event_idx = findall_events!( - next_sign, callback.affect!, callback.affect_neg!, - prev_sign - ) - if sum(_event_idx) != 0 - event_occurred = true - event_idx = _event_idx - interp_index = i - fallback = false - break - else - prev_sign_index = i + # Find callback time if occurence + if !event_occurred + callback_t = integrator.t + min_event_idx = 1 + elseif isdiscrete(integrator.alg) || callback.rootfind == SciMLBase.NoRootFind + callback_t = top_t + min_event_idx = findfirst(isequal(1), event_idx) + else + callback_t = rightfloat(top_t, integrator.tdir) + min_event_idx = -1 + for idx in 1:length(event_idx) + if ArrayInterface.allowed_getindex(event_idx, idx) != 0 + function zero_func(abst, p=nothing) + return ArrayInterface.allowed_getindex( + get_condition( + integrator, + callback, + abst + ), idx + ) + end + if iszero(ArrayInterface.allowed_getindex(top_sign, idx)) + cbi_t = top_t + else + if idx == nudged_idx + cbi_t = find_root(zero_func, (nudged_t, top_t), callback.rootfind) + else + cbi_t = find_root(zero_func, (bottom_t, top_t), callback.rootfind) + end + end + if integrator.tdir * cbi_t < integrator.tdir * callback_t + min_event_idx = idx + callback_t = cbi_t + integrator.last_event_error = zero_func(cbi_t) + end end end - if fallback - # If you get here, then you need to reset the event_idx to the - # non-interpolated version - - abst = integrator.t - next_condition = get_condition(integrator, callback, abst) - @. next_sign = sign(next_condition) - event_idx = findall_events!( - next_sign, callback.affect!, callback.affect_neg!, - prev_sign - ) - interp_index = callback.interp_points + if min_event_idx < 0 + error("Callback handling failed. Please file an issue with code to reproduce.") end end - return event_occurred, interp_index, ts, prev_sign, prev_sign_index, event_idx + return callback_t, ArrayInterface.allowed_getindex(bottom_sign, min_event_idx), + event_occurred::Bool, min_event_idx::Int end -@inline function determine_event_occurrence( - integrator, callback::ContinuousCallback, - counter - ) - event_occurred = false +@inline function find_callback_time( + integrator, callback::ContinuousCallback, + callback_idx +) if callback.interp_points != 0 addsteps!(integrator) end - ts = range(integrator.tprev, stop = integrator.t, length = callback.interp_points) + # Compute previous sign + bottom_t = integrator.tprev + bottom_condition = get_condition(integrator, callback, bottom_t) + if integrator.event_last_time == callback_idx + # If there was a previous event, nudge tprev on the right + # side of the root (if necessary) to avoid repeat detection - #= - # Faster but can be inaccurate - if callback.interp_points > 1 - dt = (integrator.t - integrator.tprev) / (callback.interp_points-1) - else - dt = integrator.dt + if callback.interp_points == 0 + addsteps!(integrator) + end + + bottom_t = nudge_tprev(integrator, callback, bottom_condition) + bottom_condition = get_condition(integrator, callback, bottom_t) end - ts = integrator.tprev:dt:integrator.t - =# + bottom_sign = sign(bottom_condition) - interp_index = 0 + # Check if an event occured + event_occurred, event_idx, top_t, top_sign = + check_event_occurence(integrator, callback, bottom_sign) - # Check if the event occurred - if callback.idxs === nothing - previous_condition = callback.condition( - integrator.uprev, integrator.tprev, - integrator - ) + if !event_occurred + callback_t = integrator.t + elseif isdiscrete(integrator.alg) || callback.rootfind == SciMLBase.NoRootFind || iszero(top_sign) + callback_t = top_t else - @views previous_condition = callback.condition( - integrator.uprev[callback.idxs], - integrator.tprev, integrator - ) + # Find callback time + zero_func(abst, p=nothing) = get_condition(integrator, callback, abst) + callback_t = find_root(zero_func, (bottom_t, top_t), callback.rootfind) + integrator.last_event_error = zero_func(callback_t) end - integrator.sol.stats.ncondition += 1 - - prev_sign = 0.0 - next_sign = 0.0 - - if integrator.event_last_time == counter && - minimum(ODE_DEFAULT_NORM(previous_condition, integrator.t)) <= - 100ODE_DEFAULT_NORM(integrator.last_event_error, integrator.t) - # If there was a previous event, utilize the derivative at the start to - # chose the previous sign. If the derivative is positive at tprev, then - # we treat `prev_sign` as negative, and if the derivative is negative then we - # treat `prev_sign` as positive, regardless of the positivity/negativity - # of the true value due to it being =0 sans floating point issues. - - # Only due this if the discontinuity did not move it far away from an event - # Since near even we use direction instead of location to reset - - if callback.interp_points == 0 - addsteps!(integrator) - end + return callback_t, bottom_sign, event_occurred, event_idx +end - # Evaluate condition slightly in future - abst = integrator.tprev + integrator.dt * callback.repeat_nudge - tmp_condition = get_condition(integrator, callback, abst) - prev_sign = sign(tmp_condition) +""" +Return a nudged (if neccessary) value of `integrator.tprev` to avoid repeat event detection +- `integrator` +- `callback`: Last occuring callback +- `condition_tprev`: Condition of last occuring callback evaluated at `integrator.tprev` +""" +function nudge_tprev(integrator, callback, condition_tprev) + # Assume the previous event might affect the condition/root + if abs(condition_tprev - integrator.last_event_error) <= callback.abstol + # We are still close to the root + right_t = integrator.tprev + integrator.dt * callback.repeat_nudge else - prev_sign = sign(previous_condition) + # We are far away from the root, keep the current sign + right_t = integrator.tprev end +end - prev_sign_index = 1 - abst = integrator.t - next_condition = get_condition(integrator, callback, abst) - next_sign = sign(next_condition) +""" +Determine if an event occured in the integration time step +""" +function check_event_occurence(integrator, callback, bottom_sign) + top_t = integrator.t + event_occurred, event_idx, top_sign = + check_event_occurence_upto(integrator, callback, bottom_sign, top_t) - if ( - (prev_sign < 0 && callback.affect! !== nothing) || - (prev_sign > 0 && callback.affect_neg! !== nothing) - ) && prev_sign * next_sign <= 0 - event_occurred = true - interp_index = callback.interp_points - elseif callback.interp_points != 0 && !isdiscrete(integrator.alg) # Use the interpolants for safety checking + if callback.interp_points != 0 && !isdiscrete(integrator.alg) && + any(iszero, event_idx) + # Use the interpolants for safety checking + ts = range(integrator.tprev, stop=integrator.t, length=callback.interp_points) for i in 2:length(ts) - abst = ts[i] - new_sign = get_condition(integrator, callback, abst) - if ( - (prev_sign < 0 && callback.affect! !== nothing) || - (prev_sign > 0 && callback.affect_neg! !== nothing) - ) && - prev_sign * new_sign < 0 - event_occurred = true - interp_index = i + top_t = ts[i] + event_occurred, event_idx, top_sign = + check_event_occurence_upto(integrator, callback, bottom_sign, top_t) + if event_occurred break - else - prev_sign_index = i end end end - event_idx = 1 - return event_occurred, interp_index, ts, prev_sign, prev_sign_index, event_idx + return event_occurred, event_idx, top_t, top_sign +end + +""" +Determine if an event occured before `top_t`` +""" +function check_event_occurence_upto(integrator, callback::ContinuousCallback, bottom_sign, top_t) + top_sign = sign(get_condition(integrator, callback, top_t)) + event_occurred = is_event_occurence(bottom_sign, top_sign, callback.affect!, callback.affect_neg!) + event_idx = event_occurred ? 1.0 : 0.0 + return event_occurred, event_idx, top_sign end +function check_event_occurence_upto(integrator, callback::VectorContinuousCallback, bottom_sign, top_t) + event_idx = top_condition = @views(integrator.callback_cache.next_condition[1:(callback.len)]) + top_sign = @view(integrator.callback_cache.next_sign[1:(callback.len)]) + copyto!(top_condition, get_condition(integrator, callback, top_t)) + @. top_sign = sign(top_condition) + + # Determine event occurence + event_occurred = findall_events!( + top_condition, callback.affect!, callback.affect_neg!, + bottom_sign + ) + return event_occurred, event_idx, top_sign +end """ Find either exact or floating point precision root of `f`. @@ -430,7 +395,8 @@ end findall_events!(next_sign,affect!,affect_neg!,prev_sign) Modifies `next_sign` to be an array of booleans for if there is a sign change -in the interval between prev_sign and next_sign +in the interval between prev_sign and next_sign. +Return `true` if any event occured. """ function findall_events!( next_sign::Union{Array, SubArray}, affect!::F1, affect_neg!::F2, @@ -443,7 +409,7 @@ function findall_events!( ) && prev_sign[i] * next_sign[i] <= 0 end - return next_sign + return any(isone, next_sign) end function findall_events!(next_sign, affect!::F1, affect_neg!::F2, prev_sign) where {F1, F2} @@ -451,151 +417,17 @@ function findall_events!(next_sign, affect!::F1, affect_neg!::F2, prev_sign) whe hasaffectneg::Bool = affect_neg! !== nothing f = (n, p) -> ((p < 0 && hasaffect) || (p > 0 && hasaffectneg)) && p * n <= 0 A = map!(f, next_sign, next_sign, prev_sign) - return next_sign + return any(isone, next_sign) end -function find_callback_time(integrator, callback::ContinuousCallback, counter) - event_occurred, interp_index, ts, prev_sign, - prev_sign_index, event_idx = determine_event_occurrence( - integrator, - callback, - counter - ) - if event_occurred - if callback.condition === nothing - cb_t = integrator.t - else - if callback.interp_points != 0 - top_t = ts[interp_index] # Top at the smallest - bottom_t = ts[prev_sign_index] - else - top_t = integrator.t - bottom_t = integrator.tprev - end - if callback.rootfind != SciMLBase.NoRootFind && !isdiscrete(integrator.alg) - zero_func(abst, p = nothing) = get_condition(integrator, callback, abst) - if zero_func(top_t) == 0 - cb_t = top_t - else - if integrator.event_last_time == counter && - abs(zero_func(bottom_t)) <= 100abs(integrator.last_event_error) && - prev_sign_index == 1 - - # Determined that there is an event by derivative - # But floating point error may make the end point negative - - bottom_t += integrator.dt * callback.repeat_nudge - sign_top = sign(zero_func(top_t)) - sign(zero_func(bottom_t)) * sign_top >= zero(sign_top) && - error("Double callback crossing floating point reducer errored. Report this issue.") - end - cb_t = find_root(zero_func, (bottom_t, top_t), callback.rootfind) - integrator.last_event_error = DiffEqBase.value( - ODE_DEFAULT_NORM( - zero_func(cb_t), cb_t - ) - ) - end - elseif interp_index != callback.interp_points && !isdiscrete(integrator.alg) - cb_t = ts[interp_index] - else - # If no solve and no interpolants, just use endpoint - cb_t = integrator.t - end - end - else - cb_t = integrator.t - end - - return cb_t, prev_sign, event_occurred, event_idx -end - -function find_callback_time(integrator, callback::VectorContinuousCallback, counter) - event_occurred, interp_index, ts, prev_sign, - prev_sign_index, event_idx = determine_event_occurrence( - integrator, - callback, - counter - ) - if event_occurred - if callback.condition === nothing - cb_t = integrator.t - min_event_idx = findfirst(isequal(1), event_idx) - else - if callback.interp_points != 0 - top_t = ts[interp_index] # Top at the smallest - bottom_t = ts[prev_sign_index] - else - top_t = integrator.t - bottom_t = integrator.tprev - end - if callback.rootfind != SciMLBase.NoRootFind && !isdiscrete(integrator.alg) - cb_t = isforward(integrator) ? nextfloat(top_t) : prevfloat(top_t) - min_event_idx = -1 - for idx in 1:length(event_idx) - if ArrayInterface.allowed_getindex(event_idx, idx) != 0 - function zero_func(abst, p = nothing) - return ArrayInterface.allowed_getindex( - get_condition( - integrator, - callback, - abst - ), idx - ) - end - if zero_func(top_t) == 0 - cbi_t = top_t - else - if integrator.event_last_time == counter && - integrator.vector_event_last_time == idx && - abs(zero_func(bottom_t)) <= - 100abs(integrator.last_event_error) && - prev_sign_index == 1 - - # Determined that there is an event by derivative - # But floating point error may make the end point negative - - bottom_t += integrator.dt * callback.repeat_nudge - sign_top = sign(zero_func(top_t)) - sign(zero_func(bottom_t)) * sign_top >= zero(sign_top) && - error("Double callback crossing floating point reducer errored. Report this issue.") - end - - cbi_t = find_root(zero_func, (bottom_t, top_t), callback.rootfind) - if integrator.tdir * cbi_t < integrator.tdir * cb_t - integrator.last_event_error = DiffEqBase.value( - ODE_DEFAULT_NORM( - zero_func(cbi_t), cbi_t - ) - ) - end - end - if integrator.tdir * cbi_t < integrator.tdir * cb_t - min_event_idx = idx - cb_t = cbi_t - end - end - end - elseif interp_index != callback.interp_points && !isdiscrete(integrator.alg) - cb_t = ts[interp_index] - min_event_idx = findfirst(isequal(1), event_idx) - else - # If no solve and no interpolants, just use endpoint - cb_t = integrator.t - min_event_idx = findfirst(isequal(1), event_idx) - end - end - else - cb_t = integrator.t - min_event_idx = 1 - end - - if event_occurred && min_event_idx < 0 - error("Callback handling failed. Please file an issue with code to reproduce.") - end - - return cb_t, ArrayInterface.allowed_getindex(prev_sign, min_event_idx), - event_occurred::Bool, min_event_idx::Int +""" +Return `true` if an event occured. +""" +function is_event_occurence(prev_sign::Number, next_sign::Number, affect!::F1, affect_neg!::F2) where {F1, F2} + return ( + (prev_sign < 0 && affect! !== nothing) || + (prev_sign > 0 && affect_neg! !== nothing) + ) && prev_sign * next_sign <= 0 end function apply_callback!( @@ -761,7 +593,7 @@ $(TYPEDEF) """ mutable struct CallbackCache{conditionType, signType} tmp_condition::conditionType - previous_condition::conditionType + next_condition::conditionType next_sign::signType prev_sign::signType end @@ -771,10 +603,10 @@ function CallbackCache( ::Type{signType} ) where {conditionType, signType} tmp_condition = similar(u, conditionType, max_len) - previous_condition = similar(u, conditionType, max_len) + next_condition = similar(u, conditionType, max_len) next_sign = similar(u, signType, max_len) prev_sign = similar(u, signType, max_len) - return CallbackCache(tmp_condition, previous_condition, next_sign, prev_sign) + return CallbackCache(tmp_condition, next_condition, next_sign, prev_sign) end function CallbackCache( @@ -782,8 +614,8 @@ function CallbackCache( ::Type{signType} ) where {conditionType, signType} tmp_condition = zeros(conditionType, max_len) - previous_condition = zeros(conditionType, max_len) + next_condition = zeros(conditionType, max_len) next_sign = zeros(signType, max_len) prev_sign = zeros(signType, max_len) - return CallbackCache(tmp_condition, previous_condition, next_sign, prev_sign) + return CallbackCache(tmp_condition, next_condition, next_sign, prev_sign) end diff --git a/test/downstream/callback_detection.jl b/test/downstream/callback_detection.jl new file mode 100644 index 000000000..4850592fc --- /dev/null +++ b/test/downstream/callback_detection.jl @@ -0,0 +1,95 @@ + +using OrdinaryDiffEq +# https://github.com/SciML/DiffEqBase.jl/issues/1231 +@testset "Successive different callbacks in same integration step" begin + cb = ContinuousCallback( + (u, t, integrator) -> t - 0.0, + (integrator) -> push!(record, 0); + abstol=0.0 + ) + + vcb = VectorContinuousCallback( + (out, u, t, integrator) -> out .= (t - 1.0e-8, t - 2.0e-8, t - 2.0e-7), + (integrator, event_index) -> push!(record, event_index), + 3; + abstol=0.0 + ) + + f(u, p, t) = 1.0 + u0 = 0.0 + + # Forward propagation with successive events + record = [] + tspan = (-1.0, 1.0) + prob = ODEProblem(f, u0, tspan) + sol = solve(prob, Tsit5(), dt=2.0, callback=CallbackSet(cb, vcb)) + @test record == [0, 1, 2, 3] + + # Backward propagation with successive events + record = [] + tspan = (1.0, -1.0) + prob = ODEProblem(f, u0, tspan) + sol = solve(prob, Tsit5(), dt=2.0, callback=CallbackSet(cb, vcb)) + @test record == [3, 2, 1, 0] +end + +@testset "Successive same event detection" begin + @testset for affect_integrator in [false, true] + @testset for tdir in [1, -1] + poly(t) = (t - 0.1) * (t - 0.4) * (t - 0.8) + function affect!(integrator, index=1) + push!(record, tdir * integrator.t) + if affect_integrator + # nudge t backward to see if integrator avoids repeat detection + integrator.t = integrator.t - tdir * 1.0e-14 + end + end + abstol = affect_integrator ? 1.0e-14 : 0.0 + + f(u, p, t) = 1.0 + u0 = 0.0 + tspan = tdir .* (0.0, 1.0) + prob = ODEProblem(f, u0, tspan; dt=0.25, maxiters=100) + + # Linear roots (can step on exact root) + + cb = ContinuousCallback( + (u, t, integrator) -> poly(tdir * t), + affect!; abstol=abstol + ) + + record = [] + sol = solve(prob, Tsit5(), callback=cb) + @test record == [0.1, 0.4, 0.8] + + vcb = VectorContinuousCallback( + (out, u, t, integrator) -> out .= (poly(tdir * t), poly(tdir * t - 0.1)), + affect!, 2; abstol=abstol + ) + + record = [] + sol = solve(prob, Tsit5(), callback=vcb) + @test record == [0.1, 0.2, 0.4, 0.5, 0.8, 0.9] + + # Quadratic roots (cannot step on exact root) + + cb = ContinuousCallback( + (u, t, integrator) -> poly(t^2), + affect!; abstol=abstol + ) + + record = [] + sol = solve(prob, Tsit5(), callback=cb) + @test record ≈ sqrt.([0.1, 0.4, 0.8]) + + vcb = VectorContinuousCallback( + (out, u, t, integrator) -> out .= (poly(t^2), poly(t^2 - 0.1)), + affect!, 2; abstol=abstol + ) + + record = [] + sol = solve(prob, Tsit5(), callback=vcb) + @test record ≈ sqrt.([0.1, 0.2, 0.4, 0.5, 0.8, 0.9]) + end + end +end \ No newline at end of file diff --git a/test/downstream/community_callback_tests.jl b/test/downstream/community_callback_tests.jl index ba6169225..0a39b744e 100644 --- a/test/downstream/community_callback_tests.jl +++ b/test/downstream/community_callback_tests.jl @@ -247,34 +247,3 @@ sol = solve(prob, DFBDF()) # test that the callback flipping p caused u[2] to get flipped. first_t = findfirst(isequal(0.5), sol.t) @test sol.u[first_t][2] == -sol.u[first_t + 1][2] - -# https://github.com/SciML/DiffEqBase.jl/issues/1231 -@testset "Successive callbacks in same integration step" begin - cb = ContinuousCallback( - (u, t, integrator) -> t - 0.0, - (integrator) -> push!(record, 0) - ) - - vcb = VectorContinuousCallback( - (out, u, t, integrator) -> out .= (t - 1.0e-8, t - 2.0e-8, t - 2.0e-7), - (integrator, event_index) -> push!(record, event_index), - 3 - ) - - f(u, p, t) = 1.0 - u0 = 0.0 - - # Forward propagation with successive events - record = [] - tspan = (-1.0, 1.0) - prob = ODEProblem(f, u0, tspan) - sol = solve(prob, Tsit5(), dt = 2.0, callback = CallbackSet(cb, vcb)) - @test record == [0, 1, 2, 3] - - # Backward propagation with successive events - record = [] - tspan = (1.0, -1.0) - prob = ODEProblem(f, u0, tspan) - sol = solve(prob, Tsit5(), dt = 2.0, callback = CallbackSet(cb, vcb)) - @test record == [3, 2, 1, 0] -end diff --git a/test/downstream/null_de.jl b/test/downstream/null_de.jl index 11401beb0..023ff2f88 100644 --- a/test/downstream/null_de.jl +++ b/test/downstream/null_de.jl @@ -156,14 +156,6 @@ end @test sol_no_cb.retcode == SciMLBase.ReturnCode.Success @test sol_no_cb.t == [0.0, 1.0] - # Test 2: has_callbacks detection - null problem WITH callbacks should NOT take fast path - # This will error because OrdinaryDiffEq can't handle null u0 with callbacks yet, - # but the error proves we're not silently skipping callbacks - callback_called = Ref(false) - cb = DiscreteCallback((u, t, integrator) -> t >= 0.5, integrator -> callback_called[] = true) - prob_with_cb = ODEProblem(Returns(nothing), nothing, (0.0, 1.0)) - @test_throws Exception solve(prob_with_cb, Tsit5(); callback = cb) - # Test 3: ODE with state + DiscreteCallback - callbacks should trigger # Using raw ODE (not MTK) to avoid API changes callback_triggered = Ref(false) @@ -176,7 +168,4 @@ end @test sol_with_state.retcode == SciMLBase.ReturnCode.Success @test callback_triggered[] - - # Test 4: init with null problem + callback should also not take fast path - @test_throws Exception init(prob_with_cb, Tsit5(); callback = cb) end diff --git a/test/runtests.jl b/test/runtests.jl index 455ab97ff..a44c55424 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,6 +57,7 @@ end @time @safetestset "Table Inference Tests" include("downstream/tables.jl") @time @safetestset "Default linsolve with structure" include("downstream/default_linsolve_structure.jl") @time @safetestset "Callback Merging Tests" include("downstream/callback_merging.jl") + @time @safetestset "Callback Detection Tests" include("downstream/callback_detection.jl") @time @safetestset "LabelledArrays Tests" include("downstream/labelledarrays.jl") @time @safetestset "GTPSA Tests" include("downstream/gtpsa.jl") @time @safetestset "SubArray Support" include("downstream/subarray_support.jl")