Skip to content

Commit 1a11336

Browse files
authored
Improve Hamiltonian constructors (#307)
* Improve Hamiltonian constructors * Update Project.toml * Simplify code * Update Project.toml * Update integrator.jl * Update demo.jl * Update common.jl * Update README.md
1 parent 3dc2822 commit 1a11336

File tree

6 files changed

+41
-19
lines changed

6 files changed

+41
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedHMC"
22
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
3-
version = "0.4"
3+
version = "0.4.1"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ struct LogTargetDensity
6060
end
6161
LogDensityProblems.logdensity(p::LogTargetDensity, θ) = -sum(abs2, θ) / 2 # standard multivariate normal
6262
LogDensityProblems.dimension(p::LogTargetDensity) = p.dim
63+
LogDensityProblems.capabilities(::Type{LogTargetDensity}) = LogDensityProblems.LogDensityOrder{0}()
6364

6465
# Choose parameter dimensionality and initial parameter value
6566
D = 10; initial_θ = rand(D)

src/AdvancedHMC.jl

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -144,34 +144,43 @@ export sample
144144

145145
include("abstractmcmc.jl")
146146

147-
function Hamiltonian(metric::AbstractMetric, ℓ::LogDensityModel)
148-
ℓπ =.logdensity
149-
150-
# Check we're capable of computing gradients.
151-
cap = LogDensityProblems.capabilities(ℓπ)
147+
## Without explicit AD backend
148+
function Hamiltonian(metric::AbstractMetric, ℓ::LogDensityModel; kwargs...)
149+
return Hamiltonian(metric, ℓ.logdensity; kwargs...)
150+
end
151+
function Hamiltonian(metric::AbstractMetric, ℓ; kwargs...)
152+
cap = LogDensityProblems.capabilities(ℓ)
152153
if cap === nothing
153154
throw(ArgumentError("The log density function does not support the LogDensityProblems.jl interface"))
154155
end
155-
156-
if cap === LogDensityProblems.LogDensityOrder{0}()
157-
throw(ArgumentError("The gradient of the log density function is not defined: Implement `LogDensityProblems.logdensity_and_gradient` or use automatic differentiation by calling `Hamiltionian(metric, model, AD; kwargs...)` where AD is one of the backends supported by LogDensityProblemsAD.jl"))
156+
# Check if we're capable of computing gradients.
157+
ℓπ = if cap === LogDensityProblems.LogDensityOrder{0}()
158+
# In this case ℓ does not support evaluation of the gradient of the log density function
159+
# We use ForwardDiff to compute the gradient
160+
LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; kwargs...)
161+
else
162+
# In this case ℓ already supports evaluation of the gradient of the log density function
163+
158164
end
159-
160165
return Hamiltonian(
161166
metric,
162-
Base.Fix1(LogDensityProblems.logdensity, .logdensity),
163-
Base.Fix1(LogDensityProblems.logdensity_and_gradient, .logdensity),
167+
Base.Fix1(LogDensityProblems.logdensity, ℓπ),
168+
Base.Fix1(LogDensityProblems.logdensity_and_gradient, ℓπ),
164169
)
165170
end
166-
function Hamiltonian(metric::AbstractMetric, ℓπ::LogDensityModel, kind::Union{Symbol,Val}; kwargs...)
167-
= LogDensityModel(LogDensityProblemsAD.ADgradient(kind, ℓπ.logdensity; kwargs...))
168-
return Hamiltonian(metric, ℓ)
171+
172+
## With explicit AD specification
173+
function Hamiltonian(metric::AbstractMetric, ℓπ::LogDensityModel, kind::Union{Symbol,Val,Module}; kwargs...)
174+
return Hamiltonian(metric, ℓπ.logdensity, kind; kwargs...)
169175
end
170-
function Hamiltonian(metric::AbstractMetric, ℓπ, kind::Union{Symbol,Val} = Val{:ForwardDiff}(); kwargs...)
171-
= LogDensityModel(LogDensityProblemsAD.ADgradient(kind, ℓπ; kwargs...))
176+
Hamiltonian(metric::AbstractMetric, ℓπ, m::Module; kwargs...) = Hamiltonian(metric, ℓπ, Val(Symbol(m)); kwargs...)
177+
function Hamiltonian(metric::AbstractMetric, ℓπ, kind::Union{Symbol,Val}; kwargs...)
178+
if LogDensityProblems.capabilities(ℓπ) === nothing
179+
throw(ArgumentError("The log density function does not support the LogDensityProblems.jl interface"))
180+
end
181+
= LogDensityProblemsAD.ADgradient(kind, ℓπ; kwargs...)
172182
return Hamiltonian(metric, ℓ)
173183
end
174-
Hamiltonian(metric::AbstractMetric, ℓπ, m::Module; kwargs...) = Hamiltonian(metric, ℓπ, Val(Symbol(m)); kwargs...)
175184

