Skip to content

Commit 05f1bce

Browse files
committed
setadtype --> LogDensityFunction
1 parent f76bb3d commit 05f1bce

File tree

3 files changed

+21
-16
lines changed

3 files changed

+21
-16
lines changed

HISTORY.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,14 @@ LogDensityProblems.logdensity_and_gradient(ldf, params)
155155
156156
without having to construct a separate `ADgradient` object.
157157
158-
If you prefer, you can also use `setadtype` to tack on the AD type afterwards:
158+
If you prefer, you can also construct a new `LogDensityFunction` with a new AD type afterwards.
159+
The model, varinfo, and context will be taken from the original `LogDensityFunction`:
159160
160161
```julia
161162
@model f() = ...
162163
163164
ldf = LogDensityFunction(f()) # by default, no adtype set
164-
ldf_with_ad = setadtype(ldf, AutoForwardDiff())
165+
ldf_with_ad = LogDensityFunction(ldf, AutoForwardDiff())
165166
```
166167
167168
## 0.34.2

src/logdensityfunction.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -143,18 +143,18 @@ struct LogDensityFunction{
143143
end
144144

145145
"""
146-
setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType})
147-
148-
Set the AD type used for evaluation of log density gradient in the given
149-
LogDensityFunction. This function also performs preparation of the gradient,
150-
and sets the `prep` field of the LogDensityFunction.
151-
152-
If `adtype` is `nothing`, the `prep` field will be set to `nothing` as well.
146+
LogDensityFunction(
147+
ldf::LogDensityFunction,
148+
adtype::Union{Nothing,ADTypes.AbstractADType}
149+
)
153150
154-
This function returns a new LogDensityFunction with the updated AD type, i.e. it does
155-
not mutate the input LogDensityFunction.
151+
Create a new LogDensityFunction using the model, varinfo, and context from the given
152+
`ldf` argument, but with the AD type set to `adtype`. To remove the AD type, pass
153+
`nothing` as the second argument.
156154
"""
157-
function setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType})
155+
function LogDensityFunction(
156+
f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType}
157+
)
158158
return if adtype === f.adtype
159159
f # Avoid recomputing prep if not needed
160160
else

test/ad.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,21 @@ using DynamicPPL: LogDensityFunction
3939
# Mooncake doesn't work with several combinations of SimpleVarInfo.
4040
if is_mooncake && is_1_11 && is_svi_vnv
4141
# https://github.com/compintell/Mooncake.jl/issues/470
42-
@test_throws ArgumentError DynamicPPL.setadtype(ref_ldf, adtype)
42+
@test_throws ArgumentError DynamicPPL.LogDensityFunction(
43+
ref_ldf, adtype
44+
)
4345
elseif is_mooncake && is_1_10 && is_svi_vnv
4446
# TODO: report upstream
45-
@test_throws UndefRefError DynamicPPL.setadtype(ref_ldf, adtype)
47+
@test_throws UndefRefError DynamicPPL.LogDensityFunction(
48+
ref_ldf, adtype
49+
)
4650
elseif is_mooncake && is_1_10 && is_svi_od
4751
# TODO: report upstream
48-
@test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.setadtype(
52+
@test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction(
4953
ref_ldf, adtype
5054
)
5155
else
52-
ldf = DynamicPPL.setadtype(ref_ldf, adtype)
56+
ldf = DynamicPPL.LogDensityFunction(ref_ldf, adtype)
5357
logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
5458
@test grad ref_grad
5559
@test logp ref_logp

0 commit comments

Comments
 (0)