Skip to content

Conversation

@penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Nov 15, 2025

The approach used in FastLDF potentially suffers from type stability issues.

One of the issues is just me being stupid: I implemented fast_evaluate!! quite poorly (one branch would return TSVI, the other branch would return OAVI). This PR fixes that.

But there is also a separate, more subtle, issue with using views. For example, this is responsible for failing type stability tests on #1115, which implement the naive solution of adding @view throughout DefaultContext code. It's also (partly) responsible for Enzyme failures on #1139.

The crux of the issue is that if you cannot tell whether a parameter is linked or unlinked, then you have to do something like this:

transform = if is_linked_param
    from_linked_vec_transform(dist)
else
    from_vec_transform(dist)
end
x = with_logabsdet_jacobian(transform, y)

Now, consider dist = product_distribution([Beta(2, 2), Beta(2, 2)]):

julia> using DynamicPPL, Distributions

julia> DynamicPPL.from_linked_vec_transform(dist)
Bijectors.Inverse{Bijectors.TruncatedBijector{Float64, Float64}}(Bijectors.TruncatedBijector{Float64, Float64}(0.0, 1.0))  identity

julia> DynamicPPL.from_vec_transform(dist)
identity (generic function with 1 method)

and the effects of this transformation on a view:

julia> x = @view [0.5, 0.5][1:2]
2-element view(::Vector{Float64}, 1:2) with eltype Float64:
 0.5
 0.5

julia> DynamicPPL.from_linked_vec_transform(dist)(x)
2-element Vector{Float64}:
 0.6224593312018546
 0.6224593312018546

julia> DynamicPPL.from_vec_transform(dist)(x)
2-element view(::Vector{Float64}, 1:2) with eltype Float64:
 0.5
 0.5

So, generally when executing this code, if you can't tell whether the parameter is linked ahead of time, you will get a union type. Now running this in Julia itself doesn't affect performance that much because Julia is capable of handling this via union splitting. However, a test like @inferred in #1115, or Enzyme's analysis, requires stricter type stability.

This PR therefore implements special cases for what are by far the two most common use cases, where either all the parameters are linked, or all the parameters are unlinked. This is determined at LogDensityFunction construction time, and passed all the way down into init via a type parameter.

I am still quite unsure whether there is a real scenario where mixed linked and unlinked variables. I think this was something to do with Gibbs, but if some samplers need linking (e.g. HMC), then surely we can just force all variables to be linked. This would only not be possible if some samplers need to be not linked, but I'm genuinely not sure if there is any sampler that has that property.

However, Gibbs doesn't use LDF, so I am not sure that this is an important consideration for this PR. Even so, there should be no regression in performance for the mixed linked/unlinked case: this PR should just be a strict improvement for the all-linked or all-unlinked case.

Why can't we just store the transform in the LDF?

The transform has to be constructed on-the-fly from dist, and can't be stored ahead of time because of

x ~ Normal()
y ~ truncated(Normal(); lower=x)

Benchmarks (unlinked)

For most of the models that were benchmarked previously, the only real difference is that this PR makes Enzyme quite a bit faster. Still, it's good to verify that for those models, this PR does not cause any regressions.

Here 'before this PR' = #1139, 'after this PR' = this branch, 'v0.38.9' is current main.

# trivial         before this PR                        after this PR                          v0.38.9
eval      ----    10.945 ns                             10.649 ns                              158.896 ns (6 allocs: 192 bytes)
grad (FD) ----    38.357 ns (3 allocs: 96 bytes)        39.345 ns (3 allocs: 96 bytes)         301.449 ns (13 allocs: 496 bytes)
grad (RD) ----    2.610 μs (46 allocs: 1.562 KiB)       2.629 μs (44 allocs: 1.500 KiB)        4.174 μs (82 allocs: 3.062 KiB)
grad (MC) ----    242.294 ns (4 allocs: 192 bytes)      269.660 ns (4 allocs: 192 bytes)       1.173 μs (25 allocs: 1.219 KiB)
grad (EN) ----    127.706 ns (2 allocs: 64 bytes)       100.539 ns (2 allocs: 64 bytes)        434.441 ns (16 allocs: 560 bytes)

