Skip to content

Commit fb5600a

Browse files
committed
Call remote ray functions using .remote()
1 parent be47cfb commit fb5600a

File tree

5 files changed

+62
-34
lines changed

5 files changed

+62
-34
lines changed

charm4py/chare.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ def array_proxy_elem(proxy, idx): # array proxy [] overload method
718718
assert _slice.start is not None and _slice.stop is not None, 'Must specify start and stop indexes for array slicing'
719719
return charm.split(proxy, 1, slicing=idx)[0]
720720

721-
def array_proxy_method_gen(ep, argcount, argnames, defaults, is_ray): # decorator, generates proxy entry methods
721+
def array_proxy_method_gen(ep, argcount, argnames, defaults): # decorator, generates proxy entry methods
722722
def proxy_entry_method(proxy, *args, **kwargs):
723723
num_args = len(args)
724724
if num_args < argcount and len(kwargs) > 0:
@@ -735,6 +735,7 @@ def proxy_entry_method(proxy, *args, **kwargs):
735735
args.append(defaults[def_idx])
736736

737737
header = {}
738+
is_ray = kwargs.pop('is_ray', False)
738739
header['is_ray'] = is_ray
739740
blockFuture = None
740741
elemIdx = proxy.elemIdx
@@ -771,9 +772,9 @@ def proxy_entry_method(proxy, *args, **kwargs):
771772
proxy_entry_method.ep = ep
772773
return proxy_entry_method
773774

774-
def array_ckNew_gen(C, epIdx, is_ray):
775+
def array_ckNew_gen(C, epIdx):
775776
@classmethod # make ckNew a class (not instance) method of proxy
776-
def array_ckNew(cls, dims=None, ndims=-1, args=[], map=None, useAtSync=False):
777+
def array_ckNew(cls, dims=None, ndims=-1, args=[], map=None, useAtSync=False, is_ray=False):
777778
# if charm.myPe() == 0: print("calling array ckNew for class " + C.__name__ + " cIdx=" + str(C.idx[ARRAY]))
778779
if type(dims) == int: dims = (dims,)
779780

@@ -808,8 +809,8 @@ def array_ckNew(cls, dims=None, ndims=-1, args=[], map=None, useAtSync=False):
808809
return proxy
809810
return array_ckNew
810811

811-
def array_ckInsert_gen(epIdx, is_ray):
812-
def array_ckInsert(proxy, index, args=[], onPE=-1, useAtSync=False, single=False):
812+
def array_ckInsert_gen(epIdx):
813+
def array_ckInsert(proxy, index, args=[], onPE=-1, useAtSync=False, single=False, is_ray=False):
813814
if type(index) == int: index = (index,)
814815
assert len(index) == proxy.ndims, 'Invalid index dimensions passed to ckInsert'
815816
header = {}
@@ -880,19 +881,19 @@ def __getProxyClass__(C, cls, sectionProxy=False):
880881
continue
881882
argcount, argnames, defaults = getEntryMethodInfo(m.C, m.name)
882883
if Options.profiling:
883-
f = profile_send_function(array_proxy_method_gen(m.epIdx, argcount, argnames, defaults, hasattr(cls, 'is_ray')))
884+
f = profile_send_function(array_proxy_method_gen(m.epIdx, argcount, argnames, defaults))
884885
else:
885-
f = array_proxy_method_gen(m.epIdx, argcount, argnames, defaults, hasattr(cls, 'is_ray'))
886+
f = array_proxy_method_gen(m.epIdx, argcount, argnames, defaults)
886887
f.__qualname__ = proxyClassName + '.' + m.name
887888
f.__name__ = m.name
888889
M[m.name] = f
889890
M['__init__'] = array_proxy_ctor
890891
M['__getitem__'] = array_proxy_elem
891892
M['__eq__'] = array_proxy__eq__
892893
M['__hash__'] = array_proxy__hash__
893-
M['ckNew'] = array_ckNew_gen(cls, entryMethods[0].epIdx, hasattr(cls, 'is_ray'))
894+
M['ckNew'] = array_ckNew_gen(cls, entryMethods[0].epIdx)
894895
M['__getsecproxy__'] = array_getsecproxy
895-
M['ckInsert'] = array_ckInsert_gen(entryMethods[0].epIdx, hasattr(cls, 'is_ray'))
896+
M['ckInsert'] = array_ckInsert_gen(entryMethods[0].epIdx)
896897
M['ckDoneInserting'] = array_proxy_doneInserting
897898
if not sectionProxy:
898899
M['ckContribute'] = array_proxy_contribute # function called when target proxy is Array
@@ -904,8 +905,6 @@ def __getProxyClass__(C, cls, sectionProxy=False):
904905
M['__setstate__'] = arraysecproxy__setstate__
905906
proxyCls = type(proxyClassName, (), M) # create and return proxy class
906907
proxyCls.issec = sectionProxy
907-
if hasattr(cls, 'is_ray'):
908-
proxyCls.is_ray = True
909908
return proxyCls
910909

