Skip to content

Commit 23ee348

Browse files
committed
add array saveTree loadTree
1 parent d7b59fd commit 23ee348

File tree

2 files changed

+53
-4
lines changed

2 files changed

+53
-4
lines changed

src/JunctionTree.jl

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1505,7 +1505,9 @@ Related
15051505
15061506
IIF.loadTree, DFG.saveDFG, DFG.loadDFG, JLD2.@save, JLD2.@load
15071507
"""
1508-
function saveTree(treel::BayesTree, filepath=joinpath("/tmp","caesar","savetree.jld2"))
1508+
function saveTree(treel::BayesTree,
1509+
filepath=joinpath("/tmp","caesar","savetree.jld2") )
1510+
#
15091511
savetree = deepcopy(treel)
15101512
for i in 1:length(savetree.cliques)
15111513
if savetree.cliques[i].attributes["data"] isa BayesTreeNodeData
@@ -1517,6 +1519,20 @@ function saveTree(treel::BayesTree, filepath=joinpath("/tmp","caesar","savetree.
15171519
return filepath
15181520
end
15191521

1522+
function saveTree(treeArr::Vector{BayesTree},
1523+
filepath=joinpath("/tmp","caesar","savetrees.jld2") )
1524+
#
1525+
savetree = deepcopy(treeArr)
1526+
for savtre in savetree, i in 1:length(savtre.cliques)
1527+
if savtre.cliques[i].attributes["data"] isa BayesTreeNodeData
1528+
savtre.cliques[i].attributes["data"] = convert(PackedBayesTreeNodeData, savtre.cliques[i].attributes["data"])
1529+
end
1530+
end
1531+
1532+
JLD2.@save filepath savetree
1533+
return filepath
1534+
end
1535+
15201536
"""
15211537
$SIGNATURES
15221538
@@ -1534,9 +1550,17 @@ function loadTree(filepath=joinpath("/tmp","caesar","savetree.jld2"))
15341550
data = @load filepath savetree
15351551

15361552
# convert back to a type that which could not be serialized by JLD2
1537-
for i in 1:length(savetree.cliques)
1538-
if savetree.cliques[i].attributes["data"] isa PackedBayesTreeNodeData
1539-
savetree.cliques[i].attributes["data"] = convert(BayesTreeNodeData, savetree.cliques[i].attributes["data"])
1553+
if savetree isa Vector
1554+
for savtre in savetree, i in 1:length(savtre.cliques)
1555+
if savtre.cliques[i].attributes["data"] isa PackedBayesTreeNodeData
1556+
savtre.cliques[i].attributes["data"] = convert(BayesTreeNodeData, savtre.cliques[i].attributes["data"])
1557+
end
1558+
end
1559+
else
1560+
for i in 1:length(savetree.cliques)
1561+
if savetree.cliques[i].attributes["data"] isa PackedBayesTreeNodeData
1562+
savetree.cliques[i].attributes["data"] = convert(BayesTreeNodeData, savetree.cliques[i].attributes["data"])
1563+
end
15401564
end
15411565
end
15421566

test/testTreeSaveLoad.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,28 @@ using IncrementalInference
2626
end
2727

2828
end
29+
30+
31+
@testset "Test loading and saving of Bayes (Junction) tree" begin
32+
33+
fg = loadCanonicalFG_Kaess(graphinit=false)
34+
tree = wipeBuildNewTree!(fg)
35+
36+
# save and load tree as array
37+
filepath = saveTree([tree;deepcopy(tree)])
38+
trees = loadTree(filepath)
39+
40+
# perform a few spot checks to see that the trees are similar
41+
@test length(tree.cliques) == length(trees[1].cliques)
42+
@test getVariableOrder(tree) == getVariableOrder(trees[1])
43+
44+
for (clid,cl) in tree.cliques
45+
fsyms = getFrontals(cl)
46+
cl2 = getCliq(trees[1], fsyms[1])
47+
fsyms2 = getFrontals(cl2)
48+
@test fsyms == fsyms2
49+
@test getCliqSeparatorVarIds(cl) == getCliqSeparatorVarIds(cl2)
50+
@test typeof(cl) == typeof(cl2)
51+
end
52+
53+
end

0 commit comments

Comments
 (0)