Skip to content

Commit e9966ff

Browse files
committed
Update CI benchmark script
1 parent c88e6f1 commit e9966ff

File tree

1 file changed

+32
-16
lines changed

1 file changed

+32
-16
lines changed

benchmarks/benchmarks.jl

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using Pkg
33
Pkg.develop(; path=joinpath(@__DIR__, ".."))
44

55
using DynamicPPL: DynamicPPL, make_benchmark_suite, VarInfo
6+
using ADTypes
67
using BenchmarkTools: @benchmark, median, run
78
using PrettyTables: PrettyTables, ft_printf
89
using ForwardDiff: ForwardDiff
@@ -48,29 +49,44 @@ lda_instance = begin
4849
Models.lda(2, d, w)
4950
end
5051

52+
# AD types setup
53+
fd = AutoForwardDiff()
54+
rd = AutoReverseDiff()
55+
mc = AutoMooncake(; config=nothing)
56+
"""
57+
get_adtype_shortname(adtype::ADTypes.AbstractADType)
58+
59+
Get the package name that corresponds to the the AD backend `adtype`. Only used
60+
for pretty-printing.
61+
"""
62+
get_adtype_shortname(::AutoMooncake) = "Mooncake"
63+
get_adtype_shortname(::AutoForwardDiff) = "ForwardDiff"
64+
get_adtype_shortname(::AutoReverseDiff{false}) = "ReverseDiff"
65+
get_adtype_shortname(::AutoReverseDiff{true}) = "ReverseDiff:Compiled"
66+
5167
# Specify the combinations to test:
5268
# (Model Name, model instance, VarInfo choice, AD backend, linked)
5369
chosen_combinations = [
5470
(
5571
"Simple assume observe",
5672
Models.simple_assume_observe(randn(StableRNG(23))),
5773
:typed,
58-
:forwarddiff,
74+
fd,
5975
false,
6076
),
61-
("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false),
62-
("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true),
63-
("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true),
64-
("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true),
65-
("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true),
66-
("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true),
67-
("Loop univariate 1k", loop_univariate1k, :typed, :mooncake, true),
68-
("Multivariate 1k", multivariate1k, :typed, :mooncake, true),
69-
("Loop univariate 10k", loop_univariate10k, :typed, :mooncake, true),
70-
("Multivariate 10k", multivariate10k, :typed, :mooncake, true),
71-
("Dynamic", Models.dynamic(), :typed, :mooncake, true),
72-
("Submodel", Models.parent(randn(StableRNG(23))), :typed, :mooncake, true),
73-
("LDA", lda_instance, :typed, :reversediff, true),
77+
("Smorgasbord", smorgasbord_instance, :typed, fd, false),
78+
("Smorgasbord", smorgasbord_instance, :simple_namedtuple, fd, true),
79+
("Smorgasbord", smorgasbord_instance, :untyped, fd, true),
80+
("Smorgasbord", smorgasbord_instance, :simple_dict, fd, true),
81+
("Smorgasbord", smorgasbord_instance, :typed, rd, true),
82+
("Smorgasbord", smorgasbord_instance, :typed, mc, true),
83+
("Loop univariate 1k", loop_univariate1k, :typed, mc, true),
84+
("Multivariate 1k", multivariate1k, :typed, mc, true),
85+
("Loop univariate 10k", loop_univariate10k, :typed, mc, true),
86+
("Multivariate 10k", multivariate10k, :typed, mc, true),
87+
("Dynamic", Models.dynamic(), :typed, mc, true),
88+
("Submodel", Models.parent(randn(StableRNG(23))), :typed, mc, true),
89+
("LDA", lda_instance, :typed, rd, true),
7490
]
7591

7692
# Time running a model-like function that does not use DynamicPPL, as a reference point.
@@ -83,7 +99,7 @@ end
8399
results_table = Tuple{String,Int,String,String,Bool,Float64,Float64}[]
84100

85101
for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations
86-
@info "Running benchmark for $model_name"
102+
@info "Running benchmark for $model_name / $varinfo_choice / $(get_adtype_shortname(adbackend))"
87103
suite = make_benchmark_suite(StableRNG(23), model, varinfo_choice, adbackend, islinked)
88104
results = run(suite)
89105
eval_time = median(results["evaluation"]).time
@@ -95,7 +111,7 @@ for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinati
95111
(
96112
model_name,
97113
model_dimension(model, islinked),
98-
string(adbackend),
114+
get_adtype_shortname(adbackend),
99115
string(varinfo_choice),
100116
islinked,
101117
relative_eval_time,

0 commit comments

Comments
 (0)