911910
# ---------------------------------------------------

charm4py/pool.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self, tasks, result_dest):
3030

3131
class Job(object):
3232

33-
def __init__(self, id, func, tasks, result, ncores, chunksize):
33+
def __init__(self, id, func, tasks, result, ncores, chunksize, is_ray=False):
3434
self.id = id
3535
self.max_cores = ncores
3636
self.n_avail = ncores
@@ -40,6 +40,7 @@ def __init__(self, id, func, tasks, result, ncores, chunksize):
4040
self.threaded = False
4141
self.failed = False
4242
self.single_task = False
43+
self.is_ray = is_ray
4344
assert chunksize > 0
4445
if func is not None:
4546
self.threaded = hasattr(func, '_ck_coro')
@@ -106,7 +107,7 @@ def __start__(self, func, tasks, result):
106107
print('Initializing charm.pool with', self.num_workers, 'worker PEs. '
107108
'Warning: charm.pool is experimental (API and performance '
108109
'is subject to change)')
109-
self.workers = Group(Worker, args=[self.thisProxy])
110+
self.workers = Array(Worker, charm.numPes(), args=[self.thisProxy])
110111

111112
if len(self.job_id_pool) == 0:
112113
oldSize = len(self.jobs)
@@ -147,7 +148,7 @@ def startSingleTask(self, func, future, *args):
147148
job.remote = self.workers.runTask_star
148149
self.schedule()
149150

150-
def start(self, func, tasks, result, ncores, chunksize):
151+
def start(self, func, tasks, result, ncores, chunksize, is_ray=False):
151152
assert ncores != 0
152153
if ncores < 0:
153154
ncores = self.num_workers
@@ -158,7 +159,7 @@ def start(self, func, tasks, result, ncores, chunksize):
158159

159160
self.__start__(func, tasks, result)
160161

161-
job = Job(self.job_id_pool.pop(), func, tasks, result, ncores, chunksize)
162+
job = Job(self.job_id_pool.pop(), func, tasks, result, ncores, chunksize, is_ray=is_ray)
162163
self.__addJob__(job)
163164

164165
if job.chunked:
@@ -210,12 +211,15 @@ def schedule(self):
210211
func = task.func
211212
# NOTE: this is a non-standard way of using proxies, but is
212213
# faster and allows the scheduler to reuse the same proxy
213-
self.workers.elemIdx = worker_id
214+
if not isinstance(worker_id, tuple):
215+
self.workers.elemIdx = (worker_id,)
216+
else:
217+
self.workers.elemIdx = worker_id
214218

215219
if isinstance(task.data, tuple):
216-
job.remote(func, [task.result_dest], job.id, *task.data)
220+
job.remote(func, [task.result_dest], job.id, *task.data, is_ray=job.is_ray)
217221
else:
218-
job.remote(func, [task.result_dest], job.id, task.data)
222+
job.remote(func, [task.result_dest], job.id, task.data, is_ray=job.is_ray)
219223

220224
if len(job.tasks) == 0:
221225
prev.job_next = job.job_next
@@ -234,6 +238,7 @@ def schedule(self):
234238
job = prev.job_next
235239

236240
def taskFinished(self, worker_id, job_id, result=None):
241+
#print('Job finished')
237242
job = self.jobs[job_id]
238243
if job.failed:
239244
return self.taskError(worker_id, job_id, job.exception)
@@ -452,7 +457,7 @@ def Task(self, func, args, ret=False, awaitable=False):
452457
def map(self, func, iterable, chunksize=1, ncores=-1, is_ray=False):
453458
result = Future(store=is_ray)
454459
# TODO shouldn't send task objects to a central place. what if they are large?
455-
self.pool_scheduler.start(func, iterable, result, ncores, chunksize)
460+
self.pool_scheduler.start(func, iterable, result, ncores, chunksize, is_ray=is_ray)
456461
return result.get()
457462

