Skip to content

Conversation

sunxd3
Copy link
Collaborator

@sunxd3 sunxd3 commented Jul 4, 2025

ref #552

@sunxd3
Copy link
Collaborator Author

sunxd3 commented Jul 4, 2025

with ADTests models

Model Current PR Release Ratio (C/R)
assume_beta 4.486810 6.528063 0.687311
assume_dirichlet 5.028416 6.503771 0.773154
assume_lkjcholu 4.169331 4.713440 0.884562
assume_mvnormal 7.584657 9.605529 0.789614
assume_normal 4.587607 7.555556 0.607183
assume_submodel 3.035702 3.936000 0.771266
assume_wishart 7.341956 9.142857 0.803026
broadcast_macro 3.552181 6.819299 0.520901
control_flow 3.767179 7.860805 0.479236
demo_assume_dot_observe 3.822346 5.625724 0.679441
demo_assume_dot_observe_literal 3.110979 6.891854 0.451399
demo_assume_index_observe 3.540136 5.826788 0.607562
demo_assume_matrix_observe_matrix_index 4.456230 5.599493 0.795827
demo_assume_multivariate_observe 4.135187 7.349003 0.562687
demo_assume_multivariate_observe_literal 4.092533 6.768094 0.604680
demo_assume_observe_literal 3.496984 6.271389 0.557609
demo_assume_submodel_observe_index_literal 3.079295 4.747754 0.648579
demo_dot_assume_observe 4.045807 5.916732 0.683791
demo_dot_assume_observe_index 3.737512 6.001495 0.622764
demo_dot_assume_observe_index_literal 3.563587 6.228103 0.572179
demo_dot_assume_observe_matrix_index 3.861691 6.480148 0.595926
demo_dot_assume_observe_submodel 3.258644 5.056410 0.644458
dot_assume 3.822570 5.309253 0.719983
dot_observe 4.396199 8.126444 0.540974
dynamic_constraint 3.199087 5.866546 0.545310
multiple_constraints_same_var 7.969536 8.591111 0.927649
multithreaded 3.893714 4.530427 0.859459
n010 3.605450 3.854293 0.935437
n050 3.498392 3.673095 0.952437
n100 3.509174 3.450000 1.017152
n500 4.415281 7.765084 0.568607
observe_bernoulli 5.329560 7.040417 0.756995
observe_categorical 5.303571 9.444518 0.561550
observe_index 4.329834 7.583224 0.570975
observe_literal 4.903116 6.428380 0.762730
observe_multivariate 4.058949 8.550705 0.474692
observe_submodel 3.884157 5.014414 0.774598
pdb_eight_schools_centered 3.702677 5.093955 0.726877

Copy link

codecov bot commented Jul 4, 2025

Codecov Report

Attention: Patch coverage is 76.59574% with 11 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
ext/MooncakeDynamicPPLExt.jl 76.59% 11 Missing ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Contributor

github-actions bot commented Jul 4, 2025

Mooncake.jl documentation for PR #644 is available at:
https://chalk-lab.github.io/Mooncake.jl/previews/PR644/

Copy link
Contributor

github-actions bot commented Jul 4, 2025

Performance Ratio:
Ratio of time to compute gradient and time to compute function.
Warning: results are very approximate! See here for more context.

┌────────────────────────────┬──────────┬──────────┬─────────┬─────────────┬─────────┐
│                      Label │   Primal │ Mooncake │  Zygote │ ReverseDiff │  Enzyme │
│                     String │   String │   String │  String │      String │  String │
├────────────────────────────┼──────────┼──────────┼─────────┼─────────────┼─────────┤
│                   sum_1000 │ 100.0 ns │      1.8 │     1.1 │        5.51 │    8.21 │
│                  _sum_1000 │ 941.0 ns │     6.68 │  1720.0 │        33.6 │    1.09 │
│               sum_sin_1000 │  6.55 μs │     2.24 │    1.72 │        10.6 │    2.24 │
│              _sum_sin_1000 │  5.33 μs │     2.63 │   281.0 │        13.1 │    2.42 │
│                   kron_sum │ 352.0 μs │     37.7 │    4.45 │       183.0 │    10.8 │
│              kron_view_sum │ 344.0 μs │     40.6 │    10.4 │       192.0 │     6.5 │
│      naive_map_sin_cos_exp │  2.14 μs │     2.15 │ missing │        7.23 │    2.34 │
│            map_sin_cos_exp │  2.11 μs │     2.43 │     1.5 │        6.18 │    2.88 │
│      broadcast_sin_cos_exp │  2.27 μs │     2.25 │    2.33 │        1.46 │    2.26 │
│                 simple_mlp │ 419.0 μs │      4.6 │     1.6 │        6.95 │    3.25 │
│                     gp_lml │ 566.0 μs │     4.46 │    2.28 │     missing │    2.71 │
│ turing_broadcast_benchmark │  1.98 ms │     3.58 │ missing │        27.7 │ missing │
│         large_single_block │ 380.0 ns │     4.51 │  4350.0 │        31.9 │    2.24 │
└────────────────────────────┴──────────┴──────────┴─────────┴─────────────┴─────────┘

