Skip to content

Commit 0fcca13

Browse files
committed
Add a couple of tests being removed from Turing.jl
1 parent 5c89efc commit 0fcca13

File tree

2 files changed

+83
-0
lines changed

2 files changed

+83
-0
lines changed

test/ad.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,73 @@ using DynamicPPL: LogDensityFunction
103103
)
104104
@test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any
105105
end
106+
107+
# Test that various different ways of specifying array types as arguments work with all
108+
# ADTypes.
109+
@testset "Array argument types" begin
110+
reference_adtype = AutoForwardDiff()
111+
test_m = randn(2, 3)
112+
113+
function eval_logp_and_grad(model, m, adtype)
114+
model_instance = model()
115+
vi = VarInfo(model_instance)
116+
ldf = LogDensityFunction(model_instance, vi, DefaultContext(); adtype=adtype)
117+
return LogDensityProblems.logdensity_and_gradient(ldf, m[:])
118+
end
119+
120+
@model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real}
121+
m = Matrix{T}(undef, 2, 3)
122+
return m ~ filldist(MvNormal(zeros(2), I), 3)
123+
end
124+
125+
scalar_matrix_model_reference = eval_logp_and_grad(
126+
scalar_matrix_model, test_m, reference_adtype
127+
)
128+
129+
@model function matrix_model(::Type{T}=Matrix{Float64}) where {T}
130+
m = T(undef, 2, 3)
131+
return m ~ filldist(MvNormal(zeros(2), I), 3)
132+
end
133+
134+
matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, reference_adtype)
135+
136+
@model function scalar_array_model(::Type{T}=Float64) where {T<:Real}
137+
m = Array{T}(undef, 2, 3)
138+
return m ~ filldist(MvNormal(zeros(2), I), 3)
139+
end
140+
141+
scalar_array_model_reference = eval_logp_and_grad(
142+
scalar_array_model, test_m, reference_adtype
143+
)
144+
145+
@model function array_model(::Type{T}=Array{Float64}) where {T}
146+
m = T(undef, 2, 3)
147+
return m ~ filldist(MvNormal(zeros(2), I), 3)
148+
end
149+
150+
array_model_reference = eval_logp_and_grad(array_model, test_m, reference_adtype)
151+
152+
@testset "$adtype" for adtype in [
153+
AutoReverseDiff(; compile=false),
154+
AutoReverseDiff(; compile=true),
155+
AutoMooncake(; config=nothing),
156+
]
157+
scalar_matrix_model_logp_and_grad = eval_logp_and_grad(
158+
scalar_matrix_model, test_m, adtype
159+
)
160+
@test scalar_matrix_model_logp_and_grad[1] scalar_matrix_model_reference[1]
161+
@test scalar_matrix_model_logp_and_grad[2] scalar_matrix_model_reference[2]
162+
matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype)
163+
@test matrix_model_logp_and_grad[1] matrix_model_reference[1]
164+
@test matrix_model_logp_and_grad[2] matrix_model_reference[2]
165+
scalar_array_model_logp_and_grad = eval_logp_and_grad(
166+
scalar_array_model, test_m, adtype
167+
)
168+
@test scalar_array_model_logp_and_grad[1] scalar_array_model_reference[1]
169+
@test scalar_array_model_logp_and_grad[2] scalar_array_model_reference[2]
170+
array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype)
171+
@test array_model_logp_and_grad[1] array_model_reference[1]
172+
@test array_model_logp_and_grad[2] array_model_reference[2]
173+
end
174+
end
106175
end

test/compiler.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,20 @@ module Issue537 end
289289
@test all((isassigned(x, i) for i in eachindex(x)))
290290
end
291291

292+
# Test that that using @. to stop unwanted broadcasting on the RHS works.
293+
@testset "@. ~ with interpolation" begin
294+
@model function at_dot_with_interpolation()
295+
x = Vector{Float64}(undef, 2)
296+
# Without the interpolation the RHS would turn into `Normal.(sum.([1.0, 2.0]))`,
297+
# which would crash.
298+
@. x ~ $(Normal(sum([1.0, 2.0])))
299+
end
300+
301+
# The main check is just that calling at_dot_with_interpolation() doesn't crash,
302+
# the check of the keys is not very important.
303+
@show keys(VarInfo(at_dot_with_interpolation())) == [@varname(x[1]), @varname(x[2])]
304+
end
305+
292306
# A couple of uses of .~ that are no longer valid as of v0.35.
293307
@testset "old .~ syntax" begin
294308
@model function multivariate_dot_tilde()

0 commit comments

Comments
 (0)