-
Notifications
You must be signed in to change notification settings - Fork 38
Expand file tree
/
Copy pathDynamicPPLDifferentiationInterfaceTestExt.jl
More file actions
86 lines (76 loc) · 2.62 KB
/
DynamicPPLDifferentiationInterfaceTestExt.jl
File metadata and controls
86 lines (76 loc) · 2.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
module DynamicPPLDifferentiationInterfaceTestExt
using DynamicPPL:
DynamicPPL,
ADTypes,
LogDensityProblems,
Model,
DI, # DifferentiationInterface
AbstractVarInfo,
VarInfo,
LogDensityFunction
import DifferentiationInterfaceTest as DIT
"""
REFERENCE_ADTYPE
Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since
it's the default AD backend used in Turing.jl.
"""
const REFERENCE_ADTYPE = ADTypes.AutoForwardDiff()
"""
make_scenario(
model::Model,
adtype::ADTypes.AbstractADType,
varinfo::AbstractVarInfo=VarInfo(model),
params::Vector{<:Real}=varinfo[:],
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
)
Construct a DifferentiationInterfaceTest.Scenario for the given `model` and `adtype`.
More docs to follow.
"""
function make_scenario(
model::Model,
adtype::ADTypes.AbstractADType;
varinfo::AbstractVarInfo=VarInfo(model),
params::Vector{<:Real}=varinfo[:],
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
expected_grad::Union{Nothing,Vector{<:Real}}=nothing,
)
params = map(identity, params)
context = DynamicPPL.DefaultContext()
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo, context)
if DynamicPPL.use_closure(adtype)
f = x -> DynamicPPL.logdensity_at(x, model, varinfo, context)
di_contexts = ()
else
f = DynamicPPL.logdensity_at
di_contexts = (DI.Constant(model), DI.Constant(varinfo), DI.Constant(context))
end
# Calculate ground truth to compare against
grad_true = if expected_grad === nothing
ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype)
LogDensityProblems.logdensity_and_gradient(ldf_reference, params)[2]
else
expected_grad
end
return DIT.Scenario{:gradient,:out}(
f, params; contexts=di_contexts, res1=grad_true, name="$(model.f)"
)
end
function DynamicPPL.TestUtils.AD.run_ad(
model::Model,
adtype::ADTypes.AbstractADType;
varinfo::AbstractVarInfo=VarInfo(model),
params::Vector{<:Real}=varinfo[:],
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
expected_grad::Union{Nothing,Vector{<:Real}}=nothing,
kwargs...,
)
scen = make_scenario(model, adtype; varinfo=varinfo, expected_grad=expected_grad)
tweaked_adtype = DynamicPPL.tweak_adtype(
adtype, model, varinfo, DynamicPPL.DefaultContext()
)
return DIT.test_differentiation(
tweaked_adtype, [scen]; scenario_intact=false, kwargs...
)
end
end