Skip to content

Commit 803d264

Browse files
authored
Merge pull request #963 from JuliaRobotics/test/4Q20/betterfluxmix
improve IIF testing
2 parents edfbefc + 5a6ffe5 commit 803d264

File tree

1 file changed

+72
-1
lines changed

1 file changed

+72
-1
lines changed

test/testFluxModelsDistribution.jl

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ end
3232

3333

3434

35-
3635
@testset "FluxModelsDistribution serialization" begin
3736

3837

@@ -63,8 +62,80 @@ ff1.Z.shuffle[] = true
6362

6463
solveTree!(fg_);
6564

65+
# remove the testing file
66+
Base.rm("/tmp/fg_test_flux.tar.gz")
67+
68+
end
69+
70+
71+
72+
73+
@testset "FluxModelsDistribution as Mixture with relative factor" begin
74+
75+
76+
mdls = [Chain(Dense(10,50, relu),Dense(50,20),softmax, Dense(20,1, tanh)) for i in 1:50];
77+
fxd = FluxModelsDistribution((10,),(1,),mdls,rand(10), false, false)
78+
6679

80+
fg = initfg()
81+
82+
addVariable!(fg, :x0, ContinuousEuclid{1})
83+
addVariable!(fg, :x1, ContinuousEuclid{1})
84+
85+
# a prior
86+
pr = Prior(Normal())
87+
addFactor!(fg, [:x0;], pr)
88+
89+
# a relative mixture network
90+
mfx = Mixture(LinearRelative, (naive=Normal(10, 10), nn=fxd), [0.5;0.5])
91+
addFactor!(fg, [:x0;:x1], mfx)
92+
93+
# and test overall serialization before solving
94+
saveDFG("/tmp/fg_test_flux", fg)
95+
96+
# solve existing fg
97+
solveTree!(fg)
98+
99+
# prior should pin x0 pretty well
100+
@test 80 < sum(-3 .< (getBelief(fg, :x0) |> getPoints) .< 3)
101+
# at least some points should land according to the naive model
102+
@test 5 < sum(5 .< (getBelief(fg, :x1) |> getPoints) .< 15)
103+
104+
105+
# will predict from existing fg
106+
f1 = getFactorType(fg, :x0x1f1)
107+
predictions = map(f->f(f1.components.nn.data), f1.components.nn.models)
108+
109+
110+
# unpack into new fg_
111+
fg_ = loadDFG("/tmp/fg_test_flux")
112+
113+
# same predictions with deserialized object
114+
f1_ = getFactorType(fg_, :x0x1f1)
115+
predictions_ = map(f->f(f1_.components.nn.data), f1_.components.nn.models)
116+
117+
# check that all predictions line up
118+
@show norm(predictions - predictions_)
119+
@test norm(predictions - predictions_) < 1e-6
120+
121+
f1_.components.nn.shuffle[] = true
122+
123+
# test solving of the new object
124+
solveTree!(fg_);
125+
126+
127+
# prior should pin x0 pretty well
128+
@test 80 < sum(-3 .< (getBelief(fg_, :x0) |> getPoints) .< 3)
129+
# at least some points should land according to the naive model
130+
@test 5 < sum(5 .< (getBelief(fg_, :x1) |> getPoints) .< 15)
131+
132+
133+
# remove the testing file
134+
Base.rm("/tmp/fg_test_flux.tar.gz")
67135

68136
end
69137

138+
139+
140+
70141
#

0 commit comments

Comments
 (0)