Skip to content

Commit 7dc8f3c

Browse files
authored
Updates for DFG factor refactoring (#1870)
* Updates for DFG factor refactoring. * fix tests for removal of GFND * rather use getState -> getFactorState * bump to v0.36 and use copyto * getSolverData -> getVariableState
1 parent 43d7e25 commit 7dc8f3c

36 files changed

+262
-226
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ name = "IncrementalInference"
22
uuid = "904591bb-b899-562f-9e6f-b8df64c7d480"
33
keywords = ["MM-iSAMv2", "Bayes tree", "junction tree", "Bayes network", "variable elimination", "graphical models", "SLAM", "inference", "sum-product", "belief-propagation"]
44
desc = "Implements the Multimodal-iSAMv2 algorithm."
5-
version = "0.35.6"
5+
version = "0.36.0"
66

77
[deps]
88
ApproxManifoldProducts = "9bbbb610-88a1-53cd-9763-118ce10c1f89"
@@ -73,7 +73,7 @@ Combinatorics = "1.0"
7373
DataStructures = "0.16, 0.17, 0.18"
7474
DelimitedFiles = "1"
7575
DifferentialEquations = "7"
76-
DistributedFactorGraphs = "0.25, 0.26"
76+
DistributedFactorGraphs = "0.26, 0.27"
7777
Distributions = "0.24, 0.25"
7878
DocStringExtensions = "0.8, 0.9"
7979
FileIO = "1"

ext/FluxModelsSerialization.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ function _deserializeFluxDataBase64(sdata::AbstractString)
3333
return data
3434
end
3535

36-
function packDistribution(obj::FluxModelsDistribution)
36+
function DFG.packDistribution(obj::FluxModelsDistribution)
3737
#
3838

3939
# and the specialSampler function -- likely to be deprecated
@@ -75,7 +75,7 @@ function packDistribution(obj::FluxModelsDistribution)
7575
#
7676
end
7777

78-
function unpackDistribution(obj::PackedFluxModelsDistribution)
78+
function DFG.unpackDistribution(obj::PackedFluxModelsDistribution)
7979
#
8080
obj.serializeHollow && @warn(
8181
"Deserialization of FluxModelsDistribution.serializationHollow=true is not yet well developed, please open issues at IncrementalInference.jl accordingly."
@@ -104,15 +104,15 @@ function Base.convert(
104104
)
105105
#
106106
# convert to packed type first
107-
return packDistribution(obj)
107+
return DFG.packDistribution(obj)
108108
end
109109

110110
function convert(
111111
::Union{Type{<:SamplableBelief}, Type{<:FluxModelsDistribution}},
112112
obj::PackedFluxModelsDistribution,
113113
)
114114
#
115-
return unpackDistribution(obj)
115+
return DFG.unpackDistribution(obj)
116116
end
117117

118118
#

src/CliqueStateMachine/services/CliqueStateMachine.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ function initUp_StateMachine(csmc::CliqStateMachineContainer)
362362
if init_for_differential
363363
frontal_vars = getVariable.(csmc.cliqSubFg, getCliqFrontalVarIds(csmc.cliq))
364364
filter!(!isInitialized, frontal_vars)
365-
foreach(fvar -> getSolverData(fvar, csmc.solveKey).initialized = true, frontal_vars)
365+
foreach(fvar -> getVariableState(fvar, csmc.solveKey).initialized = true, frontal_vars)
366366
logCSM(
367367
csmc,
368368
"CSM-2b init_for_differential: ";
@@ -404,7 +404,7 @@ function initUp_StateMachine(csmc::CliqStateMachineContainer)
404404
## FIXME init to whatever is in frontals
405405
# set frontals init back to false
406406
if init_for_differential #experimental_sommer_init_to_whatever_is_in_frontals
407-
foreach(fvar -> getSolverData(fvar, csmc.solveKey).initialized = false, frontal_vars)
407+
foreach(fvar -> getVariableState(fvar, csmc.solveKey).initialized = false, frontal_vars)
408408
if someInit
409409
solveStatus = UPSOLVED
410410
end

src/ExportAPI.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ export CSMHistory,
277277
saveDFG,
278278
loadDFG!,
279279
loadDFG,
280-
rebuildFactorMetadata!,
280+
rebuildFactorCache!,
281281
getCliqVarSolveOrderUp,
282282
getFactorsAmongVariablesOnly,
283283
setfreeze!,
@@ -287,7 +287,6 @@ export CSMHistory,
287287

288288
# some utils
289289
compare,
290-
compareAllSpecial,
291290
getMeasurements,
292291
findFactorsBetweenFrom,
293292
addDownVariableFactors!,

src/Factors/Circular.jl

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,15 @@ Related
5151
5252
[`Circular`](@ref), [`Prior`](@ref), [`PartialPrior`](@ref)
5353
"""
54-
mutable struct PriorCircular{T <: SamplableBelief} <: AbstractPrior
55-
Z::T
56-
end
54+
DFG.@defFactorType PriorCircular AbstractPrior Manifolds.RealCircleGroup()
55+
56+
# mutable struct PriorCircular{T <: SamplableBelief} <: AbstractPrior
57+
# Z::T
58+
# end
5759

5860
PriorCircular(::UniformScaling) = PriorCircular(Normal())
5961

60-
DFG.getManifold(::PriorCircular) = RealCircleGroup()
62+
# DFG.getManifold(::PriorCircular) = RealCircleGroup()
6163

6264
function getSample(cf::CalcFactor{<:PriorCircular})
6365
# FIXME workaround for issue #TBD with manifolds CircularGroup,
@@ -77,21 +79,22 @@ function Base.convert(::Type{<:MB.AbstractManifold}, ::InstanceType{PriorCircula
7779
return Manifolds.RealCircleGroup()
7880
end
7981

80-
"""
81-
$(TYPEDEF)
82-
83-
Serialized object for storing PriorCircular.
84-
"""
85-
Base.@kwdef struct PackedPriorCircular <: AbstractPackedFactor
86-
Z::PackedSamplableBelief
87-
end
88-
function convert(::Type{PackedPriorCircular}, d::PriorCircular)
89-
return PackedPriorCircular(convert(PackedSamplableBelief, d.Z))
90-
end
91-
function convert(::Type{PriorCircular}, d::PackedPriorCircular)
92-
distr = convert(SamplableBelief, d.Z)
93-
return PriorCircular{typeof(distr)}(distr)
94-
end
82+
# """
83+
# $(TYPEDEF)
84+
85+
# Serialized object for storing PriorCircular.
86+
# """
87+
# Base.@kwdef struct PackedPriorCircular <: AbstractPackedFactor
88+
# Z::PackedSamplableBelief
89+
# end
90+
91+
# function convert(::Type{PackedPriorCircular}, d::PriorCircular)
92+
# return PackedPriorCircular(convert(PackedSamplableBelief, d.Z))
93+
# end
94+
# function convert(::Type{PriorCircular}, d::PackedPriorCircular)
95+
# distr = convert(SamplableBelief, d.Z)
96+
# return PriorCircular{typeof(distr)}(distr)
97+
# end
9598

9699
# --------------------------------------------
97100

src/Factors/DefaultPrior.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,11 @@ function convert(::Type{Prior}, d::PackedPrior)
3333
return Prior(convert(SamplableBelief, d.Z))
3434
end
3535

36+
function DFG.pack(d::Prior)
37+
return PackedPrior(DFG.packDistribution(d.Z))
38+
end
39+
function DFG.unpack(d::PackedPrior)
40+
return Prior(DFG.unpackDistribution(d.Z))
41+
end
42+
3643
#

src/Factors/LinearRelative.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,5 +68,11 @@ end
6868
function convert(::Type{LinearRelative}, d::PackedLinearRelative)
6969
return LinearRelative(convert(SamplableBelief, d.Z))
7070
end
71+
function DFG.pack(d::LinearRelative)
72+
return PackedLinearRelative(DFG.packDistribution(d.Z))
73+
end
74+
function DFG.unpack(d::PackedLinearRelative)
75+
return LinearRelative(DFG.unpackDistribution(d.Z))
76+
end
7177

7278
#

src/Factors/Mixture.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,13 @@ function convert(::Type{<:PackedMixture}, obj::Mixture{N, F, S, T}) where {N, F,
179179
# FIXME ON FIRE, likely to be difficult for non-standard "Samplable" types -- e.g. Flux models in RoME
180180
push!(allcomp, dtr_)
181181
end
182-
pm = DFG.convertPackedType(obj.mechanics)
183-
pm_ = convert(pm, obj.mechanics)
184-
sT = string(typeof(pm_))
182+
if hasmethod(pack, (typeof(obj.mechanics),))
183+
pm = pack(obj.mechanics)
184+
else
185+
@warn("No pack method for mechanics type $(typeof(obj.mechanics)), using deprecated convert instead.")
186+
pm = convert(DFG.convertPackedType(obj.mechanics), obj.mechanics)
187+
end
188+
sT = string(typeof(pm))
185189
dvst = convert(PackedSamplableBelief, obj.diversity)
186190
return PackedMixture(N, sT, string.(collect(S)), allcomp, dvst)
187191
end

src/IncrementalInference.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,9 @@ import ApproxManifoldProducts: getBW
6868
import ApproxManifoldProducts: mmd
6969
import ApproxManifoldProducts: isPartial
7070
import ApproxManifoldProducts: _update!
71-
import DistributedFactorGraphs: reconstFactorData
7271
import DistributedFactorGraphs: addVariable!, addFactor!, ls, lsf, isInitialized
73-
import DistributedFactorGraphs: compare, compareAllSpecial
74-
import DistributedFactorGraphs: rebuildFactorMetadata!
72+
import DistributedFactorGraphs: compare
73+
import DistributedFactorGraphs: rebuildFactorCache!
7574
import DistributedFactorGraphs: getDimension, getManifold, getPointType, getPointIdentity
7675
import DistributedFactorGraphs: getPPE, getPPEDict
7776
import DistributedFactorGraphs: getFactorOperationalMemoryType
@@ -80,6 +79,7 @@ import DistributedFactorGraphs: getVariableType
8079
import DistributedFactorGraphs: AbstractPointParametricEst, loadDFG
8180
import DistributedFactorGraphs: getFactorType
8281
import DistributedFactorGraphs: solveGraph!, solveGraphParametric!
82+
import DistributedFactorGraphs: packDistribution, unpackDistribution
8383

8484
# will be deprecated in IIF
8585
import DistributedFactorGraphs: isSolvable

src/Serialization/services/DispatchPackedConversions.jl

Lines changed: 51 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
function convert(
66
::Type{PackedFunctionNodeData{P}},
77
d::FunctionNodeData{T},
8-
) where {P <: AbstractPackedFactor, T <: FactorOperationalMemory}
8+
) where {P <: AbstractPackedFactor, T <: FactorSolverCache}
9+
error("TODO remove. PackedFunctionNodeData is obsolete")
910
return PackedFunctionNodeData(
1011
d.eliminated,
1112
d.potentialused,
@@ -22,12 +23,15 @@ end
2223
## unpack converters------------------------------------------------------------
2324

2425
# see #1424
25-
function reconstFactorData(
26+
#TODO Consolidate: this looks alot like `getDefaultFactorData`
27+
function DFG.reconstFactorData(
2628
dfg::AbstractDFG,
2729
varOrder::AbstractVector{Symbol},
2830
::Type{<:GenericFunctionNodeData{<:CommonConvWrapper{F}}},
2931
packed::GenericFunctionNodeData{<:AbstractPackedFactor},
3032
) where {F <: AbstractFactor}
33+
34+
error("TODO remove. Obsolete: use `DFG.rebuildFactorCache!` and getDefaultFactorData instead.")
3135
#
3236
# TODO store threadmodel=MutliThreaded,SingleThreaded in persistence layer
3337
usrfnc = convert(F, packed.fnc)
@@ -89,7 +93,7 @@ Dev Notes:
8993
- TODO: We should only really do this in-memory if we can by without it (review this).
9094
- TODO: needs testing
9195
"""
92-
function rebuildFactorMetadata!(
96+
function DFG.rebuildFactorCache!(
9397
dfg::AbstractDFG{SolverParams},
9498
factor::DFGFactor,
9599
neighbors = map(vId -> getVariable(dfg, vId), listNeighbors(dfg, factor));
@@ -99,50 +103,54 @@ function rebuildFactorMetadata!(
99103
# Set up the neighbor data
100104

101105
# Rebuilding the CCW
102-
fsd = getSolverData(factor)
103-
fnd_new = getDefaultFactorData(
106+
state = DFG.getFactorState(factor)
107+
state, solvercache = getDefaultFactorData(
104108
dfg,
105109
neighbors,
106-
getFactorType(factor);
107-
multihypo = fsd.multihypo,
108-
nullhypo = fsd.nullhypo,
109-
# special inflation override
110-
inflation = fsd.inflation,
111-
eliminated = fsd.eliminated,
112-
potentialused = fsd.potentialused,
113-
edgeIDs = fsd.edgeIDs,
114-
solveInProgress = fsd.solveInProgress,
110+
DFG.getObservation(factor);
111+
multihypo = state.multihypo,
112+
nullhypo = state.nullhypo,
113+
# special inflation override
114+
inflation = state.inflation,
115+
eliminated = state.eliminated,
116+
potentialused = state.potentialused,
117+
solveInProgress = state.solveInProgress,
115118
_blockRecursion=_blockRecursionGradients
116119
)
117120
#
118-
119-
factor_ = if typeof(fnd_new) != typeof(getSolverData(factor))
120-
# must change the type of factor solver data FND{CCW{...}}
121-
# create a new factor
122-
factor__ = DFGFactor(
123-
getLabel(factor),
124-
getTimestamp(factor),
125-
factor.nstime,
126-
getTags(factor),
127-
fnd_new,
128-
getSolvable(factor),
129-
Tuple(getVariableOrder(factor)),
130-
)
131-
#
132-
133-
# replace old factor in dfg with a new one
134-
deleteFactor!(dfg, factor; suppressGetFactor = true)
135-
addFactor!(dfg, factor__)
136-
137-
factor__
138-
else
139-
setSolverData!(factor, fnd_new)
140-
# We're not updating here because we don't want
141-
# to solve cloud in loop, we want to make sure this flow works:
142-
# Pull big cloud graph into local -> solve local -> push back into cloud.
143-
# updateFactor!(dfg, factor)
144-
factor
145-
end
121+
DFG.setCache!(factor, solvercache)
122+
return factor
123+
124+
# factor_ = if typeof(solvercache) != typeof(DFG.getCache(factor))
125+
# # must change the type of factor solver data FND{CCW{...}}
126+
# # create a new factor
127+
# factor__ = FactorCompute(
128+
# getLabel(factor),
129+
# Tuple(getVariableOrder(factor)),
130+
# DFG.getObservation(factor),
131+
# state,
132+
# solvercache;
133+
# timestamp = getTimestamp(factor),
134+
# nstime = factor.nstime,
135+
# tags = getTags(factor),
136+
# solvable = getSolvable(factor),
137+
# )
138+
# #
139+
140+
# # replace old factor in dfg with a new one
141+
# deleteFactor!(dfg, factor; suppressGetFactor = true)
142+
# addFactor!(dfg, factor__)
143+
144+
# factor__
145+
# else
146+
# setSolverData!(factor, new_solverData)
147+
# DFG.setCache!(factor, solvercache)
148+
# # We're not updating here because we don't want
149+
# # to solve cloud in loop, we want to make sure this flow works:
150+
# # Pull big cloud graph into local -> solve local -> push back into cloud.
151+
# # updateFactor!(dfg, factor)
152+
# factor
153+
# end
146154

147155
#... Copying neighbor data into the factor?
148156
# JT TODO it looks like this is already updated in getDefaultFactorData -> _createCCW
@@ -151,7 +159,7 @@ function rebuildFactorMetadata!(
151159
# ccw_new.fnc.cpt[i].factormetadata.variableuserdata = deepcopy(neighborUserData)
152160
# end
153161

154-
return factor_
162+
# return factor_
155163
end
156164

157165
## =================================================================

src/Serialization/services/SerializingDistributions.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,6 @@ function convert(::Type{<:PackedSamplableBelief}, obj::StringThemSamplableBelief
6969
end
7070
convert(::Type{<:SamplableBelief}, obj::PackedSamplableBelief) = unpackDistribution(obj)
7171

72-
function convert(::Type{<:PackedSamplableBelief}, nt::Union{NamedTuple, JSON3.Object})
73-
distrType = DFG.getTypeFromSerializationModule(nt._type)
74-
return distrType(; nt...)
75-
end
76-
7772
##===================================================================================
7873

7974
# FIXME ON FIRE, must deprecate nested JSON written fields in all serialization

src/entities/BeliefTypes.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ abstract type MessageType end
1616
struct NonparametricMessage <: MessageType end
1717
struct ParametricMessage <: MessageType end
1818

19-
abstract type PackedSamplableBelief end
20-
21-
StructTypes.StructType(::Type{<:PackedSamplableBelief}) = StructTypes.UnorderedStruct()
19+
using DistributedFactorGraphs: PackedSamplableBelief
2220

2321
const SamplableBelief = Union{
2422
<:Distributions.Distribution,
@@ -89,7 +87,7 @@ function TreeBelief(vnd::VariableNodeData{T}, solvDim::Real = 0) where {T}
8987
end
9088

9189
function TreeBelief(vari::DFGVariable, solveKey::Symbol = :default; solvableDim::Real = 0)
92-
return TreeBelief(getSolverData(vari, solveKey), solvableDim)
90+
return TreeBelief(getVariableState(vari, solveKey), solvableDim)
9391
end
9492
#
9593

src/entities/FactorGradients.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ Related
5959
[`calcFactorResidualTemporary`](@ref), [`_buildGraphByFactorAndTypes`](@ref)
6060
"""
6161
mutable struct FactorGradientsCached!{F <: AbstractRelative, S, M, P, G, L}
62-
dfgfct::DFGFactor{<:CommonConvWrapper{F}}
62+
dfgfct::DFGFactor{F}
6363
# cached jacobian matrix of gradients
6464
cached_gradients::Matrix{Float64}
6565
# likely <:AbstractVector, while CalcFactor residuals are vectors in Rn but could change to Tangent vectors

src/entities/FactorOperationalMemory.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Base.@kwdef struct CommonConvWrapper{
2727
HR <: HypoRecipeCompute,
2828
MT,
2929
G
30-
} <: FactorOperationalMemory
30+
} <: FactorSolverCache
3131
# Basic factor topological info
3232
""" Values consistent across all threads during approx convolution """
3333
usrfnc!::T # user factor / function

0 commit comments

Comments
 (0)