@@ -2,6 +2,8 @@ module IncrInfrDiffEqFactorExt
2
2
3
3
@info " IncrementalInference.jl is loading extensions related to DifferentialEquations.jl"
4
4
5
+ import Base: show
6
+
5
7
using DifferentialEquations
6
8
import DifferentialEquations: solve
7
9
@@ -15,6 +17,7 @@ using DocStringExtensions
15
17
16
18
export DERelative
17
19
20
+ import Manifolds: allocate
18
21
19
22
20
23
getManifold (de:: DERelative{T} ) where {T} = getManifold (de. domain)
@@ -100,11 +103,12 @@ function _solveFactorODE!(measArr, prob, u0pts, Xtra...)
100
103
# happens when more variables (n-ary) must be included in DE solve
101
104
for (xid, xtra) in enumerate (Xtra)
102
105
# update the data register before ODE solver calls the function
103
- prob. p[xid + 1 ][:] = xtra[:]
106
+ prob. p[xid + 1 ][:] = xtra[:] # FIXME , unlikely to work with ArrayPartition, maybe use MArray and `.=`
104
107
end
105
108
106
109
# set the initial condition
107
- prob. u0[:] = u0pts[:]
110
+ prob. u0 .= u0pts
111
+
108
112
sol = DifferentialEquations. solve (prob)
109
113
110
114
# extract solution from solved ode
@@ -249,8 +253,10 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int
249
253
oder = cf. factor
250
254
251
255
# how many trajectories to propagate?
252
- # @show getLabel(cf.fullvariables[2]), getDimension(cf.fullvariables[2])
253
- meas = [zeros (getDimension (cf. fullvariables[2 ])) for _ = 1 : N]
256
+ #
257
+ v2T = getVariableType (cf. fullvariables[2 ])
258
+ meas = [allocate (getPointIdentity (v2T)) for _ = 1 : N]
259
+ # meas = [zeros(getDimension(cf.fullvariables[2])) for _ = 1:N]
254
260
255
261
# pick forward or backward direction
256
262
# set boundary condition
287
293
288
294
289
295
296
+ function Base. show (io:: IO , :: Union{<:DERelative{T,O},Type{<:DERelative{T,O}}} ) where {T,O}
297
+ println (io, " DERelative{" )
298
+ println (io, " " , T)
299
+ println (io, " " , O. name. name)
300
+ println (io, " }" )
301
+ nothing
302
+ end
290
303
304
+ Base. show (io:: IO , :: MIME"text/plain" , der:: DERelative ) = show (io, der)
291
305
292
306
# # the function
293
307
# ode.problem.f.f
0 commit comments