Skip to content

Commit 460f19f

Browse files
committed
Start spllitting out IIFTypes
1 parent 7dc8f3c commit 460f19f

File tree

11 files changed

+244
-208
lines changed

11 files changed

+244
-208
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
name = "IncrementalInferenceTypes"
2+
uuid = "9808408f-4dbc-47e4-913c-6068b950e289"
3+
authors = ["Johannes Terblanche <Affie@users.noreply.github.com>"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
DistributedFactorGraphs = "b5cc3c7e-6572-11e9-2517-99fb8daf2f04"
8+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
9+
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
10+
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
11+
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
12+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
13+
14+
[compat]
15+
DistributedFactorGraphs = "0.27.0"
16+
Distributions = "0.25.120"
17+
DocStringExtensions = "0.9.5"
18+
Manifolds = "=0.10.16"
19+
RecursiveArrayTools = "3.33.0"
20+
StaticArrays = "1.9.13"
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
module IncrementalInferenceTypes
2+
3+
using DistributedFactorGraphs
4+
using DocStringExtensions
5+
using Manifolds
6+
using Distributions
7+
using StaticArrays
8+
# using RecursiveArrayTools
9+
10+
# export variable types
11+
export
12+
Position,
13+
Position1,
14+
Position2,
15+
Position3,
16+
Position4,
17+
ContinuousScalar,
18+
ContinuousEuclid,
19+
Ciruclar
20+
21+
#export factor types
22+
export
23+
CircularCircular,
24+
PriorCircular,
25+
PackedCircularCircular,
26+
PackedPriorCircular
27+
28+
# export packed distributions
29+
export
30+
PackedCategorical,
31+
PackedUniform,
32+
PackedNormal,
33+
PackedZeroMeanDiagNormal,
34+
PackedZeroMeanFullNormal,
35+
PackedDiagNormal,
36+
PackedFullNormal,
37+
PackedRayleigh
38+
39+
40+
const IIFTypes = IncrementalInferenceTypes
41+
export IIFTypes
42+
43+
# Variable Definitions
44+
include("variables/DefaultVariableTypes.jl")
45+
46+
# Factor Definitions
47+
include("factors/Circular.jl")
48+
49+
# Distribution Serialization
50+
include("serialization/entities/SerializingDistributions.jl")
51+
include("serialization/services/SerializingDistributions.jl")
52+
53+
end # module IncrementalInferenceTypes
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
2+
"""
3+
$(TYPEDEF)
4+
5+
Factor between two Sphere1 variables.
6+
7+
Related
8+
9+
[`Sphere1`](@ref), [`PriorSphere1`](@ref), [`Polar`](@ref), [`ContinuousEuclid`](@ref)
10+
"""
11+
DFG.@defFactorType CircularCircular AbstractManifoldMinimize Manifolds.RealCircleGroup()
12+
13+
14+
"""
15+
$(TYPEDEF)
16+
17+
Introduce direct observations on all dimensions of a Circular variable:
18+
19+
Example:
20+
--------
21+
```julia
22+
PriorCircular( MvNormal([10; 10; pi/6.0], diagm([0.1;0.1;0.05].^2)) )
23+
```
24+
25+
Related
26+
27+
[`Circular`](@ref), [`Prior`](@ref), [`PartialPrior`](@ref)
28+
"""
29+
DFG.@defFactorType PriorCircular AbstractPrior Manifolds.RealCircleGroup()
30+
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
2+
Base.@kwdef struct PackedCategorical <: PackedSamplableBelief
3+
_type::String = "IncrementalInferenceTypes.PackedCategorical"
4+
p::Vector{Float64} = [1.0;]
5+
end
6+
7+
Base.@kwdef mutable struct PackedUniform <: PackedSamplableBelief
8+
_type::String = "IncrementalInferenceTypes.PackedUniform"
9+
a::Float64 = 0.0
10+
b::Float64 = 1.0
11+
PackedSamplableTypeJSON::String = "IncrementalInferenceTypes.PackedUniform"
12+
end
13+
14+
Base.@kwdef struct PackedNormal <: PackedSamplableBelief
15+
_type::String = "IncrementalInferenceTypes.PackedNormal"
16+
mu::Float64 = 0.0
17+
sigma::Float64 = 1.0
18+
end
19+
20+
Base.@kwdef struct PackedZeroMeanDiagNormal <: PackedSamplableBelief
21+
_type::String = "IncrementalInferenceTypes.PackedZeroMeanDiagNormal"
22+
diag::Vector{Float64} = ones(1)
23+
end
24+
25+
Base.@kwdef struct PackedZeroMeanFullNormal <: PackedSamplableBelief
26+
_type::String = "IncrementalInferenceTypes.PackedZeroMeanFullNormal"
27+
cov::Vector{Float64} = ones(1)
28+
end
29+
30+
Base.@kwdef mutable struct PackedDiagNormal <: PackedSamplableBelief
31+
_type::String = "IncrementalInferenceTypes.PackedDiagNormal"
32+
mu::Vector{Float64} = zeros(1)
33+
diag::Vector{Float64} = ones(1)
34+
end
35+
36+
Base.@kwdef struct PackedFullNormal <: PackedSamplableBelief
37+
_type::String = "IncrementalInferenceTypes.PackedFullNormal"
38+
mu::Vector{Float64} = zeros(1)
39+
cov::Vector{Float64} = ones(1)
40+
end
41+
42+
Base.@kwdef struct PackedRayleigh <: PackedSamplableBelief
43+
_type::String = "IncrementalInferenceTypes.PackedRayleigh"
44+
sigma::Float64 = 1.0
45+
end
46+
47+
#
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
2+
## Distributions to JSON/Packed types
3+
4+
packDistribution(dtr::Categorical) = PackedCategorical(; p = dtr.p)
5+
packDistribution(dtr::Uniform) = PackedUniform(; a = dtr.a, b = dtr.b)
6+
packDistribution(dtr::Normal) = PackedNormal(; mu = dtr.μ, sigma = dtr.σ)
7+
packDistribution(dtr::ZeroMeanDiagNormal) = PackedZeroMeanDiagNormal(; diag = dtr.Σ.diag)
8+
packDistribution(dtr::ZeroMeanFullNormal) = PackedZeroMeanFullNormal(; cov = dtr.Σ.mat[:])
9+
packDistribution(dtr::DiagNormal) = PackedDiagNormal(; mu = dtr.μ, diag = dtr.Σ.diag)
10+
packDistribution(dtr::FullNormal) = PackedFullNormal(; mu = dtr.μ, cov = dtr.Σ.mat[:])
11+
packDistribution(dtr::Rayleigh) = PackedRayleigh(; sigma = dtr.σ)
12+
13+
## Unpack JSON/Packed to Distribution types
14+
15+
unpackDistribution(dtr::PackedCategorical) = Categorical(dtr.p ./ sum(dtr.p))
16+
unpackDistribution(dtr::PackedUniform) = Uniform(dtr.a, dtr.b)
17+
unpackDistribution(dtr::PackedNormal) = Normal(dtr.mu, dtr.sigma)
18+
function unpackDistribution(dtr::PackedZeroMeanDiagNormal)
19+
return MvNormal(LinearAlgebra.Diagonal(map(abs2, sqrt.(dtr.diag))))
20+
end # sqrt.(dtr.diag)
21+
function unpackDistribution(dtr::PackedZeroMeanFullNormal)
22+
d = round(Int, sqrt(size(dtr.cov)[1]))
23+
return MvNormal(reshape(dtr.cov, d, d))
24+
end
25+
unpackDistribution(dtr::PackedDiagNormal) = MvNormal(dtr.mu, sqrt.(dtr.diag))
26+
function unpackDistribution(dtr::PackedFullNormal)
27+
return MvNormal(dtr.mu, reshape(dtr.cov, length(dtr.mu), :))
28+
end
29+
unpackDistribution(dtr::PackedRayleigh) = Rayleigh(dtr.sigma)
30+
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#FIXME This is discouraged in the julia style guide, rather standardize to instance or type
2+
const InstanceType{T} = Union{Type{<:T}, <:T}
3+
4+
## Euclid 1
5+
6+
"""
7+
$TYPEDEF
8+
9+
Continuous Euclidean variable of dimension `N` representing a Position in cartesian space.
10+
"""
11+
struct Position{N} <: InferenceVariable end
12+
13+
Position(N::Int) = Position{N}()
14+
15+
# not sure if these overloads are necessary since DFG 775?
16+
DFG.getManifold(::InstanceType{Position{N}}) where {N} = TranslationGroup(N)
17+
function DFG.getDimension(val::InstanceType{Position{N}}) where {N}
18+
return manifold_dimension(getManifold(val))
19+
end
20+
DFG.getPointType(::Type{Position{N}}) where {N} = SVector{N, Float64}
21+
DFG.getPointIdentity(M_::Type{Position{N}}) where {N} = @SVector(zeros(N)) # identity_element(getManifold(M_), zeros(N))
22+
23+
24+
#
25+
26+
"""
27+
$(TYPEDEF)
28+
29+
Most basic continuous scalar variable in a `::DFG.AbstractDFG` object.
30+
31+
Alias of `Position{1}`
32+
"""
33+
const ContinuousScalar = Position{1}
34+
const ContinuousEuclid{N} = Position{N}
35+
36+
const Position1 = Position{1}
37+
const Position2 = Position{2}
38+
const Position3 = Position{3}
39+
const Position4 = Position{4}
40+
41+
## Circular
42+
43+
"""
44+
$(TYPEDEF)
45+
46+
Circular is a `Manifolds.Circle{ℝ}` mechanization of one rotation, with `theta in [-pi,pi)`.
47+
"""
48+
@defVariable Circular RealCircleGroup() [0.0;]
49+
#TODO This is an example of what we want working, possible issue upstream in Manifolds.jl
50+
# @defVariable Circular RealCircleGroup() Scalar(0.0)
51+
52+
#

src/Factors/Circular.jl

Lines changed: 8 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,8 @@
1-
2-
export CircularCircular, PriorCircular, PackedCircularCircular, PackedPriorCircular
3-
4-
"""
5-
$(TYPEDEF)
6-
7-
Factor between two Sphere1 variables.
8-
9-
Related
10-
11-
[`Sphere1`](@ref), [`PriorSphere1`](@ref), [`Polar`](@ref), [`ContinuousEuclid`](@ref)
12-
"""
13-
mutable struct CircularCircular{T <: SamplableBelief} <: AbstractManifoldMinimize
14-
Z::T
15-
# Sphere1Sphere1(z::T=Normal()) where {T <: SamplableBelief} = new{T}(z)
16-
end
17-
18-
const Sphere1Sphere1 = CircularCircular
19-
20-
CircularCircular(::UniformScaling) = CircularCircular(Normal())
21-
22-
DFG.getManifold(::CircularCircular) = RealCircleGroup()
1+
# ---------------------------------------------
2+
# CircularCircular
3+
# ---------------------------------------------
234

245
function (cf::CalcFactor{<:CircularCircular})(X, p, q)
25-
#
266
M = getManifold(cf)
277
return distanceTangent2Point(M, X, p, q)
288
end
@@ -36,30 +16,13 @@ function Base.convert(::Type{<:MB.AbstractManifold}, ::InstanceType{CircularCirc
3616
return Manifolds.RealCircleGroup()
3717
end
3818

39-
"""
40-
$(TYPEDEF)
41-
42-
Introduce direct observations on all dimensions of a Circular variable:
43-
44-
Example:
45-
--------
46-
```julia
47-
PriorCircular( MvNormal([10; 10; pi/6.0], diagm([0.1;0.1;0.05].^2)) )
48-
```
49-
50-
Related
19+
IIFTypes.CircularCircular(::UniformScaling) = CircularCircular(Normal())
5120

52-
[`Circular`](@ref), [`Prior`](@ref), [`PartialPrior`](@ref)
53-
"""
54-
DFG.@defFactorType PriorCircular AbstractPrior Manifolds.RealCircleGroup()
21+
# ---------------------------------------------
22+
# PriorCircular
23+
# ---------------------------------------------
5524

56-
# mutable struct PriorCircular{T <: SamplableBelief} <: AbstractPrior
57-
# Z::T
58-
# end
59-
60-
PriorCircular(::UniformScaling) = PriorCircular(Normal())
61-
62-
# DFG.getManifold(::PriorCircular) = RealCircleGroup()
25+
IIFTypes.PriorCircular(::UniformScaling) = PriorCircular(Normal())
6326

6427
function getSample(cf::CalcFactor{<:PriorCircular})
6528
# FIXME workaround for issue #TBD with manifolds CircularGroup,
@@ -79,38 +42,4 @@ function Base.convert(::Type{<:MB.AbstractManifold}, ::InstanceType{PriorCircula
7942
return Manifolds.RealCircleGroup()
8043
end
8144

82-
# """
83-
# $(TYPEDEF)
84-
85-
# Serialized object for storing PriorCircular.
86-
# """
87-
# Base.@kwdef struct PackedPriorCircular <: AbstractPackedFactor
88-
# Z::PackedSamplableBelief
89-
# end
90-
91-
# function convert(::Type{PackedPriorCircular}, d::PriorCircular)
92-
# return PackedPriorCircular(convert(PackedSamplableBelief, d.Z))
93-
# end
94-
# function convert(::Type{PriorCircular}, d::PackedPriorCircular)
95-
# distr = convert(SamplableBelief, d.Z)
96-
# return PriorCircular{typeof(distr)}(distr)
97-
# end
98-
99-
# --------------------------------------------
100-
101-
"""
102-
$(TYPEDEF)
103-
104-
Serialized object for storing CircularCircular.
105-
"""
106-
Base.@kwdef struct PackedCircularCircular <: AbstractPackedFactor
107-
Z::PackedSamplableBelief
108-
end
109-
function convert(::Type{CircularCircular}, d::PackedCircularCircular)
110-
return CircularCircular(convert(SamplableBelief, d.Z))
111-
end
112-
function convert(::Type{PackedCircularCircular}, d::CircularCircular)
113-
return PackedCircularCircular(convert(PackedSamplableBelief, d.Z))
114-
end
115-
11645
# --------------------------------------------

src/IncrementalInference.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ const BeliefArray{T} = Union{<:AbstractMatrix{<:T}, <:Adjoint{<:T, AbstractMatri
102102
# FIXME, remove this and let the user do either import or const definitions
103103
export KDE, AMP, DFG, FSM, IIF
104104

105+
include("../IncrementalInferenceTypes/src/IncrementalInferenceTypes.jl")
106+
using ..IncrementalInferenceTypes
107+
105108
# TODO temporary for initial version of on-manifold products
106109
KDE.setForceEvalDirect!(true)
107110

0 commit comments

Comments
 (0)