Skip to content

Commit 98fa91d

Browse files
committed
final fixes?
1 parent 6f12082 commit 98fa91d

File tree

5 files changed

+33
-53
lines changed

5 files changed

+33
-53
lines changed

src/debug_utils.jl

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,7 @@ A context used for checking validity of a model.
131131
# Fields
132132
$(FIELDS)
133133
"""
134-
struct DebugContext{M<:Model,C<:AbstractContext} <: AbstractContext
135-
"model that is being run"
136-
model::M
134+
struct DebugContext{C<:AbstractContext} <: AbstractContext
137135
"context used for running the model"
138136
context::C
139137
"mapping from varnames to the number of times they have been seen"
@@ -149,7 +147,6 @@ struct DebugContext{M<:Model,C<:AbstractContext} <: AbstractContext
149147
end
150148

151149
function DebugContext(
152-
model::Model,
153150
context::AbstractContext=DefaultContext();
154151
varnames_seen=OrderedDict{VarName,Int}(),
155152
statements=Vector{Stmt}(),
@@ -158,7 +155,6 @@ function DebugContext(
158155
record_varinfo=false,
159156
)
160157
return DebugContext(
161-
model,
162158
context,
163159
varnames_seen,
164160
statements,
@@ -344,7 +340,7 @@ function check_varnames_seen(varnames_seen::AbstractDict{VarName,Int})
344340
end
345341

346342
# A check we run on the model before evaluating it.
347-
function check_model_pre_evaluation(context::DebugContext, model::Model)
343+
function check_model_pre_evaluation(model::Model)
348344
issuccess = true
349345
# If something is in the model arguments, then it should NOT be in `condition`,
350346
# nor should there be any symbol present in `condition` that has the same symbol.
@@ -361,8 +357,8 @@ function check_model_pre_evaluation(context::DebugContext, model::Model)
361357
return issuccess
362358
end
363359

364-
function check_model_post_evaluation(context::DebugContext, model::Model)
365-
return check_varnames_seen(context.varnames_seen)
360+
function check_model_post_evaluation(model::Model)
361+
return check_varnames_seen(model.context.varnames_seen)
366362
end
367363

368364
"""
@@ -443,21 +439,18 @@ function check_model_and_trace(
443439
)
444440
# Execute the model with the debug context.
445441
debug_context = DebugContext(
446-
model,
447-
SamplingContext(rng, model.context);
448-
error_on_failure=error_on_failure,
449-
kwargs...,
442+
SamplingContext(rng, model.context); error_on_failure=error_on_failure, kwargs...
450443
)
451444
debug_model = DynamicPPL.contextualize(model, debug_context)
452445

453446
# Perform checks before evaluating the model.
454-
issuccess = check_model_pre_evaluation(debug_context, debug_model)
447+
issuccess = check_model_pre_evaluation(debug_model)
455448

456449
# Force single-threaded execution.
457450
DynamicPPL.evaluate_threadunsafe!!(debug_model, varinfo)
458451

459452
# Perform checks after evaluating the model.
460-
issuccess &= check_model_post_evaluation(debug_context, debug_model)
453+
issuccess &= check_model_post_evaluation(debug_model)
461454

462455
if !issuccess && error_on_failure
463456
error("model check failed")

test/compiler.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,10 @@ module Issue537 end
193193
varinfo = VarInfo(model)
194194
@test getlogjoint(varinfo) == lp
195195
@test varinfo_ isa AbstractVarInfo
196-
@test model_ === model
196+
# During the model evaluation, its context is wrapped in a
197+
# SamplingContext, so `model_` is not going to be equal to `model`.
198+
# We can still check equality of `f` though.
199+
@test model_.f === model.f
197200
@test model_.context isa SamplingContext
198201
@test model_.context.rng isa Random.AbstractRNG
199202

test/debug_utils.jl

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
@testset "check_model" begin
22
@testset "context interface" begin
33
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
4-
context = DynamicPPL.DebugUtils.DebugContext(model)
4+
context = DynamicPPL.DebugUtils.DebugContext()
55
DynamicPPL.TestUtils.test_context(context, model)
66
end
77
end
@@ -35,9 +35,7 @@
3535
buggy_model = buggy_demo_model()
3636

3737
@test_logs (:warn,) (:warn,) check_model(buggy_model)
38-
issuccess = check_model(
39-
buggy_model; context=SamplingContext(), record_varinfo=false
40-
)
38+
issuccess = check_model(buggy_model; record_varinfo=false)
4139
@test !issuccess
4240
@test_throws ErrorException check_model(buggy_model; error_on_failure=true)
4341
end
@@ -81,9 +79,7 @@
8179
buggy_model = buggy_subsumes_demo_model()
8280

