Skip to content

Commit ba5c851

Browse files
committed
general upgrades and more testing, wip
1 parent be2ab1f commit ba5c851

File tree

8 files changed

+204
-21
lines changed

8 files changed

+204
-21
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ version = "0.16.0"
88
ApproxManifoldProducts = "9bbbb610-88a1-53cd-9763-118ce10c1f89"
99
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1010
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
11+
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1112
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1213
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
1314
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
@@ -40,6 +41,7 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
4041
[compat]
4142
ApproxManifoldProducts = "0.1, 0.2"
4243
Combinatorics = "1.0"
44+
DataStructures = "0.16, 0.17, 0.18"
4345
DistributedFactorGraphs = "0.10.6"
4446
Distributions = "0.20, 0.21, 0.22, 0.23, 0.24"
4547
DocStringExtensions = "0.8, 0.9, 0.10, 1"

src/ApproxConv.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ Future work:
1111
- improve handling of n and particleidx, especially considering future multithreading support
1212
1313
"""
14-
function approxConvOnElements!( ccwl::CommonConvWrapper{T},
15-
elements::Union{Vector{Int}, UnitRange{Int}}, ::Type{MultiThreaded} ) where {T <: AbstractRelative}
14+
function approxConvOnElements!( ccwl::Union{CommonConvWrapper{F},
15+
CommonConvWrapper{Mixture{N_,F,S,T}}},
16+
elements::Union{Vector{Int}, UnitRange{Int}}, ::Type{MultiThreaded} ) where {N_,F<:AbstractRelative,S,T}
1617
#
1718
Threads.@threads for n in elements
1819
# ccwl.thrid_ = Threads.threadid()
@@ -36,7 +37,8 @@ Future work:
3637
- improve handling of n and particleidx, especially considering future multithreading support
3738
3839
"""
39-
function approxConvOnElements!( ccwl::Union{CommonConvWrapper{F},CommonConvWrapper{Mixture{N_,F,S,T}}},
40+
function approxConvOnElements!( ccwl::Union{CommonConvWrapper{F},
41+
CommonConvWrapper{Mixture{N_,F,S,T}}},
4042
elements::Union{Vector{Int}, UnitRange{Int}}, ::Type{SingleThreaded}) where {N_,F<:AbstractRelative,S,T}
4143
#
4244
for n in elements
@@ -60,7 +62,8 @@ Future work:
6062
- improve handling of n and particleidx, especially considering future multithreading support
6163
6264
"""
63-
function approxConvOnElements!( ccwl::Union{CommonConvWrapper{F},CommonConvWrapper{Mixture{N_,F,S,T}}}, #CommonConvWrapper{T},
65+
function approxConvOnElements!( ccwl::Union{CommonConvWrapper{F},
66+
CommonConvWrapper{Mixture{N_,F,S,T}}},
6467
elements::Union{Vector{Int}, UnitRange{Int}} ) where {N_,F<:AbstractRelative,S,T}
6568
#
6669
approxConvOnElements!(ccwl, elements, ccwl.threadmodel)
@@ -114,7 +117,8 @@ function prepareCommonConvWrapper!( F_::Type{<:AbstractRelative},
114117
end
115118

116119

117-
function prepareCommonConvWrapper!( ccwl::Union{CommonConvWrapper{F},CommonConvWrapper{Mixture{N_,F,S,T}}},
120+
function prepareCommonConvWrapper!( ccwl::Union{CommonConvWrapper{F},
121+
CommonConvWrapper{Mixture{N_,F,S,T}}},
118122
Xi::Vector{DFGVariable},
119123
solvefor::Symbol,
120124
N::Int;
@@ -204,7 +208,8 @@ end
204208
205209
Common function to compute across a single user defined multi-hypothesis ambiguity per factor. This function dispatches both `AbstractRelativeFactor` and `AbstractRelativeFactorMinimize` factors.
206210
"""
207-
function computeAcrossHypothesis!(ccwl::Union{CommonConvWrapper{F},CommonConvWrapper{Mixture{N_,F,S,T}}},
211+
function computeAcrossHypothesis!(ccwl::Union{CommonConvWrapper{F},
212+
CommonConvWrapper{Mixture{N_,F,S,T}}},
208213
allelements,
209214
activehypo,
210215
certainidx::Vector{Int},

src/Factors/Mixture.jl

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,41 @@ end
2222
Mixture(f::Type{F},
2323
z::NamedTuple{S,T},
2424
c::Distributions.DiscreteNonParametric ) where {F<:FunctorInferenceType, S, T} = Mixture{length(z),F,S,T}(f(LinearAlgebra.I), z, c, size( rand(z[1],1), 1), zeros(Int, 0))
25-
#
2625
Mixture(f::F,
2726
z::NamedTuple{S,T},
2827
c::Distributions.DiscreteNonParametric ) where {F<:FunctorInferenceType, S, T} = Mixture{length(z),F,S,T}(f, z, c, size( rand(z[1],1), 1), zeros(Int, 0))
29-
#
30-
Mixture(f::Union{F,Type{F}},z::NamedTuple{S,T}, c::AbstractVector{<:Real}) where {F<:FunctorInferenceType,S,T} = Mixture(f, z, Categorical([c...]) )
31-
Mixture(f::Union{F,Type{F}},z::NamedTuple{S,T}, c::NTuple{N,<:Real}) where {N,F<:FunctorInferenceType,S,T} = Mixture(f, z, [c...] )
32-
Mixture(f::Union{F,Type{F}},z::AbstractVector{<:SamplableBelief}, c::Union{<:Distributions.DiscreteNonParametric, <:AbstractVector{<:Real}, <:NTuple{N,<:Real}} ) where {F <: FunctorInferenceType, N} = Mixture(f,NamedTuple{_defaultNamesMixtures(length(z))}((z...,)), c)
33-
Mixture(f::Union{F,Type{F}},z::Tuple, c::Union{<:Distributions.DiscreteNonParametric, <:AbstractVector{<:Real}, <:NTuple{N,<:Real}} ) where {F<:FunctorInferenceType, N} = Mixture(f,NamedTuple{_defaultNamesMixtures(length(z))}(z), c)
28+
Mixture(f::Union{F,Type{F}},z::NamedTuple{S,T},
29+
c::AbstractVector{<:Real}) where {F<:FunctorInferenceType,S,T} = Mixture(f, z, Categorical([c...]) )
30+
Mixture(f::Union{F,Type{F}},
31+
z::NamedTuple{S,T},
32+
c::NTuple{N,<:Real}) where {N,F<:FunctorInferenceType,S,T} = Mixture(f, z, [c...] )
33+
Mixture(f::Union{F,Type{F}},
34+
z::Tuple,
35+
c::Union{<:Distributions.DiscreteNonParametric, <:AbstractVector{<:Real}, <:NTuple{N,<:Real}} ) where {F<:FunctorInferenceType, N} = Mixture(f,NamedTuple{_defaultNamesMixtures(length(z))}(z), c )
36+
Mixture(f::Union{F,Type{F}},
37+
z::AbstractVector{<:SamplableBelief},
38+
c::Union{<:Distributions.DiscreteNonParametric, <:AbstractVector{<:Real}, <:NTuple{N,<:Real}} ) where {F <: FunctorInferenceType, N} = Mixture(f,(z...,), c )
3439

3540

3641
function Base.resize!(mp::Mixture, s::Int)
3742
resize!(mp.labels, s)
3843
end
3944

4045
# TODO make in-place memory version
41-
function getSample(s::Mixture, N::Int=1)
46+
function getSample( s::Mixture{N_,F,S,T},
47+
N::Int=1,
48+
special...;
49+
kw... ) where {N_,F<:FunctorInferenceType,S,T}
50+
#
51+
# TODO consolidate #927, case if mechanics has a special sampler
52+
# TODO slight bit of waste in computation, but easiest way to ensure special tricks in s.mechanics::F are included
53+
## example case is old FluxModelsPose2Pose2 requiring velocity
54+
smplLambda = hasfield(typeof(s.mechanics), :specialSampler) ? ()->s.specialSampler(s.mechanics, N, special...; kw...)[1] : ()->getSample(s.mechanics, N)[1]
55+
smpls = smplLambda()
56+
# smpls = Array{Float64,2}(undef,s.dims,N)
4257
#out memory should be right size first
4358
(length(s.labels) != N) && resize!(s, N)
4459
s.labels .= rand(s.diversity, N)
45-
smpls = Array{Float64,2}(undef,s.dims,N)
4660
for i in 1:N
4761
mixComponent = s.components[s.labels[i]]
4862
smpls[:,i] = rand(mixComponent,1)

src/Factors/Sphere1D.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,18 @@ PriorSphere1( MvNormal([10; 10; pi/6.0], Matrix(Diagonal([0.1;0.1;0.05].^2))) )
4040
"""
4141
mutable struct PriorSphere1{T<: SamplableBelief} <: AbstractPrior
4242
Z::T
43-
PriorSphere1{T}() where T = new{T}()
44-
PriorSphere1{T}(x::T) where {T <: IncrementalInference.SamplableBelief} = new{T}(x)
43+
# PriorSphere1{T}() where T = new{T}()
44+
# PriorSphere1{T}(x::T) where {T <: IncrementalInference.SamplableBelief} = new{T}(x)
4545
end
46-
PriorSphere1(x::T) where {T <: IncrementalInference.SamplableBelief} = PriorSphere1{T}(x)
46+
47+
48+
# PriorSphere1(x::T) where {T <: IncrementalInference.SamplableBelief} = PriorSphere1{T}(x)
49+
PriorSphere1(::UniformScaling) = PriorSphere1(Normal())
4750
function PriorSphere1(mu::Array{Float64}, cov::Array{Float64,2}, W::Vector{Float64})
4851
@warn "PriorSphere1(mu,cov,W) is deprecated in favor of PriorSphere1(T(...)) -- use for example PriorSphere1(MvNormal(mu, cov))"
4952
PriorSphere1(MvNormal(mu[:], cov))
5053
end
54+
5155
function getSample(p2::PriorSphere1, N::Int=1)
5256
return (reshape(rand(p2.Z,N),:,N), )
5357
end

src/Flux/FluxModelsDistribution.jl

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,21 @@
33
@info "IncrementalInference is adding Flux related functionality."
44

55

6-
76
# the factor definitions
87
export FluxModelsDistribution
9-
# some utilities
8+
export MixtureFluxModels
109

1110

1211
# Required packages
1312
using .Flux
14-
13+
using DataStructures: OrderedDict
1514
using Random, Statistics
16-
# using DistributedFactorGraphs
15+
1716

1817
# import Base: convert
1918
import Random: rand
2019

20+
const _IIFListTypes = Union{<:AbstractVector, <:Tuple, <:NTuple, <:NamedTuple}
2121

2222

2323
function rand(nfb::FluxModelsDistribution,
@@ -55,6 +55,83 @@ FluxModelsDistribution( inDim::NTuple{ID,Int},
5555
serializeHollow::Bool=false ) where {ID,OD,P,D<:AbstractArray} = FluxModelsDistribution{ID,OD,P,D}(inDim, outDim, models, data, Ref(shuffle), Ref(serializeHollow) )
5656
#
5757

58+
# @deprecate
59+
60+
61+
62+
63+
64+
65+
"""
66+
$SIGNATURES
67+
68+
Helper function to construct `MixtureFluxModels` containing a `NamedTuple`, resulting in a
69+
`::Mixture` such that `(fluxnn=FluxNNModels, c1=>MvNormal, c2=>Uniform...)` and order sensitive
70+
`diversity=[0.7;0.2;0.1]`. The result is the mixture heavily favors `.fluxnn` and names
71+
`c1` and `c2` for two other components were auto generated.
72+
73+
Notes
74+
- The user can specify own component names if desired (see example).
75+
- `shuffle` is passed through to internal `FluxModelsDistribution` to command shuffling of NN models.
76+
- `shuffle` does not influence selection of components in the mixture.
77+
78+
Example:
79+
80+
```julia
81+
# some made up data
82+
data = randn(10)
83+
# Flux models
84+
models = [Flux.Chain(softmax, Dense(10,5,σ), Dense(5,1, tanh)) for i in 1:20]
85+
# mixture with user defined names (optional) -- could also just pass Vector or Tuple of components
86+
mix = MixtureFluxModels(PriorSphere1, models, (10,), data, (1,),
87+
(naiveNorm=Normal(),naiveUnif=Uniform()),
88+
[0.7; 0.2; 0.1],
89+
shuffle=false )
90+
#
91+
92+
# test by add to simple graph
93+
fg = initfg()
94+
addVariable!(fg, :testmix, Sphere1)
95+
addFactor!(fg, [:testmix;], mix)
96+
97+
# look at proposal distribution from the only factor on :testmix
98+
_,pts,__, = localProduct(fg, :testmix)
99+
```
100+
101+
Related
102+
103+
Mixture, FluxModelsDistribution
104+
"""
105+
function MixtureFluxModels( F_::FunctorInferenceType,
106+
nnModels::Vector{P},
107+
inDim::NTuple{ID,Int},
108+
data::D,
109+
outDim::NTuple{OD,Int},
110+
otherComp::_IIFListTypes,
111+
diversity::Union{<:AbstractVector, <:NTuple, <:DiscreteNonParametric};
112+
shuffle::Bool=true,
113+
serializeHollow::Bool=false ) where {P,ID,D,OD}
114+
#
115+
# must preserve order
116+
allComp = OrderedDict{Symbol, Any}()
117+
118+
# always add the Flux model first
119+
allComp[:fluxnn] = FluxModelsDistribution(inDim,outDim,nnModels,data,shuffle,serializeHollow)
120+
isNT = otherComp isa NamedTuple
121+
for idx in 1:length(otherComp)
122+
nm = isNT ? keys(otherComp)[idx] : Symbol("c$(idx+1)")
123+
allComp[nm] = otherComp[idx]
124+
end
125+
# convert to named tuple
126+
ntup = (;allComp...)
127+
128+
# construct all the internal objects
129+
return Mixture(F_, ntup, diversity)
130+
end
131+
132+
MixtureFluxModels(::Type{F}, w...;kw...) where F <: FunctorInferenceType = MixtureFluxModels(F(LinearAlgebra.I), w...;kw...)
133+
134+
58135

59136

60137

src/SerializingDistributions.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,28 @@
11

2+
3+
export PackedUniform
4+
5+
6+
mutable struct PackedUniform <: PackedSamplableBelief
7+
a::Float64
8+
b::Float64
9+
PackedSamplableTypeJSON::String
10+
end
11+
12+
function convert(::Union{Type{<:PackedSamplableBelief},Type{<:PackedUniform}},
13+
obj::Distributions.Uniform)
14+
#
15+
packed = PackedUniform(obj.a, obj.b,
16+
"Distributions.PackedUniform")
17+
#
18+
return JSON2.write(packed)
19+
end
20+
21+
22+
convert(::Type{<:SamplableBelief}, obj::PackedUniform) = return Uniform(obj.a, obj.b)
23+
24+
25+
226
# TODO stop-gap string storage of Distrubtion types, should be upgraded to more efficient storage
327
function normalfromstring(str::AbstractString)
428
meanstr = match(r"μ=[+-]?([0-9]*[.])?[0-9]+", str).match

test/testFluxModelsDistribution.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,46 @@ end
138138

139139

140140

141+
@testset "MixtureFluxModels testing" begin
142+
143+
# some made up data
144+
data = randn(10)
145+
# Flux models
146+
models = [Flux.Chain(softmax, Dense(10,5,σ), Dense(5,1, tanh)) for i in 1:20]
147+
# mixture with user defined names (optional) -- could also just pass Vector or Tuple of components
148+
mix = MixtureFluxModels(PriorSphere1, models, (10,), data, (1,),
149+
(naiveNorm=Normal(),naiveUnif=Uniform()),
150+
[0.7; 0.2; 0.1],
151+
shuffle=false )
152+
#
153+
154+
# test by add to simple graph
155+
fg = initfg()
156+
addVariable!(fg, :testmix, Sphere1)
157+
addFactor!(fg, [:testmix;], mix)
158+
159+
160+
pts = approxConv(fg, :testmixf1, :testmix);
161+
162+
# look at proposal distribution from the only factor on :testmix
163+
_,pts,__, = localProduct(fg, :testmix);
164+
165+
saveDFG("/tmp/fg_mfx", fg)
166+
167+
#
168+
fg_ = loadDFG("/tmp/fg_mfx")
169+
170+
171+
Base.rm("/tmp/fg_mfx.tar.gz")
172+
173+
solveTree!(fg_);
174+
175+
@test 10 < getBelief(fg_, :testmix) |> getPoints |> length
176+
177+
end
178+
179+
180+
181+
182+
141183
#

test/testgraphpackingconverters.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,21 @@ using DistributedFactorGraphs
66
using Test
77

88

9+
@testset "Serialization of SamplableBelief types" begin
10+
11+
td = Uniform()
12+
13+
ptd = convert(PackedSamplableBelief, td)
14+
utd = convert(SamplableBelief, td)
15+
16+
@test td.a - utd.a |> abs < 1e-10
17+
@test td.b - utd.b |> abs < 1e-10
18+
19+
20+
end
21+
22+
23+
924
dfg = initfg()
1025

1126
@testset "hard-coded test of PackedPrior to Prior" begin

0 commit comments

Comments
 (0)