Skip to content

Commit f00879b

Browse files
committed
Remove Dictionaries with Any key type
1 parent a19c9a6 commit f00879b

File tree

9 files changed

+26
-21
lines changed

9 files changed

+26
-21
lines changed

src/DynamicPPL.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using DocStringExtensions
2323
using Random: Random
2424

2525
# For extending
26-
import AbstractPPL: predict
26+
import AbstractPPL: predict, hasvalue, getvalue
2727

2828
# TODO: Remove these when it's possible.
2929
import Bijectors: link, invlink

src/model.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,11 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f)
981981
Generate a sample of type `T` from the prior distribution of the `model`.
982982
"""
983983
function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T}
984-
x = last(evaluate_and_sample!!(rng, model, SimpleVarInfo{Float64}(OrderedDict())))
984+
x = last(
985+
evaluate_and_sample!!(
986+
rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())
987+
),
988+
)
985989
return values_as(x, T)
986990
end
987991

@@ -1028,7 +1032,7 @@ julia> logjoint(demo_model([1., 2.]), chain);
10281032
function logjoint(model::Model, chain::AbstractMCMC.AbstractChains)
10291033
var_info = VarInfo(model) # extract variables info from the model
10301034
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
1031-
argvals_dict = OrderedDict(
1035+
argvals_dict = OrderedDict{VarName,Any}(
10321036
vn_parent =>
10331037
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
10341038
vn_parent in keys(var_info)
@@ -1082,7 +1086,7 @@ julia> logprior(demo_model([1., 2.]), chain);
10821086
function logprior(model::Model, chain::AbstractMCMC.AbstractChains)
10831087
var_info = VarInfo(model) # extract variables info from the model
10841088
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
1085-
argvals_dict = OrderedDict(
1089+
argvals_dict = OrderedDict{VarName,Any}(
10861090
vn_parent =>
10871091
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
10881092
vn_parent in keys(var_info)
@@ -1136,7 +1140,7 @@ julia> loglikelihood(demo_model([1., 2.]), chain);
11361140
function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains)
11371141
var_info = VarInfo(model) # extract variables info from the model
11381142
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
1139-
argvals_dict = OrderedDict(
1143+
argvals_dict = OrderedDict{VarName,Any}(
11401144
vn_parent =>
11411145
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
11421146
vn_parent in keys(var_info)

src/simple_varinfo.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ ERROR: type NamedTuple has no field x
6262
[...]
6363
6464
julia> # If one does not know the varnames, we can use a `OrderedDict` instead.
65-
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict()));
65+
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()));
6666
6767
julia> # (✓) Sort of fast, but only possible at runtime.
6868
vi[@varname(x[1])]
@@ -107,7 +107,7 @@ julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
107107
true
108108
109109
julia> # And with `OrderedDict` of course!
110-
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict()), true));
110+
_, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true));
111111
112112
julia> vi[@varname(x)] # (✓) -∞ < x < ∞
113113
0.6225185067787314
@@ -206,7 +206,7 @@ end
206206
function SimpleVarInfo(values)
207207
return SimpleVarInfo{LogProbType}(values)
208208
end
209-
function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict})
209+
function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict{<:VarName}})
210210
return if isempty(values)
211211
# Can't infer from values, so we just use default.
212212
SimpleVarInfo{LogProbType}(values)
@@ -258,7 +258,7 @@ function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D}
258258
end
259259

260260
function untyped_simple_varinfo(model::Model)
261-
varinfo = SimpleVarInfo(OrderedDict())
261+
varinfo = SimpleVarInfo(OrderedDict{VarName,Any}())
262262
return last(evaluate_and_sample!!(model, varinfo))
263263
end
264264

src/test_utils/varinfo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function setup_varinfos(
3434

3535
# SimpleVarInfo
3636
svi_typed = SimpleVarInfo(example_values)
37-
svi_untyped = SimpleVarInfo(OrderedDict())
37+
svi_untyped = SimpleVarInfo(OrderedDict{VarName,Any}())
3838
svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector())
3939

4040
varinfos = map((

src/values_as_in_model.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ $(TYPEDFIELDS)
1212
"""
1313
struct ValuesAsInModelAccumulator <: AbstractAccumulator
1414
"values that are extracted from the model"
15-
values::OrderedDict
15+
values::OrderedDict{<:VarName}
1616
"whether to extract variables on the LHS of :="
1717
include_colon_eq::Bool
1818
end
1919
function ValuesAsInModelAccumulator(include_colon_eq)
20-
return ValuesAsInModelAccumulator(OrderedDict(), include_colon_eq)
20+
return ValuesAsInModelAccumulator(OrderedDict{VarName,Any}(), include_colon_eq)
2121
end
2222

