Skip to content

Commit 70a6347

Browse files
committed
Use Threads.nthreads() * 2 in TSVI
1 parent 2a1b650 commit 70a6347

File tree

5 files changed

+14
-5
lines changed

5 files changed

+14
-5
lines changed

HISTORY.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# DynamicPPL Changelog
22

3+
## 0.36.11
4+
5+
Make `ThreadSafeVarInfo` hold a total of `Threads.nthreads() * 2` logp values, instead of just `Threads.nthreads()`.
6+
This fix helps to paper over the cracks in using `threadid()` to index into the `ThreadSafeVarInfo` object.
7+
38
## 0.36.10
49

510
Added compatibility with ForwardDiff 1.0.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.36.10"
3+
version = "0.36.11"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/simple_varinfo.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -648,10 +648,12 @@ end
648648
# Threadsafe stuff.
649649
# For `SimpleVarInfo` we don't really need `Ref` so let's not use it.
650650
function ThreadSafeVarInfo(vi::SimpleVarInfo)
651-
return ThreadSafeVarInfo(vi, zeros(typeof(getlogp(vi)), Threads.nthreads()))
651+
return ThreadSafeVarInfo(vi, zeros(typeof(getlogp(vi)), Threads.nthreads() * 2))
652652
end
653653
function ThreadSafeVarInfo(vi::SimpleVarInfo{<:Any,<:Ref})
654-
return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()])
654+
return ThreadSafeVarInfo(
655+
vi, [Ref(zero(getlogp(vi))) for _ in 1:(Threads.nthreads() * 2)]
656+
)
655657
end
656658

657659
has_varnamedvector(vi::SimpleVarInfo) = vi.values isa VarNamedVector

src/threadsafe.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ struct ThreadSafeVarInfo{V<:AbstractVarInfo,L} <: AbstractVarInfo
99
logps::L
1010
end
1111
function ThreadSafeVarInfo(vi::AbstractVarInfo)
12-
return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()])
12+
return ThreadSafeVarInfo(
13+
vi, [Ref(zero(getlogp(vi))) for _ in 1:(Threads.nthreads() * 2)]
14+
)
1315
end
1416
ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi
1517

test/threadsafe.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
@test threadsafe_vi.varinfo === vi
77
@test threadsafe_vi.logps isa Vector{typeof(Ref(getlogp(vi)))}
8-
@test length(threadsafe_vi.logps) == Threads.nthreads()
8+
@test length(threadsafe_vi.logps) == Threads.nthreads() * 2
99
@test all(iszero(x[]) for x in threadsafe_vi.logps)
1010
end
1111

0 commit comments

Comments
 (0)