-
-
Notifications
You must be signed in to change notification settings - Fork 106
Feat: adjoints through observable functions #689
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
92ad6a8
22dc7ec
0c2b69d
c61e08c
8600a8d
a69d087
785b052
adee4f0
2197a30
9172014
95cf416
4ce8257
839bd63
2474a8d
f68cb05
9ab29d9
a417cdd
ff9bb2c
032b927
44bfc91
de2d6cd
c63dfbf
940ea78
8e48f1c
d061ce4
f817b52
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
using SciMLBase: ODESolution, remake, | ||
getobserved, build_solution, EnsembleSolution, | ||
NonlinearSolution, AbstractTimeseriesSolution | ||
using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index | ||
using SymbolicIndexingInterface: symbolic_type, NotSymbolic, variable_index, is_observed, observed, parameter_values | ||
using RecursiveArrayTools | ||
|
||
# This method resolves the ambiguity with the pullback defined in | ||
|
@@ -109,7 +109,15 @@ | |
@adjoint function Base.getindex(VA::ODESolution, sym) | ||
function ODESolution_getindex_pullback(Δ) | ||
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym | ||
if i === nothing | ||
if is_observed(VA, sym) | ||
y, back = Zygote.pullback(VA) do sol | ||
f = observed(sol, sym) | ||
p = parameter_values(sol) | ||
f.(sol.u,Ref(p), sol.t) | ||
end | ||
gs = back(Δ) | ||
(gs[1], nothing) | ||
elseif i === nothing | ||
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")) | ||
else | ||
Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)] | ||
|
@@ -122,6 +130,7 @@ | |
|
||
@adjoint function Base.getindex(VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}) where T | ||
function ODESolution_getindex_pullback(Δ) | ||
@show typeof(Δ) | ||
sym = sym isa Tuple ? collect(sym) : sym | ||
i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym) | ||
if i === nothing | ||
|
@@ -182,15 +191,15 @@ | |
NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, args...), NonlinearSolutionAdjoint | ||
end | ||
|
||
@adjoint function literal_getproperty(sol::AbstractTimeseriesSolution, | ||
::Val{:u}) | ||
function solu_adjoint(Δ) | ||
zerou = zero(sol.prob.u0) | ||
_Δ = @. ifelse(Δ === nothing, (zerou,), Δ) | ||
(build_solution(sol.prob, sol.alg, sol.t, _Δ),) | ||
end | ||
sol.u, solu_adjoint | ||
end | ||
# @adjoint function literal_getproperty(sol::AbstractTimeseriesSolution, | ||
# ::Val{:u}) | ||
# function solu_adjoint(Δ) | ||
# zerou = zero(sol.prob.u0) | ||
# _Δ = @. ifelse(Δ === nothing, (zerou,), Δ) | ||
# (build_solution(sol.prob, sol.alg, sol.t, _Δ),) | ||
# end | ||
# sol.u, solu_adjoint | ||
# end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this removed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was returning the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a unit test in the downstream set which shows this is fine? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Happy to. In fact, that's why I asked if anything was relying on this behavior previously. Could you suggest what kind of test you have in mind? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to be the root cause of many of the test failures? So that means it's caught by the tests already. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this is what the error is referring to. I am missing a branch https://github.com/DhairyaLGandhi/RecursiveArrayTools.jl/tree/dg/noproj which removes an extra projection rule. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does refer to projecting to a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. now that we have restored the adjoint, I believe this can be resolved |
||
|
||
@adjoint function literal_getproperty(sol::SciMLBase.AbstractNoTimeSolution, | ||
::Val{:u}) | ||
|
Uh oh!
There was an error while loading. Please reload this page.