Skip to content

Commit a7a024c

Browse files
authored
Merge pull request #1803 from JuliaRobotics/23Q4/enh/derelins
cleanup more organized on DERelative
2 parents 05bf861 + cfed191 commit a7a024c

File tree

2 files changed

+65
-57
lines changed

2 files changed

+65
-57
lines changed

ext/IncrInfrDiffEqFactorExt.jl

Lines changed: 64 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,30 @@ using DocStringExtensions
1717

1818
export DERelative
1919

20-
import Manifolds: allocate
20+
import Manifolds: allocate, compose, hat, Identity, vee, log
2121

2222

2323
getManifold(de::DERelative{T}) where {T} = getManifold(de.domain)
2424

25+
26+
function Base.show(
27+
io::IO,
28+
::Union{<:DERelative{T,O},Type{<:DERelative{T,O}}}
29+
) where {T,O}
30+
println(io, " DERelative{")
31+
println(io, " ", T)
32+
println(io, " ", O.name.name)
33+
println(io, " }")
34+
nothing
35+
end
36+
37+
Base.show(
38+
io::IO,
39+
::MIME"text/plain",
40+
der::DERelative
41+
) = show(io, der)
42+
43+
2544
"""
2645
$SIGNATURES
2746
@@ -31,7 +50,9 @@ DevNotes
3150
- TODO does not yet incorporate Xi.nanosecond field.
3251
- TODO does not handle timezone crossing properly yet.
3352
"""
34-
function _calcTimespan(Xi::AbstractVector{<:DFGVariable})
53+
function _calcTimespan(
54+
Xi::AbstractVector{<:DFGVariable}
55+
)
3556
#
3657
tsmps = getTimestamp.(Xi[1:2]) .|> DateTime .|> datetime2unix
3758
# toffs = (tsmps .- tsmps[1]) .|> x-> elemType(x.value*1e-3)
@@ -50,10 +71,10 @@ function DERelative(
5071
f::Function,
5172
data = () -> ();
5273
dt::Real = 1,
53-
state0::AbstractVector{<:Real} = zeros(getDimension(domain)),
54-
state1::AbstractVector{<:Real} = zeros(getDimension(domain)),
74+
state0::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), # zeros(getDimension(domain)),
75+
state1::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), # zeros(getDimension(domain)),
5576
tspan::Tuple{<:Real, <:Real} = _calcTimespan(Xi),
56-
problemType = DiscreteProblem,
77+
problemType = ODEProblem, # DiscreteProblem,
5778
)
5879
#
5980
datatuple = if 2 < length(Xi)
@@ -63,11 +84,11 @@ function DERelative(
6384
data
6485
end
6586
# forward time problem
66-
fproblem = problemType(f, state0, tspan, datatuple; dt = dt)
87+
fproblem = problemType(f, state0, tspan, datatuple; dt)
6788
# backward time problem
6889
bproblem = problemType(f, state1, (tspan[2], tspan[1]), datatuple; dt = -dt)
6990
# build the IIF recognizable object
70-
return DERelative(domain, fproblem, bproblem, datatuple, getSample)
91+
return DERelative(domain, fproblem, bproblem, datatuple) #, getSample)
7192
end
7293

7394
function DERelative(
@@ -78,8 +99,8 @@ function DERelative(
7899
data = () -> ();
79100
Xi::AbstractArray{<:DFGVariable} = getVariable.(dfg, labels),
80101
dt::Real = 1,
81-
state0::AbstractVector{<:Real} = zeros(getDimension(domain)),
82-
state1::AbstractVector{<:Real} = zeros(getDimension(domain)),
102+
state1::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), #zeros(getDimension(domain)),
103+
state0::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), #zeros(getDimension(domain)),
83104
tspan::Tuple{<:Real, <:Real} = _calcTimespan(Xi),
84105
problemType = DiscreteProblem,
85106
)
@@ -88,18 +109,23 @@ function DERelative(
88109
domain,
89110
f,
90111
data;
91-
dt = dt,
92-
state0 = state0,
93-
state1 = state1,
94-
tspan = tspan,
95-
problemType = problemType,
112+
dt,
113+
state0,
114+
state1,
115+
tspan,
116+
problemType,
96117
)
97118
end
98119
#
99120
#
100121

101122
# n-ary factor: Xtra splat are variable points (X3::Matrix, X4::Matrix,...)
102-
function _solveFactorODE!(measArr, prob, u0pts, Xtra...)
123+
function _solveFactorODE!(
124+
measArr,
125+
prob,
126+
u0pts,
127+
Xtra...
128+
)
103129
# happens when more variables (n-ary) must be included in DE solve
104130
for (xid, xtra) in enumerate(Xtra)
105131
# update the data register before ODE solver calls the function
@@ -159,21 +185,21 @@ end
159185