8381
@test_logs (:warn,) (:warn,) check_model(buggy_model)
84-
issuccess = check_model(
85-
buggy_model; context=SamplingContext(), record_varinfo=false
86-
)
82+
issuccess = check_model(buggy_model; record_varinfo=false)
8783
@test !issuccess
8884
@test_throws ErrorException check_model(buggy_model; error_on_failure=true)
8985
end
@@ -98,9 +94,7 @@
9894
buggy_model = buggy_subsumes_demo_model()
9995

10096
@test_logs (:warn,) (:warn,) check_model(buggy_model)
101-
issuccess = check_model(
102-
buggy_model; context=SamplingContext(), record_varinfo=false
103-
)
97+
issuccess = check_model(buggy_model; record_varinfo=false)
10498
@test !issuccess
10599
@test_throws ErrorException check_model(buggy_model; error_on_failure=true)
106100
end
@@ -115,9 +109,7 @@
115109
buggy_model = buggy_subsumes_demo_model()
116110

117111
@test_logs (:warn,) (:warn,) check_model(buggy_model)
118-
issuccess = check_model(
119-
buggy_model; context=SamplingContext(), record_varinfo=false
120-
)
112+
issuccess = check_model(buggy_model; record_varinfo=false)
121113
@test !issuccess
122114
@test_throws ErrorException check_model(buggy_model; error_on_failure=true)
123115
end

test/model.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
591591
xs_train = 1:0.1:10
592592
ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train))
593593
m_lin_reg = linear_reg(xs_train, ys_train)
594-
chain = [last(evaluate!!(m_lin_reg, VarInfo())) for _ in 1:10000]
594+
chain = [last(sample!!(m_lin_reg, VarInfo())) for _ in 1:10000]
595595

596596
# chain is generated from the prior
597597
@test mean([chain[i][@varname(β)] for i in eachindex(chain)]) 1.0 atol = 0.1

test/threadsafe.jl

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,10 @@
5252
x[i] ~ Normal(x[i - 1], 1)
5353
end
5454
end
55+
model = wthreads(x)
5556

5657
vi = VarInfo()
57-
wthreads(x)(vi)
58+
model(vi)
5859
lp_w_threads = getlogjoint(vi)
5960
if Threads.nthreads() == 1
6061
@test vi_ isa VarInfo
@@ -64,23 +65,19 @@
6465

6566
println("With `@threads`:")
6667
println(" default:")
67-
@time wthreads(x)(vi)
68+
@time model(vi)
6869

6970
# Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements.
70-
DynamicPPL.evaluate_threadsafe!!(
71-
wthreads(x),
72-
vi,
73-
SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()),
74-
)
71+
sampling_model = contextualize(model, SamplingContext(model.context))
72+
DynamicPPL.evaluate_threadsafe!!(sampling_model, vi)
7573
@test getlogjoint(vi) lp_w_threads
74+
# check that it's wrapped during the model evaluation
7675
@test vi_ isa DynamicPPL.ThreadSafeVarInfo
76+
# ensure that it's unwrapped after evaluation finishes
77+
@test vi isa VarInfo
7778

7879
println(" evaluate_threadsafe!!:")
79-
@time DynamicPPL.evaluate_threadsafe!!(
80-
wthreads(x),
81-
vi,
82-
SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()),
83-
)
80+
@time DynamicPPL.evaluate_threadsafe!!(sampling_model, vi)
8481

8582
@model function wothreads(x)
8683
global vi_ = __varinfo__
@@ -89,9 +86,10 @@
8986
x[i] ~ Normal(x[i - 1], 1)
9087
end
9188
end
89+
model = wothreads(x)
9290

9391
vi = VarInfo()
94-
wothreads(x)(vi)
92+
model(vi)
9593
lp_wo_threads = getlogjoint(vi)
9694
if Threads.nthreads() == 1
9795
@test vi_ isa VarInfo
@@ -101,24 +99,18 @@
10199

102100
println("Without `@threads`:")
103101
println(" default:")
104-
@time wothreads(x)(vi)
102+
@time model(vi)
105103

106104
@test lp_w_threads lp_wo_threads
107105

108106
# Ensure that we use `VarInfo`.
109-
DynamicPPL.evaluate_threadunsafe!!(
110-
wothreads(x),
111-
vi,
112-
SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()),
113-
)
107+
sampling_model = contextualize(model, SamplingContext(model.context))
108+
DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi)
114109
@test getlogjoint(vi) lp_w_threads
115110
@test vi_ isa VarInfo
111+
@test vi isa VarInfo
116112

117113
println(" evaluate_threadunsafe!!:")
118-
@time DynamicPPL.evaluate_threadunsafe!!(
119-
wothreads(x),
120-
vi,
121-
SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()),
122-
)
114+
@time DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi)
123115
end
124116
end

0 commit comments

Comments
 (0)