Skip to content

Commit e6bd8fb

Browse files
committed
[Distributed] Make worker state variable threadsafe
1 parent 70cc57c commit e6bd8fb

File tree

4 files changed

+113
-14
lines changed

4 files changed

+113
-14
lines changed

stdlib/Distributed/src/cluster.jl

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,9 @@ mutable struct Worker
9999
add_msgs::Array{Any,1}
100100
gcflag::Bool
101101
state::WorkerState
102-
c_state::Condition # wait for state changes
103-
ct_time::Float64 # creation time
104-
conn_func::Any # used to setup connections lazily
102+
c_state::Threads.Condition # wait for state changes, lock for state
103+
ct_time::Float64 # creation time
104+
conn_func::Any # used to setup connections lazily
105105

106106
r_stream::IO
107107
w_stream::IO
@@ -133,7 +133,7 @@ mutable struct Worker
133133
if haskey(map_pid_wrkr, id)
134134
return map_pid_wrkr[id]
135135
end
136-
w=new(id, [], [], false, W_CREATED, Condition(), time(), conn_func)
136+
w=new(id, [], [], false, W_CREATED, Threads.Condition(), time(), conn_func)
137137
w.initialized = Event()
138138
register_worker(w)
139139
w
@@ -143,12 +143,16 @@ mutable struct Worker
143143
end
144144

145145
function set_worker_state(w, state)
146-
w.state = state
147-
notify(w.c_state; all=true)
146+
lock(w.c_state) do
147+
w.state = state
148+
notify(w.c_state; all=true)
149+
end
148150
end
149151

150152
function check_worker_state(w::Worker)
153+
lock(w.c_state)
151154
if w.state === W_CREATED
155+
unlock(w.c_state)
152156
if !isclusterlazy()
153157
if PGRP.topology === :all_to_all
154158
# Since higher pids connect with lower pids, the remote worker
@@ -168,6 +172,8 @@ function check_worker_state(w::Worker)
168172
errormonitor(t)
169173
wait_for_conn(w)
170174
end
175+
else
176+
unlock(w.c_state)
171177
end
172178
end
173179

@@ -186,13 +192,25 @@ function exec_conn_func(w::Worker)
186192
end
187193

188194
function wait_for_conn(w)
195+
lock(w.c_state)
189196
if w.state === W_CREATED
197+
unlock(w.c_state)
190198
timeout = worker_timeout() - (time() - w.ct_time)
191199
timeout <= 0 && error("peer $(w.id) has not connected to $(myid())")
192200

193-
@async (sleep(timeout); notify(w.c_state; all=true))
194-
wait(w.c_state)
195-
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
201+
T = Threads.@spawn begin
202+
sleep($timeout)
203+
lock(w.c_state) do
204+
notify(w.c_state; all=true)
205+
end
206+
end
207+
errormonitor(T)
208+
lock(w.c_state) do
209+
wait(w.c_state)
210+
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
211+
end
212+
else
213+
unlock(w.c_state)
196214
end
197215
nothing
198216
end
@@ -483,7 +501,10 @@ function addprocs_locked(manager::ClusterManager; kwargs...)
483501
while true
484502
if isempty(launched)
485503
istaskdone(t_launch) && break
486-
@async (sleep(1); notify(launch_ntfy))
504+
@async begin
505+
sleep(1)
506+
notify(launch_ntfy)
507+
end
487508
wait(launch_ntfy)
488509
end
489510

@@ -636,7 +657,12 @@ function create_worker(manager, wconfig)
636657
# require the value of config.connect_at which is set only upon connection completion
637658
for jw in PGRP.workers
638659
if (jw.id != 1) && (jw.id < w.id)
639-
(jw.state === W_CREATED) && wait(jw.c_state)
660+
# wait for wl to join
661+
lock(jw.c_state) do
662+
if jw.state === W_CREATED
663+
wait(jw.c_state)
664+
end
665+
end
640666
push!(join_list, jw)
641667
end
642668
end
@@ -659,7 +685,12 @@ function create_worker(manager, wconfig)
659685
end
660686

