Skip to content

Support for adjoints for non-numeric types #3973

@mx-p9a

Description

@mx-p9a

I’d like to use non-numeric types in ModelingToolkit.jl, but with adjoint sensitivities. When using DifferentialEquations.jl there is support for this through SciMLStructures.jl as described here. So I assumed this would work with ModelingToolkit.jl as well, but thus far have been unsuccessful. The reason it is important to me to define a parameter as a non-numeric type is because there are a ton of fields in my type that I would rather not all expose as parameters or structural parameters for code clarity reasons.

MWE:

using ModelingToolkit, OrdinaryDiffEq, SciMLSensitivity, Zygote
using ModelingToolkit: t_nounits as t, D_nounits as D
using SciMLStructures
using SymbolicIndexingInterface: parameter_values

# parametric so AD duals work
struct MyType{T <: Real}
    a::T
    b::T
end

# accessor functions, registered for symbolic use
geta(θ::MyType) = θ.a
getb(θ::MyType) = θ.b
@register_symbolic geta(mytype::MyType)
@register_symbolic getb(mytype::MyType)

# tell SciMLStructures how to "flatten" and "repack" MyType
SciMLStructures.isscimlstructure(::MyType) = true
function SciMLStructures.canonicalize(::SciMLStructures.Tunable, θ::MyType{T}) where {T}
    vals = T[θ.a, θ.b]                    # flatten to a Vector
    repack = x -> (@assert length(x) == 2; MyType(x[1], x[2]))
    return vals, repack, false            # false => no aliasing
end

@mtkmodel Toy begin
    @parameters begin
        custom::MyType = MyType(2.0, 3.0), [tunable = true]
        # a = 2.0
        # b = 3.0
        c = -1.0
    end
    @variables begin
        x(t) = 1.0
    end
    @equations begin
        # Use the accessors instead of `custom.a` / `custom.b` directly
        D(x) ~ c * geta(custom) * x + getb(custom) # this doesn't work
        # D(x) ~ c * a * x + b # this works
    end
end

@mtkcompile sys = Toy()
tspan = (0.0, 1.0)
prob = ODEProblem(sys, [], tspan)

x_target = 0.5

# Canonicalize the tunables of the MTK parameter object to get an initial vector [a,b]
ps0 = parameter_values(prob)
x0, repack_ps, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), ps0)  # I want this to be x0 == [2.0, 3.0], but x0 is empty when using MyType

function loss(ab_vec)
    # repack the entire MTKParameters object, *including* custom::MyType, from the vector
    new_ps = repack_ps(ab_vec)
    newprob = remake(prob; p = new_ps)
    sol = solve(newprob, Tsit5(); saveat = [last(tspan)], sensealg = GaussAdjoint())
    xT = sol[sys.x][end]
    return (xT - x_target)^2
end

# Gradient w.r.t. [a,b]
∇loss = Zygote.gradient(loss, x0)[1]
println("∂loss/∂a = ", ∇loss[1], ",  ∂loss/∂b = ", ∇loss[2])

When using custom::MyType in the Toy system, x0 is empty which (I believe) then leads to this error:

┌ Warning: Potential performance improvement omitted. ReverseDiffVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call.
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/7zKgz/src/concrete_solve.jl:68

ERROR: LoadError: BoundsError: attempt to access 1-element Vector{Float64} at index [2]
Stacktrace:
 [1] throw_boundserror(A::Vector{Float64}, I::Tuple{Int64})
   @ Base ./essentials.jl:14
 [2] getindex(A::Vector{Float64}, i::Int64)
   @ Base ./essentials.jl:916
 [3] top-level scope
   @ ~/test_adjoint_mtk.jl:64
 [4] include(fname::String)
   @ Main ./sysimg.jl:38
 [5] top-level scope
   @ REPL[15]:1
in expression starting at /test_adjoint_mtk.jl:64

It does work when I set a=2.0 and b=3.0 as parameters directly (without CustomType).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions