diff --git a/src/CoupledSDEs.jl b/src/CoupledSDEs.jl index fdc63067..2292c82a 100644 --- a/src/CoupledSDEs.jl +++ b/src/CoupledSDEs.jl @@ -6,6 +6,12 @@ using DynamicalSystemsBase: SciMLBase, correct_state, _set_parameter!, + current_state, + rulestring, + isdeterministic, + isdiscretetime, + current_parameters, + current_time, current_state using StochasticDiffEq: SDEProblem @@ -191,12 +197,6 @@ function DynamicalSystemsBase.CoupledODEs( ) end -# Pretty print -function additional_details(ds::CoupledSDEs) - solver, remaining = _decompose_into_solver_and_remaining(ds.diffeq) - return ["SDE solver" => string(nameof(typeof(solver))), "SDE kwargs" => remaining] -end - ########################################################################################### # API - obtaining information from the system ########################################################################################### @@ -228,3 +228,46 @@ function DynamicalSystemsBase.successful_step(integ::SciMLBase.AbstractSDEIntegr rcode = integ.sol.retcode return rcode == SciMLBase.ReturnCode.Default || rcode == SciMLBase.ReturnCode.Success end + +function DynamicalSystemsBase.dynamic_rule(sys::CoupledSDEs) + f = sys.integ.f + while hasfield(typeof(f), :f) + f = f.f + end + return f +end +DynamicalSystemsBase.isdeterministic(ds::CoupledSDEs) = false + +# Pretty print +function additional_details(ds::CoupledSDEs) + solver, remaining = _decompose_into_solver_and_remaining(ds.diffeq) + return [ + "Noise strength" => ds.noise_strength, + "SDE solver" => string(nameof(typeof(solver))), + "SDE kwargs" => remaining, + ] +end + +function Base.show(io::IO, ::MIME"text/plain", ds::CoupledSDEs) + descriptors = [ + "deterministic" => isdeterministic(ds), + "discrete time" => isdiscretetime(ds), + "in-place" => isinplace(ds), + "dynamic rule" => rulestring(dynamic_rule(ds)), + ] + append!(descriptors, additional_details(ds)) + append!( + descriptors, + [ + "parameters" => current_parameters(ds), + "time" => current_time(ds), + "state" => current_state(ds), + ], + ) + padlen = maximum(length(d[1]) for d in descriptors) + 3 + + println(io, summary(ds)) + for (desc, val) in descriptors + println(io, rpad(" $(desc): ", padlen), val) + end +end