# eight-schools   before this PR                        after this PR                          v0.38.9
eval      ----    168.374 ns (4 allocs: 256 bytes)      170.667 ns (4 allocs: 256 bytes)       851.706 ns (21 allocs: 1.344 KiB)
grad (FD) ----    821.500 ns (11 allocs: 2.594 KiB)     775.211 ns (11 allocs: 2.594 KiB)      1.528 μs (28 allocs: 5.484 KiB)
grad (RD) ----    35.167 μs (562 allocs: 20.562 KiB)    35.083 μs (555 allocs: 20.297 KiB)     40.250 μs (616 allocs: 25.766 KiB)
grad (MC) ----    1.248 μs (12 allocs: 784 bytes)       1.264 μs (12 allocs: 784 bytes)        4.479 μs (64 allocs: 4.016 KiB)
grad (EN) ----    728.659 ns (13 allocs: 832 bytes)     630.319 ns (13 allocs: 832 bytes)      1.826 μs (44 allocs: 2.609 KiB)

# badvarnames     before this PR                        after this PR                          v0.38.9
eval      ----    359.756 ns (2 allocs: 224 bytes)      325.720 ns (2 allocs: 224 bytes)       1.438 μs (46 allocs: 1.906 KiB)
grad (FD) ----    2.804 μs (11 allocs: 4.281 KiB)       3.000 μs (11 allocs: 4.281 KiB)        4.575 μs (103 allocs: 14.266 KiB)
grad (RD) ----    44.167 μs (773 allocs: 27.438 KiB)    45.333 μs (753 allocs: 26.812 KiB)     59.792 μs (1076 allocs: 38.828 KiB)
grad (MC) ----    2.080 μs (28 allocs: 1.094 KiB)       2.048 μs (28 allocs: 1.094 KiB)        6.854 μs (160 allocs: 7.000 KiB)
grad (EN) ----    1.649 μs (5 allocs: 2.188 KiB)        1.102 μs (5 allocs: 1.578 KiB)         3.264 μs (64 allocs: 6.141 KiB)

# submodel        before this PR                        after this PR                          v0.38.9
eval      ----    105.855 ns                            112.548 ns                             763.889 ns (20 allocs: 1.234 KiB)
grad (FD) ----    210.734 ns (3 allocs: 112 bytes)      214.715 ns (3 allocs: 112 bytes)       1.083 μs (27 allocs: 2.219 KiB)
grad (RD) ----    10.104 μs (148 allocs: 5.188 KiB)     9.750 μs (132 allocs: 4.641 KiB)       13.792 μs (221 allocs: 9.266 KiB)
grad (MC) ----    537.873 ns (6 allocs: 240 bytes)      540.094 ns (6 allocs: 240 bytes)       5.492 μs (72 allocs: 3.312 KiB)
grad (EN) ----    342.058 ns (2 allocs: 80 bytes)       229.326 ns (2 allocs: 80 bytes)        2.375 μs (52 allocs: 2.500 KiB)

The 'problem' with these benchmarks is that those models didn't catch this type stability issue. For a model where the type instability actually kicks in (demo3 here is DynamicPPL.TestUtils.DEMO_MODELS[3], see definition here), this makes a huge difference.

# demo3           before this PR                        after this PR                          v0.38.9
eval      ----    616.848 ns (24 allocs: 1.344 KiB)     241.597 ns (8 allocs: 352 bytes)       713.667 ns (20 allocs: 1008 bytes)
grad (FD) ----    809.028 ns (30 allocs: 2.078 KiB)     472.000 ns (13 allocs: 752 bytes)      948.903 ns (29 allocs: 2.078 KiB)
grad (RD) ----    15.250 μs (252 allocs: 9.797 KiB)     15.959 μs (230 allocs: 8.578 KiB)      16.500 μs (288 allocs: 11.891 KiB)
grad (MC) ----    12.542 μs (158 allocs: 8.469 KiB)     1.431 μs (20 allocs: 928 bytes)        3.312 μs (57 allocs: 3.047 KiB)
grad (EN) ----    errors                                1.215 μs (23 allocs: 1.078 KiB)        1.833 μs (45 allocs: 2.062 KiB)

Benchmarks (linked)

Here are the same benchmarks but run with linked parameters instead. This is arguably the more important case because HMC/NUTS use this.

