Skip to content

Commit 046d51c

Browse files
feat: add support for obtaining derivatives of CoSimulation FMU outputs
1 parent 7561632 commit 046d51c

File tree

1 file changed

+82
-5
lines changed

1 file changed

+82
-5
lines changed

ext/MTKFMIExt.jl

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -249,15 +249,21 @@ function MTK.FMIComponent(::Val{Ver}; fmu = nothing, tolerance = 1e-6,
249249
FMI3CSFunctor(state_value_references, output_value_references)
250250
end
251251
@parameters (functor::(typeof(_functor)))(..)[1:(length(__mtk_internal_u) + length(__mtk_internal_o))] = _functor
252-
# for co-simulation, we need to ensure the output buffer is solved for
253-
# during initialization
252+
253+
diffeqs = Equation[]
254254
for (i, x) in enumerate(collect(__mtk_internal_o))
255+
# for co-simulation, we need to ensure the output buffer is solved for
256+
# during initialization
255257
push!(initialization_eqs,
256258
x ~ functor(
257-
wrapper, __mtk_internal_u, __mtk_internal_x, __mtk_internal_p, t)[i])
258-
end
259+
wrapper, (__mtk_internal_u), __mtk_internal_x, __mtk_internal_p, t)[i])
259260

260-
diffeqs = Equation[]
261+
# also add equations for output derivatives
262+
push!(diffeqs,
263+
D(x) ~ term(
264+
getOutputDerivative, functor, wrapper, i, 1, collect(__mtk_internal_u),
265+
__mtk_internal_x, __mtk_internal_p, t; type = Real))
266+
end
261267

262268
# use `ImperativeAffect` for instance management here
263269
cb_observed = (; inputs = __mtk_internal_x, params = copy(params),
@@ -739,6 +745,15 @@ struct FMI2CSFunctor
739745
The value references of output variables in the FMU.
740746
"""
741747
output_value_references::Vector{FMI.fmi2ValueReference}
748+
"""
749+
Simply a buffer to store the order of output derivative required from
750+
`getRealOutputderivatives` and avoid unnecessary allocations.
751+
"""
752+
output_derivative_order_buffer::Vector{FMI.fmi2Integer}
753+
end
754+
755+
function FMI2CSFunctor(svref, ovref)
756+
FMI2CSFunctor(svref, ovref, FMI.fmi2Integer[1])
742757
end
743758

744759
function (fn::FMI2CSFunctor)(wrapper::FMI2InstanceWrapper, states, inputs, params, t)
@@ -764,6 +779,41 @@ end
764779
ndims = 1
765780
end
766781

782+
"""
783+
$(TYPEDSIGNATURES)
784+
785+
Calculate the `order` order derivative of the `var`th output of the FMU.
786+
"""
787+
function getOutputDerivative(fn::FMI2CSFunctor, wrapper::FMI2InstanceWrapper, var::Int,
788+
order::FMI.fmi2Integer, states, inputs, params, t)
789+
states = states isa SubArray ? copy(states) : states
790+
inputs = inputs isa SubArray ? copy(inputs) : inputs
791+
params = params isa SubArray ? copy(params) : params
792+
instance = get_instance_CS!(wrapper, states, inputs, params, t)
793+
fn.output_derivative_order_buffer[] = order
794+
return FMI.fmi2GetRealOutputDerivatives(
795+
instance, fn.output_value_references[var], fn.output_derivative_order_buffer)
796+
end
797+
798+
# @register_symbolic getOutputDerivative(fn::FMI2CSFunctor, w::FMI2InstanceWrapper, var::Int, order::FMI.fmi2Integer, states::Vector{<:Real}, inputs::Vector{<:Real}, params::Vector{<:Real}, t::Real)
799+
800+
# HACK-ish for allowing higher order output derivatives
801+
# The first `D(output)` will result in a `getOutputDerivatives` expression.
802+
# Subsequent differentiations of this expression will expand to
803+
# `Σ_{i = 1:8} Differential(args[i])(getOutputDerivative(args...)) * D(args[i])`
804+
# using the chain rule. `i = 1:4` are not time-dependent (or real). We define
805+
# the derivatives for `i = 5:7` to be zero, and the derivative for `i = 8` (w.r.t `t`)
806+
# to be the same `getOutputDerivative` call but with the order increased. This results
807+
# in `D(output) = getOutputDerivative(fn, w, var, order + 1, states, inputs, params, t) * 1`
808+
# which is exactly what we want.
809+
for i in 1:7
810+
@eval Symbolics.derivative(::typeof(getOutputDerivative), args::NTuple{8, Any}, ::Val{$i}) = 0
811+
end
812+
function Symbolics.derivative(::typeof(getOutputDerivative), args::NTuple{8, Any}, ::Val{8})
813+
term(getOutputDerivative, args[1], args[2], args[3],
814+
args[4] + 1, args[5], args[6], args[7], args[8])
815+
end
816+
767817
"""
768818
$(TYPEDSIGNATURES)
769819
@@ -848,6 +898,15 @@ struct FMI3CSFunctor
848898
The value references of output variables in the FMU.
849899
"""
850900
output_value_references::Vector{FMI.fmi3ValueReference}
901+
"""
902+
Simply a buffer to store the order of output derivative required from
903+
`getRealOutputderivatives` and avoid unnecessary allocations.
904+
"""
905+
output_derivative_order_buffer::Vector{FMI.fmi3Int32}
906+
end
907+
908+
function FMI3CSFunctor(svref, ovref)
909+
FMI3CSFunctor(svref, ovref, FMI.fmi3Int32[1])
851910
end
852911

853912
function (fn::FMI3CSFunctor)(wrapper::FMI3InstanceWrapper, states, inputs, params, t)
@@ -871,6 +930,24 @@ end
871930
ndims = 1
872931
end
873932

933+
"""
934+
$(TYPEDSIGNATURES)
935+
936+
Calculate the `order` order derivative of the `var`th output of the FMU.
937+
"""
938+
function getOutputDerivative(fn::FMI3CSFunctor, wrapper::FMI3InstanceWrapper, var::Int,
939+
order::FMI.fmi3Int32, states, inputs, params, t)
940+
states = states isa SubArray ? copy(states) : states
941+
inputs = inputs isa SubArray ? copy(inputs) : inputs
942+
params = params isa SubArray ? copy(params) : params
943+
instance = get_instance_CS!(wrapper, states, inputs, params, t)
944+
fn.output_derivative_order_buffer[] = order
945+
return FMI.fmi3GetOutputDerivatives(
946+
instance, fn.output_value_references[var], fn.output_derivative_order_buffer)
947+
end
948+
949+
# @register_symbolic getOutputDerivative(fn::FMI3CSFunctor, w::FMI3InstanceWrapper, var::Int, order::FMI.fmi3Int32, states::Vector{<:Real}, inputs::Vector{<:Real}, params::Vector{<:Real}, t::Real) false
950+
874951
"""
875952
$(TYPEDSIGNATURES)
876953
"""

0 commit comments

Comments
 (0)