2323
accumulator_name(::Type{<:ValuesAsInModelAccumulator}) = :ValuesAsInModel

src/varnamedvector.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1482,7 +1482,7 @@ function values_as(vnv::VarNamedVector, ::Type{D}) where {D<:AbstractDict}
14821482
end
14831483

14841484
# See the docstring of `getvalue` for the semantics of `hasvalue` and `getvalue`, and how
1485-
# they differ from `haskey` and `getindex`. They can be found in src/utils.jl.
1485+
# they differ from `haskey` and `getindex`. They can be found in AbstractPPL.jl.
14861486

14871487
# TODO(mhauru) This is tricky to implement in the general case, and the below implementation
14881488
# only covers some simple cases. It's probably sufficient in most situations though.

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ using LinearAlgebra # Diagonal
2727

2828
using JET: JET
2929

30+
# need to call this to get the AbstractPPL I think
31+
Pkg.update()
32+
3033
using Combinatorics: combinations
3134
using OrderedCollections: OrderedSet
3235

test/simple_varinfo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@
9090
DynamicPPL.TestUtils.DEMO_MODELS
9191
values_constrained = DynamicPPL.TestUtils.rand_prior_true(model)
9292
@testset "$(typeof(vi))" for vi in (
93-
SimpleVarInfo(Dict()),
93+
SimpleVarInfo(Dict{VarName,Any}()),
9494
SimpleVarInfo(values_constrained),
9595
SimpleVarInfo(DynamicPPL.VarNamedVector()),
9696
DynamicPPL.typed_varinfo(model),

test/varinfo.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ end
110110
test_base(VarInfo())
111111
test_base(DynamicPPL.typed_varinfo(VarInfo()))
112112
test_base(SimpleVarInfo())
113-
test_base(SimpleVarInfo(Dict()))
113+
test_base(SimpleVarInfo(Dict{VarName,Any}()))
114114
test_base(SimpleVarInfo(DynamicPPL.VarNamedVector()))
115115
end
116116

@@ -604,8 +604,7 @@ end
604604

605605
## `SimpleVarInfo{<:Dict}`
606606
vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true)
607-
# Sample in unconstrained space.
608-
vi = last(DynamicPPL.evaluate_and_sample!!(model, vi))
607+
vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict{VarName,Any}()), true)
609608
f = DynamicPPL.from_linked_internal_transform(vi, vn, dist)
610609
x = f(DynamicPPL.getindex_internal(vi, vn))
611610
@test getlogjoint(vi) Bijectors.logpdf_with_trans(dist, x, true)
@@ -750,11 +749,10 @@ end
750749
model, (; x=1.0), (@varname(x),); include_threadsafe=true
751750
)
752751
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
753-
# Skip the severely inconcrete `SimpleVarInfo` types, since checking for type
752+
# Skip the inconcrete `SimpleVarInfo` types, since checking for type
754753
# stability for them doesn't make much sense anyway.
755-
if varinfo isa SimpleVarInfo{OrderedDict{Any,Any}} ||
756-
varinfo isa
757-
DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo{OrderedDict{Any,Any}}}
754+
if varinfo isa SimpleVarInfo{<:AbstractDict} ||
755+
varinfo isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo{<:AbstractDict}}
758756
continue
759757
end
760758
@inferred DynamicPPL.unflatten(varinfo, varinfo[:])

0 commit comments

Comments
 (0)