Skip to content

Commit e1d3af6

Browse files
committed
DERelative residual more on-manifold
1 parent eb17dad commit e1d3af6

File tree

2 files changed

+28
-21
lines changed

2 files changed

+28
-21
lines changed

ext/IncrInfrDiffEqFactorExt.jl

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ using DocStringExtensions
1717

1818
export DERelative
1919

20-
import Manifolds: allocate
20+
import Manifolds: allocate, compose, hat, Identity, vee, log
2121

2222

2323
getManifold(de::DERelative{T}) where {T} = getManifold(de.domain)
@@ -67,7 +67,7 @@ function DERelative(
6767
# backward time problem
6868
bproblem = problemType(f, state1, (tspan[2], tspan[1]), datatuple; dt = -dt)
6969
# build the IIF recognizable object
70-
return DERelative(domain, fproblem, bproblem, datatuple, getSample)
70+
return DERelative(domain, fproblem, bproblem, datatuple) #, getSample)
7171
end
7272

7373
function DERelative(
@@ -88,11 +88,11 @@ function DERelative(
8888
domain,
8989
f,
9090
data;
91-
dt = dt,
92-
state0 = state0,
93-
state1 = state1,
94-
tspan = tspan,
95-
problemType = problemType,
91+
dt,
92+
state0,
93+
state1,
94+
tspan,
95+
problemType,
9696
)
9797
end
9898
#
@@ -162,7 +162,8 @@ end
162162
function (cf::CalcFactor{<:DERelative})(measurement, X...)
163163
#
164164
meas1 = measurement[1]
165-
diffOp = measurement[2]
165+
M = measurement[2]
166+
# diffOp = measurement[2]
166167

167168
oderel = cf.factor
168169

@@ -193,12 +194,15 @@ function (cf::CalcFactor{<:DERelative})(measurement, X...)
193194
## FIXME, obviously this is not going to work for more compilcated groups/manifolds -- must fix this soon!
194195
# @show cf._sampleIdx, solveforIdx, meas1
195196

196-
#FIXME
197-
res = zeros(size(X[2], 1))
198-
for i = 1:size(X[2], 1)
199-
# diffop( reference?, test? ) <===> ΔX = test \ reference
200-
res[i] = diffOp[i](X[solveforIdx][i], meas1[i])
201-
end
197+
res_ = compose(M, inv(M, X[solveforIdx]), meas1)
198+
res = vee(M, Identity(M), log(M, Identity(M), res_))
199+
200+
# #FIXME 0
201+
# res = zeros(size(X[2], 1))
202+
# for i = 1:size(X[2], 1)
203+
# # diffop( reference?, test? ) <===> ΔX = test \ reference
204+
# res[i] = diffOp[i](X[solveforIdx][i], meas1[i])
205+
# end
202206
return res
203207
end
204208

@@ -260,23 +264,25 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int
260264

261265
# pick forward or backward direction
262266
# set boundary condition
263-
u0pts = if cf.solvefor == 1
267+
u0pts, M = if cf.solvefor == 1
264268
# backward direction
265269
prob = oder.backwardProblem
270+
M_ = getManifold(getVariableType(cf.fullvariables[1]))
266271
addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks(
267-
convert(Tuple, getManifold(getVariableType(cf.fullvariables[1]))),
272+
convert(Tuple, M_),
268273
)
269274
# getBelief(cf.fullvariables[2]) |> getPoints
270-
cf._legacyParams[2]
275+
cf._legacyParams[2], M_
271276
else
272277
# forward backward
273278
prob = oder.forwardProblem
279+
M_ = getManifold(getVariableType(cf.fullvariables[2]))
274280
# buffer manifold operations for use during factor evaluation
275281
addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks(
276-
convert(Tuple, getManifold(getVariableType(cf.fullvariables[2]))),
282+
convert(Tuple, M_),
277283
)
278284
# getBelief(cf.fullvariables[1]) |> getPoints
279-
cf._legacyParams[1]
285+
cf._legacyParams[1], M_
280286
end
281287

282288
# solve likely elements
@@ -287,7 +293,8 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int
287293
# _solveFactorODE!(meas, prob, u0pts, i, _maketuplebeyond2args(cf._legacyParams...)...)
288294
end
289295

290-
return map(x -> (x, diffOp), meas)
296+
# return meas, M
297+
return map(x -> (x, M), meas)
291298
end
292299
# getDimension(oderel.domain)
293300

src/entities/ExtFactors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@ struct DERelative{T <: InferenceVariable, P, D} <: AbstractManifoldMinimize # Ab
2525
backwardProblem::P
2626
""" second element of this data tuple is additional variables that will be passed down as a parameter """
2727
data::D
28-
specialSampler::Function
28+
# specialSampler::Function
2929
end

0 commit comments

Comments
 (0)