diff --git a/.github/workflows/Benchmarking.yml b/.github/workflows/Benchmarking.yml new file mode 100644 index 000000000..96817e2a9 --- /dev/null +++ b/.github/workflows/Benchmarking.yml @@ -0,0 +1,76 @@ +name: Benchmarking + +on: + pull_request: + +jobs: + benchmarks: + runs-on: ubuntu-latest + + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.sha }} + + - name: Set up Julia + uses: julia-actions/setup-julia@v2 + with: + version: '1' + + - name: Install Dependencies + run: julia --project=benchmarks/ -e 'using Pkg; Pkg.instantiate()' + + - name: Run Benchmarks + id: run_benchmarks + run: | + # Capture version info into a variable, print it, and set it as an env var for later steps + version_info=$(julia -e 'using InteractiveUtils; versioninfo()') + echo "$version_info" + echo "VERSION_INFO<> $GITHUB_ENV + echo "$version_info" >> $GITHUB_ENV + echo "EOF" >> $GITHUB_ENV + + # Capture benchmark output into a variable + echo "Running Benchmarks..." + benchmark_output=$(julia --project=benchmarks benchmarks/benchmarks.jl) + + # Print benchmark results directly to the workflow log + echo "Benchmark Results:" + echo "$benchmark_output" + + # Set the benchmark output as an env var for later steps + echo "BENCHMARK_OUTPUT<> $GITHUB_ENV + echo "$benchmark_output" >> $GITHUB_ENV + echo "EOF" >> $GITHUB_ENV + + # Get the current commit SHA of DynamicPPL + DPPL_COMMIT_SHA=$(git rev-parse HEAD) + echo "DPPL_COMMIT_SHA=$DPPL_COMMIT_SHA" >> $GITHUB_ENV + + COMMIT_URL="https://github.com/${{ github.repository }}/commit/$DPPL_COMMIT_SHA" + echo "DPPL_COMMIT_URL=$COMMIT_URL" >> $GITHUB_ENV + + - name: Find Existing Comment + uses: peter-evans/find-comment@v3 + id: find_comment + with: + issue-number: ${{ github.event.pull_request.number }} + comment-author: github-actions[bot] + + - name: Post Benchmark Results as PR Comment + uses: peter-evans/create-or-update-comment@v4 + with: + issue-number: ${{ github.event.pull_request.number }} + body: | + ## Benchmark Report for Commit [`${{ env.DPPL_COMMIT_SHA }}`](${{ env.DPPL_COMMIT_URL }}) + ### Computer Information + ``` + ${{ env.VERSION_INFO }} + ``` + ### Benchmark Results + ``` + ${{ env.BENCHMARK_OUTPUT }} + ``` + comment-id: ${{ steps.find_comment.outputs.comment-id }} + edit-mode: replace \ No newline at end of file diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index a8e8f09a2..88386f243 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -3,11 +3,23 @@ uuid = "d94a1522-c11e-44a7-981a-42bf5dc1a001" version = "0.1.0" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -DiffUtils = "8294860b-85a6-42f8-8c35-d911f667b5f6" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" -LibGit2 = "76f85450-5226-5b5a-8eaa-529ad045b433" -Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -Weave = "44d3d7a6-8a23-5bf8-98c5-b353f8df5ec9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" + +[compat] +ADTypes = "1.14.0" +BenchmarkTools = "1.6.0" +Distributions = "0.25.117" +ForwardDiff = "0.10.38" +LogDensityProblems = "2.1.2" +PrettyTables = "2.4.0" +ReverseDiff = "1.15.3" diff --git a/benchmarks/README.md b/benchmarks/README.md index 9ddebc4bd..68287ade4 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -1,27 +1,5 @@ -To run the benchmarks, simply do: +To run the benchmarks, run this from the root directory of the repository: ```sh -julia --project -e 'using DynamicPPLBenchmarks; weave_benchmarks();' -``` - -```julia -julia> @doc weave_benchmarks - weave_benchmarks(input="benchmarks.jmd"; kwargs...) - - Weave benchmarks present in benchmarks.jmd into a single file. - - Keyword arguments - ≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡ - - • benchmarkbody: JMD-file to be rendered for each model. - - • include_commit_id=false: specify whether to include commit-id in the default name. - - • name: the name of directory in results/ to use as output directory. - - • name_old=nothing: if specified, comparisons of current run vs. the run pinted to by name_old will be included in the generated document. - - • include_typed_code=false: if true, output of code_typed for the evaluator of the model will be included in the weaved document. - - • Rest of the passed kwargs will be passed on to Weave.weave. -``` +julia --project=benchmarks benchmarks/benchmarks.jl +``` \ No newline at end of file diff --git a/benchmarks/benchmark_body.jmd b/benchmarks/benchmark_body.jmd deleted file mode 100644 index ba02957e7..000000000 --- a/benchmarks/benchmark_body.jmd +++ /dev/null @@ -1,49 +0,0 @@ -```julia -@time model_def(data)(); -``` - -```julia -m = time_model_def(model_def, data); -``` - -```julia -suite = make_suite(m); -results = run(suite); -``` - -```julia -results["evaluation_untyped"] -``` - -```julia -results["evaluation_typed"] -``` - -```julia; echo=false; results="hidden"; -BenchmarkTools.save( - joinpath("results", WEAVE_ARGS[:name], "$(nameof(m))_benchmarks.json"), results -) -``` - -```julia; wrap=false -if WEAVE_ARGS[:include_typed_code] - typed = typed_code(m) -end -``` - -```julia; echo=false; results="hidden" -if WEAVE_ARGS[:include_typed_code] - # Serialize the output of `typed_code` so we can compare later. - haskey(WEAVE_ARGS, :name) && - serialize(joinpath("results", WEAVE_ARGS[:name], "$(nameof(m)).jls"), string(typed)) -end -``` - -```julia; wrap=false; echo=false; -if haskey(WEAVE_ARGS, :name_old) - # We want to compare the generated code to the previous version. - using DiffUtils: DiffUtils - typed_old = deserialize(joinpath("results", WEAVE_ARGS[:name_old], "$(nameof(m)).jls")) - DiffUtils.diff(typed_old, string(typed); width=130) -end -``` diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl new file mode 100644 index 000000000..89b65d2de --- /dev/null +++ b/benchmarks/benchmarks.jl @@ -0,0 +1,103 @@ +using Pkg +# To ensure we benchmark the local version of DynamicPPL, dev the folder above. +Pkg.develop(; path=joinpath(@__DIR__, "..")) + +using DynamicPPLBenchmarks: Models, make_suite, model_dimension +using BenchmarkTools: @benchmark, median, run +using PrettyTables: PrettyTables, ft_printf +using StableRNGs: StableRNG + +rng = StableRNG(23) + +# Create DynamicPPL.Model instances to run benchmarks on. +smorgasbord_instance = Models.smorgasbord(randn(rng, 100), randn(rng, 100)) +loop_univariate1k, multivariate1k = begin + data_1k = randn(rng, 1_000) + loop = Models.loop_univariate(length(data_1k)) | (; o=data_1k) + multi = Models.multivariate(length(data_1k)) | (; o=data_1k) + loop, multi +end +loop_univariate10k, multivariate10k = begin + data_10k = randn(rng, 10_000) + loop = Models.loop_univariate(length(data_10k)) | (; o=data_10k) + multi = Models.multivariate(length(data_10k)) | (; o=data_10k) + loop, multi +end +lda_instance = begin + w = [1, 2, 3, 2, 1, 1] + d = [1, 1, 1, 2, 2, 2] + Models.lda(2, d, w) +end + +# Specify the combinations to test: +# (Model Name, model instance, VarInfo choice, AD backend, linked) +chosen_combinations = [ + ( + "Simple assume observe", + Models.simple_assume_observe(randn(rng)), + :typed, + :forwarddiff, + false, + ), + ("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false), + ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), + ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), + ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), + ("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true), + ("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true), + ("Loop univariate 1k", loop_univariate1k, :typed, :mooncake, true), + ("Multivariate 1k", multivariate1k, :typed, :mooncake, true), + ("Loop univariate 10k", loop_univariate10k, :typed, :mooncake, true), + ("Multivariate 10k", multivariate10k, :typed, :mooncake, true), + ("Dynamic", Models.dynamic(), :typed, :mooncake, true), + ("Submodel", Models.parent(randn(rng)), :typed, :mooncake, true), + ("LDA", lda_instance, :typed, :reversediff, true), +] + +# Time running a model-like function that does not use DynamicPPL, as a reference point. +# Eval timings will be relative to this. +reference_time = begin + obs = randn(rng) + median(@benchmark Models.simple_assume_observe_non_model(obs)).time +end + +results_table = Tuple{String,Int,String,String,Bool,Float64,Float64}[] + +for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations + @info "Running benchmark for $model_name" + suite = make_suite(model, varinfo_choice, adbackend, islinked) + results = run(suite) + eval_time = median(results["evaluation"]).time + relative_eval_time = eval_time / reference_time + ad_eval_time = median(results["gradient"]).time + relative_ad_eval_time = ad_eval_time / eval_time + push!( + results_table, + ( + model_name, + model_dimension(model, islinked), + string(adbackend), + string(varinfo_choice), + islinked, + relative_eval_time, + relative_ad_eval_time, + ), + ) +end + +table_matrix = hcat(Iterators.map(collect, zip(results_table...))...) +header = [ + "Model", + "Dimension", + "AD Backend", + "VarInfo Type", + "Linked", + "Eval Time / Ref Time", + "AD Time / Eval Time", +] +PrettyTables.pretty_table( + table_matrix; + header=header, + tf=PrettyTables.tf_markdown, + formatters=ft_printf("%.1f", [6, 7]), +) diff --git a/benchmarks/benchmarks.jmd b/benchmarks/benchmarks.jmd deleted file mode 100644 index 8021e4883..000000000 --- a/benchmarks/benchmarks.jmd +++ /dev/null @@ -1,130 +0,0 @@ -# Benchmarks - -## Setup - -```julia -using BenchmarkTools, DynamicPPL, Distributions, Serialization -``` - -```julia -import DynamicPPLBenchmarks: time_model_def, make_suite, typed_code, weave_child -``` - -## Models - -### `demo1` - -```julia -@model function demo1(x) - m ~ Normal() - x ~ Normal(m, 1) - - return (m=m, x=x) -end - -model_def = demo1; -data = 1.0; -``` - -```julia; results="markup"; echo=false -weave_child(WEAVE_ARGS[:benchmarkbody]; mod=@__MODULE__, args=WEAVE_ARGS) -``` - -### `demo2` - -```julia -@model function demo2(y) - # Our prior belief about the probability of heads in a coin. - p ~ Beta(1, 1) - - # The number of observations. - N = length(y) - for n in 1:N - # Heads or tails of a coin are drawn from a Bernoulli distribution. - y[n] ~ Bernoulli(p) - end -end - -model_def = demo2; -data = rand(0:1, 10); -``` - -```julia; results="markup"; echo=false -weave_child(WEAVE_ARGS[:benchmarkbody]; mod=@__MODULE__, args=WEAVE_ARGS) -``` - -### `demo3` - -```julia -@model function demo3(x) - D, N = size(x) - - # Draw the parameters for cluster 1. - μ1 ~ Normal() - - # Draw the parameters for cluster 2. - μ2 ~ Normal() - - μ = [μ1, μ2] - - # Comment out this line if you instead want to draw the weights. - w = [0.5, 0.5] - - # Draw assignments for each datum and generate it from a multivariate normal. - k = Vector{Int}(undef, N) - for i in 1:N - k[i] ~ Categorical(w) - x[:, i] ~ MvNormal([μ[k[i]], μ[k[i]]], 1.0) - end - return k -end - -model_def = demo3 - -# Construct 30 data points for each cluster. -N = 30 - -# Parameters for each cluster, we assume that each cluster is Gaussian distributed in the example. -μs = [-3.5, 0.0] - -# Construct the data points. -data = mapreduce(c -> rand(MvNormal([μs[c], μs[c]], 1.0), N), hcat, 1:2); -``` - -```julia; echo=false -weave_child(WEAVE_ARGS[:benchmarkbody]; mod=@__MODULE__, args=WEAVE_ARGS) -``` - -### `demo4`: loads of indexing - -```julia -@model function demo4(n, ::Type{TV}=Vector{Float64}) where {TV} - m ~ Normal() - x = TV(undef, n) - for i in eachindex(x) - x[i] ~ Normal(m, 1.0) - end -end - -model_def = demo4 -data = (100_000,); -``` - -```julia; echo=false -weave_child(WEAVE_ARGS[:benchmarkbody]; mod=@__MODULE__, args=WEAVE_ARGS) -``` - -```julia -@model function demo4_dotted(n, ::Type{TV}=Vector{Float64}) where {TV} - m ~ Normal() - x = TV(undef, n) - return x .~ Normal(m, 1.0) -end - -model_def = demo4_dotted -data = (100_000,); -``` - -```julia; echo=false -weave_child(WEAVE_ARGS[:benchmarkbody]; mod=@__MODULE__, args=WEAVE_ARGS) -``` diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index d9af7bf50..b67f2ce06 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -1,172 +1,103 @@ module DynamicPPLBenchmarks -using DynamicPPL -using BenchmarkTools +using DynamicPPL: VarInfo, SimpleVarInfo, VarName +using BenchmarkTools: BenchmarkGroup, @benchmarkable +using DynamicPPL: DynamicPPL +using ADTypes: ADTypes +using LogDensityProblems: LogDensityProblems -using Weave: Weave -using Markdown: Markdown +using ForwardDiff: ForwardDiff +using Mooncake: Mooncake +using ReverseDiff: ReverseDiff -using LibGit2: LibGit2 -using Pkg: Pkg -using Random: Random +include("./Models.jl") +using .Models: Models -export weave_benchmarks +export Models, make_suite, model_dimension -function time_model_def(model_def, args...) - return @time model_def(args...) -end +""" + model_dimension(model, islinked) -function benchmark_untyped_varinfo!(suite, m) +Return the dimension of `model`, accounting for linking, if any. +""" +function model_dimension(model, islinked) vi = VarInfo() - # Populate. - m(vi) - # Evaluate. - suite["evaluation_untyped"] = @benchmarkable $m($vi, $(DefaultContext())) - return suite -end - -function benchmark_typed_varinfo!(suite, m) - # Populate. - vi = VarInfo(m) - # Evaluate. - suite["evaluation_typed"] = @benchmarkable $m($vi, $(DefaultContext())) - return suite -end - -function typed_code(m, vi=VarInfo(m)) - rng = Random.MersenneTwister(42) - spl = SampleFromPrior() - ctx = SamplingContext(rng, spl, DefaultContext()) - - results = code_typed(m.f, Base.typesof(m, vi, ctx, m.args...)) - return first(results) + model(vi) + if islinked + vi = DynamicPPL.link(vi, model) + end + return length(vi[:]) end -""" - make_suite(model) - -Create default benchmark suite for `model`. -""" -function make_suite(model) - suite = BenchmarkGroup() - benchmark_untyped_varinfo!(suite, model) - benchmark_typed_varinfo!(suite, model) +# Utility functions for representing AD backends using symbols. +# Copied from TuringBenchmarking.jl. +const SYMBOL_TO_BACKEND = Dict( + :forwarddiff => ADTypes.AutoForwardDiff(), + :reversediff => ADTypes.AutoReverseDiff(; compile=false), + :reversediff_compiled => ADTypes.AutoReverseDiff(; compile=true), + :mooncake => ADTypes.AutoMooncake(; config=nothing), +) - return suite +to_backend(x) = error("Unknown backend: $x") +to_backend(x::ADTypes.AbstractADType) = x +function to_backend(x::Union{AbstractString,Symbol}) + k = Symbol(lowercase(string(x))) + haskey(SYMBOL_TO_BACKEND, k) || error("Unknown backend: $x") + return SYMBOL_TO_BACKEND[k] end """ - weave_child(indoc; mod, args, kwargs...) + make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool) -Weave `indoc` with scope of `mod` into markdown. +Create a benchmark suite for `model` using the selected varinfo type and AD backend. +Available varinfo choices: + • `:untyped` → uses `VarInfo()` + • `:typed` → uses `VarInfo(model)` + • `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())` + • `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs) -Useful for weaving within weaving, e.g. -```julia -weave_child(child_jmd_path, mod = @__MODULE__, args = WEAVE_ARGS) -``` -together with `results="markup"` and `echo=false` will simply insert -the weaved version of `indoc`. +The AD backend should be specified as a Symbol (e.g. `:forwarddiff`, `:reversediff`, `:zygote`). -# Notes -- Currently only supports `doctype == "github"`. Other outputs are "supported" - in the sense that it works but you might lose niceties such as syntax highlighting. +`islinked` determines whether to link the VarInfo for evaluation. """ -function weave_child(indoc; mod, args, kwargs...) - # FIXME: Make this work for other output formats than just `github`. - doc = Weave.WeaveDoc(indoc, nothing) - doc = Weave.run_doc(doc; doctype="github", mod=mod, args=args, kwargs...) - rendered = Weave.render_doc(doc) - return display(Markdown.parse(rendered)) -end +function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool) + suite = BenchmarkGroup() -""" - pkgversion(m::Module) + vi = if varinfo_choice == :untyped + vi = VarInfo() + model(vi) + vi + elseif varinfo_choice == :typed + VarInfo(model) + elseif varinfo_choice == :simple_namedtuple + SimpleVarInfo{Float64}(model()) + elseif varinfo_choice == :simple_dict + retvals = model() + vns = [VarName{k}() for k in keys(retvals)] + SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals)))) + else + error("Unknown varinfo choice: $varinfo_choice") + end -Return version of module `m` as listed in its Project.toml. -""" -function pkgversion(m::Module) - projecttoml_path = joinpath(dirname(pathof(m)), "..", "Project.toml") - return Pkg.TOML.parsefile(projecttoml_path)["version"] -end + adbackend = to_backend(adbackend) + context = DynamicPPL.DefaultContext() -""" - default_name(; include_commit_id=false) + if islinked + vi = DynamicPPL.link(vi, model) + end -Construct a name from either repo information or package version -of `DynamicPPL`. + f = DynamicPPL.LogDensityFunction(model, vi, context; adtype=adbackend) + # The parameters at which we evaluate f. + θ = vi[:] -If the path of `DynamicPPL` is a git-repo, return name of current branch, -joined with the commit id if `include_commit_id` is `true`. + # Run once to trigger compilation. + LogDensityProblems.logdensity_and_gradient(f, θ) + suite["gradient"] = @benchmarkable $(LogDensityProblems.logdensity_and_gradient)($f, $θ) -If path of `DynamicPPL` is _not_ a git-repo, it is assumed to be a release, -resulting in a name of the form `release-VERSION`. -""" -function default_name(; include_commit_id=false) - dppl_path = abspath(joinpath(dirname(pathof(DynamicPPL)), "..")) - - # Extract branch name and commit id - local name - try - githead = LibGit2.head(LibGit2.GitRepo(dppl_path)) - branchname = LibGit2.shortname(githead) - - name = replace(branchname, "/" => "_") - if include_commit_id - gitcommit = LibGit2.peel(LibGit2.GitCommit, githead) - commitid = string(LibGit2.GitHash(gitcommit)) - name *= "-$(commitid)" - end - catch e - if e isa LibGit2.GitError - @info "No git repo found for $(dppl_path); extracting name from package version." - name = "release-$(pkgversion(DynamicPPL))" - else - rethrow(e) - end - end - - return name -end + # Also benchmark just standard model evaluation because why not. + suite["evaluation"] = @benchmarkable $(LogDensityProblems.logdensity)($f, $θ) -""" - weave_benchmarks(input="benchmarks.jmd"; kwargs...) - -Weave benchmarks present in `benchmarks.jmd` into a single file. - -# Keyword arguments -- `benchmarkbody`: JMD-file to be rendered for each model. -- `include_commit_id=false`: specify whether to include commit-id in the default name. -- `name`: the name of directory in `results/` to use as output directory. -- `name_old=nothing`: if specified, comparisons of current run vs. the run pinted to - by `name_old` will be included in the generated document. -- `include_typed_code=false`: if `true`, output of `code_typed` for the evaluator - of the model will be included in the weaved document. -- Rest of the passed `kwargs` will be passed on to `Weave.weave`. -""" -function weave_benchmarks( - input=joinpath(dirname(pathof(DynamicPPLBenchmarks)), "..", "benchmarks.jmd"); - benchmarkbody=joinpath( - dirname(pathof(DynamicPPLBenchmarks)), "..", "benchmark_body.jmd" - ), - include_commit_id=false, - name=default_name(; include_commit_id=include_commit_id), - name_old=nothing, - include_typed_code=false, - doctype="github", - outpath="results/$(name)/", - kwargs..., -) - args = Dict( - :benchmarkbody => benchmarkbody, - :name => name, - :include_typed_code => include_typed_code, - ) - if !isnothing(name_old) - args[:name_old] = name_old - end - @info "Storing output in $(outpath)" - mkpath(outpath) - return Weave.weave(input, doctype; out_path=outpath, args=args, kwargs...) + return suite end end # module diff --git a/benchmarks/src/Models.jl b/benchmarks/src/Models.jl new file mode 100644 index 000000000..2c881aa95 --- /dev/null +++ b/benchmarks/src/Models.jl @@ -0,0 +1,156 @@ +""" +Models for benchmarking Turing.jl. + +Each model returns a NamedTuple of all the random variables in the model that are not +observed (this is used for constructing SimpleVarInfos). +""" +module Models + +using Distributions: + Categorical, + Dirichlet, + Exponential, + Gamma, + LKJCholesky, + InverseWishart, + Normal, + logpdf, + product_distribution, + truncated +using DynamicPPL: DynamicPPL, @model, to_submodel +using LinearAlgebra: cholesky + +export simple_assume_observe_non_model, + simple_assume_observe, smorgasbord, loop_univariate, multivariate, parent, dynamic, lda + +# This one is like simple_assume_observe, but explicitly does not use DynamicPPL. +# Other runtimes are normalised by this one's runtime. +function simple_assume_observe_non_model(obs) + x = rand(Normal()) + logp = logpdf(Normal(), x) + logp += logpdf(Normal(x, 1), obs) + return (; logp=logp, x=x) +end + +""" +A simple model that does one scalar assumption and one scalar observation. +""" +@model function simple_assume_observe(obs) + x ~ Normal() + obs ~ Normal(x, 1) + return (; x=x) +end + +""" +A short model that tries to cover many DynamicPPL features. + +Includes scalar, vector univariate, and multivariate variables; ~, .~, and loops; allocating +a variable vector; observations passed as arguments, and as literals. +""" +@model function smorgasbord(x, y, ::Type{TV}=Vector{Float64}) where {TV} + @assert length(x) == length(y) + m ~ truncated(Normal(); lower=0) + means ~ product_distribution(fill(Exponential(m), length(x))) + stds = TV(undef, length(x)) + stds .~ Gamma(1, 1) + for i in 1:length(x) + x[i] ~ Normal(means[i], stds[i]) + end + y ~ product_distribution(map((mean, std) -> Normal(mean, std), means, stds)) + 0.0 ~ Normal(sum(y), 1) + return (; m=m, means=means, stds=stds) +end + +""" +A model that loops over two vectors of univariate normals of length `num_dims`. + +The second variable, `o`, is meant to be conditioned on after model instantiation. + +See `multivariate` for a version that uses `product_distribution` rather than loops. +""" +@model function loop_univariate(num_dims, ::Type{TV}=Vector{Float64}) where {TV} + a = TV(undef, num_dims) + o = TV(undef, num_dims) + for i in 1:num_dims + a[i] ~ Normal(0, 1) + end + m = sum(a) + for i in 1:num_dims + o[i] ~ Normal(m, 1) + end + return (; a=a) +end + +""" +A model with two multivariate normal distributed variables of dimension `num_dims`. + +The second variable, `o`, is meant to be conditioned on after model instantiation. + +See `loop_univariate` for a version that uses loops rather than `product_distribution`. +""" +@model function multivariate(num_dims, ::Type{TV}=Vector{Float64}) where {TV} + a = TV(undef, num_dims) + o = TV(undef, num_dims) + a ~ product_distribution(fill(Normal(0, 1), num_dims)) + m = sum(a) + o ~ product_distribution(fill(Normal(m, 1), num_dims)) + return (; a=a) +end + +""" +A submodel for `parent`. Not exported. +""" +@model function sub() + x ~ Normal() + return x +end + +""" +Like simple_assume_observe, but with a submodel for the assumed random variable. +""" +@model function parent(obs) + x ~ to_submodel(sub()) + obs ~ Normal(x, 1) + return (; x=x) +end + +""" +A model with random variables that have changing support under linking, or otherwise +complicated bijectors. +""" +@model function dynamic(::Type{T}=Vector{Float64}) where {T} + eta ~ truncated(Normal(); lower=0.0, upper=0.1) + mat1 ~ LKJCholesky(4, eta) + mat2 ~ InverseWishart(3.2, cholesky([1.0 0.5; 0.5 1.0])) + return (; eta=eta, mat1=mat1, mat2=mat2) +end + +""" +A simple Linear Discriminant Analysis model. +""" +@model function lda(K, d, w) + V = length(unique(w)) + D = length(unique(d)) + N = length(d) + @assert length(w) == N + + ϕ = Vector{Vector{Real}}(undef, K) + for i in 1:K + ϕ[i] ~ Dirichlet(ones(V) / V) + end + + θ = Vector{Vector{Real}}(undef, D) + for i in 1:D + θ[i] ~ Dirichlet(ones(K) / K) + end + + z = zeros(Int, N) + + for i in 1:N + z[i] ~ Categorical(θ[d[i]]) + w[i] ~ Categorical(ϕ[d[i]]) + end + return (; ϕ=ϕ, θ=θ, z=z) +end + +end