Skip to content

Commit 460a65e

Browse files
committed
Add clarifying comment
1 parent 70a6347 commit 460a65e

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

src/threadsafe.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,15 @@ struct ThreadSafeVarInfo{V<:AbstractVarInfo,L} <: AbstractVarInfo
99
logps::L
1010
end
1111
function ThreadSafeVarInfo(vi::AbstractVarInfo)
12+
# In ThreadSafeVarInfo we use threadid() to index into the array of logp
13+
# fields. This is not good practice --- see
14+
# https://github.com/TuringLang/DynamicPPL.jl/issues/924 for a full
15+
# explanation --- but it has worked okay so far.
16+
# The use of nthreads()*2 here ensures that threadid() doesn't exceed
17+
# the length of the logps array. Ideally, we would use maxthreadid(),
18+
# but Mooncake can't differentiate through that. Empirically, nthreads()*2
19+
# seems to provide an upper bound to maxthreadid(), so we use that here.
20+
# See https://github.com/TuringLang/DynamicPPL.jl/pull/936
1221
return ThreadSafeVarInfo(
1322
vi, [Ref(zero(getlogp(vi))) for _ in 1:(Threads.nthreads() * 2)]
1423
)

0 commit comments

Comments
 (0)