-
-
Notifications
You must be signed in to change notification settings - Fork 104
Accumulation for ODEProblem #1036
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
236fb65
d50778b
584f8c3
53cc142
4aaae10
a03bb02
e3b0108
c4178b1
c85dfa2
296f43e
0c249ab
40ba20b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,14 @@ | ||
module SciMLBaseZygoteExt | ||
|
||
using Zygote | ||
using Zygote: @adjoint, pullback | ||
import Zygote: literal_getproperty | ||
using Zygote: @adjoint, pullback, @_adjoint_keepthunks, _project, pair | ||
import Zygote: literal_getproperty, literal_getfield | ||
import ChainRulesCore | ||
using SciMLBase | ||
using SciMLBase: ODESolution, remake, ODEFunction, | ||
getobserved, build_solution, EnsembleSolution, | ||
NonlinearSolution, AbstractTimeseriesSolution | ||
NonlinearSolution, AbstractTimeseriesSolution, | ||
ODEProblem | ||
using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed, | ||
observed, parameter_values, state_values, current_time | ||
using RecursiveArrayTools | ||
|
@@ -299,4 +300,23 @@ end | |
∇responsible_map(__context__, f, args...) | ||
end | ||
|
||
@_adjoint_keepthunks function Zygote.literal_getfield(x::ODEProblem, ::Val{f}) where f | ||
val = getfield(x, f) | ||
function back(Δ) | ||
# error() | ||
Zygote.accum_param(__context__, val, Δ) === nothing && return | ||
if isimmutable(x) | ||
error() | ||
dx = (; Zygote.nt_nothing(x)..., pair(Val(f), Δ, x)...) | ||
(_project(x, dx), nothing) | ||
else | ||
dx = Zygote.grad_mut(__context__, x) | ||
dx[] = (; dx[]..., pair(Val(f), Zygote.accum(getfield(dx[], f), Δ))...) | ||
return (dx,nothing) | ||
end | ||
end | ||
Zygote.unwrap(val), back | ||
end | ||
Zygote.accum(::Tuple{}, ::NamedTuple{}) = () | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't this type piracy? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is 😅 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah... FluxML/Zygote.jl#1574 is definitely the safer bet. Why are problem types |
||
|
||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -258,50 +258,58 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, | |
initdata.update_initializeprob!(initprob, valp) | ||
end | ||
end | ||
nlsol, success = solve_initialization(initdata, initprob, alg; reltol, abstol, nlsolve_alg ) | ||
|
||
if is_trivial_initialization(initdata) | ||
nlsol = initdata | ||
success = true | ||
else | ||
nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing)) | ||
if nlsolve_alg === nothing && state_values(initprob) !== nothing | ||
throw(OverrideInitMissingAlgorithm()) | ||
end | ||
if alg.abstol !== nothing | ||
_abstol = alg.abstol | ||
elseif abstol !== nothing | ||
_abstol = abstol | ||
else | ||
throw(OverrideInitNoTolerance(:abstol)) | ||
end | ||
if alg.reltol !== nothing | ||
_reltol = alg.reltol | ||
elseif reltol !== nothing | ||
_reltol = reltol | ||
else | ||
throw(OverrideInitNoTolerance(:reltol)) | ||
end | ||
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol, kwargs...) | ||
|
||
success = if initprob isa NonlinearLeastSquaresProblem | ||
# Do not accept StalledSuccess as a solution | ||
# A good local minima is not a success | ||
resid = nlsol.resid | ||
normresid = norm(resid) | ||
SciMLBase.successful_retcode(nlsol) && normresid <= abstol | ||
else | ||
SciMLBase.successful_retcode(nlsol) | ||
end | ||
end | ||
|
||
nlsol2 = prob.f.initialization_data.initializeprob | ||
if initdata.initializeprobmap !== nothing | ||
u0 = initdata.initializeprobmap(choose_branch(nlsol)) | ||
u02 = initdata.initializeprobmap(nlsol2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't make sense - we solve initialization, but calculate the new There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This has contains some of the changes we had made previously to resolve the error. It is possible to remove the changes to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in c85dfa2 |
||
end | ||
if initdata.initializeprobpmap !== nothing | ||
p = initdata.initializeprobpmap(valp, choose_branch(nlsol)) | ||
p2 = initdata.initializeprobpmap(valp, nlsol) | ||
end | ||
|
||
return u0, p, success | ||
u03 = isnothing(initdata.initializeprobmap) ? u0 : u02 | ||
p3 = isnothing(initdata.initializeprobpmap) ? p : p2 | ||
return u03, p3, success | ||
end | ||
|
||
function solve_initialization(initdata::OverrideInitData{<:AbstractNonlinearProblem{Nothing}}, initprob, alg; kwargs...) | ||
nlsol = initprob | ||
success = true | ||
return nlsol, success | ||
end | ||
|
||
function solve_initialization(initdata, initprob, alg; reltol, abstol, nlsolve_alg) | ||
nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing)) | ||
if nlsolve_alg === nothing && state_values(initprob) !== nothing | ||
throw(OverrideInitMissingAlgorithm()) | ||
end | ||
if alg.abstol !== nothing | ||
_abstol = alg.abstol | ||
elseif abstol !== nothing | ||
_abstol = abstol | ||
else | ||
throw(OverrideInitNoTolerance(:abstol)) | ||
end | ||
if alg.reltol !== nothing | ||
_reltol = alg.reltol | ||
elseif reltol !== nothing | ||
_reltol = reltol | ||
else | ||
throw(OverrideInitNoTolerance(:reltol)) | ||
end | ||
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol, kwargs...) | ||
|
||
success = if initprob isa NonlinearLeastSquaresProblem | ||
# Do not accept StalledSuccess as a solution | ||
# A good local minima is not a success | ||
resid = nlsol.resid | ||
normresid = norm(resid) | ||
SciMLBase.successful_retcode(nlsol) && normresid <= abstol | ||
else | ||
SciMLBase.successful_retcode(nlsol) | ||
end | ||
return nlsol, success | ||
end | ||
|
||
""" | ||
|
@@ -314,21 +322,6 @@ function get_initial_values(prob, integrator, f, ::NoInit, iip; kwargs...) | |
return state_values(integrator), parameter_values(integrator), true | ||
end | ||
|
||
is_trivial_initialization(::Nothing) = true | ||
|
||
function is_trivial_initialization(initdata::OverrideInitData) | ||
!(initdata.initializeprob isa NonlinearLeastSquaresProblem) && | ||
state_values(initdata.initializeprob) === nothing | ||
end | ||
|
||
function is_trivial_initialization(f::AbstractSciMLFunction) | ||
has_initialization_data(f) && is_trivial_initialization(f.initialization_data) | ||
end | ||
|
||
function is_trivial_initialization(prob::AbstractSciMLProblem) | ||
is_trivial_initialization(prob.f) | ||
end | ||
|
||
@enum DETERMINED_STATUS OVERDETERMINED FULLY_DETERMINED UNDERDETERMINED | ||
|
||
function initialization_status(prob::AbstractSciMLProblem) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the point of this block if the first line is
error()
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed in a03bb02