Skip to content

Commit 564be3b

Browse files
tchatonthomas
andauthored
Streaming Dataset: Resolve chunks eviction (#19214)
Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
1 parent e040ef2 commit 564be3b

File tree

2 files changed

+37
-36
lines changed

2 files changed

+37
-36
lines changed

src/lightning/data/streaming/reader.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868

6969
# FIXME: This should be divided by the number of nodes to provide a more granular support with scaling out
7070
self._delete_chunks_when_processed = self._config.num_bytes > max_cache_size if max_cache_size else False
71+
self._has_exited = False
7172

7273
def download(self, chunk_indexes: List[int]) -> None:
7374
"""Receive the list of the chunk indices to download for the current epoch."""
@@ -111,7 +112,7 @@ def _maybe_delete_chunks(self) -> None:
111112

112113
def _can_delete_chunk(self) -> bool:
113114
if self._delete_chunks_when_processed:
114-
return self._pre_download_counter == self._max_pre_download - 1
115+
return self._pre_download_counter >= self._max_pre_download - 1
115116
return self._max_cache_size is not None and _get_folder_size(self._parent_cache_dir) >= self._max_cache_size
116117

117118
def _pre_load_chunk(self, chunk_index: int) -> None:
@@ -120,9 +121,10 @@ def _pre_load_chunk(self, chunk_index: int) -> None:
120121

121122
def run(self) -> None:
122123
while True:
123-
if self._pre_download_counter <= self._max_pre_download:
124+
if self._pre_download_counter < self._max_pre_download:
124125
chunk_index = _get_from_queue(self._to_download_queue)
125126
if chunk_index == _END_TOKEN:
127+
self._has_exited = True
126128
return
127129

128130
if chunk_index is not None:
Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import os
22
import shutil
3+
from time import sleep
34

45
import numpy as np
6+
from lightning.data.streaming import reader
57
from lightning.data.streaming.cache import Cache
68
from lightning.data.streaming.config import ChunkedIndex
79
from lightning.data.streaming.item_loader import PyTreeLoader
8-
from lightning.data.streaming.reader import PrepareChunksThread, _get_folder_size
10+
from lightning.data.streaming.reader import _END_TOKEN, PrepareChunksThread, _get_folder_size
911
from lightning_cloud.resolver import Dir
1012

1113

@@ -36,40 +38,11 @@ def test_reader_chunk_removal(tmpdir):
3638
shutil.rmtree(cache_dir)
3739
os.makedirs(cache_dir, exist_ok=True)
3840

39-
generated = []
4041
for i in range(25):
41-
generated.append([i, len(os.listdir(cache_dir))])
42+
assert len(os.listdir(cache_dir)) <= 3
4243
index = ChunkedIndex(i, cache._get_chunk_index_from_index(i), is_last_index=i == 24)
4344
assert cache[index] == i
4445

45-
assert generated == [
46-
[0, 0],
47-
[1, 2],
48-
[2, 2],
49-
[3, 3],
50-
[4, 3],
51-
[5, 3],
52-
[6, 3],
53-
[7, 3],
54-
[8, 3],
55-
[9, 3],
56-
[10, 3],
57-
[11, 3],
58-
[12, 3],
59-
[13, 3],
60-
[14, 3],
61-
[15, 3],
62-
[16, 3],
63-
[17, 3],
64-
[18, 3],
65-
[19, 3],
66-
[20, 3],
67-
[21, 3],
68-
[22, 3],
69-
[23, 3],
70-
[24, 3],
71-
]
72-
7346
assert len(os.listdir(cache_dir)) == 3
7447

7548

@@ -82,7 +55,9 @@ def test_get_folder_size(tmpdir):
8255
assert _get_folder_size(tmpdir) == 928 * 2
8356

8457

85-
def test_prepare_chunks_thread(tmpdir):
58+
def test_prepare_chunks_thread_eviction(tmpdir, monkeypatch):
59+
monkeypatch.setattr(reader, "_LONG_DEFAULT_TIMEOUT", 0.1)
60+
8661
cache_dir = os.path.join(tmpdir, "cache_dir")
8762
os.makedirs(cache_dir, exist_ok=True)
8863
cache = Cache(input_dir=cache_dir, chunk_size=2, max_cache_size=28020)
@@ -95,8 +70,32 @@ def test_prepare_chunks_thread(tmpdir):
9570

9671
cache._reader._try_load_config()
9772

98-
thread = PrepareChunksThread(cache._reader.config, item_loader=PyTreeLoader(), max_cache_size=1)
99-
assert thread._delete_chunks_when_processed
73+
assert len(os.listdir(cache_dir)) == 14
10074

10175
thread = PrepareChunksThread(cache._reader.config, item_loader=PyTreeLoader(), max_cache_size=10000)
10276
assert not thread._delete_chunks_when_processed
77+
78+
thread = PrepareChunksThread(cache._reader.config, item_loader=PyTreeLoader(), max_cache_size=1)
79+
assert thread._delete_chunks_when_processed
80+
81+
thread.start()
82+
83+
assert thread._pre_download_counter == 0
84+
85+
thread.download([0, 1, 2, 3, 4, 5, _END_TOKEN])
86+
87+
while thread._pre_download_counter == 0:
88+
sleep(0.01)
89+
90+
assert not thread._has_exited
91+
92+
for i in range(5):
93+
thread.delete([i])
94+
while len(os.listdir(cache_dir)) != 14 - (i + 1):
95+
sleep(0.01)
96+
97+
assert thread._pre_download_counter <= 2
98+
99+
assert len(os.listdir(cache_dir)) == 9
100+
assert thread._has_exited
101+
thread.join()

0 commit comments

Comments
 (0)