458463
def map_async(self, func, iterable, chunksize=1, ncores=-1, multi_future=False, is_ray=False):
@@ -464,7 +469,7 @@ def map_async(self, func, iterable, chunksize=1, ncores=-1, multi_future=False,
464469
result = [Future(store=is_ray) for _ in range(len(iterable))]
465470
else:
466471
result = Future(store=is_ray)
467-
self.pool_scheduler.start(func, iterable, result, ncores, chunksize)
472+
self.pool_scheduler.start(func, iterable, result, ncores, chunksize, is_ray=is_ray)
468473
return result
469474

470475
# iterable is a sequence of (function, args) tuples

charm4py/ray/api.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,31 @@ def get_object_store():
1515
global object_store
1616
return object_store
1717

18+
class RayProxyFunction(object):
19+
def __init__(self, func):
20+
self.func = func
21+
22+
def __call__(self, *args, **kwargs):
23+
raise RuntimeError("Cannot call remote function without .remote()")
24+
25+
def remote(self, *args, **kwargs):
26+
return self.func(*args, **kwargs)
27+
28+
29+
class RayProxy(object):
30+
def __init__(self, subclass, args, pe):
31+
from charm4py import Chare, register, charm
32+
self.proxy = Chare(subclass, args=args, onPE=pe)
33+
for f in dir(self.proxy):
34+
if not f.startswith('__'):
35+
setattr(self, f, RayProxyFunction(self.remote_function(f)))
36+
37+
def remote_function(self, f):
38+
proxy_func = getattr(self.proxy, f)
39+
def call_remote(*args, **kwargs):
40+
return proxy_func(*args, **kwargs, is_ray=True)
41+
return call_remote
42+
1843

1944
def get_ray_class(subclass):
2045
from charm4py import Chare, register, charm
@@ -23,9 +48,9 @@ class RayChare(Chare):
2348
@staticmethod
2449
def remote(*a):
2550
global counter
26-
chare = Chare(subclass, args=a, onPE=counter % charm.numPes())
51+
ray_proxy = RayProxy(subclass, a, counter % charm.numPes())
2752
counter += 1
28-
return chare
53+
return ray_proxy
2954
return RayChare
3055

3156
def get_ray_task(func):
@@ -46,7 +71,6 @@ def remote(*args, **kwargs):
4671
else:
4772
# decorating without any arguments
4873
subclass = type(args[0].__name__, (Chare, args[0]), {"__init__": args[0].__init__})
49-
subclass.is_ray = True
5074
register(subclass)
5175
rayclass = get_ray_class(subclass)
5276
rayclass.__name__ = args[0].__name__

examples/ray/parameter_server.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,11 @@ def sync_train(args):
136136
ps = ParameterServer.remote(1e-2)
137137
workers = [DataWorker.remote() for i in range(num_workers)]
138138

139-
current_weights = ps.get_weights()
139+
current_weights = ps.get_weights.remote()
140140
for i in range(iterations):
141-
gradients = [worker.compute_gradients(current_weights) for worker in workers]
141+
gradients = [worker.compute_gradients.remote(current_weights) for worker in workers]
142142
# Calculate update after all gradients are available.
143-
current_weights = ps.apply_gradients(*gradients)
143+
current_weights = ps.apply_gradients.remote(*gradients)
144144

145145
if i % 10 == 0:
146146
# Evaluate the current model.
@@ -167,19 +167,19 @@ def async_train(args):
167167
ps = ParameterServer.remote(1e-2)
168168
workers = [DataWorker.remote() for i in range(num_workers)]
169169

170-
current_weights = ps.get_weights()
170+
current_weights = ps.get_weights.remote()
171171
gradients = {}
172172
for worker in workers:
173-
gradients[worker.compute_gradients(current_weights)] = worker
173+
gradients[worker.compute_gradients.remote(current_weights)] = worker
174174

175175
for i in range(iterations * num_workers):
176176
ready_gradient_list, _ = ray.wait(list(gradients))
177177
ready_gradient_id = ready_gradient_list[0]
178178
worker = gradients.pop(ready_gradient_id)
179179

180180
# Compute and apply gradients.
181-
current_weights = ps.apply_gradients(*[ready_gradient_id])
182-
gradients[worker.compute_gradients(current_weights)] = worker
181+
current_weights = ps.apply_gradients.remote(*[ready_gradient_id])
182+
gradients[worker.compute_gradients.remote(current_weights)] = worker
183183

184184
if i % 10 == 0:
185185
# Evaluate the current model after every 10 updates.

examples/ray/simple.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ def main(args):
2727
# create 3 instances of MyChare, distributed among cores by the runtime
2828
arr = [Compute.remote(i) for i in range(4)]
2929

30-
c = arr[0].add(1, 2) # fut id 0
31-
d = arr[1].add(3, c) # fut id 1
32-
e = arr[2].add(2, d)
33-
f = arr[3].add(c, 4)
30+
c = arr[0].add.remote(1, 2) # fut id 0
31+
d = arr[1].add.remote(3, c) # fut id 1
32+
e = arr[2].add.remote(2, d)
33+
f = arr[3].add.remote(c, 4)
3434
g = add_task.remote(e, f)
3535

3636
not_ready = [c, d, e, f, g]

0 commit comments

Comments
 (0)