# trivial        before this PR                      after this PR                         v0.38.9
eval      ----   36.364 ns (1 allocs: 32 bytes)      14.664 ns (1 allocs: 32 bytes)        163.035 ns (7 allocs: 224 bytes)    
grad (FD) ----   69.378 ns (4 allocs: 144 bytes)     43.595 ns (4 allocs: 144 bytes)       324.728 ns (14 allocs: 544 bytes)   
grad (RD) ----   3.194 μs (53 allocs: 1.812 KiB)     2.718 μs (52 allocs: 1.781 KiB)       4.042 μs (82 allocs: 3.125 KiB)     
grad (MC) ----   310.136 ns (6 allocs: 256 bytes)    319.596 ns (6 allocs: 256 bytes)      1.194 μs (27 allocs: 1.281 KiB)     
grad (EN) ----   171.965 ns (5 allocs: 160 bytes)    172.721 ns (6 allocs: 208 bytes)      483.607 ns (20 allocs: 688 bytes)   
                                                                                                                                 
# eight-schools  before this PR                      after this PR                         v0.38.9
eval      ----   276.206 ns (7 allocs: 352 bytes)    241.115 ns (7 allocs: 352 bytes)      760.417 ns (22 allocs: 1.094 KiB)   
grad (FD) ----   965.333 ns (13 allocs: 2.812 KiB)   886.719 ns (13 allocs: 2.812 KiB)     1.476 μs (28 allocs: 4.828 KiB)     
grad (RD) ----   40.750 μs (595 allocs: 21.734 KiB)  38.250 μs (593 allocs: 21.641 KiB)    43.000 μs (639 allocs: 26.359 KiB)  
grad (MC) ----   1.460 μs (18 allocs: 976 bytes)     1.511 μs (18 allocs: 976 bytes)       4.910 μs (68 allocs: 3.859 KiB)     
grad (EN) ----   991.379 ns (31 allocs: 1.375 KiB)   998.600 ns (33 allocs: 1.469 KiB)     1.942 μs (59 allocs: 2.797 KiB)     
                                                                                                                                 
# badvarnames    before this PR                      after this PR                         v0.38.9
eval      ----   608.521 ns (22 allocs: 864 bytes)   611.104 ns (22 allocs: 864 bytes)     1.635 μs (66 allocs: 2.531 KiB)     
grad (FD) ----   3.922 μs (51 allocs: 8.656 KiB)     3.530 μs (51 allocs: 8.656 KiB)       5.475 μs (143 allocs: 18.641 KiB)   
grad (RD) ----   52.167 μs (913 allocs: 32.438 KiB)  47.541 μs (893 allocs: 31.812 KiB)    57.167 μs (1076 allocs: 40.078 KiB) 
grad (MC) ----   2.983 μs (68 allocs: 2.344 KiB)     2.842 μs (68 allocs: 2.344 KiB)       8.319 μs (200 allocs: 8.250 KiB)    
grad (EN) ----   2.515 μs (65 allocs: 4.062 KiB)     2.633 μs (85 allocs: 4.469 KiB)       3.589 μs (144 allocs: 8.641 KiB)    
                                                                                                                                 
# submodel       before this PR                      after this PR                         v0.38.9
eval      ----   134.341 ns (3 allocs: 96 bytes)     124.658 ns (3 allocs: 96 bytes)       613.000 ns (19 allocs: 848 bytes)   
grad (FD) ----   260.697 ns (6 allocs: 304 bytes)    229.321 ns (6 allocs: 304 bytes)      920.966 ns (26 allocs: 1.594 KiB)   
grad (RD) ----   12.459 μs (181 allocs: 6.328 KiB)   10.667 μs (165 allocs: 5.781 KiB)     16.375 μs (235 allocs: 9.547 KiB)   
grad (MC) ----   736.450 ns (12 allocs: 432 bytes)   729.158 ns (12 allocs: 432 bytes)     5.667 μs (74 allocs: 3.000 KiB)     
grad (EN) ----   612.854 ns (23 allocs: 816 bytes)   661.111 ns (25 allocs: 928 bytes)     2.420 μs (70 allocs: 2.750 KiB)     
                                                                                                                                 
