Skip to content

Commit edfbefc

Browse files
authored
Merge pull request #961 from JuliaRobotics/feat/4Q20/fluxnn
add FluxModelsDistribution
2 parents 4097cda + 279b364 commit edfbefc

13 files changed

+341
-27
lines changed

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.16.0"
66

77
[deps]
88
ApproxManifoldProducts = "9bbbb610-88a1-53cd-9763-118ce10c1f89"
9+
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
910
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1011
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1112
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
@@ -39,7 +40,7 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
3940
[compat]
4041
ApproxManifoldProducts = "0.1, 0.2"
4142
Combinatorics = "1.0"
42-
DistributedFactorGraphs = "0.10.5"
43+
DistributedFactorGraphs = "0.10.6"
4344
Distributions = "0.20, 0.21, 0.22, 0.23, 0.24"
4445
DocStringExtensions = "0.8, 0.9, 0.10, 1"
4546
FileIO = "1.0.2, 1.1, 1.2"
@@ -60,8 +61,10 @@ TimeZones = "1.3.1"
6061
julia = "1.4"
6162

6263
[extras]
64+
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
65+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
6366
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
6467
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6568

6669
[targets]
67-
test = ["Test"]
70+
test = ["BSON", "Flux", "Test"]

src/BeliefTypes.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ struct NonparametricMessage <: MessageType end
2525
struct ParametricMessage <: MessageType end
2626

2727

28-
const SamplableBelief = Union{Distributions.Distribution, KernelDensityEstimate.BallTreeDensity, AliasingScalarSampler}
28+
const SamplableBelief = Union{Distributions.Distribution, KernelDensityEstimate.BallTreeDensity, AliasingScalarSampler, FluxModelsDistribution}
2929

3030
abstract type PackedSamplableBelief end
3131

src/Deprecated.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -858,6 +858,8 @@ function initVariable!(fgl::AbstractDFG,
858858
end
859859

860860

861+
export uppA
862+
861863
#global pidx
862864
global pidx = 1
863865
global pidl = 1

src/FactorGraph.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,15 @@ end
558558

559559
# import IncrementalInference: prepgenericconvolution, convert
560560

561+
"""
562+
$SIGNATURES
563+
564+
Function to calculate measurement dimension from factor sampling.
565+
566+
Notes
567+
- Will not work in all situations, but good enough so far.
568+
- # TODO standardize
569+
"""
561570
function calcZDim(usrfnc::T, Xi::Vector{<:DFGVariable})::Int where {T <: FunctorInferenceType}
562571
# zdim = T != GenericMarginal ? size(getSample(usrfnc, 2)[1],1) : 0
563572
zdim = if T != GenericMarginal

src/Factors/DefaultPrior.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ mutable struct PackedPrior <: PackedInferenceType
5252
PackedPrior(z::AS) where {AS <: AbstractString} = new(z)
5353
end
5454
function convert(::Type{PackedPrior}, d::Prior)
55-
PackedPrior(string(d.Z))
55+
PackedPrior(convert(PackedSamplableBelief, d.Z))
5656
end
5757
function convert(::Type{Prior}, d::PackedPrior)
58-
Prior(extractdistribution(d.Z))
58+
Prior(convert(SamplableBelief, d.Z))
5959
end
6060

6161

src/Factors/Mixture.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@ end
5353

5454
# should not be called in case of Prior
5555
(s::Mixture)( res::AbstractArray{<:Real},
56-
userdata::FactorMetadata,
56+
fmd::FactorMetadata,
5757
idx::Int,
5858
meas::Tuple,
59-
X... ) = s.mechanics(res, userdata, idx, meas, X...)
59+
X... ) = s.mechanics(res, fmd, idx, meas, X...)
6060
#
6161

6262

