|
1 | 1 | module DynamicPPLForwardDiffExt |
2 | 2 |
|
3 | | -if isdefined(Base, :get_extension) |
4 | | - using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD |
5 | | - using ForwardDiff |
6 | | -else |
7 | | - using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD |
8 | | - using ..ForwardDiff |
9 | | -end |
10 | | - |
11 | | -getchunksize(::ADTypes.AutoForwardDiff{chunk}) where {chunk} = chunk |
12 | | - |
13 | | -standardtag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true |
14 | | -standardtag(::ADTypes.AutoForwardDiff) = false |
15 | | - |
16 | | -function LogDensityProblemsAD.ADgradient( |
17 | | - ad::ADTypes.AutoForwardDiff, ℓ::DynamicPPL.LogDensityFunction |
18 | | -) |
19 | | - θ = DynamicPPL.getparams(ℓ) |
20 | | - f = Base.Fix1(LogDensityProblems.logdensity, ℓ) |
21 | | - |
22 | | - # Define configuration for ForwardDiff. |
23 | | - tag = if standardtag(ad) |
24 | | - ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(θ)) |
| 3 | +using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems |
| 4 | +using ForwardDiff |
| 5 | + |
| 6 | +# check if the AD type already has a tag |
| 7 | +use_dynamicppl_tag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true |
| 8 | +use_dynamicppl_tag(::ADTypes.AutoForwardDiff) = false |
| 9 | + |
| 10 | +function DynamicPPL.tweak_adtype( |
| 11 | + ad::ADTypes.AutoForwardDiff{chunk_size}, |
| 12 | + ::DynamicPPL.Model, |
| 13 | + vi::DynamicPPL.AbstractVarInfo, |
| 14 | + ::DynamicPPL.AbstractContext, |
| 15 | +) where {chunk_size} |
| 16 | + params = vi[:] |
| 17 | + |
| 18 | + # Use DynamicPPL tag to improve stack traces |
| 19 | + # https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/ |
| 20 | + # NOTE: DifferentiationInterface disables tag checking if the |
| 21 | + # tag inside the AutoForwardDiff type is not nothing. See |
| 22 | + # https://github.com/JuliaDiff/DifferentiationInterface.jl/blob/1df562180bdcc3e91c885aa5f4162a0be2ced850/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl#L338-L350. |
| 23 | + # So we don't currently need to override ForwardDiff.checktag as well. |
| 24 | + tag = if use_dynamicppl_tag(ad) |
| 25 | + ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(params)) |
25 | 26 | else |
26 | | - ForwardDiff.Tag(f, eltype(θ)) |
| 27 | + ad.tag |
27 | 28 | end |
28 | | - chunk_size = getchunksize(ad) |
| 29 | + |
| 30 | + # Optimise chunk size according to size of model |
29 | 31 | chunk = if chunk_size == 0 || chunk_size === nothing |
30 | | - ForwardDiff.Chunk(θ) |
| 32 | + ForwardDiff.Chunk(params) |
31 | 33 | else |
32 | | - ForwardDiff.Chunk(length(θ), chunk_size) |
| 34 | + ForwardDiff.Chunk(length(params), chunk_size) |
33 | 35 | end |
34 | 36 |
|
35 | | - return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x=θ) |
36 | | -end |
37 | | - |
38 | | -# Allow Turing tag in gradient etc. calls of the log density function |
39 | | -function ForwardDiff.checktag( |
40 | | - ::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}}, |
41 | | - ::DynamicPPL.LogDensityFunction, |
42 | | - ::AbstractArray{W}, |
43 | | -) where {V,W} |
44 | | - return true |
45 | | -end |
46 | | -function ForwardDiff.checktag( |
47 | | - ::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}}, |
48 | | - ::Base.Fix1{typeof(LogDensityProblems.logdensity),<:DynamicPPL.LogDensityFunction}, |
49 | | - ::AbstractArray{W}, |
50 | | -) where {V,W} |
51 | | - return true |
| 37 | + return ADTypes.AutoForwardDiff(; chunksize=ForwardDiff.chunksize(chunk), tag=tag) |
52 | 38 | end |
53 | 39 |
|
54 | 40 | end # module |
0 commit comments