# demo3          before this PR                      after this PR                         v0.38.9
eval      ----   758.763 ns (27 allocs: 1.281 KiB)   306.122 ns (12 allocs: 528 bytes)     835.794 ns (23 allocs: 1.031 KiB)  
grad (FD) ----   914.063 ns (33 allocs: 2.172 KiB)   555.288 ns (17 allocs: 1.094 KiB)     1.141 μs (32 allocs: 2.219 KiB)    
grad (RD) ----   16.917 μs (269 allocs: 10.047 KiB)  15.917 μs (253 allocs: 9.297 KiB)     20.292 μs (315 allocs: 12.641 KiB) 
grad (MC) ----   12.750 μs (169 allocs: 8.344 KiB)   1.839 μs (28 allocs: 1.250 KiB)       3.810 μs (64 allocs: 3.266 KiB)    
grad (EN) ----   errors                              1.378 μs (32 allocs: 1.406 KiB)       1.922 μs (52 allocs: 2.281 KiB)   

Benchmark code

using DynamicPPL, Distributions, LogDensityProblems, Chairmarks, LinearAlgebra
using ADTypes, ForwardDiff, ReverseDiff, Mooncake, Enzyme

const adtypes = [
    ("FD", AutoForwardDiff()),
    ("RD", AutoReverseDiff()),
    ("MC", AutoMooncake()),
    ("EN", AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const))
]

function benchmark_ldf(model; skip=Union{})
    vi = VarInfo(model)
    vi = DynamicPPL.link!!(vi, model) # comment out to use unlinked
    x = vi[:]
    ldf_no = DynamicPPL.LogDensityFunction(model, getlogjoint, vi)
    m = median(@be LogDensityProblems.logdensity(ldf_no, x))
    print("eval      ----  ")
    display(m)
    for name_adtype in adtypes
        name, adtype = name_adtype
        adtype isa skip && continue
        ldf = DynamicPPL.LogDensityFunction(model, getlogjoint, vi; adtype=adtype)
        m = median(@be LogDensityProblems.logdensity_and_gradient(ldf, x))
        print("grad ($name) ----  ")
        display(m)
    end
end

@model f() = x ~ Normal()
benchmark_ldf(f())

y = [28, 8, -3, 7, -1, 1, 18, 12]
sigma = [15, 10, 16, 11, 9, 11, 10, 18]
@model function eight_schools(y, sigma)
    mu ~ Normal(0, 5)
    tau ~ truncated(Cauchy(0, 5); lower=0)
    theta ~ MvNormal(fill(mu, length(y)), tau^2 * I)
    for i in eachindex(y)
        y[i] ~ Normal(theta[i], sigma[i])
    end
    return (mu=mu, tau=tau)
end
benchmark_ldf(eight_schools(y, sigma))

@model function badvarnames()
    N = 20
    x = Vector{Float64}(undef, N)
    for i in 1:N
        x[i] ~ Normal()
    end
end
benchmark_ldf(badvarnames())

@model function inner()
    m ~ Normal(0, 1)
    s ~ Exponential()
    return (m=m, s=s)
end
@model function withsubmodel()
    params ~ to_submodel(inner())
    y ~ Normal(params.m, params.s)
    1.0 ~ Normal(y)
end
benchmark_ldf(withsubmodel())

benchmark_ldf(DynamicPPL.TestUtils.DEMO_MODELS[3])

@github-actions
Copy link
Contributor

github-actions bot commented Nov 15, 2025

Benchmark Report

  • this PR's head: fa0022b9ebfefeaf7ae1b413b1bc8863f065fdde
  • base branch: 6849bc2f6e3826072d6d3a6d0888eddcb48a25dd

Computer Information

Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