src/Flux/FluxModelsDistribution.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
2+
3+
@info "IncrementalInference is adding Flux related functionality."
4+
5+
6+
7+
# the factor definitions
8+
export FluxModelsDistribution
9+
# some utilities
10+
11+
12+
# Required packages
13+
using .Flux
14+
15+
using Random, Statistics
16+
# using DistributedFactorGraphs
17+
18+
# import Base: convert
19+
import Random: rand
20+
21+
22+
23+
function rand(nfb::FluxModelsDistribution,
24+
N::Int=1 )
25+
#
26+
27+
# number of predictors to choose from, and choose random subset
28+
numModels = length(nfb.models)
29+
allPreds = 1:numModels |> collect # 1:Npreds |> collect
30+
# TODO -- compensate when there arent enough prediction models
31+
if numModels < N
32+
reps = (N ÷ numModels) + 1
33+
allPreds = repeat(allPreds, reps )
34+
resize!(allPreds, N)
35+
end
36+
# samples for the order in which to use models, dont shuffle if N models
37+
# can suppress shuffle for NN training purposes
38+
1 < numModels && nfb.shuffle[] ? shuffle!(allPreds) : nothing
39+
40+
# generate the measurements
41+
meas = zeros(nfb.outputDim..., N)
42+
for i in 1:N
43+
meas[:,i] = (nfb.models[allPreds[i]])(nfb.data)
44+
end
45+
46+
return meas
47+
end
48+
49+
50+
FluxModelsDistribution( inDim::NTuple{ID,Int},
51+
outDim::NTuple{OD,Int},
52+
models::Vector{P},
53+
data::D,
54+
shuffle::Bool=true,
55+
serializeHollow::Bool=false ) where {ID,OD,P,D<:AbstractArray} = FluxModelsDistribution{ID,OD,P,D}(inDim, outDim, models, data, Ref(shuffle), Ref(serializeHollow) )
56+
#
57+
58+
59+
60+
61+
#

src/Flux/FluxModelsSerialization.jl

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Serialization functions for Flux models that depend on BSON
2+
3+
@info "IncrementalInference is adding Flux/BSON serialization functionality."
4+
5+
export PackedFluxModelsDistribution
6+
7+
using .BSON
8+
using Base64
9+
10+
import Base: convert
11+
12+
13+
mutable struct PackedFluxModelsDistribution <: IIF.PackedSamplableBelief
14+
# shape of the input data
15+
inputDim::Vector{Int}
16+
# shape of the output data
17+
outputDim::Vector{Int}
18+
# actual Flux models (Base64 encoded binary)
19+
mimeTypeModel::String
20+
models::Vector{String}
21+
# the data used for prediction, must be <: AbstractArray
22+
mimeTypeData::String
23+
data::String
24+
# shuffle model predictions relative to particle index at each sampling
25+
shuffle::Bool
26+
# false for default serialization with model info, set true for separate storage of models
27+
serializeHollow::Bool
28+
# TODO remove requirement and standardize sampler API
29+
# specialSampler::Symbol
30+
# field name usage to direct the IIF serialization towards JSON method
31+
PackedSamplableTypeJSON::String
32+
end
33+
34+
function _serializeFluxModelBase64(model::Flux.Chain)
35+
io = IOBuffer()
36+
iob64 = Base64EncodePipe(io)
37+
BSON.@save iob64 model
38+
close(iob64)
39+
return String(take!(io))
40+
end
41+
42+
function _deserializeFluxModelBase64(smodel::AbstractString)
43+
iob64 = PipeBuffer(base64decode(smodel))
44+
BSON.@load iob64 model
45+
close(iob64)
46+
return model
47+
end
48+
49+
function _serializeFluxDataBase64(data::AbstractArray)
50+
io = IOBuffer()
51+
iob64 = Base64EncodePipe(io)
52+
BSON.@save iob64 data
53+
close(iob64)
54+
return String(take!(io))
55+
end
56+
57+
function _deserializeFluxDataBase64(sdata::AbstractString)
58+
iob64 = PipeBuffer(base64decode(sdata))
59+
BSON.@load iob64 data
60+
close(iob64)
61+
return data
62+
end
63+
64+
65+
function convert( ::Union{Type{<:PackedSamplableBelief},Type{<:PackedFluxModelsDistribution}},
66+
obj::FluxModelsDistribution)
67+
#
68+
69+
# and the specialSampler function -- likely to be deprecated
70+
# specialSampler = Symbol(obj.specialSampler)
71+
# fields to persist
72+
inputDim = collect(obj.inputDim)
73+
outputDim = collect(obj.outputDim)
74+
models = Vector{String}()
75+
# store all models as Base64 Strings (using BSON)
76+
if !obj.serializeHollow[]
77+
resize!(models, length(obj.models))
78+
# serialize the Vector of Flux models (each one individually)
79+
models .= _serializeFluxModelBase64.(obj.models)
80+
# also store data as Base64 String, using BSON
81+
sdata = _serializeFluxDataBase64(obj.data)
82+
mimeTypeData = "application/bson/octet-stream/base64"
83+
else
84+
# store one just model to preserve the type (allows resizing on immutable Ref after deserialize)
85+
push!(models,_serializeFluxModelBase64(obj.models[1]))
86+
# at least capture the type of how the data looks for future deserialization
87+
sdata = string(typeof(obj.data))
88+
mimeTypeData = "application/text"
89+
end
90+
mimeTypeModel = "application/bson/octet-stream/base64"
91+
92+
# and build the JSON-able object
93+
packed = PackedFluxModelsDistribution(inputDim,
94+
outputDim,
95+
mimeTypeModel,
96+
models,
97+
mimeTypeData,
98+
sdata,
99+
obj.shuffle[],
100+
obj.serializeHollow[],
101+
# specialSampler,
102+
"IncrementalInference.PackedFluxModelsDistribution" )
103+
#
104+
return JSON2.write(packed)
105+
end
106+
107+
108+
109+
function convert( ::Union{Type{<:SamplableBelief},Type{FluxModelsDistribution}},
110+
obj::PackedFluxModelsDistribution)
111+
#
112+
113+
obj.serializeHollow && @warn("Deserialization of FluxModelsDistribution.serializationHollow=true is not yet well developed, please open issues at IncrementalInference.jl accordingly.")
114+
115+
# specialSampler likely to be deprecated
116+
# specialSampler = getfield(Main, obj.specialSampler)
117+
118+
# deserialize
119+
# @assert obj.mimeTypeModel == "application/bson/octet-stream/base64"
120+
models = _deserializeFluxModelBase64.(obj.models)
121+
122+
# @assert obj.mimeTypeData == "application/bson/octet-stream/base64"
123+
data = !obj.serializeHollow ? _deserializeFluxDataBase64.(obj.data) : zeros(0)
124+
125+
return FluxModelsDistribution((obj.inputDim...,),
126+
(obj.outputDim...,),
127+
models,
128+
data,
129+
obj.shuffle,
130+
obj.serializeHollow )
131+
end
132+
133+
134+
#

