Skip to content

Commit 8026e4a

Browse files
committed
Start spllitting out IIFTypes
1 parent 7dc8f3c commit 8026e4a

File tree

14 files changed

+341
-213
lines changed

14 files changed

+341
-213
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
name = "IncrementalInferenceTypes"
2+
uuid = "9808408f-4dbc-47e4-913c-6068b950e289"
3+
authors = ["Johannes Terblanche <Affie@users.noreply.github.com>"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
8+
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
9+
DistributedFactorGraphs = "b5cc3c7e-6572-11e9-2517-99fb8daf2f04"
10+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
11+
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
12+
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
13+
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
14+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
15+
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
16+
17+
[compat]
18+
Dates = "1.11.0"
19+
Distributed = "1.11.0"
20+
DistributedFactorGraphs = "0.27.0"
21+
Distributions = "0.25.120"
22+
DocStringExtensions = "0.9.5"
23+
Manifolds = "=0.10.16"
24+
RecursiveArrayTools = "3.33.0"
25+
StaticArrays = "1.9.13"
26+
StructTypes = "1.11.0"
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
module IncrementalInferenceTypes
2+
3+
using DistributedFactorGraphs
4+
using DocStringExtensions
5+
using Manifolds
6+
using Distributions
7+
using StaticArrays
8+
import StructTypes
9+
10+
using Dates: now
11+
using Distributed: nprocs
12+
# using RecursiveArrayTools
13+
14+
# export variable types
15+
export
16+
Position,
17+
Position1,
18+
Position2,
19+
Position3,
20+
Position4,
21+
ContinuousScalar,
22+
ContinuousEuclid,
23+
Ciruclar
24+
25+
#export factor types
26+
export
27+
Prior,
28+
PackedPrior,
29+
LinearRelative,
30+
PackedLinearRelative,
31+
CircularCircular,
32+
PriorCircular,
33+
PackedCircularCircular,
34+
PackedPriorCircular
35+
36+
# export packed distributions
37+
export
38+
PackedCategorical,
39+
PackedUniform,
40+
PackedNormal,
41+
PackedZeroMeanDiagNormal,
42+
PackedZeroMeanFullNormal,
43+
PackedDiagNormal,
44+
PackedFullNormal,
45+
PackedRayleigh
46+
47+
export
48+
SolverParams
49+
50+
const IIFTypes = IncrementalInferenceTypes
51+
export IIFTypes
52+
53+
# Variable Definitions
54+
include("variables/DefaultVariableTypes.jl")
55+
56+
# Factor Definitions
57+
include("factors/DefaultPrior.jl")
58+
#FIXME maybe upgrade linear relative to this
59+
# include("factors/LinearRelative.jl")
60+
include("factors/Circular.jl")
61+
62+
# Distribution Serialization
63+
include("serialization/entities/SerializingDistributions.jl")
64+
include("serialization/services/SerializingDistributions.jl")
65+
66+
# solver params
67+
include("solverparams/SolverParams.jl")
68+
69+
end # module IncrementalInferenceTypes
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
2+
"""
3+
$(TYPEDEF)
4+
5+
Factor between two Sphere1 variables.
6+
7+
Related
8+
9+
[`Sphere1`](@ref), [`PriorSphere1`](@ref), [`Polar`](@ref), [`ContinuousEuclid`](@ref)
10+
"""
11+
DFG.@defFactorType CircularCircular AbstractManifoldMinimize Manifolds.RealCircleGroup()
12+
13+
14+
"""
15+
$(TYPEDEF)
16+
17+
Introduce direct observations on all dimensions of a Circular variable:
18+
19+
Example:
20+
--------
21+
```julia
22+
PriorCircular( MvNormal([10; 10; pi/6.0], diagm([0.1;0.1;0.05].^2)) )
23+
```
24+
25+
Related
26+
27+
[`Circular`](@ref), [`Prior`](@ref), [`PartialPrior`](@ref)
28+
"""
29+
DFG.@defFactorType PriorCircular AbstractPrior Manifolds.RealCircleGroup()
30+
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""
2+
$(TYPEDEF)
3+
4+
Default prior on all dimensions of a variable node in the factor graph. `Prior` is
5+
not recommended when non-Euclidean dimensions are used in variables.
6+
"""
7+
struct Prior{T} <: AbstractPrior
8+
Z::T
9+
end
10+
DFG.getManifold(pr::Prior) = TranslationGroup(getDimension(pr.Z))
11+
12+
"""
13+
$(TYPEDEF)
14+
15+
Serialization type for Prior.
16+
"""
17+
Base.@kwdef mutable struct PackedPrior <: AbstractPackedFactor
18+
Z::PackedSamplableBelief
19+
end
20+
21+
function DFG.pack(d::Prior)
22+
return PackedPrior(DFG.packDistribution(d.Z))
23+
end
24+
25+
function DFG.unpack(d::PackedPrior)
26+
return Prior(DFG.unpackDistribution(d.Z))
27+
end
28+
29+
#
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
2+
"""
3+
$(TYPEDEF)
4+
5+
Default linear offset between two scalar variables.
6+
7+
```math
8+
X_2 = X_1 + η_Z
9+
```
10+
"""
11+
struct LinearRelative{T} <: AbstractManifoldMinimize
12+
Z::T
13+
end
14+
15+
DFG.getManifold(obs::LinearRelative) = TranslationGroup(getDimension(obs.Z))
16+
17+
"""
18+
$(TYPEDEF)
19+
Serialization type for `LinearRelative` binary factor.
20+
"""
21+
Base.@kwdef mutable struct PackedLinearRelative <: AbstractPackedFactor
22+
Z::PackedSamplableBelief
23+
end
24+
25+
function DFG.pack(d::LinearRelative)
26+
return PackedLinearRelative(DFG.packDistribution(d.Z))
27+
end
28+
29+
function DFG.unpack(d::PackedLinearRelative)
30+
return LinearRelative(DFG.unpackDistribution(d.Z))
31+
end
32+
33+
#
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
2+
Base.@kwdef struct PackedCategorical <: PackedSamplableBelief
3+
_type::String = "IncrementalInferenceTypes.PackedCategorical"
4+
p::Vector{Float64} = [1.0;]
5+
end
6+
7+
Base.@kwdef mutable struct PackedUniform <: PackedSamplableBelief
8+
_type::String = "IncrementalInferenceTypes.PackedUniform"
9+
a::Float64 = 0.0
10+
b::Float64 = 1.0
11+
PackedSamplableTypeJSON::String = "IncrementalInferenceTypes.PackedUniform"
12+
end
13+
14+
Base.@kwdef struct PackedNormal <: PackedSamplableBelief
15+
_type::String = "IncrementalInferenceTypes.PackedNormal"
16+
mu::Float64 = 0.0
17+
sigma::Float64 = 1.0
18+
end
19+
20+
Base.@kwdef struct PackedZeroMeanDiagNormal <: PackedSamplableBelief
21+
_type::String = "IncrementalInferenceTypes.PackedZeroMeanDiagNormal"
22+
diag::Vector{Float64} = ones(1)
23+
end
24+
25+
Base.@kwdef struct PackedZeroMeanFullNormal <: PackedSamplableBelief
26+
_type::String = "IncrementalInferenceTypes.PackedZeroMeanFullNormal"
27+
cov::Vector{Float64} = ones(1)
28+
end
29+
30+
Base.@kwdef mutable struct PackedDiagNormal <: PackedSamplableBelief
31+
_type::String = "IncrementalInferenceTypes.PackedDiagNormal"
32+
mu::Vector{Float64} = zeros(1)
33+
diag::Vector{Float64} = ones(1)
34+
end
35+
36+
Base.@kwdef struct PackedFullNormal <: PackedSamplableBelief
37+
_type::String = "IncrementalInferenceTypes.PackedFullNormal"
38+
mu::Vector{Float64} = zeros(1)
39+
cov::Vector{Float64} = ones(1)
40+
end
41+
42+
Base.@kwdef struct PackedRayleigh <: PackedSamplableBelief
43+
_type::String = "IncrementalInferenceTypes.PackedRayleigh"
44+
sigma::Float64 = 1.0
45+
end
46+
47+
#
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
2+
## Distributions to JSON/Packed types
3+
4+
DFG.packDistribution(dtr::Categorical) = PackedCategorical(; p = dtr.p)
5+
DFG.packDistribution(dtr::Uniform) = PackedUniform(; a = dtr.a, b = dtr.b)
6+
DFG.packDistribution(dtr::Normal) = PackedNormal(; mu = dtr.μ, sigma = dtr.σ)
7+
DFG.packDistribution(dtr::ZeroMeanDiagNormal) = PackedZeroMeanDiagNormal(; diag = dtr.Σ.diag)
8+
DFG.packDistribution(dtr::ZeroMeanFullNormal) = PackedZeroMeanFullNormal(; cov = dtr.Σ.mat[:])
9+
DFG.packDistribution(dtr::DiagNormal) = PackedDiagNormal(; mu = dtr.μ, diag = dtr.Σ.diag)
10+
DFG.packDistribution(dtr::FullNormal) = PackedFullNormal(; mu = dtr.μ, cov = dtr.Σ.mat[:])
11+
DFG.packDistribution(dtr::Rayleigh) = PackedRayleigh(; sigma = dtr.σ)
12+
13+
## Unpack JSON/Packed to Distribution types
14+
15+
DFG.unpackDistribution(dtr::PackedCategorical) = Categorical(dtr.p ./ sum(dtr.p))
16+
DFG.unpackDistribution(dtr::PackedUniform) = Uniform(dtr.a, dtr.b)
17+
DFG.unpackDistribution(dtr::PackedNormal) = Normal(dtr.mu, dtr.sigma)
18+
function DFG.unpackDistribution(dtr::PackedZeroMeanDiagNormal)
19+
return MvNormal(LinearAlgebra.Diagonal(map(abs2, sqrt.(dtr.diag))))
20+
end # sqrt.(dtr.diag)
21+
function DFG.unpackDistribution(dtr::PackedZeroMeanFullNormal)
22+
d = round(Int, sqrt(size(dtr.cov)[1]))
23+
return MvNormal(reshape(dtr.cov, d, d))
24+
end
25+
DFG.unpackDistribution(dtr::PackedDiagNormal) = MvNormal(dtr.mu, sqrt.(dtr.diag))
26+
function DFG.unpackDistribution(dtr::PackedFullNormal)
27+
return MvNormal(dtr.mu, reshape(dtr.cov, length(dtr.mu), :))
28+
end
29+
DFG.unpackDistribution(dtr::PackedRayleigh) = Rayleigh(dtr.sigma)
30+

src/entities/SolverParams.jl renamed to IncrementalInferenceTypes/src/solverparams/SolverParams.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Dev Notes
1111
"""
1212
Base.@kwdef mutable struct SolverParams <: DFG.AbstractParams
1313
dimID::Int = 0
14-
reference::NothingUnion{Dict{Symbol, Tuple{Symbol, Vector{Float64}}}} = nothing
14+
reference::Union{Nothing, Dict{Symbol, Tuple{Symbol, Vector{Float64}}}} = nothing
1515
stateless::Bool = false
1616
""" Quasi fixed length """
1717
qfl::Int = (2^(Sys.WORD_SIZE - 1) - 1)
@@ -45,7 +45,7 @@ Base.@kwdef mutable struct SolverParams <: DFG.AbstractParams
4545
""" should Distributed.jl tree solve compute features be used """
4646
multiproc::Bool = 1 < nprocs()
4747
""" "/tmp/caesar/logs/$(now())" # unique temporary file storage location for a solve """
48-
logpath::String = joinpath(tempdir(),"caesar","logs","$(now(UTC))")
48+
logpath::String = joinpath(tempdir(),"caesar","logs","$(now(DFG.UTC))")
4949
""" default to graph-based initialization of variables """
5050
graphinit::Bool = true
5151
""" init variables on the tree """
@@ -76,6 +76,9 @@ end
7676

7777
StructTypes.omitempties(::Type{SolverParams}) = (:reference,)
7878

79-
80-
convert(::Type{SolverParams}, ::NoSolverParams) = SolverParams()
8179
#
80+
Base.convert(::Type{SolverParams}, ::NoSolverParams) = begin
81+
@warn "FIXME Why converting NoSolverParams to SolverParams?"
82+
SolverParams()
83+
end
84+
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#FIXME This is discouraged in the julia style guide, rather standardize to instance or type
2+
const InstanceType{T} = Union{Type{<:T}, <:T}
3+
4+
## Euclid 1
5+
6+
"""
7+
$TYPEDEF
8+
9+
Continuous Euclidean variable of dimension `N` representing a Position in cartesian space.
10+
"""
11+
struct Position{N} <: InferenceVariable end
12+
13+
Position(N::Int) = Position{N}()
14+
15+
# not sure if these overloads are necessary since DFG 775?
16+
DFG.getManifold(::InstanceType{Position{N}}) where {N} = TranslationGroup(N)
17+
function DFG.getDimension(val::InstanceType{Position{N}}) where {N}
18+
return manifold_dimension(getManifold(val))
19+
end
20+
DFG.getPointType(::Type{Position{N}}) where {N} = SVector{N, Float64}
21+
DFG.getPointIdentity(M_::Type{Position{N}}) where {N} = @SVector(zeros(N)) # identity_element(getManifold(M_), zeros(N))
22+
23+
24+
#
25+
26+
"""
27+
$(TYPEDEF)
28+
29+
Most basic continuous scalar variable in a `::DFG.AbstractDFG` object.
30+
31+
Alias of `Position{1}`
32+
"""
33+
const ContinuousScalar = Position{1}
34+
const ContinuousEuclid{N} = Position{N}
35+
36+
const Position1 = Position{1}
37+
const Position2 = Position{2}
38+
const Position3 = Position{3}
39+
const Position4 = Position{4}
40+
41+
#TODO maybe just use @defVariable for all Position types?
42+
# @defVariable Position1 TranslationGroup(1) @SVector(zeros(1))
43+
# @defVariable Position2 TranslationGroup(2) @SVector(zeros(2))
44+
# @defVariable Position3 TranslationGroup(3) @SVector(zeros(3))
45+
# @defVariable Position4 TranslationGroup(4) @SVector(zeros(4))
46+
47+
## Circular
48+
49+
"""
50+
$(TYPEDEF)
51+
52+
Circular is a `Manifolds.Circle{ℝ}` mechanization of one rotation, with `theta in [-pi,pi)`.
53+
"""
54+
@defVariable Circular RealCircleGroup() [0.0;]
55+
#TODO This is an example of what we want working, possible issue upstream in Manifolds.jl
56+
# @defVariable Circular RealCircleGroup() Scalar(0.0)
57+
58+
#

0 commit comments

Comments
 (0)