Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions doreisa/_async_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@ def __init__(self) -> None:
self._data: dict[K, V] = {}
self._new_key_event = asyncio.Event()

def __getitem__(self, key: K) -> V:
return self._data[key]

def __setitem__(self, key: K, value: V):
self._data[key] = value
self._new_key_event.set()
self._new_key_event.clear()

def __getitem__(self, key: K) -> V:
return self._data[key]
def __delitem__(self, key: K):
del self._data[key]

async def wait_for_key(self, key: K) -> V:
while key not in self._data:
Expand All @@ -28,3 +31,6 @@ def __contains__(self, key: K) -> bool:

def __len__(self) -> int:
return len(self._data)

def clear(self):
self._data.clear()
14 changes: 14 additions & 0 deletions doreisa/_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Callable

import ray
import ray.actor
from dask.core import get_dependencies

from doreisa._scheduling_actor import ChunkRef, ScheduledByOtherActor
Expand Down Expand Up @@ -111,6 +112,10 @@ def log(message: str, debug_logs_path: str | None) -> None:

res_ref = scheduling_actors[partition[key]].get_value.remote(graph_id, key)

clear_graph.remote(
[actor for id, actor in enumerate(scheduling_actors) if partitioned_graphs[id]], res_ref, graph_id
)

if kwargs.get("ray_persist"):
if isinstance(keys[0], list):
return [[res_ref]]
Expand All @@ -123,3 +128,12 @@ def log(message: str, debug_logs_path: str | None) -> None:
if isinstance(keys[0], list):
return [[res]]
return [res]


@ray.remote(max_retries=0, num_cpus=0)
def clear_graph(scheduling_actors: list[ray.actor.ActorHandle], res: ray.ObjectRef, graph_id: int) -> None:
# Wait until the result is ready
ray.wait([res], fetch_local=False)

# Clear the graph
ray.get([actor.clear_graph.remote(graph_id) for actor in scheduling_actors])
52 changes: 34 additions & 18 deletions doreisa/_scheduling_actor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import pickle
from dataclasses import dataclass

import numpy as np
Expand All @@ -26,8 +25,10 @@ class ChunkRef:
timestep: Timestep
position: tuple[int, ...]

# Set for one chunk only.
_all_chunks: ray.ObjectRef | None = None
# An ObjectRef containing a dictionary {position: chunk ObjectRef}. The dictionary
# contains all the chunks of the array for this timestep that are owned by this actor.
# None is the array is a preparation array.
all_chunks: ray.ObjectRef | None


@dataclass
Expand Down Expand Up @@ -99,7 +100,8 @@ def __init__(self):
self.chunks_ready_event: asyncio.Event = asyncio.Event()

# {position: chunk}
self.local_chunks: AsyncDict[tuple[int, ...], ray.ObjectRef | bytes] = AsyncDict()
# The chunk is represented by an ObjectRef that directly contains the data.
self.local_chunks: AsyncDict[tuple[int, ...], ray.ObjectRef] = AsyncDict()