src/Flux/entities.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# entities immediately available as private members in IIF, but require Flux for actual use
2+
3+
4+
struct FluxModelsDistribution{ID,OD,P,D<:AbstractArray}
5+
# shape of the input data
6+
inputDim::NTuple{ID,Int}
7+
# shape of the output data
8+
outputDim::NTuple{OD,Int}
9+
# actual Flux models
10+
models::Vector{P}
11+
# the data used for prediction, must be <: AbstractArray
12+
data::D
13+
# shuffle model predictions relative to particle index at each sampling
14+
shuffle::Base.RefValue{Bool}
15+
# false for default serialization with model info, set true for separate storage of models
16+
serializeHollow::Base.RefValue{Bool}
17+
# # TODO remove requirement and standardize sampler API
18+
# specialSampler::Function
19+
end

src/IncrementalInference.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ export AbstractDFG,
8484
listDataEntries,
8585
FolderStore,
8686
addBlobStore!,
87+
addData!,
8788
getData,
8889
DFGVariable,
8990
DFGVariableSummary,
@@ -96,6 +97,9 @@ export AbstractDFG,
9697
export FunctorInferenceType, PackedInferenceType
9798
export AbstractPrior, AbstractRelativeFactor, AbstractRelativeFactorMinimize
9899

100+
# not sure if this is necessary
101+
export convert
102+
99103
export *,
100104
notifyCSMCondition,
101105
CSMHistory,
@@ -367,10 +371,6 @@ export *,
367371
GenericMarginal,
368372
PackedGenericMarginal,
369373

370-
uppA,
371-
convert,
372-
extractdistribution,
373-
374374
# factor graph operating system utils (fgos)
375375
saveTree,
376376
loadTree,
@@ -422,6 +422,7 @@ getFactorOperationalMemoryType(dfg::SolverParams) = CommonConvWrapper
422422

423423
include("AliasScalarSampling.jl")
424424
include("CliqueTypes.jl")
425+
include("Flux/entities.jl")
425426
include("BeliefTypes.jl")
426427
include("JunctionTreeTypes.jl")
427428
include("FactorGraph.jl")
@@ -492,9 +493,15 @@ include("Deprecated.jl")
492493

493494
exportimg(pl) = error("Please do `using Gadfly` before IncrementalInference is used to allow image export.")
494495
function __init__()
495-
@require InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" include("RequireInteractiveUtils.jl")
496+
@require InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" include("RequireInteractiveUtils.jl")
497+
498+
@require Gadfly="c91e804a-d5a3-530f-b6f0-dfbca275c004" include("EmbeddedPlottingUtils.jl")
496499

497-
@require Gadfly="c91e804a-d5a3-530f-b6f0-dfbca275c004" include("EmbeddedPlottingUtils.jl")
500+
# combining neural networks natively into the non-Gaussian factor graph object
501+
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin
502+
include("Flux/FluxModelsDistribution.jl")
503+
@require BSON="fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" include("Flux/FluxModelsSerialization.jl")
504+
end
498505
end
499506

500507
# Old code that might be used again

0 commit comments

Comments
 (0)