Skip to content

Commit 76e16c1

Browse files
committed
last updates for MixtureFluxModels
1 parent f644be9 commit 76e16c1

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

src/Flux/FluxModelsDistribution.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,15 @@ FluxModelsDistribution( inDim::NTuple{ID,Int},
5555
serializeHollow::Bool=false ) where {ID,OD,P,D<:AbstractArray} = FluxModelsDistribution{ID,OD,P,D}(inDim, outDim, models, data, Ref(shuffle), Ref(serializeHollow) )
5656
#
5757

58-
# @deprecate
58+
FluxModelsDistribution( models::Vector{P},
59+
inDim::NTuple{ID,Int},
60+
data::D,
61+
outDim::NTuple{OD,Int};
62+
shuffle::Bool=true,
63+
serializeHollow::Bool=false ) where {ID,OD,P,D<:AbstractArray} = FluxModelsDistribution{ID,OD,P,D}(inDim, outDim, models, data, Ref(shuffle), Ref(serializeHollow) )
64+
#
65+
66+
5967

6068

6169

@@ -110,13 +118,19 @@ function MixtureFluxModels( F_::FunctorInferenceType,
110118
otherComp::_IIFListTypes,
111119
diversity::Union{<:AbstractVector, <:NTuple, <:DiscreteNonParametric};
112120
shuffle::Bool=true,
113-
serializeHollow::Bool=false ) where {P,ID,D,OD}
121+
serializeHollow::Bool=false ) where {P,ID,D<:AbstractArray,OD}
114122
#
115123
# must preserve order
116124
allComp = OrderedDict{Symbol, Any}()
117125

118126
# always add the Flux model first
119-
allComp[:fluxnn] = FluxModelsDistribution(inDim,outDim,nnModels,data,shuffle,serializeHollow)
127+
allComp[:fluxnn] = FluxModelsDistribution(nnModels,
128+
inDim,
129+
data,
130+
outDim,
131+
shuffle=shuffle,
132+
serializeHollow=serializeHollow)
133+
#
120134
isNT = otherComp isa NamedTuple
121135
for idx in 1:length(otherComp)
122136
nm = isNT ? keys(otherComp)[idx] : Symbol("c$(idx+1)")
@@ -129,7 +143,9 @@ function MixtureFluxModels( F_::FunctorInferenceType,
129143
return Mixture(F_, ntup, diversity)
130144
end
131145

132-
MixtureFluxModels(::Type{F}, w...;kw...) where F <: FunctorInferenceType = MixtureFluxModels(F(LinearAlgebra.I), w...;kw...)
146+
MixtureFluxModels(::Type{F},
147+
w...;
148+
kw...) where F <: FunctorInferenceType = MixtureFluxModels(F(LinearAlgebra.I),w...;kw...)
133149

134150

135151

src/Flux/FluxModelsSerialization.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,12 @@ function convert( ::Union{Type{<:SamplableBelief},Type{FluxModelsDistribution}},
122122
# @assert obj.mimeTypeData == "application/bson/octet-stream/base64"
123123
data = !obj.serializeHollow ? _deserializeFluxDataBase64.(obj.data) : zeros(0)
124124

125-
return FluxModelsDistribution((obj.inputDim...,),
126-
(obj.outputDim...,),
127-
models,
125+
return FluxModelsDistribution(models,
126+
(obj.inputDim...,),
128127
data,
129-
obj.shuffle,
130-
obj.serializeHollow )
128+
(obj.outputDim...,),
129+
shuffle=obj.shuffle,
130+
serializeHollow=obj.serializeHollow )
131131
end
132132

133133

0 commit comments

Comments
 (0)