@@ -103,4 +103,73 @@ using DynamicPPL: LogDensityFunction
103
103
)
104
104
@test LogDensityProblems. logdensity_and_gradient (ldf, vi[:]) isa Any
105
105
end
106
+
107
+ # Test that various different ways of specifying array types as arguments work with all
108
+ # ADTypes.
109
+ @testset " Array argument types" begin
110
+ reference_adtype = AutoForwardDiff ()
111
+ test_m = randn (2 , 3 )
112
+
113
+ function eval_logp_and_grad (model, m, adtype)
114
+ model_instance = model ()
115
+ vi = VarInfo (model_instance)
116
+ ldf = LogDensityFunction (model_instance, vi, DefaultContext (); adtype= adtype)
117
+ return LogDensityProblems. logdensity_and_gradient (ldf, m[:])
118
+ end
119
+
120
+ @model function scalar_matrix_model (:: Type{T} = Float64) where {T<: Real }
121
+ m = Matrix {T} (undef, 2 , 3 )
122
+ return m ~ filldist (MvNormal (zeros (2 ), I), 3 )
123
+ end
124
+
125
+ scalar_matrix_model_reference = eval_logp_and_grad (
126
+ scalar_matrix_model, test_m, reference_adtype
127
+ )
128
+
129
+ @model function matrix_model (:: Type{T} = Matrix{Float64}) where {T}
130
+ m = T (undef, 2 , 3 )
131
+ return m ~ filldist (MvNormal (zeros (2 ), I), 3 )
132
+ end
133
+
134
+ matrix_model_reference = eval_logp_and_grad (matrix_model, test_m, reference_adtype)
135
+
136
+ @model function scalar_array_model (:: Type{T} = Float64) where {T<: Real }
137
+ m = Array {T} (undef, 2 , 3 )
138
+ return m ~ filldist (MvNormal (zeros (2 ), I), 3 )
139
+ end
140
+
141
+ scalar_array_model_reference = eval_logp_and_grad (
142
+ scalar_array_model, test_m, reference_adtype
143
+ )
144
+
145
+ @model function array_model (:: Type{T} = Array{Float64}) where {T}
146
+ m = T (undef, 2 , 3 )
147
+ return m ~ filldist (MvNormal (zeros (2 ), I), 3 )
148
+ end
149
+
150
+ array_model_reference = eval_logp_and_grad (array_model, test_m, reference_adtype)
151
+
152
+ @testset " $adtype " for adtype in [
153
+ AutoReverseDiff (; compile= false ),
154
+ AutoReverseDiff (; compile= true ),
155
+ AutoMooncake (; config= nothing ),
156
+ ]
157
+ scalar_matrix_model_logp_and_grad = eval_logp_and_grad (
158
+ scalar_matrix_model, test_m, adtype
159
+ )
160
+ @test scalar_matrix_model_logp_and_grad[1 ] ≈ scalar_matrix_model_reference[1 ]
161
+ @test scalar_matrix_model_logp_and_grad[2 ] ≈ scalar_matrix_model_reference[2 ]
162
+ matrix_model_logp_and_grad = eval_logp_and_grad (matrix_model, test_m, adtype)
163
+ @test matrix_model_logp_and_grad[1 ] ≈ matrix_model_reference[1 ]
164
+ @test matrix_model_logp_and_grad[2 ] ≈ matrix_model_reference[2 ]
165
+ scalar_array_model_logp_and_grad = eval_logp_and_grad (
166
+ scalar_array_model, test_m, adtype
167
+ )
168
+ @test scalar_array_model_logp_and_grad[1 ] ≈ scalar_array_model_reference[1 ]
169
+ @test scalar_array_model_logp_and_grad[2 ] ≈ scalar_array_model_reference[2 ]
170
+ array_model_logp_and_grad = eval_logp_and_grad (array_model, test_m, adtype)
171
+ @test array_model_logp_and_grad[1 ] ≈ array_model_reference[1 ]
172
+ @test array_model_logp_and_grad[2 ] ≈ array_model_reference[2 ]
173
+ end
174
+ end
106
175
end
0 commit comments