661687
for wl in wlist
662-
(wl.state === W_CREATED) && wait(wl.c_state)
688+
lock(wl.c_state) do
689+
if wl.state === W_CREATED
690+
# wait for wl to join
691+
wait(wl.c_state)
692+
end
693+
end
663694
push!(join_list, wl)
664695
end
665696
end
@@ -676,7 +707,11 @@ function create_worker(manager, wconfig)
676707
@async manage(w.manager, w.id, w.config, :register)
677708
# wait for rr_ntfy_join with timeout
678709
timedout = false
679-
@async (sleep($timeout); timedout = true; put!(rr_ntfy_join, 1))
710+
@async begin
711+
sleep($timeout)
712+
timedout = true
713+
put!(rr_ntfy_join, 1)
714+
end
680715
wait(rr_ntfy_join)
681716
if timedout
682717
error("worker did not connect within $timeout seconds")

stdlib/Distributed/src/managers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ function launch(manager::SSHManager, params::Dict, launched::Array, launch_ntfy:
163163
# Wait for all launches to complete.
164164
@sync for (i, (machine, cnt)) in enumerate(manager.machines)
165165
let machine=machine, cnt=cnt
166-
@async try
166+
@async try
167167
launch_on_machine(manager, $machine, $cnt, params, launched, launch_ntfy)
168168
catch e
169169
print(stderr, "exception launching on machine $(machine) : $(e)\n")

stdlib/Distributed/test/distributed_exec.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,4 +1696,5 @@ include("splitrange.jl")
16961696
# Run topology tests last after removing all workers, since a given
16971697
# cluster at any time only supports a single topology.
16981698
rmprocs(workers())
1699+
include("threads.jl")
16991700
include("topology.jl")

stdlib/Distributed/test/threads.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
using Test
2+
using Distributed, Base.Threads
3+
using Base.Iterators: product
4+
5+
exeflags = ("--startup-file=no",
6+
"--check-bounds=yes",
7+
"--depwarn=error",
8+
"--threads=2")
9+
10+
function call_on(f, wid, tid)
11+
remotecall(wid) do
12+
t = Task(f)
13+
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid - 1)
14+
schedule(t)
15+
@assert threadid(t) == tid
16+
t
17+
end
18+
end
19+
20+
# Run function on process holding the data to only serialize the result of f.
21+
# This becomes useful for things that cannot be serialized (e.g. running tasks)
22+
# or that would be unnecessarily big if serialized.
23+
fetch_from_owner(f, rr) = remotecall_fetch(f fetch, rr.where, rr)
24+
25+
isdone(rr) = fetch_from_owner(istaskdone, rr)
26+
isfailed(rr) = fetch_from_owner(istaskfailed, rr)
27+
28+
@testset "RemoteChannel allows put!/take! from thread other than 1" begin
29+
ws = ts = product(1:2, 1:2)
30+
@testset "from worker $w1 to $w2 via 1" for (w1, w2) in ws
31+
@testset "from thread $w1.$t1 to $w2.$t2" for (t1, t2) in ts
32+
# We want (the default) lazyness, so that we wait for `Worker.c_state`!
33+
procs_added = addprocs(2; exeflags, lazy=true)
34+
@everywhere procs_added using Base.Threads
35+
36+
p1 = procs_added[w1]
37+
p2 = procs_added[w2]
38+
chan_id = first(procs_added)
39+
chan = RemoteChannel(chan_id)
40+
send = call_on(p1, t1) do
41+
put!(chan, nothing)
42+
end
43+
recv = call_on(p2, t2) do
44+
take!(chan)
45+
end
46+
47+
# Wait on the spawned tasks on the owner
48+
@sync begin
49+
Threads.@spawn fetch_from_owner(wait, recv)
50+
Threads.@spawn fetch_from_owner(wait, send)
51+
end
52+
53+
# Check the tasks
54+
@test isdone(send)
55+
@test isdone(recv)
56+
57+
@test !isfailed(send)
58+
@test !isfailed(recv)
59+
60+
rmprocs(procs_added)
61+
end
62+
end
63+
end

0 commit comments

Comments
 (0)