Skip to content

Commit 5341540

Browse files
committed
Refactor Dataset.map to reuse cache files mapped with different num_proc
Fixes huggingface#7433 This refactor unifies num_proc is None or num_proc == 1 and num_proc > 1; instead of handling them completely separately where one uses a list of kwargs and shards and the other just uses a single set of kwargs and self, by wrapping the num_proc == 1 case in a list and making the difference just whether or not you use a pool, you set up either case to be able to load each other cache_files just by changing num_shards; num_proc == 1 can sequentially load the shards of a dataset mapped num_shards > 1 and sequentially map any missing shards Other than the structural refactor, the main contribution of this PR is get_existing_cache_file_map, which uses a regex of cache_file_name and suffix_template to find existing cache files, grouped by their num_shards; using this data structure, we can reset num_shards to an existing set of cache files, and load them accordingly
1 parent 14233c0 commit 5341540

File tree

2 files changed

+270
-107
lines changed

2 files changed

+270
-107
lines changed

src/datasets/arrow_dataset.py

+185-107
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import contextlib
2020
import copy
2121
import fnmatch
22+
import glob
2223
import inspect
2324
import itertools
2425
import json
@@ -27,12 +28,13 @@
2728
import posixpath
2829
import re
2930
import shutil
31+
import string
3032
import sys
3133
import tempfile
3234
import time
3335
import warnings
3436
import weakref
35-
from collections import Counter
37+
from collections import Counter, defaultdict
3638
from collections.abc import Mapping
3739
from copy import deepcopy
3840
from functools import partial, wraps
@@ -2964,6 +2966,11 @@ def map(
29642966
if num_proc is not None and num_proc <= 0:
29652967
raise ValueError("num_proc must be an integer > 0.")
29662968

2969+
string_formatter = string.Formatter()
2970+
fields = {field_name for _, field_name, _, _ in string_formatter.parse(suffix_template) if field_name}
2971+
if fields != {"rank", "num_proc"}:
2972+
raise ValueError(f"suffix_template must contain exactly the fields 'rank' and 'num_proc', got: {fields}")
2973+
29672974
# If the array is empty we do nothing (but we make sure to handle an empty indices mapping and remove the requested columns anyway)
29682975
if len(self) == 0:
29692976
if self._indices is not None: # empty indices mapping
@@ -3045,7 +3052,7 @@ def map(
30453052
cache_file_name = self._get_cache_file_path(new_fingerprint)
30463053
dataset_kwargs["cache_file_name"] = cache_file_name
30473054

3048-
def load_processed_shard_from_cache(shard_kwargs):
3055+
def load_processed_shard_from_cache(shard_kwargs: dict[str, Any]) -> Dataset:
30493056
"""Load a processed shard from cache if it exists, otherwise throw an error."""
30503057
shard = shard_kwargs["shard"]
30513058
# Check if we've already cached this computation (indexed by a hash)
@@ -3056,64 +3063,98 @@ def load_processed_shard_from_cache(shard_kwargs):
30563063
return Dataset.from_file(shard_kwargs["cache_file_name"], info=info, split=shard.split)
30573064
raise NonExistentDatasetError
30583065

3059-
num_shards = num_proc if num_proc is not None else 1
3060-
if batched and drop_last_batch:
3061-
pbar_total = len(self) // num_shards // batch_size * num_shards * batch_size
3062-
else:
3063-
pbar_total = len(self)
3066+
def pbar_total(num_shards: int, batch_size: Optional[int]) -> int:
3067+
total = len(self)
3068+
if len(existing_cache_files) < num_shards:
3069+
total -= len(existing_cache_files) * total // num_shards
3070+
if batched and drop_last_batch:
3071+
batch_size = batch_size or 1
3072+
return total // num_shards // batch_size * num_shards * batch_size
3073+
return total
3074+
3075+
def get_existing_cache_file_map(
3076+
cache_file_name: Optional[str],
3077+
) -> dict[int, list[str]]:
3078+
cache_files_by_num_proc: dict[int, list[str]] = defaultdict(list)
3079+
if cache_file_name is None:
3080+
return cache_files_by_num_proc
3081+
if os.path.exists(cache_file_name):
3082+
cache_files_by_num_proc[1] = [cache_file_name]
3083+
3084+
suffix_pattern_parts: list[str] = []
3085+
for literal_text, field_name, format_spec, _ in string_formatter.parse(suffix_template):
3086+
suffix_pattern_parts.append(re.escape(literal_text))
3087+
if field_name:
3088+
# TODO: we may want to place restrictions on acceptable format_spec or we will fail to match
3089+
# someone's hexidecimal or scientific notation format 😵
3090+
suffix_pattern_parts.append(f"(?P<{field_name}>\\d+)")
3091+
suffix_pattern = "".join(suffix_pattern_parts)
3092+
3093+
cache_file_prefix, cache_file_ext = os.path.splitext(cache_file_name)
3094+
if not cache_file_ext:
3095+
raise ValueError(f"Expected cache_file_name to have an extension, but got: {cache_file_name}")
3096+
3097+
cache_file_pattern = "^" + re.escape(cache_file_prefix) + suffix_pattern + re.escape(cache_file_ext) + "$"
3098+
cache_file_regex = re.compile(cache_file_pattern)
3099+
3100+
for cache_file in glob.iglob(f"{cache_file_prefix}*{cache_file_ext}"):
3101+
if m := cache_file_regex.match(cache_file):
3102+
file_num_proc = int(m.group("num_proc"))
3103+
cache_files_by_num_proc[file_num_proc].append(cache_file)
3104+
3105+
return cache_files_by_num_proc
3106+
3107+
existing_cache_file_map = get_existing_cache_file_map(cache_file_name)
3108+
3109+
num_shards = num_proc or 1
3110+
if existing_cache_file_map:
3111+
# to avoid remapping when a different num_proc is given than when originally cached, update num_shards to
3112+
# what was used originally
3113+
3114+
def select_existing_cache_files(mapped_num_proc: int) -> tuple[float, ...]:
3115+
percent_missing = (mapped_num_proc - len(existing_cache_file_map[mapped_num_proc])) / mapped_num_proc
3116+
num_shards_diff = abs(mapped_num_proc - num_shards)
3117+
return (
3118+
percent_missing, # choose the most complete set of existing cache files
3119+
num_shards_diff, # then choose the mapped_num_proc closest to the current num_proc
3120+
mapped_num_proc, # finally, choose whichever mapped_num_proc is lower
3121+
)
30643122

3065-
shards_done = 0
3066-
if num_proc is None or num_proc == 1:
3067-
transformed_dataset = None
3068-
try:
3069-
transformed_dataset = load_processed_shard_from_cache(dataset_kwargs)
3070-
logger.info(f"Loading cached processed dataset at {dataset_kwargs['cache_file_name']}")
3071-
except NonExistentDatasetError:
3072-
pass
3073-
if transformed_dataset is None:
3074-
with hf_tqdm(
3075-
unit=" examples",
3076-
total=pbar_total,
3077-
desc=desc or "Map",
3078-
) as pbar:
3079-
for rank, done, content in Dataset._map_single(**dataset_kwargs):
3080-
if done:
3081-
shards_done += 1
3082-
logger.debug(f"Finished processing shard number {rank} of {num_shards}.")
3083-
transformed_dataset = content
3084-
else:
3085-
pbar.update(content)
3086-
assert transformed_dataset is not None, "Failed to retrieve the result from map"
3087-
# update fingerprint if the dataset changed
3088-
if transformed_dataset._fingerprint != self._fingerprint:
3089-
transformed_dataset._fingerprint = new_fingerprint
3090-
return transformed_dataset
3091-
else:
3123+
num_shards = min(existing_cache_file_map, key=select_existing_cache_files)
30923124

3093-
def format_cache_file_name(
3094-
cache_file_name: Optional[str],
3095-
rank: Union[int, Literal["*"]], # noqa: F722
3096-
) -> Optional[str]:
3097-
if not cache_file_name:
3098-
return cache_file_name
3099-
sep = cache_file_name.rindex(".")
3100-
base_name, extension = cache_file_name[:sep], cache_file_name[sep:]
3101-
if isinstance(rank, int):
3102-
cache_file_name = base_name + suffix_template.format(rank=rank, num_proc=num_proc) + extension
3103-
logger.info(f"Process #{rank} will write at {cache_file_name}")
3104-
else:
3105-
cache_file_name = (
3106-
base_name
3107-
+ suffix_template.replace("{rank:05d}", "{rank}").format(rank=rank, num_proc=num_proc)
3108-
+ extension
3109-
)
3125+
existing_cache_files = existing_cache_file_map.get(num_shards, [])
3126+
3127+
def format_cache_file_name(
3128+
cache_file_name: Optional[str],
3129+
rank: Union[int, Literal["*"]], # noqa: F722
3130+
) -> Optional[str]:
3131+
if not cache_file_name:
31103132
return cache_file_name
31113133

3112-
def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
3113-
new_fingerprint = new_fingerprint + suffix_template.format(rank=rank, num_proc=num_proc)
3114-
validate_fingerprint(new_fingerprint)
3115-
return new_fingerprint
3134+
cache_file_prefix, cache_file_ext = os.path.splitext(cache_file_name)
3135+
if not cache_file_ext:
3136+
raise ValueError(f"Expected cache_file_name to have an extension, but got: {cache_file_name}")
3137+
3138+
if isinstance(rank, int):
3139+
cache_file_name = (
3140+
cache_file_prefix + suffix_template.format(rank=rank, num_proc=num_shards) + cache_file_ext
3141+
)
3142+
logger.info(f"Process #{rank} will write at {cache_file_name}")
3143+
else:
3144+
# TODO: this assumes the format_spec of rank in suffix_template
3145+
cache_file_name = (
3146+
cache_file_prefix
3147+
+ suffix_template.replace("{rank:05d}", "{rank}").format(rank=rank, num_proc=num_shards)
3148+
+ cache_file_ext
3149+
)
3150+
return cache_file_name
3151+
3152+
def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
3153+
new_fingerprint = new_fingerprint + suffix_template.format(rank=rank, num_proc=num_shards)
3154+
validate_fingerprint(new_fingerprint)
3155+
return new_fingerprint
31163156

3157+
if num_proc is not None and num_proc > 1:
31173158
prev_env = deepcopy(os.environ)
31183159
# check if parallelism if off
31193160
# from https://github.com/huggingface/tokenizers/blob/bb668bc439dc34389b71dbb8ce0c597f15707b53/tokenizers/src/utils/parallelism.rs#L22
@@ -3128,9 +3169,17 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
31283169
):
31293170
logger.warning("Setting TOKENIZERS_PARALLELISM=false for forked processes.")
31303171
os.environ["TOKENIZERS_PARALLELISM"] = "false"
3172+
else:
3173+
prev_env = os.environ
3174+
3175+
kwargs_per_job: list[Optional[dict[str, Any]]]
3176+
if num_shards == 1:
3177+
shards = [self]
3178+
kwargs_per_job = [dataset_kwargs]
3179+
else:
31313180
shards = [
3132-
self.shard(num_shards=num_proc, index=rank, contiguous=True, keep_in_memory=keep_in_memory)
3133-
for rank in range(num_proc)
3181+
self.shard(num_shards=num_shards, index=rank, contiguous=True, keep_in_memory=keep_in_memory)
3182+
for rank in range(num_shards)
31343183
]
31353184
kwargs_per_job = [
31363185
{
@@ -3144,60 +3193,89 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
31443193
for rank in range(num_shards)
31453194
]
31463195

3147-
transformed_shards = [None] * num_shards
3148-
for rank in range(num_shards):
3149-
try:
3150-
transformed_shards[rank] = load_processed_shard_from_cache(kwargs_per_job[rank])
3151-
kwargs_per_job[rank] = None
3152-
except NonExistentDatasetError:
3153-
pass
3154-
3155-
kwargs_per_job = [kwargs for kwargs in kwargs_per_job if kwargs is not None]
3156-
3157-
# We try to create a pool with as many workers as dataset not yet cached.
3158-
if kwargs_per_job:
3159-
if len(kwargs_per_job) < num_shards:
3160-
logger.info(
3161-
f"Reprocessing {len(kwargs_per_job)}/{num_shards} shards because some of them were missing from the cache."
3162-
)
3163-
with Pool(len(kwargs_per_job)) as pool:
3164-
os.environ = prev_env
3165-
logger.info(f"Spawning {num_proc} processes")
3166-
with hf_tqdm(
3167-
unit=" examples",
3168-
total=pbar_total,
3169-
desc=(desc or "Map") + f" (num_proc={num_proc})",
3170-
) as pbar:
3196+
transformed_shards: list[Optional[Dataset]] = [None] * num_shards
3197+
for rank in range(num_shards):
3198+
try:
3199+
job_kwargs = kwargs_per_job[rank]
3200+
assert job_kwargs is not None
3201+
transformed_shards[rank] = load_processed_shard_from_cache(job_kwargs)
3202+
kwargs_per_job[rank] = None
3203+
except NonExistentDatasetError:
3204+
pass
3205+
3206+
if unprocessed_kwargs_per_job := [kwargs for kwargs in kwargs_per_job if kwargs is not None]:
3207+
if len(unprocessed_kwargs_per_job) < num_shards:
3208+
logger.info(
3209+
f"Reprocessing {len(unprocessed_kwargs_per_job)}/{num_shards} shards because some of them were "
3210+
" missing from the cache."
3211+
)
3212+
3213+
with hf_tqdm(
3214+
unit=" examples",
3215+
total=pbar_total(num_shards, batch_size),
3216+
desc=(desc or "Map") + (f" (num_proc={num_proc})" if num_proc is not None and num_proc > 1 else ""),
3217+
) as pbar:
3218+
shards_done = 0
3219+
3220+
def check_if_shard_done(rank: Optional[int], done: bool, content: Union[Dataset, int]) -> None:
3221+
nonlocal shards_done
3222+
if done:
3223+
shards_done += 1
3224+
logger.debug(f"Finished processing shard number {rank} of {num_shards}.")
3225+
assert isinstance(content, Dataset)
3226+
transformed_shards[rank or 0] = content
3227+
else:
3228+
assert isinstance(content, int)
3229+
pbar.update(content)
3230+
3231+
if num_proc is not None and num_proc > 1:
3232+
with Pool(num_proc) as pool:
3233+
os.environ = prev_env
3234+
logger.info(f"Spawning {num_proc} processes")
3235+
31713236
for rank, done, content in iflatmap_unordered(
3172-
pool, Dataset._map_single, kwargs_iterable=kwargs_per_job
3237+
pool, Dataset._map_single, kwargs_iterable=unprocessed_kwargs_per_job
31733238
):
3174-
if done:
3175-
shards_done += 1
3176-
logger.debug(f"Finished processing shard number {rank} of {num_shards}.")
3177-
transformed_shards[rank] = content
3178-
else:
3179-
pbar.update(content)
3180-
pool.close()
3181-
pool.join()
3182-
# Avoids PermissionError on Windows (the error: https://github.com/huggingface/datasets/actions/runs/4026734820/jobs/6921621805)
3183-
for kwargs in kwargs_per_job:
3184-
del kwargs["shard"]
3185-
else:
3186-
logger.info(f"Loading cached processed dataset at {format_cache_file_name(cache_file_name, '*')}")
3187-
assert None not in transformed_shards, (
3188-
f"Failed to retrieve results from map: result list {transformed_shards} still contains None - at least one worker failed to return its results"
3239+
check_if_shard_done(rank, done, content)
3240+
3241+
pool.close()
3242+
pool.join()
3243+
else:
3244+
for unprocessed_kwargs in unprocessed_kwargs_per_job:
3245+
for rank, done, content in Dataset._map_single(**unprocessed_kwargs):
3246+
check_if_shard_done(rank, done, content)
3247+
3248+
# Avoids PermissionError on Windows (the error: https://github.com/huggingface/datasets/actions/runs/4026734820/jobs/6921621805)
3249+
for job_kwargs in unprocessed_kwargs_per_job:
3250+
if "shard" in job_kwargs:
3251+
del job_kwargs["shard"]
3252+
else:
3253+
logger.info(f"Loading cached processed dataset at {format_cache_file_name(cache_file_name, '*')}")
3254+
3255+
all_transformed_shards = [shard for shard in transformed_shards if shard is not None]
3256+
if len(transformed_shards) != len(all_transformed_shards):
3257+
raise ValueError(
3258+
f"Failed to retrieve results from map: result list {transformed_shards} still contains None - "
3259+
"at least one worker failed to return its results"
31893260
)
3190-
logger.info(f"Concatenating {num_proc} shards")
3191-
result = _concatenate_map_style_datasets(transformed_shards)
3192-
# update fingerprint if the dataset changed
3261+
3262+
if num_shards == 1:
3263+
result = all_transformed_shards[0]
3264+
else:
3265+
logger.info(f"Concatenating {num_shards} shards")
3266+
result = _concatenate_map_style_datasets(all_transformed_shards)
3267+
3268+
# update fingerprint if the dataset changed
3269+
result._fingerprint = (
3270+
new_fingerprint
31933271
if any(
31943272
transformed_shard._fingerprint != shard._fingerprint
3195-
for transformed_shard, shard in zip(transformed_shards, shards)
3196-
):
3197-
result._fingerprint = new_fingerprint
3198-
else:
3199-
result._fingerprint = self._fingerprint
3200-
return result
3273+
for transformed_shard, shard in zip(all_transformed_shards, shards)
3274+
)
3275+
else self._fingerprint
3276+
)
3277+
3278+
return result
32013279

32023280
@staticmethod
32033281
def _map_single(
@@ -3219,7 +3297,7 @@ def _map_single(
32193297
new_fingerprint: Optional[str] = None,
32203298
rank: Optional[int] = None,
32213299
offset: int = 0,
3222-
) -> Iterable[Tuple[int, bool, Union[int, "Dataset"]]]:
3300+
) -> Iterable[Tuple[Optional[int], bool, Union[int, "Dataset"]]]:
32233301
"""Apply a function to all the elements in the table (individually or in batches)
32243302
and update the table (if function does update examples).
32253303

0 commit comments

Comments
 (0)