Skip to content

Commit eb17dad

Browse files
committed
fix DERelative for manifold meas
1 parent 23c46b9 commit eb17dad

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

ext/IncrInfrDiffEqFactorExt.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ module IncrInfrDiffEqFactorExt
22

33
@info "IncrementalInference.jl is loading extensions related to DifferentialEquations.jl"
44

5+
import Base: show
6+
57
using DifferentialEquations
68
import DifferentialEquations: solve
79

@@ -15,6 +17,7 @@ using DocStringExtensions
1517

1618
export DERelative
1719

20+
import Manifolds: allocate
1821

1922

2023
getManifold(de::DERelative{T}) where {T} = getManifold(de.domain)
@@ -100,11 +103,11 @@ function _solveFactorODE!(measArr, prob, u0pts, Xtra...)
100103
# happens when more variables (n-ary) must be included in DE solve
101104
for (xid, xtra) in enumerate(Xtra)
102105
# update the data register before ODE solver calls the function
103-
prob.p[xid + 1][:] = xtra[:]
106+
prob.p[xid + 1][:] = xtra[:] # FIXME, unlikely to work with ArrayPartition, maybe use MArray and `.=`
104107
end
105108

106109
# set the initial condition
107-
prob.u0 = u0pts
110+
prob.u0 .= u0pts
108111

109112
sol = DifferentialEquations.solve(prob)
110113

@@ -250,8 +253,10 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int
250253
oder = cf.factor
251254

252255
# how many trajectories to propagate?
253-
# @show getLabel(cf.fullvariables[2]), getDimension(cf.fullvariables[2])
254-
meas = [zeros(getDimension(cf.fullvariables[2])) for _ = 1:N]
256+
#
257+
v2T = getVariableType(cf.fullvariables[2])
258+
meas = [allocate(getPointIdentity(v2T)) for _ = 1:N]
259+
# meas = [zeros(getDimension(cf.fullvariables[2])) for _ = 1:N]
255260

256261
# pick forward or backward direction
257262
# set boundary condition
@@ -288,7 +293,15 @@ end
288293

289294

290295

296+
function Base.show(io::IO, ::Union{<:DERelative{T,O},Type{<:DERelative{T,O}}}) where {T,O}
297+
println(io, " DERelative{")
298+
println(io, " ", T)
299+
println(io, " ", O.name.name)
300+
println(io, " }")
301+
nothing
302+
end
291303

304+
Base.show(io::IO, ::MIME"text/plain", der::DERelative) = show(io, der)
292305

293306
## the function
294307
# ode.problem.f.f

0 commit comments

Comments
 (0)