@sunxd3
Copy link
Collaborator Author

sunxd3 commented Jul 4, 2025

As Penny pointed out in TuringLang/DynamicPPL.jl#856, when a model is declared in a function, it becomes a closure.

julia> function create_closure()
           @model function closure_model(x)
               s ~ InverseGamma(2, 3)
               return x ~ Normal(0, s)
           end
           return closure_model
       end
create_closure (generic function with 1 method)

julia> closure_fn = create_closure()
(::var"#closure_model#17") (generic function with 2 methods)

julia> model = closure_fn([1.0, 2.0])
Model{var"#closure_model#17", (:x,), (), (), Tuple{Vector{Float64}}, Tuple{}, DefaultContext}(var"#closure_model#17"(Core.Box(var"#closure_model#17"(#= circular reference @-2 =#))), (x = [1.0, 2.0],), NamedTuple(), DefaultContext())

julia>  vi = DynamicPPL.VarInfo(Random.default_rng(), model); ldf = DynamicPPL.LogDensityFunction(model, vi, DynamicPPL.DefaultContext()); tangent = zero_tangent(ldf);

julia> tangent.fields.model.fields.f.fields.closure_model.fields.contents.tangent === model_f_tangent
true

This is a special case where tangent might refer to itself, and reason for

elseif model_f_tangent isa Tangent && hasfield(typeof(model_f_tangent), :fields)
# Check if any field is a MutableTangent with PossiblyUninitTangent{Any}
for (_, fval) in pairs(model_f_tangent.fields)
if fval isa MutableTangent &&
hasfield(typeof(fval), :fields) &&
hasfield(typeof(fval.fields), :contents) &&
fval.fields.contents isa Mooncake.PossiblyUninitTangent{Any}
is_closure = true
break

@sunxd3 sunxd3 requested review from yebai and penelopeysm July 4, 2025 12:44
Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code looks sensible. Quick comment on the interface, I suggest we introduce a function

Mooncake._build_iddict(::Type{Tangent}) = IdDict{Any,Bool}()

Then, DynamicPPL can overload this function for various DynamicPPL related tangent types, e.g. LDF, Model, VarInfo:

Mooncake._build_iddict(::Type{Tangent{VarInfo}}) = NoCache()

EDIT: We could also have a two-argument version to pass runtime information

Mooncake._build_iddict(::Type{Tangent{VarInfo}}, extra) = ...

@sunxd3
Copy link
Collaborator Author

sunxd3 commented Jul 6, 2025

The interface is fine, but achieving feature parity with the current code is difficult to do in a clean way. The main reason is that closure detection requires runtime information. We can go for a hybrid approach, which somewhat defeats the purpose of introducing an interface function. Or we can opt for cleanliness and leave some performance gains on the table.

@yebai
Copy link
Member

yebai commented Jul 6, 2025

Or we can opt for cleanliness and leave some performance gains on the table.

Yes, that is okay if we do it robustly, excluding invalid cases like closures.

@penelopeysm
Copy link
Collaborator

I haven't yet looked at the code, but I wanted to point out that if Mooncake has a DynamicPPLExt (and thus a compat entry for DynamicPPL), then breaking changes of DynamicPPL cannot be tested with Mooncake.

That isn't a massive problem in and of itself; we could separate the Mooncake part out from the test suite (in exactly the same way Enzyme is).

However, because breaking changes in DPPL are much more common than breaking changes in Mooncake, I think it may be easier to implement this code inside DynamicPPLMooncakeExt rather than the other way around.

(see: TuringLang/DynamicPPL.jl#740 where this issue was discussed)

Copy link
Collaborator

@penelopeysm penelopeysm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a look at the code and I'm generally happy. However, I think it would make sense to move this to DynamicPPL so that we can also test this with the existing Mooncake tests there. If we do that then I can review it there in more detail if needed?

sunxd3 and others added 3 commits July 8, 2025 07:55
Co-authored-by: Penelope Yong <penelopeysm@gmail.com>
Signed-off-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com>
Co-authored-by: Penelope Yong <penelopeysm@gmail.com>
Signed-off-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com>
@yebai
Copy link
Member

yebai commented Jul 9, 2025

Closed in favour of TuringLang/DynamicPPL.jl#975

@yebai yebai closed this Jul 9, 2025
@yebai yebai deleted the sunxd/dynamicppl_perf_improve branch July 9, 2025 15:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants