Skip to content

Commit dd78b10

Browse files
committed
Experiment with factor gradients in solve
1 parent 2d42ae3 commit dd78b10

File tree

3 files changed

+50
-4
lines changed

3 files changed

+50
-4
lines changed

IncrementalInference/src/parametric/services/ParametricManopt.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,17 @@ function CalcFactorResidualAP(
7272
return ArrayPartition{CalcFactorResidual, typeof(parts_tuple)}(parts_tuple)
7373
end
7474

75-
function (cfm::CalcFactorResidual)(p)
75+
function (cfm::CalcFactorResidual)(p::Vector)
7676
meas = cfm.meas
7777
points = map(idx->p[idx], cfm.varOrderIdxs)
7878
return cfm.sqrt_iΣ * cfm(meas, points...)
7979
end
8080

81+
function (cfm::CalcFactorResidual)(p::ArrayPartition)
82+
points = map(idx->p.x[idx], cfm.varOrderIdxs)
83+
return cfm.sqrt_iΣ * cfm(cfm.meas, points...)
84+
end
85+
8186
# cost function f: M->ℝᵈ for Riemannian Levenberg-Marquardt
8287
struct CostFres_cond!{PT, CFT}
8388
points::PT

IncrementalInference/src/services/FactorGradients.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,36 @@ function factorJacobian(
3434
return ManifoldDiff.jacobian(M_dom, M_codom, costf, p0, backend)
3535
end
3636

37+
function factorGradient(
38+
cf::CalcFactorResidual,
39+
M,
40+
p,
41+
backend = ManifoldDiff.TangentDiffBackend(ManifoldDiff.FiniteDiffBackend()),
42+
)
43+
ManifoldDiff.gradient(M, (x) -> 1//2 * norm(cf(x))^2, p, backend)
44+
end
3745

46+
function factorJacobian(
47+
cf::CalcFactorResidual,
48+
M_dom,
49+
p,
50+
backend = ManifoldDiff.TangentDiffBackend(ManifoldDiff.FiniteDiffBackend()),
51+
)
52+
# M_dom = ProductManifold(getManifold.(fg, varlabels)...)
53+
M_codom = Euclidean(manifold_dimension(getManifold(cf)))
54+
55+
return ManifoldDiff.jacobian(M_dom, M_codom, cf, p, backend)
56+
end
57+
58+
#
59+
function factorGradient(
60+
cf::CalcFactorNormSq,
61+
M,
62+
p,
63+
backend = ManifoldDiff.TangentDiffBackend(ManifoldDiff.FiniteDiffBackend()),
64+
)
65+
ManifoldDiff.gradient(M, cf, p, backend)
66+
end
3867

3968
export getCoordSizes
4069
export checkGradientsToleranceMask, calcPerturbationFromVariable

IncrementalInference/src/services/NumericalCalculations.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,16 @@ function _solveLambdaNumeric(
7171
return r.minimizer
7272
end
7373

74-
# struct OptimCalcConv end
74+
struct OptimCalcConv end
7575
# CalcFactorNormSq cost function for an input in coordinates as used by Optim.jl
76-
function (hypoCalcFactor::CalcFactorNormSq)(M::AbstractManifold, Xc::AbstractVector)
76+
function (hypoCalcFactor::CalcFactorNormSq)(::Type{OptimCalcConv}, M::AbstractManifold, Xc::AbstractVector)
7777
# hypoCalcFactor.manifold is the factor's manifold, not the variable's manifold that is needed here
7878
ϵ = getPointIdentity(M)
7979
X = get_vector(M, ϵ, SVector(Xc), DefaultOrthogonalBasis())
8080
p = exp(M, ϵ, X)
8181
return hypoCalcFactor(CalcConv, p)
8282
end
83+
(hypoCalcFactor::CalcFactorNormSq)(M::AbstractManifold, p) = hypoCalcFactor(OptimCalcConv, M, p)
8384

8485
struct ManoptCalcConv end
8586

@@ -117,10 +118,19 @@ function _solveLambdaNumeric(
117118
retraction_method = ExponentialRetraction()
118119
)
119120
return r
121+
elseif false
122+
r = gradient_descent(
123+
M,
124+
(M,x)->hypoCalcFactor(x),
125+
(M, x)-> factorGradient(hypoCalcFactor, M, x),
126+
u0;
127+
stepsize=ConstantStepsize(0.1),
128+
)
129+
return r
120130
end
121131

122132
r = Optim.optimize(
123-
x->hypoCalcFactor(M, x),
133+
x->hypoCalcFactor(OptimCalcConv, M, x),
124134
X0c,
125135
alg
126136
)
@@ -394,6 +404,8 @@ function (cf::CalcFactorNormSq)(::Type{CalcConv}, x)
394404
res = isnothing(cf.slack) ? res : res .- cf.slack
395405
return sum(x->x^2, res)
396406
end
407+
#default to conv
408+
(cf::CalcFactorNormSq)(x) = cf(CalcConv, x)
397409

398410
function _buildHypoCalcFactor(ccwl::CommonConvWrapper, smpid::Integer, _slack=nothing)
399411
# build a view to the decision variable memory

0 commit comments

Comments
 (0)