Skip to content

remove LogDensityProblemsAD #2490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 9 additions & 15 deletions ext/TuringDynamicHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,10 @@
### DynamicHMC backend - https://github.com/tpapp/DynamicHMC.jl
###

if isdefined(Base, :get_extension)
using DynamicHMC: DynamicHMC
using Turing
using Turing: AbstractMCMC, Random, LogDensityProblems, DynamicPPL
using Turing.Inference: ADTypes, LogDensityProblemsAD, TYPEDFIELDS
else
import ..DynamicHMC
using ..Turing
using ..Turing: AbstractMCMC, Random, LogDensityProblems, DynamicPPL
using ..Turing.Inference: ADTypes, LogDensityProblemsAD, TYPEDFIELDS
end
using DynamicHMC: DynamicHMC
using Turing
using Turing: AbstractMCMC, Random, LogDensityProblems, DynamicPPL
using Turing.Inference: ADTypes, TYPEDFIELDS

"""
DynamicNUTS
Expand Down Expand Up @@ -69,10 +62,11 @@
end

# Define log-density function.
ℓ = LogDensityProblemsAD.ADgradient(
Turing.LogDensityFunction(
model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext())
),
ℓ = DynamicPPL.LogDensityFunction(

Check warning on line 65 in ext/TuringDynamicHMCExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TuringDynamicHMCExt.jl#L65

Added line #L65 was not covered by tests
model,
vi,
DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext());
adtype=spl.alg.adtype,
)

# Perform initial step.
Expand Down
41 changes: 21 additions & 20 deletions ext/TuringOptimExt.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
module TuringOptimExt

if isdefined(Base, :get_extension)
using Turing: Turing
import Turing: DynamicPPL, NamedArrays, Accessors, Optimisation
using Optim: Optim
else
import ..Turing
import ..Turing: DynamicPPL, NamedArrays, Accessors, Optimisation
import ..Optim
end
using Turing: Turing
import Turing: DynamicPPL, NamedArrays, Accessors, Optimisation
using Optim: Optim

####################
# Optim.jl methods #
Expand Down Expand Up @@ -42,7 +36,7 @@
)
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
f = Optimisation.OptimLogDensity(model, ctx)
init_vals = DynamicPPL.getparams(f)
init_vals = DynamicPPL.getparams(f.ldf)

Check warning on line 39 in ext/TuringOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TuringOptimExt.jl#L39

Added line #L39 was not covered by tests
optimizer = Optim.LBFGS()
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
end
Expand All @@ -65,7 +59,7 @@
)
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
f = Optimisation.OptimLogDensity(model, ctx)
init_vals = DynamicPPL.getparams(f)
init_vals = DynamicPPL.getparams(f.ldf)

Check warning on line 62 in ext/TuringOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TuringOptimExt.jl#L62

Added line #L62 was not covered by tests
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
end
function Optim.optimize(
Expand Down Expand Up @@ -112,7 +106,7 @@
)
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
f = Optimisation.OptimLogDensity(model, ctx)
init_vals = DynamicPPL.getparams(f)
init_vals = DynamicPPL.getparams(f.ldf)

Check warning on line 109 in ext/TuringOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TuringOptimExt.jl#L109

Added line #L109 was not covered by tests
optimizer = Optim.LBFGS()
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
end
Expand All @@ -135,7 +129,7 @@
)
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
f = Optimisation.OptimLogDensity(model, ctx)
init_vals = DynamicPPL.getparams(f)
init_vals = DynamicPPL.getparams(f.ldf)

Check warning on line 132 in ext/TuringOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TuringOptimExt.jl#L132

Added line #L132 was not covered by tests
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
end
function Optim.optimize(
Expand All @@ -162,17 +156,20 @@
function _optimize(
model::DynamicPPL.Model,
f::Optimisation.OptimLogDensity,
init_vals::AbstractArray=DynamicPPL.getparams(f),
init_vals::AbstractArray=DynamicPPL.getparams(f.ldf),
optimizer::Optim.AbstractOptimizer=Optim.LBFGS(),
options::Optim.Options=Optim.Options(),
args...;
kwargs...,
)
# Convert the initial values, since it is assumed that users provide them
# in the constrained space.
f = Accessors.@set f.varinfo = DynamicPPL.unflatten(f.varinfo, init_vals)
f = Accessors.@set f.varinfo = DynamicPPL.link(f.varinfo, model)
init_vals = DynamicPPL.getparams(f)
# TODO(penelopeysm): As with in src/optimisation/Optimisation.jl, unclear
# whether initialisation is really necessary at all
vi = DynamicPPL.unflatten(f.ldf.varinfo, init_vals)
vi = DynamicPPL.link(vi, f.ldf.model)
f = Optimisation.OptimLogDensity(f.ldf.model, vi, f.ldf.context; adtype=f.ldf.adtype)
init_vals = DynamicPPL.getparams(f.ldf)

Check warning on line 172 in ext/TuringOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TuringOptimExt.jl#L169-L172

Added lines #L169 - L172 were not covered by tests

# Optimize!
M = Optim.optimize(Optim.only_fg!(f), init_vals, optimizer, options, args...; kwargs...)
Expand All @@ -186,12 +183,16 @@
end

# Get the optimum in unconstrained space. `getparams` does the invlinking.
f = Accessors.@set f.varinfo = DynamicPPL.unflatten(f.varinfo, M.minimizer)
vns_vals_iter = Turing.Inference.getparams(model, f.varinfo)
vi = f.ldf.varinfo
vi_optimum = DynamicPPL.unflatten(vi, M.minimizer)
logdensity_optimum = Optimisation.OptimLogDensity(

Check warning on line 188 in ext/TuringOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TuringOptimExt.jl#L186-L188

Added lines #L186 - L188 were not covered by tests
f.ldf.model, vi_optimum, f.ldf.context
)
vns_vals_iter = Turing.Inference.getparams(model, vi_optimum)

Check warning on line 191 in ext/TuringOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TuringOptimExt.jl#L191

Added line #L191 was not covered by tests
varnames = map(Symbol ∘ first, vns_vals_iter)
vals = map(last, vns_vals_iter)
vmat = NamedArrays.NamedArray(vals, varnames)
return Optimisation.ModeResult(vmat, M, -M.minimum, f)
return Optimisation.ModeResult(vmat, M, -M.minimum, logdensity_optimum)

Check warning on line 195 in ext/TuringOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TuringOptimExt.jl#L195

Added line #L195 was not covered by tests
end

end # module
24 changes: 0 additions & 24 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ import AdvancedPS
import Accessors
import EllipticalSliceSampling
import LogDensityProblems
import LogDensityProblemsAD
import Random
import MCMCChains
import StatsBase: predict
Expand Down Expand Up @@ -160,29 +159,6 @@ function externalsampler(
return ExternalSampler(sampler, adtype, Val(unconstrained))
end

getADType(spl::Sampler) = getADType(spl.alg)
getADType(::SampleFromPrior) = Turing.DEFAULT_ADTYPE

getADType(ctx::DynamicPPL.SamplingContext) = getADType(ctx.sampler)
getADType(ctx::DynamicPPL.AbstractContext) = getADType(DynamicPPL.NodeTrait(ctx), ctx)
getADType(::DynamicPPL.IsLeaf, ctx::DynamicPPL.AbstractContext) = Turing.DEFAULT_ADTYPE
function getADType(::DynamicPPL.IsParent, ctx::DynamicPPL.AbstractContext)
return getADType(DynamicPPL.childcontext(ctx))
end

getADType(alg::Hamiltonian) = alg.adtype

function LogDensityProblemsAD.ADgradient(ℓ::DynamicPPL.LogDensityFunction)
return LogDensityProblemsAD.ADgradient(getADType(ℓ.context), ℓ)
end

function LogDensityProblems.logdensity(
f::Turing.LogDensityFunction{<:AbstractVarInfo,<:Model,<:DynamicPPL.DefaultContext},
x::NamedTuple,
)
return DynamicPPL.logjoint(f.model, DynamicPPL.unflatten(f.varinfo, x))
end

# TODO: make a nicer `set_namedtuple!` and move these functions to DynamicPPL.
function DynamicPPL.unflatten(vi::TypedVarInfo, θ::NamedTuple)
set_namedtuple!(deepcopy(vi), θ)
Expand Down
51 changes: 11 additions & 40 deletions src/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
struct TuringState{S,F}
struct TuringState{S,M,V,C}
state::S
logdensity::F
ldf::DynamicPPL.LogDensityFunction{M,V,C}
end

state_to_turing(f::DynamicPPL.LogDensityFunction, state) = TuringState(state, f)
Expand All @@ -12,20 +12,10 @@
return Transition(f.model, varinfo, transition)
end

state_to_turing(f::LogDensityProblemsAD.ADGradientWrapper, state) = TuringState(state, f)
function transition_to_turing(f::LogDensityProblemsAD.ADGradientWrapper, transition)
return transition_to_turing(parent(f), transition)
end

function varinfo_from_logdensityfn(f::LogDensityProblemsAD.ADGradientWrapper)
return varinfo_from_logdensityfn(parent(f))
end
varinfo_from_logdensityfn(f::DynamicPPL.LogDensityFunction) = f.varinfo

function varinfo(state::TuringState)
θ = getparams(DynamicPPL.getmodel(state.logdensity), state.state)
θ = getparams(state.ldf.model, state.state)

Check warning on line 16 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L16

Added line #L16 was not covered by tests
# TODO: Do we need to link here first?
return DynamicPPL.unflatten(varinfo_from_logdensityfn(state.logdensity), θ)
return DynamicPPL.unflatten(state.ldf.varinfo, θ)

Check warning on line 18 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L18

Added line #L18 was not covered by tests
end
varinfo(state::AbstractVarInfo) = state

Expand All @@ -40,23 +30,6 @@

getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params

getvarinfo(f::DynamicPPL.LogDensityFunction) = f.varinfo
function getvarinfo(f::LogDensityProblemsAD.ADGradientWrapper)
return getvarinfo(LogDensityProblemsAD.parent(f))
end

function setvarinfo(f::DynamicPPL.LogDensityFunction, varinfo)
return DynamicPPL.LogDensityFunction(f.model, varinfo, f.context; adtype=f.adtype)
end

function setvarinfo(
f::LogDensityProblemsAD.ADGradientWrapper, varinfo, adtype::ADTypes.AbstractADType
)
return LogDensityProblemsAD.ADgradient(
adtype, setvarinfo(LogDensityProblemsAD.parent(f), varinfo)
)
end

# TODO: Do we also support `resume`, etc?
function AbstractMCMC.step(
rng::Random.AbstractRNG,
Expand All @@ -69,12 +42,8 @@
alg = sampler_wrapper.alg
sampler = alg.sampler

# Create a log-density function with an implementation of the
# gradient so we ensure that we're using the same AD backend as in Turing.
f = LogDensityProblemsAD.ADgradient(alg.adtype, DynamicPPL.LogDensityFunction(model))

# Link the varinfo if needed.
varinfo = getvarinfo(f)
# Initialise varinfo with initial params and link the varinfo if needed.
varinfo = DynamicPPL.VarInfo(model)
if requires_unconstrained_space(alg)
if initial_params !== nothing
# If we have initial parameters, we need to set the varinfo before linking.
Expand All @@ -85,9 +54,11 @@
varinfo = DynamicPPL.link(varinfo, model)
end
end
f = setvarinfo(f, varinfo, alg.adtype)

# Then just call `AdvancedHMC.step` with the right arguments.
# Construct LogDensityFunction
f = DynamicPPL.LogDensityFunction(model, varinfo; adtype=alg.adtype)

# Then just call `AbstractMCMC.step` with the right arguments.
if initial_state === nothing
transition_inner, state_inner = AbstractMCMC.step(
rng, AbstractMCMC.LogDensityModel(f), sampler; initial_params, kwargs...
Expand All @@ -114,7 +85,7 @@
kwargs...,
)
sampler = sampler_wrapper.alg.sampler
f = state.logdensity
f = state.ldf

# Then just call `AdvancedHMC.step` with the right arguments.
transition_inner, state_inner = AbstractMCMC.step(
Expand Down
45 changes: 24 additions & 21 deletions src/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,19 +156,19 @@
# Create a Hamiltonian.
metricT = getmetricT(spl.alg)
metric = metricT(length(theta))
= LogDensityProblemsAD.ADgradient(
Turing.LogDensityFunction(
model,
vi,
# Use the leaf-context from the `model` in case the user has
# contextualized the model with something like `PriorContext`
# to sample from the prior.
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)),
),
ldf = DynamicPPL.LogDensityFunction(

Check warning on line 159 in src/mcmc/hmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/hmc.jl#L159

Added line #L159 was not covered by tests
model,
vi,
# TODO(penelopeysm): Can we just use leafcontext(model.context)? Do we
# need to pass in the sampler? (In fact LogDensityFunction defaults to
# using leafcontext(model.context) so could we just remove the argument
# entirely?)
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context));
adtype=spl.alg.adtype,
)
logπ = Base.Fix1(LogDensityProblems.logdensity, )
∂logπ∂θ(x) = LogDensityProblems.logdensity_and_gradient(ℓ, x)
hamiltonian = AHMC.Hamiltonian(metric, logπ, ∂logπ∂θ)
lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)
lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf)
hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func)

Check warning on line 171 in src/mcmc/hmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/hmc.jl#L169-L171

Added lines #L169 - L171 were not covered by tests

# Compute phase point z.
z = AHMC.phasepoint(rng, theta, hamiltonian)
Expand Down Expand Up @@ -287,16 +287,19 @@

function get_hamiltonian(model, spl, vi, state, n)
metric = gen_metric(n, spl, state)
ℓ = LogDensityProblemsAD.ADgradient(
Turing.LogDensityFunction(
model,
vi,
DynamicPPL.SamplingContext(spl, DynamicPPL.leafcontext(model.context)),
),
ldf = DynamicPPL.LogDensityFunction(

Check warning on line 290 in src/mcmc/hmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/hmc.jl#L290

Added line #L290 was not covered by tests
model,
vi,
# TODO(penelopeysm): Can we just use leafcontext(model.context)? Do we
# need to pass in the sampler? (In fact LogDensityFunction defaults to
# using leafcontext(model.context) so could we just remove the argument
# entirely?)
DynamicPPL.SamplingContext(spl, DynamicPPL.leafcontext(model.context));
adtype=spl.alg.adtype,
)
ℓπ = Base.Fix1(LogDensityProblems.logdensity, )
∂ℓπ∂θ = Base.Fix1(LogDensityProblems.logdensity_and_gradient, )
return AHMC.Hamiltonian(metric, ℓπ, ∂ℓπ∂θ)
lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)
lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf)
return AHMC.Hamiltonian(metric, lp_func, lp_grad_func)

Check warning on line 302 in src/mcmc/hmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/hmc.jl#L300-L302

Added lines #L300 - L302 were not covered by tests
end

"""
Expand Down
18 changes: 10 additions & 8 deletions src/mcmc/sghmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,11 @@

# Compute initial sample and state.
sample = Transition(model, vi)
ℓ = LogDensityProblemsAD.ADgradient(
Turing.LogDensityFunction(
model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext())
),
ℓ = DynamicPPL.LogDensityFunction(

Check warning on line 69 in src/mcmc/sghmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/sghmc.jl#L69

Added line #L69 was not covered by tests
model,
vi,
DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext());
adtype=spl.alg.adtype,
)
state = SGHMCState(ℓ, vi, zero(vi[spl]))

Expand Down Expand Up @@ -228,10 +229,11 @@

# Create first sample and state.
sample = SGLDTransition(model, vi, zero(spl.alg.stepsize(0)))
ℓ = LogDensityProblemsAD.ADgradient(
Turing.LogDensityFunction(
model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext())
),
ℓ = DynamicPPL.LogDensityFunction(

Check warning on line 232 in src/mcmc/sghmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/sghmc.jl#L232

Added line #L232 was not covered by tests
model,
vi,
DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext());
adtype=spl.alg.adtype,
)
state = SGLDState(ℓ, vi, 1)

Expand Down
Loading
Loading