┌───────────────────────┬───────┬─────────────┬────────────────┬────────┬────────────────────────────────┬────────────────────────────┐
│                       │       │             │                │        │        t(eval) / t(ref)        │     t(grad) / t(eval)      │
│                       │       │             │                │        │ ──────────┬──────────┬──────── │ ───────┬─────────┬──────── │
│                 Model │   Dim │  AD Backend │        VarInfo │ Linked │      base │  this PR │ speedup │   base │ this PR │ speedup │
├───────────────────────┼───────┼─────────────┼────────────────┼────────┼───────────┼──────────┼─────────┼────────┼─────────┼─────────┤
│               Dynamic │    10 │    mooncake │          typed │   true │    349.42 │   349.29 │    1.00 │  13.61 │    9.52 │    1.43 │
│                   LDA │    12 │ reversediff │          typed │   true │   2991.85 │  2513.10 │    1.19 │   4.41 │    5.23 │    0.84 │
│   Loop univariate 10k │ 10000 │    mooncake │          typed │   true │ 108952.13 │ 97435.23 │    1.12 │   4.01 │    3.91 │    1.03 │
│    Loop univariate 1k │  1000 │    mooncake │          typed │   true │   8549.77 │  7597.47 │    1.13 │   4.45 │    5.13 │    0.87 │
│      Multivariate 10k │ 10000 │    mooncake │          typed │   true │  73912.24 │ 29324.50 │    2.52 │   6.64 │   10.67 │    0.62 │
│       Multivariate 1k │  1000 │    mooncake │          typed │   true │   3460.81 │  4404.13 │    0.79 │   9.36 │    7.47 │    1.25 │
│ Simple assume observe │     1 │ forwarddiff │          typed │  false │      3.34 │     3.57 │    0.93 │   2.91 │    2.97 │    0.98 │
│           Smorgasbord │   201 │      enzyme │          typed │   true │       err │  1600.41 │     err │    err │    4.71 │     err │
│           Smorgasbord │   201 │ forwarddiff │          typed │  false │   1232.60 │  1163.37 │    1.06 │ 117.21 │  121.36 │    0.97 │
│           Smorgasbord │   201 │ forwarddiff │   typed_vector │   true │   1663.01 │  1585.38 │    1.05 │  58.10 │   54.14 │    1.07 │
│           Smorgasbord │   201 │ forwarddiff │        untyped │   true │   1676.47 │  1583.89 │    1.06 │  58.80 │   57.53 │    1.02 │
│           Smorgasbord │   201 │ forwarddiff │ untyped_vector │   true │   1660.76 │  1600.41 │    1.04 │  59.12 │   54.82 │    1.08 │
│           Smorgasbord │   201 │    mooncake │          typed │   true │   2057.98 │  1583.96 │    1.30 │   4.72 │    5.43 │    0.87 │
│           Smorgasbord │   201 │ reversediff │          typed │   true │   1693.00 │  1602.66 │    1.06 │  84.66 │   88.27 │    0.96 │
│              Submodel │     1 │    mooncake │          typed │   true │      7.50 │     8.05 │    0.93 │   5.73 │    4.52 │    1.27 │
└───────────────────────┴───────┴─────────────┴────────────────┴────────┴───────────┴──────────┴─────────┴────────┴─────────┴─────────┘

@codecov
Copy link

codecov bot commented Nov 15, 2025

Codecov Report

❌ Patch coverage is 94.44444% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 77.05%. Comparing base (6849bc2) to head (fa0022b).

Files with missing lines Patch % Lines
src/contexts/init.jl 91.66% 1 Missing ⚠️
src/fasteval.jl 95.83% 1 Missing ⚠️
Additional details and impacted files
@@                   Coverage Diff                   @@
##           py/not-experimental    #1141      +/-   ##
=======================================================
- Coverage                77.25%   77.05%   -0.20%     
=======================================================
  Files                       40       40              
  Lines                     3706     3731      +25     
=======================================================
+ Hits                      2863     2875      +12     
- Misses                     843      856      +13     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@github-actions
Copy link
Contributor

DynamicPPL.jl documentation for PR #1141 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR1141/

@penelopeysm penelopeysm marked this pull request as draft November 15, 2025 20:31
@penelopeysm penelopeysm force-pushed the py/not-experimental branch 2 times, most recently from 177656b to 9310ec0 Compare November 15, 2025 20:44
@mhauru
Copy link
Member

mhauru commented Nov 17, 2025

Would something bad happen if we just wrapped all the arrays returned by Bijectors in trivial SubArrays? I did some very crude benchmarks locally and at least getindex and setindex! seem to have essentially zero overhead from the wrap.

@penelopeysm penelopeysm force-pushed the py/not-experimental branch 3 times, most recently from 992cea9 to 759bf8a Compare November 18, 2025 11:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants