Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ jobs:
test_group: [
{test_type: 'ext', label: 'differentiation_interface'},
{test_type: 'ext', label: 'dynamic_expressions'},
{test_type: 'ext', label: 'dynamicppl'},
{test_type: 'ext', label: 'flux'},
{test_type: 'ext', label: 'luxlib'},
{test_type: 'ext', label: 'nnlib'},
Expand Down
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
Expand All @@ -34,6 +35,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
MooncakeAllocCheckExt = "AllocCheck"
MooncakeCUDAExt = "CUDA"
MooncakeDynamicExpressionsExt = "DynamicExpressions"
MooncakeDynamicPPLExt = "DynamicPPL"
MooncakeFluxExt = "Flux"
MooncakeJETExt = "JET"
MooncakeLuxLibExt = "LuxLib"
Expand All @@ -53,6 +55,7 @@ DiffRules = "1"
DiffTests = "0.1"
DispatchDoctor = "0.4.26"
DynamicExpressions = "2"
DynamicPPL = "0.36"
ExprTools = "0.1"
Flux = "0.16.3"
FunctionWrappers = "1.1.3"
Expand Down
104 changes: 104 additions & 0 deletions ext/MooncakeDynamicPPLExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
module MooncakeDynamicPPLExt

__precompile__(false)

using DynamicPPL
using Mooncake
import Mooncake: set_to_zero!!
using Mooncake: NoTangent, Tangent, MutableTangent, NoCache, set_to_zero_internal!!

"""
Check if a tangent corresponds to a DynamicPPL.LogDensityFunction
"""
function is_dppl_ldf_tangent(x)
x isa Tangent || return false
hasfield(typeof(x), :fields) || return false

fields = x.fields
propertynames(fields) == (:model, :varinfo, :context, :adtype, :prep) || return false

is_dppl_varinfo_tangent(fields.varinfo) || return false
is_dppl_model_tangent(fields.model) || return false

return true
end

"""
Check if a tangent corresponds to a DynamicPPL.VarInfo
"""
function is_dppl_varinfo_tangent(x)
x isa Tangent || return false
hasfield(typeof(x), :fields) || return false

fields = x.fields
propertynames(fields) == (:metadata, :logp, :num_produce) || return false

# Additional validation could be added here
return true
end

"""
Check if a tangent corresponds to a DynamicPPL.Model
"""
function is_dppl_model_tangent(x)
x isa Tangent || return false
hasfield(typeof(x), :fields) || return false

fields = x.fields
all(f in propertynames(fields) for f in (:f, :args, :defaults, :context)) ||
return false

return true
end

"""
Check if a MutableTangent corresponds to DynamicPPL.Metadata
"""
function is_dppl_metadata_tangent(x)
x isa MutableTangent || return false
hasfield(typeof(x), :fields) || return false

Check warning on line 59 in ext/MooncakeDynamicPPLExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeDynamicPPLExt.jl#L57-L59

Added lines #L57 - L59 were not covered by tests

fields = x.fields

Check warning on line 61 in ext/MooncakeDynamicPPLExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeDynamicPPLExt.jl#L61

Added line #L61 was not covered by tests
# Check for the expected fields in Metadata
expected_fields = (:idcs, :vns, :ranges, :vals, :dists, :orders, :flags)
all(f in propertynames(fields) for f in expected_fields) || return false

Check warning on line 64 in ext/MooncakeDynamicPPLExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeDynamicPPLExt.jl#L63-L64

Added lines #L63 - L64 were not covered by tests

return true

Check warning on line 66 in ext/MooncakeDynamicPPLExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeDynamicPPLExt.jl#L66

Added line #L66 was not covered by tests
end

function Mooncake.set_to_zero!!(x)
if is_dppl_ldf_tangent(x)
model_f_tangent = x.fields.model.fields.f
is_closure = false
if model_f_tangent isa MutableTangent
is_closure = true

Check warning on line 74 in ext/MooncakeDynamicPPLExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeDynamicPPLExt.jl#L74

Added line #L74 was not covered by tests
elseif model_f_tangent isa Tangent && hasfield(typeof(model_f_tangent), :fields)
# Check if any field is a MutableTangent with PossiblyUninitTangent{Any}
for (_, fval) in pairs(model_f_tangent.fields)
if fval isa MutableTangent &&
hasfield(typeof(fval), :fields) &&
hasfield(typeof(fval.fields), :contents) &&
fval.fields.contents isa Mooncake.PossiblyUninitTangent{Any}
is_closure = true
break
end
end

Check warning on line 85 in ext/MooncakeDynamicPPLExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeDynamicPPLExt.jl#L85

Added line #L85 was not covered by tests
end

if is_closure
return set_to_zero_internal!!(IdDict{Any,Bool}(), x)
else
return set_to_zero_internal!!(NoCache(), x)
end
elseif x isa Tangent && (is_dppl_varinfo_tangent(x) || is_dppl_model_tangent(x))
# VarInfo and Model
return set_to_zero_internal!!(NoCache(), x)

Check warning on line 95 in ext/MooncakeDynamicPPLExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeDynamicPPLExt.jl#L95

Added line #L95 was not covered by tests
elseif x isa MutableTangent && is_dppl_metadata_tangent(x)
return set_to_zero_internal!!(NoCache(), x)

Check warning on line 97 in ext/MooncakeDynamicPPLExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeDynamicPPLExt.jl#L97

Added line #L97 was not covered by tests
else
# Use the original implementation with IdDict for all other types
return set_to_zero_internal!!(IdDict{Any,Bool}(), x)
end
end

end # module
7 changes: 7 additions & 0 deletions test/ext/dynamicppl/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
205 changes: 205 additions & 0 deletions test/ext/dynamicppl/dynamicppl.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
using Pkg
Pkg.activate(@__DIR__)
Pkg.develop(; path=joinpath(@__DIR__, "..", "..", ".."))

using BenchmarkTools, Distributions, DynamicPPL, Mooncake, Random, Test
using Mooncake: NoCache, set_to_zero!!, set_to_zero_internal!!, zero_tangent

# Define models globally to avoid closure issues
@model function test_model1(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
return x .~ Normal(m, sqrt(s))
end

@model function test_model2(x, y)
τ ~ Gamma(1, 1)
σ ~ InverseGamma(2, 3)
μ ~ Normal(0, τ)
x .~ Normal(μ, σ)
return y .~ Normal(μ, σ)
end

@testset "MooncakeDynamicPPLExt" begin
@testset "Validation functions" begin
# Test with a real DynamicPPL model
model = test_model1([1.0, 2.0, 3.0])
vi = DynamicPPL.VarInfo(Random.default_rng(), model)
ldf = DynamicPPL.LogDensityFunction(model, vi, DynamicPPL.DefaultContext())
tangent = zero_tangent(ldf)

# Since we can't access extension functions directly,
# test the behavior indirectly through set_to_zero!!
# If the optimization is working, set_to_zero!! should handle DynamicPPL types efficiently
result = set_to_zero!!(deepcopy(tangent))
@test result isa typeof(tangent)

# Test with metadata - verify structure exists
if hasfield(typeof(tangent.fields.varinfo.fields), :metadata)
metadata = tangent.fields.varinfo.fields.metadata
@test !isnothing(metadata)
end

# Test that non-DPPL tangents still work with set_to_zero!!
dummy_tangent = Mooncake.Tangent(
NamedTuple{(:model, :varinfo, :context, :adtype, :prep)}((
1.0, 2.0, Mooncake.NoTangent(), Mooncake.NoTangent(), Mooncake.NoTangent()
)),
)
# This should use the fallback implementation
result2 = set_to_zero!!(deepcopy(dummy_tangent))
@test result2 isa typeof(dummy_tangent)
end

@testset "NoCache optimization correctness" begin
# Test that set_to_zero!! uses NoCache for DynamicPPL types
model = test_model1([1.0, 2.0, 3.0])
vi = DynamicPPL.VarInfo(Random.default_rng(), model)
ldf = DynamicPPL.LogDensityFunction(model, vi, DynamicPPL.DefaultContext())
tangent = zero_tangent(ldf)

# Modify some values
if hasfield(typeof(tangent.fields.model.fields), :args) &&
hasfield(typeof(tangent.fields.model.fields.args), :x)
x_tangent = tangent.fields.model.fields.args.x
if !isempty(x_tangent)
x_tangent[1] = 5.0
end
end

# Call set_to_zero!! and verify it works
set_to_zero!!(tangent)

# Check that values are zeroed
if hasfield(typeof(tangent.fields.model.fields), :args) &&
hasfield(typeof(tangent.fields.model.fields.args), :x)
x_tangent = tangent.fields.model.fields.args.x
if !isempty(x_tangent)
@test x_tangent[1] == 0.0
end
end
end

@testset "Performance improvement" begin
# Test with DEMO_MODELS if available
if isdefined(DynamicPPL.TestUtils, :DEMO_MODELS) &&
!isempty(DynamicPPL.TestUtils.DEMO_MODELS)
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
else
# Fallback to our test model
model = test_model1([1.0, 2.0, 3.0, 4.0])
end

vi = DynamicPPL.VarInfo(Random.default_rng(), model)
ldf = DynamicPPL.LogDensityFunction(model, vi, DynamicPPL.DefaultContext())
tangent = zero_tangent(ldf)

# Run benchmarks
result_iddict = @benchmark begin
cache = IdDict{Any,Bool}()
set_to_zero_internal!!(cache, $tangent)
end

result_nocache = @benchmark set_to_zero!!($tangent)

# Extract median times
time_iddict = median(result_iddict).time
time_nocache = median(result_nocache).time

# We expect NoCache to be faster
speedup = time_iddict / time_nocache
@test speedup > 1.5 # Conservative expectation - should be ~4x

println("Performance improvement: $(round(speedup, digits=2))x speedup")
println("IdDict: $(round(time_iddict/1000, digits=2)) μs")
println("NoCache: $(round(time_nocache/1000, digits=2)) μs")
end

@testset "Aliasing safety" begin
# Test with aliased data
shared_data = [1.0, 2.0, 3.0]
model = test_model2(shared_data, shared_data) # x and y are the same array
vi = DynamicPPL.VarInfo(Random.default_rng(), model)
ldf = DynamicPPL.LogDensityFunction(model, vi, DynamicPPL.DefaultContext())
tangent = zero_tangent(ldf)

# Check that aliasing is preserved in tangent
if hasfield(typeof(tangent.fields.model.fields), :args)
args = tangent.fields.model.fields.args
if hasfield(typeof(args), :x) && hasfield(typeof(args), :y)
@test args.x === args.y # Aliasing should be preserved

# Modify via x
if !isempty(args.x)
args.x[1] = 10.0
@test args.y[1] == 10.0 # Should also change y
end

# Zero and check both are zeroed
# Since x and y are aliased, zeroing one zeros both
set_to_zero!!(tangent)
if !isempty(args.x)
@test args.x[1] == 0.0
@test args.y[1] == 0.0
end
end
end
end

@testset "Closure handling" begin
# Test that closure models are correctly handled

# Create closure model (captures environment, has circular references)
function create_closure_model()
local_var = 42
@model function closure_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
return x .~ Normal(m, sqrt(s))
end
return closure_model
end

closure_fn = create_closure_model()
model_closure = closure_fn([1.0, 2.0, 3.0])
vi_closure = DynamicPPL.VarInfo(Random.default_rng(), model_closure)
ldf_closure = DynamicPPL.LogDensityFunction(
model_closure, vi_closure, DynamicPPL.DefaultContext()
)
tangent_closure = zero_tangent(ldf_closure)

# Test that it works without stack overflow
@test_nowarn set_to_zero!!(deepcopy(tangent_closure))

# Compare with global model (no closure)
model_global = test_model1([1.0, 2.0, 3.0])
vi_global = DynamicPPL.VarInfo(Random.default_rng(), model_global)
ldf_global = DynamicPPL.LogDensityFunction(
model_global, vi_global, DynamicPPL.DefaultContext()
)
tangent_global = zero_tangent(ldf_global)

# Verify model.f tangent types differ
f_tangent_closure = tangent_closure.fields.model.fields.f
f_tangent_global = tangent_global.fields.model.fields.f

@test f_tangent_global isa Mooncake.NoTangent # Global function
@test f_tangent_closure isa Mooncake.Tangent # Closure function

# Performance comparison
time_global = @elapsed for _ in 1:100
set_to_zero!!(tangent_global)
end

time_closure = @elapsed for _ in 1:100
set_to_zero!!(tangent_closure)
end

# Global should be faster (uses NoCache)
@test time_global < time_closure

println(
"Closure handling: Global $(round(time_global*1000, digits=2))ms vs Closure $(round(time_closure*1000, digits=2))ms",
)
end
end
Loading