|
1 | 1 | import DifferentiationInterface as DI
|
2 | 2 |
|
3 | 3 | """
|
4 |
| - LogDensityFunction |
5 |
| -
|
6 |
| -A callable representing a log density function of a `model`. |
7 |
| -`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface, |
8 |
| -but only to 0th-order, i.e. it is only possible to calculate the log density, |
9 |
| -and not its gradient. If you need to calculate the gradient as well, you have |
10 |
| -to construct a [`DynamicPPL.LogDensityFunctionWithGrad`](@ref) object. |
| 4 | + LogDensityFunction( |
| 5 | + model::Model, |
| 6 | + varinfo::AbstractVarInfo=VarInfo(model), |
| 7 | + context::AbstractContext=DefaultContext(); |
| 8 | + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing |
| 9 | + ) |
| 10 | +
|
| 11 | +A struct which contains a model, along with all the information necessary to: |
| 12 | +
|
| 13 | + - calculate its log density at a given point; |
| 14 | + - and if `adtype` is provided, calculate the gradient of the log density at |
| 15 | + that point. |
| 16 | +
|
| 17 | +At its most basic level, a LogDensityFunction wraps the model together with its |
| 18 | +the type of varinfo to be used, as well as the evaluation context. These must |
| 19 | +be known in order to calculate the log density (using |
| 20 | +[`DynamicPPL.evaluate!!`](@ref)). |
| 21 | +
|
| 22 | +If the `adtype` keyword argument is provided, then this struct will also store |
| 23 | +the adtype along with other information for efficient calculation of the |
| 24 | +gradient of the log density. Note that preparing a `LogDensityFunction` with an |
| 25 | +AD type `AutoBackend()` requires the AD backend itself to have been loaded |
| 26 | +(e.g. with `import Backend`). |
| 27 | +
|
| 28 | +`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface. |
| 29 | +If `adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a |
| 30 | +concrete AD backend type, then `logdensity_and_gradient` is also implemented. |
11 | 31 |
|
12 | 32 | # Fields
|
13 | 33 | $(FIELDS)
|
14 | 34 |
|
15 | 35 | # Examples
|
| 36 | +
|
16 | 37 | ```jldoctest
|
17 | 38 | julia> using Distributions
|
18 | 39 |
|
@@ -48,66 +69,150 @@ julia> # This also respects the context in `model`.
|
48 | 69 |
|
49 | 70 | julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
|
50 | 71 | true
|
| 72 | +
|
| 73 | +julia> # If we also need to calculate the gradient, we can specify an AD backend. |
| 74 | + import ForwardDiff, ADTypes |
| 75 | +
|
| 76 | +julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff()); |
| 77 | +
|
| 78 | +julia> LogDensityProblems.logdensity_and_gradient(f, [0.0]) |
| 79 | +(-2.3378770664093453, [1.0]) |
51 | 80 | ```
|
52 | 81 | """
|
53 |
| -struct LogDensityFunction{V,M,C} |
54 |
| - "varinfo used for evaluation" |
55 |
| - varinfo::V |
| 82 | +struct LogDensityFunction{ |
| 83 | + M<:Model,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType} |
| 84 | +} |
56 | 85 | "model used for evaluation"
|
57 | 86 | model::M
|
| 87 | + "varinfo used for evaluation" |
| 88 | + varinfo::V |
58 | 89 | "context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable"
|
59 | 90 | context::C
|
60 |
| -end |
61 |
| - |
62 |
| -function LogDensityFunction( |
63 |
| - model::Model, |
64 |
| - varinfo::AbstractVarInfo=VarInfo(model), |
65 |
| - context::Union{Nothing,AbstractContext}=nothing, |
66 |
| -) |
67 |
| - return LogDensityFunction(varinfo, model, context) |
68 |
| -end |
| 91 | + "AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated" |
| 92 | + adtype::AD |
| 93 | + "(internal use only) gradient preparation object for the model" |
| 94 | + prep::Union{Nothing,DI.GradientPrep} |
| 95 | + "(internal use only) whether a closure was used for the gradient preparation" |
| 96 | + with_closure::Bool |
69 | 97 |
|
70 |
| -# If a `context` has been specified, we use that. Otherwise we just use the leaf context of `model`. |
71 |
| -function getcontext(f::LogDensityFunction) |
72 |
| - return f.context === nothing ? leafcontext(f.model.context) : f.context |
| 98 | + function LogDensityFunction( |
| 99 | + model::Model, |
| 100 | + varinfo::AbstractVarInfo=VarInfo(model), |
| 101 | + context::AbstractContext=leafcontext(model.context); |
| 102 | + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, |
| 103 | + ) |
| 104 | + if adtype === nothing |
| 105 | + prep = nothing |
| 106 | + with_closure = false |
| 107 | + else |
| 108 | + # Get a set of dummy params to use for prep |
| 109 | + x = map(identity, varinfo[:]) |
| 110 | + with_closure = use_closure(adtype) |
| 111 | + if with_closure |
| 112 | + prep = DI.prepare_gradient( |
| 113 | + x -> logdensity_at(x, model, varinfo, context), adtype, x |
| 114 | + ) |
| 115 | + else |
| 116 | + prep = DI.prepare_gradient( |
| 117 | + logdensity_at, |
| 118 | + adtype, |
| 119 | + x, |
| 120 | + DI.Constant(model), |
| 121 | + DI.Constant(varinfo), |
| 122 | + DI.Constant(context), |
| 123 | + ) |
| 124 | + end |
| 125 | + with_closure = with_closure |
| 126 | + end |
| 127 | + return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}( |
| 128 | + model, varinfo, context, adtype, prep, with_closure |
| 129 | + ) |
| 130 | + end |
73 | 131 | end
|
74 | 132 |
|
75 | 133 | """
|
76 |
| - getmodel(f) |
| 134 | + setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType}) |
77 | 135 |
|
78 |
| -Return the `DynamicPPL.Model` wrapped in the given log-density function `f`. |
79 |
| -""" |
80 |
| -getmodel(f::DynamicPPL.LogDensityFunction) = f.model |
| 136 | +Set the AD type used for evaluation of log density gradient in the given LogDensityFunction. |
| 137 | +This function also performs preparation of the gradient, and sets the `prep` |
| 138 | +and `with_closure` fields of the LogDensityFunction. |
81 | 139 |
|
82 |
| -""" |
83 |
| - setmodel(f, model[, adtype]) |
| 140 | +If `adtype` is `nothing`, the `prep` field will be set to `nothing` as well. |
84 | 141 |
|
85 |
| -Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. |
| 142 | +This function returns a new LogDensityFunction with the updated AD type, i.e. it does |
| 143 | +not mutate the input LogDensityFunction. |
86 | 144 | """
|
87 |
| -function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) |
88 |
| - return Accessors.@set f.model = model |
| 145 | +function setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType}) |
| 146 | + return if adtype === f.adtype |
| 147 | + f # Avoid recomputing prep if not needed |
| 148 | + else |
| 149 | + LogDensityFunction(f.model, f.varinfo, f.context; adtype=adtype) |
| 150 | + end |
89 | 151 | end
|
90 | 152 |
|
91 | 153 | """
|
92 |
| - getparams(f::LogDensityFunction) |
93 |
| -
|
94 |
| -Return the parameters of the wrapped varinfo as a vector. |
| 154 | + logdensity_at( |
| 155 | + x::AbstractVector, |
| 156 | + model::Model, |
| 157 | + varinfo::AbstractVarInfo, |
| 158 | + context::AbstractContext |
| 159 | + ) |
| 160 | +
|
| 161 | +Evaluate the log density of the given `model` at the given parameter values `x`, |
| 162 | +using the given `varinfo` and `context`. Note that the `varinfo` argument is provided |
| 163 | +only for its structure, in the sense that the parameters from the vector `x` are inserted into |
| 164 | +it, and its own parameters are discarded. |
95 | 165 | """
|
96 |
| -getparams(f::LogDensityFunction) = f.varinfo[:] |
| 166 | +function logdensity_at( |
| 167 | + x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext |
| 168 | +) |
| 169 | + varinfo_new = unflatten(varinfo, x) |
| 170 | + return getlogp(last(evaluate!!(model, varinfo_new, context))) |
| 171 | +end |
| 172 | + |
| 173 | +### LogDensityProblems interface |
97 | 174 |
|
98 |
| -# LogDensityProblems interface: logp (0th order) |
| 175 | +function LogDensityProblems.capabilities( |
| 176 | + ::Type{<:LogDensityFunction{M,V,C,Nothing}} |
| 177 | +) where {M,V,C} |
| 178 | + return LogDensityProblems.LogDensityOrder{0}() |
| 179 | +end |
| 180 | +function LogDensityProblems.capabilities( |
| 181 | + ::Type{<:LogDensityFunction{M,V,C,AD}} |
| 182 | +) where {M,V,C,AD<:ADTypes.AbstractADType} |
| 183 | + return LogDensityProblems.LogDensityOrder{1}() |
| 184 | +end |
99 | 185 | function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector)
|
100 |
| - context = getcontext(f) |
101 |
| - vi_new = unflatten(f.varinfo, x) |
102 |
| - return getlogp(last(evaluate!!(f.model, vi_new, context))) |
| 186 | + return logdensity_at(x, f.model, f.varinfo, f.context) |
103 | 187 | end
|
104 |
| -function LogDensityProblems.capabilities(::Type{<:LogDensityFunction}) |
105 |
| - return LogDensityProblems.LogDensityOrder{0}() |
| 188 | +function LogDensityProblems.logdensity_and_gradient( |
| 189 | + f::LogDensityFunction{M,V,C,AD}, x::AbstractVector |
| 190 | +) where {M,V,C,AD<:ADTypes.AbstractADType} |
| 191 | + f.prep === nothing && |
| 192 | + error("Gradient preparation not available; this should not happen") |
| 193 | + x = map(identity, x) # Concretise type |
| 194 | + return if f.with_closure |
| 195 | + DI.value_and_gradient( |
| 196 | + x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x |
| 197 | + ) |
| 198 | + else |
| 199 | + DI.value_and_gradient( |
| 200 | + logdensity_at, |
| 201 | + f.prep, |
| 202 | + f.adtype, |
| 203 | + x, |
| 204 | + DI.Constant(f.model), |
| 205 | + DI.Constant(f.varinfo), |
| 206 | + DI.Constant(f.context), |
| 207 | + ) |
| 208 | + end |
106 | 209 | end
|
| 210 | + |
107 | 211 | # TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)?
|
108 | 212 | LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f))
|
109 | 213 |
|
110 |
| -# LogDensityProblems interface: gradient (1st order) |
| 214 | +### Utils |
| 215 | + |
111 | 216 | """
|
112 | 217 | use_closure(adtype::ADTypes.AbstractADType)
|
113 | 218 |
|
@@ -139,75 +244,24 @@ use_closure(::ADTypes.AutoMooncake) = false
|
139 | 244 | use_closure(::ADTypes.AutoReverseDiff) = true
|
140 | 245 |
|
141 | 246 | """
|
142 |
| - _flipped_logdensity(f::LogDensityFunction, x::AbstractVector) |
| 247 | + getmodel(f) |
143 | 248 |
|
144 |
| -This function is the same as `LogDensityProblems.logdensity(f, x)` but with the |
145 |
| -arguments flipped. It is used in the 'constant' approach to DifferentiationInterface |
146 |
| -(see `use_closure` for more information). |
| 249 | +Return the `DynamicPPL.Model` wrapped in the given log-density function `f`. |
147 | 250 | """
|
148 |
| -function _flipped_logdensity(x::AbstractVector, f::LogDensityFunction) |
149 |
| - return LogDensityProblems.logdensity(f, x) |
150 |
| -end |
| 251 | +getmodel(f::DynamicPPL.LogDensityFunction) = f.model |
151 | 252 |
|
152 | 253 | """
|
153 |
| - LogDensityFunctionWithGrad(ldf::DynamicPPL.LogDensityFunction, adtype::ADTypes.AbstractADType) |
154 |
| -
|
155 |
| -A callable representing a log density function of a `model`. |
156 |
| -`DynamicPPL.LogDensityFunctionWithGrad` implements the LogDensityProblems.jl |
157 |
| -interface to 1st-order, meaning that you can both calculate the log density |
158 |
| -using |
159 |
| -
|
160 |
| - LogDensityProblems.logdensity(f, x) |
161 |
| -
|
162 |
| -and its gradient using |
163 |
| -
|
164 |
| - LogDensityProblems.logdensity_and_gradient(f, x) |
| 254 | + setmodel(f, model[, adtype]) |
165 | 255 |
|
166 |
| -where `f` is a `LogDensityFunctionWithGrad` object and `x` is a vector of parameters. |
| 256 | +Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. |
| 257 | +""" |
| 258 | +function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) |
| 259 | + return LogDensityFunction(model, f.varinfo, f.context; adtype=f.adtype) |
| 260 | +end |
167 | 261 |
|
168 |
| -# Fields |
169 |
| -$(FIELDS) |
170 | 262 | """
|
171 |
| -struct LogDensityFunctionWithGrad{V,M,C,TAD<:ADTypes.AbstractADType} |
172 |
| - ldf::LogDensityFunction{V,M,C} |
173 |
| - adtype::TAD |
174 |
| - prep::DI.GradientPrep |
175 |
| - with_closure::Bool |
| 263 | + getparams(f::LogDensityFunction) |
176 | 264 |
|
177 |
| - function LogDensityFunctionWithGrad( |
178 |
| - ldf::LogDensityFunction{V,M,C}, adtype::TAD |
179 |
| - ) where {V,M,C,TAD} |
180 |
| - # Get a set of dummy params to use for prep |
181 |
| - x = map(identity, getparams(ldf)) |
182 |
| - with_closure = use_closure(adtype) |
183 |
| - if with_closure |
184 |
| - prep = DI.prepare_gradient( |
185 |
| - Base.Fix1(LogDensityProblems.logdensity, ldf), adtype, x |
186 |
| - ) |
187 |
| - else |
188 |
| - prep = DI.prepare_gradient(_flipped_logdensity, adtype, x, DI.Constant(ldf)) |
189 |
| - end |
190 |
| - # Store the prep with the struct. We also store whether a closure was used because |
191 |
| - # we need to know this when calling `DI.value_and_gradient`. In practice we could |
192 |
| - # recalculate it, but this runs the risk of introducing inconsistencies. |
193 |
| - return new{V,M,C,TAD}(ldf, adtype, prep, with_closure) |
194 |
| - end |
195 |
| -end |
196 |
| -function LogDensityProblems.logdensity(f::LogDensityFunctionWithGrad) |
197 |
| - return LogDensityProblems.logdensity(f.ldf) |
198 |
| -end |
199 |
| -function LogDensityProblems.capabilities(::Type{<:LogDensityFunctionWithGrad}) |
200 |
| - return LogDensityProblems.LogDensityOrder{1}() |
201 |
| -end |
202 |
| -function LogDensityProblems.logdensity_and_gradient( |
203 |
| - f::LogDensityFunctionWithGrad, x::AbstractVector |
204 |
| -) |
205 |
| - x = map(identity, x) # Concretise type |
206 |
| - return if f.with_closure |
207 |
| - DI.value_and_gradient( |
208 |
| - Base.Fix1(LogDensityProblems.logdensity, f.ldf), f.prep, f.adtype, x |
209 |
| - ) |
210 |
| - else |
211 |
| - DI.value_and_gradient(_flipped_logdensity, f.prep, f.adtype, x, DI.Constant(f.ldf)) |
212 |
| - end |
213 |
| -end |
| 265 | +Return the parameters of the wrapped varinfo as a vector. |
| 266 | +""" |
| 267 | +getparams(f::LogDensityFunction) = f.varinfo[:] |
0 commit comments