160186

161187
# NOTE see #1025, CalcFactor should fix `multihypo=` in `cf.__` fields; OBSOLETE
162-
function (cf::CalcFactor{<:DERelative})(measurement, X...)
188+
function (cf::CalcFactor{<:DERelative})(
189+
measurement,
190+
X...
191+
)
163192
#
193+
# numerical measurement values
164194
meas1 = measurement[1]
165-
diffOp = measurement[2]
166-
195+
# work on-manifold via sampleFactor piggy back of particular manifold definition
196+
M = measurement[2]
197+
# lazy factor pointer
167198
oderel = cf.factor
168-
169-
# work on-manifold
170-
# diffOp = meas[2]
171-
# if backwardSolve else forward
172-
173199
# check direction
174-
175200
solveforIdx = cf.solvefor
176-
201+
202+
# if backwardSolve else forward
177203
if solveforIdx > 2
178204
# need to recalculate new ODE (forward) for change in parameters (solving for 3rd or higher variable)
179205
solveforIdx = 2
@@ -189,16 +215,10 @@ function (cf::CalcFactor{<:DERelative})(measurement, X...)
189215
end
190216

191217
# find the difference between measured and predicted.
192-
## assuming the ODE integrated from current X1 through to predicted X2 (ie `meas1[:,idx]`)
193-
## FIXME, obviously this is not going to work for more compilcated groups/manifolds -- must fix this soon!
194-
# @show cf._sampleIdx, solveforIdx, meas1
195-
196-
#FIXME
197-
res = zeros(size(X[2], 1))
198-
for i = 1:size(X[2], 1)
199-
# diffop( reference?, test? ) <===> ΔX = test \ reference
200-
res[i] = diffOp[i](X[solveforIdx][i], meas1[i])
201-
end
218+
# assuming the ODE integrated from current X1 through to predicted X2 (ie `meas1[:,idx]`)
219+
res_ = compose(M, inv(M, X[solveforIdx]), meas1)
220+
res = vee(M, Identity(M), log(M, Identity(M), res_))
221+
202222
return res
203223
end
204224

@@ -260,23 +280,25 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int
260280

261281
# pick forward or backward direction
262282
# set boundary condition
263-
u0pts = if cf.solvefor == 1
283+
u0pts, M = if cf.solvefor == 1
264284
# backward direction
265285
prob = oder.backwardProblem
286+
M_ = getManifold(getVariableType(cf.fullvariables[1]))
266287
addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks(
267-
convert(Tuple, getManifold(getVariableType(cf.fullvariables[1]))),
288+
convert(Tuple, M_),
268289
)
269290
# getBelief(cf.fullvariables[2]) |> getPoints
270-
cf._legacyParams[2]
291+
cf._legacyParams[2], M_
271292
else
272293
# forward backward
273294
prob = oder.forwardProblem
295+
M_ = getManifold(getVariableType(cf.fullvariables[2]))
274296
# buffer manifold operations for use during factor evaluation
275297
addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks(
276-
convert(Tuple, getManifold(getVariableType(cf.fullvariables[2]))),
298+
convert(Tuple, M_),
277299
)
278300
# getBelief(cf.fullvariables[1]) |> getPoints
279-
cf._legacyParams[1]
301+
cf._legacyParams[1], M_
280302
end
281303

282304
# solve likely elements
@@ -287,25 +309,11 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int
287309
# _solveFactorODE!(meas, prob, u0pts, i, _maketuplebeyond2args(cf._legacyParams...)...)
288310
end
289311

290-
return map(x -> (x, diffOp), meas)
312+
# return meas, M
313+
return map(x -> (x, M), meas)
291314
end
292315
# getDimension(oderel.domain)
293316

294317

295318

296-
function Base.show(io::IO, ::Union{<:DERelative{T,O},Type{<:DERelative{T,O}}}) where {T,O}
297-
println(io, " DERelative{")
298-
println(io, " ", T)
299-
println(io, " ", O.name.name)
300-
println(io, " }")
301-
nothing
302-
end
303-
304-
Base.show(io::IO, ::MIME"text/plain", der::DERelative) = show(io, der)
305-
306-
## the function
307-
# ode.problem.f.f
308-
309-
#
310-
311319
end # module

src/entities/ExtFactors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@ struct DERelative{T <: InferenceVariable, P, D} <: AbstractManifoldMinimize # Ab
2525
backwardProblem::P
2626
""" second element of this data tuple is additional variables that will be passed down as a parameter """
2727
data::D
28-
specialSampler::Function
28+
# specialSampler::Function
2929
end

0 commit comments

Comments
 (0)