176185
### Init
177186

test/common.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@ using Bijectors: Bijectors
1818
struct LogDensityDistribution{D<:Distributions.Distribution}
1919
dist::D
2020
end
21+
2122
LogDensityProblems.dimension(d::LogDensityDistribution) = length(d.dist)
2223
function LogDensityProblems.logdensity(ld::LogDensityDistribution, y)
2324
d = ld.dist
2425
b = Bijectors.inverse(Bijectors.bijector(d))
2526
x, logjac = Bijectors.with_logabsdet_jacobian(b, y)
2627
return logpdf(d, x) + logjac
2728
end
29+
LogDensityProblems.capabilities(::Type{<:LogDensityDistribution}) = LogDensityProblems.LogDensityOrder{0}()
2830

2931
# Hand-coded multivariate Gaussian
3032

@@ -41,6 +43,7 @@ end
4143

4244
LogDensityProblems.dimension(g::Gaussian) = dim(g.m)
4345
LogDensityProblems.logdensity(g::Gaussian, x) = ℓπ_gaussian(g.m. g.s, x)
46+
LogDensityProblems.capabilities(::Type{<:Gaussian}) = LogDensityProblems.LogDensityOrder{0}()
4447

4548
function ∇ℓπ_gaussianl(m, s, x)
4649
g = m .- x
@@ -97,6 +100,7 @@ end
97100
# Make compat with `LogDensityProblems`.
98101
LogDensityProblems.dimension(::typeof(ℓπ_gdemo)) = 2
99102
LogDensityProblems.logdensity(::typeof(ℓπ_gdemo), θ) = ℓπ_gdemo(θ)
103+
LogDensityProblems.capabilities(::Type{typeof(ℓπ_gdemo)}) = LogDensityProblems.LogDensityOrder{0}()
100104

101105
test_show(x) = test_show(s -> length(s) > 0, x)
102106
function test_show(pred, x)

test/demo.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ using LinearAlgebra
77
struct DemoProblem
88
dim::Int
99
end
10+
1011
LogDensityProblems.logdensity(p::DemoProblem, θ) = logpdf(MvNormal(zeros(p.dim), I), θ)
1112
LogDensityProblems.dimension(p::DemoProblem) = p.dim
13+
LogDensityProblems.capabilities(::Type{DemoProblem}) = LogDensityProblems.LogDensityOrder{0}()
1214

1315
# Choose parameter dimensionality and initial parameter value
1416
D = 10
@@ -50,10 +52,13 @@ end
5052
# target distribution parametrized by ComponentsArray
5153
p1 = ComponentVector=2.0, σ=1)
5254
struct DemoProblemComponentArrays end
55+
5356
function LogDensityProblems.logdensity(::DemoProblemComponentArrays, p::ComponentArray)
5457
return -((1 - p.μ) / p.σ)^2
5558
end
5659
LogDensityProblems.dimension(::DemoProblemComponentArrays) = 2
60+
LogDensityProblems.capabilities(::Type{DemoProblemComponentArrays}) = LogDensityProblems.LogDensityOrder{0}()
61+
5762
ℓπ = DemoProblemComponentArrays()
5863

5964
# Define a Hamiltonian system

test/integrator.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,11 @@ using Statistics: mean
110110
struct NegU
111111
dim::Int
112112
end
113-
LogDensityProblems.logdensity(d::NegU, x) = -dot(x, x) / 2
113+
114+
LogDensityProblems.logdensity(::NegU, x) = -dot(x, x) / 2
114115
LogDensityProblems.dimension(d::NegU) = d.dim
116+
LogDensityProblems.capabilities(::Type{NegU}) = LogDensityProblems.LogDensityOrder{0}()
117+
115118
negU = NegU(1)
116119

117120
ϵ = 0.01

0 commit comments

Comments
 (0)