class _Array:
Expand Down Expand Up @@ -166,7 +168,7 @@ async def add_chunk(
array_timestep = array.timesteps[timestep]

assert chunk_position not in array_timestep.local_chunks
array_timestep.local_chunks[chunk_position] = self.actor_handle._pack_object_ref.remote(chunk)
array_timestep.local_chunks[chunk_position] = chunk[0]

array.owned_chunks.add((chunk_position, chunk_shape))

Expand All @@ -182,17 +184,14 @@ async def add_chunk(
)
array.is_registered = True

chunks = []
chunks: dict[tuple[int, ...], ray.ObjectRef] = {}
for position, size in array.owned_chunks:
c = array_timestep.local_chunks[position]
assert isinstance(c, ray.ObjectRef)
chunks.append(c)
array_timestep.local_chunks[position] = pickle.dumps(c)
chunks[position] = array_timestep.local_chunks[position]

all_chunks_ref = ray.put(chunks)

await self.head.chunks_ready.options(enable_task_events=False).remote(
array_name, timestep, [all_chunks_ref]
self.actor_id, array_name, timestep, [all_chunks_ref]
)

array_timestep.chunks_ready_event.set()
Expand All @@ -218,14 +217,18 @@ async def schedule_graph(self, graph_id: int, dsk: dict) -> None:
if isinstance(val, ChunkRef):
assert val.actor_id == self.actor_id

array = await self.arrays.wait_for_key(val.array_name)
array_timestep = await array.timesteps.wait_for_key(val.timestep)
ref = await array_timestep.local_chunks.wait_for_key(val.position)

if isinstance(ref, bytes): # This may not be the case depending on the asyncio scheduling order
ref = pickle.loads(ref)
if val.all_chunks is None:
ref = self.actor_handle.get_local_chunk.remote(val.array_name, val.timestep, val.position)
else:
ref = pickle.loads(pickle.dumps(ref)) # To free the memory automatically
# TODO get the dictionnary only once
local_refs = await val.all_chunks

ref = self.actor_handle._pack_object_ref.remote([local_refs[val.position]])

# TODO do we still need to have asyncdicts?
# array = await self.arrays.wait_for_key(val.array_name)
# array_timestep = await array.timesteps.wait_for_key(val.timestep)
# ref = await array_timestep.local_chunks.wait_for_key(val.position)

dsk[key] = ref

Expand All @@ -239,8 +242,21 @@ async def schedule_graph(self, graph_id: int, dsk: dict) -> None:

info.scheduled_event.set()

async def get_local_chunk(self, array_name: str, timestep: Timestep, position: tuple[int, ...]) -> ray.ObjectRef:
array = await self.arrays.wait_for_key(array_name)
array_timestep = await array.timesteps.wait_for_key(timestep)
res = await array_timestep.local_chunks.wait_for_key(position)
return res

async def get_value(self, graph_id: int, key: str):
graph_info = await self.graph_infos.wait_for_key(graph_id)

await graph_info.scheduled_event.wait()
return await graph_info.refs[key]

def clear_graph(self, graph_id: int):
# return # TODO if we add this return, everything works
del self.graph_infos[graph_id]

def clear_array(self, name: str, timestep: int):
del self.arrays[name].timesteps[timestep]
33 changes: 13 additions & 20 deletions doreisa/head_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,9 @@ def __init__(self, definition: ArrayDefinition) -> None:
# Number of scheduling actors owning chunks of this array.
self.nb_scheduling_actors: int | None = None

# Each reference comes from one scheduling actor. The reference a list of
# ObjectRefs, each ObjectRef corresponding to a chunk. These references
# shouldn't be used directly. They exists only to release the memory
# automatically.
# When the array is buit, these references are put in the object store, and the
# global reference is added to the Dask graph. Then, the list is cleared.
self.chunk_refs: dict[Timestep, list[ray.ObjectRef]] = {}
# Each reference comes from one scheduling actor. The reference a dictionary of
# ObjectRefs, each ObjectRef corresponding to a chunk.
self.chunk_refs: dict[Timestep, dict[int, ray.ObjectRef]] = {}

def set_chunk_owner(
self,
Expand Down Expand Up @@ -99,14 +95,14 @@ def set_chunk_owner(
else:
assert self.chunks_size[d][position[d]] == size[d]

def add_chunk_ref(self, chunk_ref: ray.ObjectRef, timestep: Timestep) -> bool:
def add_chunk_ref(self, scheduling_actor_id: int, chunk_ref: ray.ObjectRef, timestep: Timestep) -> bool:
"""
Add a reference sent by a scheduling actor.

Return:
True if all the chunks for this timestep are ready, False otherwise.
"""
self.chunk_refs[timestep].append(chunk_ref)
self.chunk_refs[timestep][scheduling_actor_id] = chunk_ref

# We don't know all the owners yet
if len(self.scheduling_actors_id) != self.nb_chunks:
Expand All @@ -129,12 +125,6 @@ def get_full_array(self, timestep: Timestep, *, is_preparation: bool = False) ->
assert len(self.scheduling_actors_id) == self.nb_chunks
assert self.nb_chunks is not None and self.nb_chunks_per_dim is not None

if is_preparation:
all_chunks = None
else:
all_chunks = ray.put(self.chunk_refs[timestep])
del self.chunk_refs[timestep]

# We need to add the timestep since the same name can be used several times for different
# timesteps
dask_name = f"{self.definition.name}_{timestep}"
Expand All @@ -147,9 +137,9 @@ def get_full_array(self, timestep: Timestep, *, is_preparation: bool = False) ->
self.definition.name,
timestep,
position,
_all_chunks=all_chunks if it == 0 else None,
all_chunks=None if is_preparation else self.chunk_refs[timestep][actor_id],
)
for it, (position, actor_id) in enumerate(self.scheduling_actors_id.items())
for position, actor_id in self.scheduling_actors_id.items()
}

dsk = HighLevelGraph.from_collections(dask_name, graph, dependencies=())
Expand Down Expand Up @@ -268,7 +258,9 @@ def set_owned_chunks(
for position, size in chunks:
array.set_chunk_owner(nb_chunks_per_dim, dtype, position, size, scheduling_actor_id)

async def chunks_ready(self, array_name: str, timestep: Timestep, all_chunks_ref: list[ray.ObjectRef]) -> None:
async def chunks_ready(
self, scheduling_actor_id: int, array_name: str, timestep: Timestep, all_chunks_ref: list[ray.ObjectRef]
) -> None:
"""
Called by the scheduling actors to inform the head actor that the chunks are ready.
The chunks are not sent.
Expand All @@ -293,12 +285,12 @@ async def chunks_ready(self, array_name: str, timestep: Timestep, all_chunks_ref
# The array was already created by another scheduling actor
self.new_pending_array_semaphore.release()
else:
array.chunk_refs[timestep] = []
array.chunk_refs[timestep] = {}

self.new_array_created.set()
self.new_array_created.clear()

is_ready = array.add_chunk_ref(all_chunks_ref[0], timestep)
is_ready = array.add_chunk_ref(scheduling_actor_id, all_chunks_ref[0], timestep)

if is_ready:
self.arrays_ready.put_nowait(
Expand All @@ -309,6 +301,7 @@ async def chunks_ready(self, array_name: str, timestep: Timestep, all_chunks_ref
)
)
array.fully_defined.set()
del array.chunk_refs[timestep]

async def get_next_array(self) -> tuple[str, Timestep, da.Array]:
array = await self.arrays_ready.get()
Expand Down
7 changes: 7 additions & 0 deletions doreisa/window_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import dask
import dask.array as da
import ray
import ray.actor

from doreisa._scheduler import doreisa_get
from doreisa.head_node import ArrayDefinition as HeadArrayDefinition
Expand Down Expand Up @@ -110,6 +111,12 @@ def run_simulation(
if older_timestep >= 0:
del arrays_by_iteration[older_timestep][description.name]

# ray.get(head.clear_array.remote(description.name, older_timestep))

# TODO not here, only once when the actors are ready
scheduling_actors: list[ray.actor.ActorHandle] = ray.get(head.list_scheduling_actors.remote())
ray.get([actor.clear_array.remote(description.name, older_timestep) for actor in scheduling_actors])

if not arrays_by_iteration[older_timestep]:
del arrays_by_iteration[older_timestep]

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ requires-python = ">=3.12"
readme = "README.md"
dependencies = [
"dask[dataframe] (==2024.6.0)",
"ray[default] (>=2.46.0,<3.0.0)",
"ray[default] (==2.47.0,<3.0.0)",
"numpy (==1.26.4)", # TODO this was pinned for PDI, remove the pinning?
]

Expand Down
Loading
Loading