1- module DynamicPPLBenchmarks
1+ module DynamicPPLBenchmarkToolsExt
22
3- using DynamicPPL: VarInfo, SimpleVarInfo, VarName
3+ using DynamicPPL:
4+ DynamicPPL, ADTypes, LogDensityProblems, Model, VarInfo, SimpleVarInfo, VarName
45using BenchmarkTools: BenchmarkGroup, @benchmarkable
5- using DynamicPPL: DynamicPPL
6- using ADTypes: ADTypes
7- using LogDensityProblems: LogDensityProblems
8-
9- using ForwardDiff: ForwardDiff
10- using Mooncake: Mooncake
11- using ReverseDiff: ReverseDiff
12- using StableRNGs: StableRNG
13-
14- include (" ./Models.jl" )
15- using . Models: Models
16-
17- export Models, make_suite, model_dimension
6+ using Random: Random
187
198"""
20- model_dimension(model, islinked)
21-
22- Return the dimension of `model`, accounting for linking, if any.
23- """
24- function model_dimension (model, islinked)
25- vi = VarInfo ()
26- model (vi)
27- if islinked
28- vi = DynamicPPL. link (vi, model)
29- end
30- return length (vi[:])
31- end
32-
33- # Utility functions for representing AD backends using symbols.
34- # Copied from TuringBenchmarking.jl.
35- const SYMBOL_TO_BACKEND = Dict (
36- :forwarddiff => ADTypes. AutoForwardDiff (),
37- :reversediff => ADTypes. AutoReverseDiff (; compile= false ),
38- :reversediff_compiled => ADTypes. AutoReverseDiff (; compile= true ),
39- :mooncake => ADTypes. AutoMooncake (; config= nothing ),
40- )
41-
42- to_backend (x) = error (" Unknown backend: $x " )
43- to_backend (x:: ADTypes.AbstractADType ) = x
44- function to_backend (x:: Union{AbstractString,Symbol} )
45- k = Symbol (lowercase (string (x)))
46- haskey (SYMBOL_TO_BACKEND, k) || error (" Unknown backend: $x " )
47- return SYMBOL_TO_BACKEND[k]
48- end
49-
50- """
51- make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool)
9+ make_benchmark_suite(
10+ [rng::Random.AbstractRNG,]
11+ model::Model,
12+ varinfo_choice::Symbol,
13+ adtype::ADTypes.AbstractADType,
14+ islinked::Bool
15+ )
5216
5317Create a benchmark suite for `model` using the selected varinfo type and AD backend.
5418Available varinfo choices:
@@ -57,13 +21,15 @@ Available varinfo choices:
5721 • `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())`
5822 • `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs)
5923
60- The AD backend should be specified as a Symbol (e.g. `:forwarddiff`, `:reversediff`, `:zygote`).
61-
6224`islinked` determines whether to link the VarInfo for evaluation.
6325"""
64- function make_suite (model, varinfo_choice:: Symbol , adbackend:: Symbol , islinked:: Bool )
65- rng = StableRNG (23 )
66-
26+ function make_benchmark_suite (
27+ rng:: Random.AbstractRNG ,
28+ model:: Model ,
29+ varinfo_choice:: Symbol ,
30+ adtype:: ADTypes.AbstractADType ,
31+ islinked:: Bool ,
32+ )
6733 suite = BenchmarkGroup ()
6834
6935 vi = if varinfo_choice == :untyped
@@ -82,14 +48,13 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
8248 error (" Unknown varinfo choice: $varinfo_choice " )
8349 end
8450
85- adbackend = to_backend (adbackend)
8651 context = DynamicPPL. DefaultContext ()
8752
8853 if islinked
8954 vi = DynamicPPL. link (vi, model)
9055 end
9156
92- f = DynamicPPL. LogDensityFunction (model, vi, context; adtype= adbackend )
57+ f = DynamicPPL. LogDensityFunction (model, vi, context; adtype= adtype )
9358 # The parameters at which we evaluate f.
9459 θ = vi[:]
9560
@@ -102,5 +67,12 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
10267
10368 return suite
10469end
70+ function make_benchmark_suite (
71+ model:: Model , varinfo_choice:: Symbol , adtype:: Symbol , islinked:: Bool
72+ )
73+ return make_benchmark_suite (
74+ Random. default_rng (), model, varinfo_choice, adtype, islinked
75+ )
76+ end
10577
106- end # module
78+ end
0 commit comments