-
Notifications
You must be signed in to change notification settings - Fork 35
Implement AD testing and benchmarking (with DITest) #883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0aedf08
4539a80
70e1aa9
e1a34e1
236f279
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
module DynamicPPLDifferentiationInterfaceTestExt | ||
|
||
using DynamicPPL: | ||
DynamicPPL, | ||
ADTypes, | ||
LogDensityProblems, | ||
Model, | ||
DI, # DifferentiationInterface | ||
AbstractVarInfo, | ||
VarInfo, | ||
LogDensityFunction | ||
import DifferentiationInterfaceTest as DIT | ||
|
||
""" | ||
REFERENCE_ADTYPE | ||
|
||
Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since | ||
it's the default AD backend used in Turing.jl. | ||
""" | ||
const REFERENCE_ADTYPE = ADTypes.AutoForwardDiff() | ||
|
||
""" | ||
make_scenario( | ||
model::Model, | ||
adtype::ADTypes.AbstractADType, | ||
varinfo::AbstractVarInfo=VarInfo(model), | ||
params::Vector{<:Real}=varinfo[:], | ||
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, | ||
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing, | ||
) | ||
|
||
Construct a DifferentiationInterfaceTest.Scenario for the given `model` and `adtype`. | ||
|
||
More docs to follow. | ||
""" | ||
function make_scenario( | ||
model::Model, | ||
adtype::ADTypes.AbstractADType; | ||
varinfo::AbstractVarInfo=VarInfo(model), | ||
params::Vector{<:Real}=varinfo[:], | ||
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, | ||
expected_grad::Union{Nothing,Vector{<:Real}}=nothing, | ||
) | ||
params = map(identity, params) | ||
context = DynamicPPL.DefaultContext() | ||
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo, context) | ||
# Below is a performance optimisation, see: https://github.com/TuringLang/DynamicPPL.jl/pull/806#issuecomment-2658049143 | ||
if DynamicPPL.use_closure(adtype) | ||
penelopeysm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
f = x -> DynamicPPL.logdensity_at(x, model, varinfo, context) | ||
di_contexts = () | ||
else | ||
f = DynamicPPL.logdensity_at | ||
di_contexts = (DI.Constant(model), DI.Constant(varinfo), DI.Constant(context)) | ||
end | ||
|
||
# Calculate ground truth to compare against | ||
grad_true = if expected_grad === nothing | ||
ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype) | ||
LogDensityProblems.logdensity_and_gradient(ldf_reference, params)[2] | ||
else | ||
expected_grad | ||
end | ||
|
||
return DIT.Scenario{:gradient,:out}( | ||
f, params; contexts=di_contexts, res1=grad_true, name="$(model.f)" | ||
) | ||
end | ||
|
||
function DynamicPPL.TestUtils.AD.run_ad( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One alternative is to overload There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DIT.Scenario doesn't contain a type parameter that we can dispatch on, though. (While putting the ADTests bit together, I also found out that Scenarios can't be prepared with a specific value of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess we could do something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I like this idea. One could have both
Having unified interfaces across DI and Turing for these autodiff tests would be nice. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please keep in mind that DIT was designed mostly to test DI itself, so its interface is still rather dirty and unstable. Also, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, @gdalle, for the context. It is an excellent idea to improve DIT so it can become a community resource like DI. It's very helpful to have a standard interface where
It would help the autodiff dev community discover bugs more quickly. It would also inform the general users which AD backend is likely compatible with the library (e.g. Lux, Turing) they want to use (see, e.g. https://turinglang.org/ADTests/) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DIT is in the weird position where it simultaneously does much more than what we need and also doesn't do some of the things we need. I've said this elsewhere (in meetings etc) but this isn't a criticism of DIT, it's just about choosing the right tool for the job IMO. |
||
model::Model, | ||
adtype::ADTypes.AbstractADType; | ||
varinfo::AbstractVarInfo=VarInfo(model), | ||
params::Vector{<:Real}=varinfo[:], | ||
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, | ||
expected_grad::Union{Nothing,Vector{<:Real}}=nothing, | ||
kwargs..., | ||
) | ||
scen = make_scenario(model, adtype; varinfo=varinfo, expected_grad=expected_grad) | ||
tweaked_adtype = DynamicPPL.tweak_adtype( | ||
adtype, model, varinfo, DynamicPPL.DefaultContext() | ||
) | ||
return DIT.test_differentiation( | ||
tweaked_adtype, [scen]; scenario_intact=false, kwargs... | ||
) | ||
end | ||
|
||
end |
Uh oh!
There was an error while loading. Please reload this page.