@@ -3,6 +3,7 @@ using Pkg
3
3
Pkg. develop (; path= joinpath (@__DIR__ , " .." ))
4
4
5
5
using DynamicPPL: DynamicPPL, make_benchmark_suite, VarInfo
6
+ using ADTypes
6
7
using BenchmarkTools: @benchmark , median, run
7
8
using PrettyTables: PrettyTables, ft_printf
8
9
using ForwardDiff: ForwardDiff
@@ -48,29 +49,44 @@ lda_instance = begin
48
49
Models. lda (2 , d, w)
49
50
end
50
51
52
+ # AD types setup
53
+ fd = AutoForwardDiff ()
54
+ rd = AutoReverseDiff ()
55
+ mc = AutoMooncake (; config= nothing )
56
+ """
57
+ get_adtype_shortname(adtype::ADTypes.AbstractADType)
58
+
59
+ Get the package name that corresponds to the the AD backend `adtype`. Only used
60
+ for pretty-printing.
61
+ """
62
+ get_adtype_shortname (:: AutoMooncake ) = " Mooncake"
63
+ get_adtype_shortname (:: AutoForwardDiff ) = " ForwardDiff"
64
+ get_adtype_shortname (:: AutoReverseDiff{false} ) = " ReverseDiff"
65
+ get_adtype_shortname (:: AutoReverseDiff{true} ) = " ReverseDiff:Compiled"
66
+
51
67
# Specify the combinations to test:
52
68
# (Model Name, model instance, VarInfo choice, AD backend, linked)
53
69
chosen_combinations = [
54
70
(
55
71
" Simple assume observe" ,
56
72
Models. simple_assume_observe (randn (StableRNG (23 ))),
57
73
:typed ,
58
- :forwarddiff ,
74
+ fd ,
59
75
false ,
60
76
),
61
- (" Smorgasbord" , smorgasbord_instance, :typed , :forwarddiff , false ),
62
- (" Smorgasbord" , smorgasbord_instance, :simple_namedtuple , :forwarddiff , true ),
63
- (" Smorgasbord" , smorgasbord_instance, :untyped , :forwarddiff , true ),
64
- (" Smorgasbord" , smorgasbord_instance, :simple_dict , :forwarddiff , true ),
65
- (" Smorgasbord" , smorgasbord_instance, :typed , :reversediff , true ),
66
- (" Smorgasbord" , smorgasbord_instance, :typed , :mooncake , true ),
67
- (" Loop univariate 1k" , loop_univariate1k, :typed , :mooncake , true ),
68
- (" Multivariate 1k" , multivariate1k, :typed , :mooncake , true ),
69
- (" Loop univariate 10k" , loop_univariate10k, :typed , :mooncake , true ),
70
- (" Multivariate 10k" , multivariate10k, :typed , :mooncake , true ),
71
- (" Dynamic" , Models. dynamic (), :typed , :mooncake , true ),
72
- (" Submodel" , Models. parent (randn (StableRNG (23 ))), :typed , :mooncake , true ),
73
- (" LDA" , lda_instance, :typed , :reversediff , true ),
77
+ (" Smorgasbord" , smorgasbord_instance, :typed , fd , false ),
78
+ (" Smorgasbord" , smorgasbord_instance, :simple_namedtuple , fd , true ),
79
+ (" Smorgasbord" , smorgasbord_instance, :untyped , fd , true ),
80
+ (" Smorgasbord" , smorgasbord_instance, :simple_dict , fd , true ),
81
+ (" Smorgasbord" , smorgasbord_instance, :typed , rd , true ),
82
+ (" Smorgasbord" , smorgasbord_instance, :typed , mc , true ),
83
+ (" Loop univariate 1k" , loop_univariate1k, :typed , mc , true ),
84
+ (" Multivariate 1k" , multivariate1k, :typed , mc , true ),
85
+ (" Loop univariate 10k" , loop_univariate10k, :typed , mc , true ),
86
+ (" Multivariate 10k" , multivariate10k, :typed , mc , true ),
87
+ (" Dynamic" , Models. dynamic (), :typed , mc , true ),
88
+ (" Submodel" , Models. parent (randn (StableRNG (23 ))), :typed , mc , true ),
89
+ (" LDA" , lda_instance, :typed , rd , true ),
74
90
]
75
91
76
92
# Time running a model-like function that does not use DynamicPPL, as a reference point.
83
99
results_table = Tuple{String,Int,String,String,Bool,Float64,Float64}[]
84
100
85
101
for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations
86
- @info " Running benchmark for $model_name "
102
+ @info " Running benchmark for $model_name / $varinfo_choice / $( get_adtype_shortname (adbackend)) "
87
103
suite = make_benchmark_suite (StableRNG (23 ), model, varinfo_choice, adbackend, islinked)
88
104
results = run (suite)
89
105
eval_time = median (results[" evaluation" ]). time
@@ -95,7 +111,7 @@ for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinati
95
111
(
96
112
model_name,
97
113
model_dimension (model, islinked),
98
- string (adbackend),
114
+ get_adtype_shortname (adbackend),
99
115
string (varinfo_choice),
100
116
islinked,
101
117
relative_eval_time,
0 commit comments