19
19
import contextlib
20
20
import copy
21
21
import fnmatch
22
+ import glob
22
23
import inspect
23
24
import itertools
24
25
import json
27
28
import posixpath
28
29
import re
29
30
import shutil
31
+ import string
30
32
import sys
31
33
import tempfile
32
34
import time
33
35
import warnings
34
36
import weakref
35
- from collections import Counter
37
+ from collections import Counter , defaultdict
36
38
from collections .abc import Mapping
37
39
from copy import deepcopy
38
40
from functools import partial , wraps
@@ -2964,6 +2966,11 @@ def map(
2964
2966
if num_proc is not None and num_proc <= 0 :
2965
2967
raise ValueError ("num_proc must be an integer > 0." )
2966
2968
2969
+ string_formatter = string .Formatter ()
2970
+ fields = {field_name for _ , field_name , _ , _ in string_formatter .parse (suffix_template ) if field_name }
2971
+ if fields != {"rank" , "num_proc" }:
2972
+ raise ValueError (f"suffix_template must contain exactly the fields 'rank' and 'num_proc', got: { fields } " )
2973
+
2967
2974
# If the array is empty we do nothing (but we make sure to handle an empty indices mapping and remove the requested columns anyway)
2968
2975
if len (self ) == 0 :
2969
2976
if self ._indices is not None : # empty indices mapping
@@ -3045,7 +3052,7 @@ def map(
3045
3052
cache_file_name = self ._get_cache_file_path (new_fingerprint )
3046
3053
dataset_kwargs ["cache_file_name" ] = cache_file_name
3047
3054
3048
- def load_processed_shard_from_cache (shard_kwargs ) :
3055
+ def load_processed_shard_from_cache (shard_kwargs : dict [ str , Any ]) -> Dataset :
3049
3056
"""Load a processed shard from cache if it exists, otherwise throw an error."""
3050
3057
shard = shard_kwargs ["shard" ]
3051
3058
# Check if we've already cached this computation (indexed by a hash)
@@ -3056,64 +3063,98 @@ def load_processed_shard_from_cache(shard_kwargs):
3056
3063
return Dataset .from_file (shard_kwargs ["cache_file_name" ], info = info , split = shard .split )
3057
3064
raise NonExistentDatasetError
3058
3065
3059
- num_shards = num_proc if num_proc is not None else 1
3060
- if batched and drop_last_batch :
3061
- pbar_total = len (self ) // num_shards // batch_size * num_shards * batch_size
3062
- else :
3063
- pbar_total = len (self )
3066
+ def pbar_total (num_shards : int , batch_size : Optional [int ]) -> int :
3067
+ total = len (self )
3068
+ if len (existing_cache_files ) < num_shards :
3069
+ total -= len (existing_cache_files ) * total // num_shards
3070
+ if batched and drop_last_batch :
3071
+ batch_size = batch_size or 1
3072
+ return total // num_shards // batch_size * num_shards * batch_size
3073
+ return total
3074
+
3075
+ def get_existing_cache_file_map (
3076
+ cache_file_name : Optional [str ],
3077
+ ) -> dict [int , list [str ]]:
3078
+ cache_files_by_num_proc : dict [int , list [str ]] = defaultdict (list )
3079
+ if cache_file_name is None :
3080
+ return cache_files_by_num_proc
3081
+ if os .path .exists (cache_file_name ):
3082
+ cache_files_by_num_proc [1 ] = [cache_file_name ]
3083
+
3084
+ suffix_pattern_parts : list [str ] = []
3085
+ for literal_text , field_name , format_spec , _ in string_formatter .parse (suffix_template ):
3086
+ suffix_pattern_parts .append (re .escape (literal_text ))
3087
+ if field_name :
3088
+ # TODO: we may want to place restrictions on acceptable format_spec or we will fail to match
3089
+ # someone's hexidecimal or scientific notation format 😵
3090
+ suffix_pattern_parts .append (f"(?P<{ field_name } >\\ d+)" )
3091
+ suffix_pattern = "" .join (suffix_pattern_parts )
3092
+
3093
+ cache_file_prefix , cache_file_ext = os .path .splitext (cache_file_name )
3094
+ if not cache_file_ext :
3095
+ raise ValueError (f"Expected cache_file_name to have an extension, but got: { cache_file_name } " )
3096
+
3097
+ cache_file_pattern = "^" + re .escape (cache_file_prefix ) + suffix_pattern + re .escape (cache_file_ext ) + "$"
3098
+ cache_file_regex = re .compile (cache_file_pattern )
3099
+
3100
+ for cache_file in glob .iglob (f"{ cache_file_prefix } *{ cache_file_ext } " ):
3101
+ if m := cache_file_regex .match (cache_file ):
3102
+ file_num_proc = int (m .group ("num_proc" ))
3103
+ cache_files_by_num_proc [file_num_proc ].append (cache_file )
3104
+
3105
+ return cache_files_by_num_proc
3106
+
3107
+ existing_cache_file_map = get_existing_cache_file_map (cache_file_name )
3108
+
3109
+ num_shards = num_proc or 1
3110
+ if existing_cache_file_map :
3111
+ # to avoid remapping when a different num_proc is given than when originally cached, update num_shards to
3112
+ # what was used originally
3113
+
3114
+ def select_existing_cache_files (mapped_num_proc : int ) -> tuple [float , ...]:
3115
+ percent_missing = (mapped_num_proc - len (existing_cache_file_map [mapped_num_proc ])) / mapped_num_proc
3116
+ num_shards_diff = abs (mapped_num_proc - num_shards )
3117
+ return (
3118
+ percent_missing , # choose the most complete set of existing cache files
3119
+ num_shards_diff , # then choose the mapped_num_proc closest to the current num_proc
3120
+ mapped_num_proc , # finally, choose whichever mapped_num_proc is lower
3121
+ )
3064
3122
3065
- shards_done = 0
3066
- if num_proc is None or num_proc == 1 :
3067
- transformed_dataset = None
3068
- try :
3069
- transformed_dataset = load_processed_shard_from_cache (dataset_kwargs )
3070
- logger .info (f"Loading cached processed dataset at { dataset_kwargs ['cache_file_name' ]} " )
3071
- except NonExistentDatasetError :
3072
- pass
3073
- if transformed_dataset is None :
3074
- with hf_tqdm (
3075
- unit = " examples" ,
3076
- total = pbar_total ,
3077
- desc = desc or "Map" ,
3078
- ) as pbar :
3079
- for rank , done , content in Dataset ._map_single (** dataset_kwargs ):
3080
- if done :
3081
- shards_done += 1
3082
- logger .debug (f"Finished processing shard number { rank } of { num_shards } ." )
3083
- transformed_dataset = content
3084
- else :
3085
- pbar .update (content )
3086
- assert transformed_dataset is not None , "Failed to retrieve the result from map"
3087
- # update fingerprint if the dataset changed
3088
- if transformed_dataset ._fingerprint != self ._fingerprint :
3089
- transformed_dataset ._fingerprint = new_fingerprint
3090
- return transformed_dataset
3091
- else :
3123
+ num_shards = min (existing_cache_file_map , key = select_existing_cache_files )
3092
3124
3093
- def format_cache_file_name (
3094
- cache_file_name : Optional [str ],
3095
- rank : Union [int , Literal ["*" ]], # noqa: F722
3096
- ) -> Optional [str ]:
3097
- if not cache_file_name :
3098
- return cache_file_name
3099
- sep = cache_file_name .rindex ("." )
3100
- base_name , extension = cache_file_name [:sep ], cache_file_name [sep :]
3101
- if isinstance (rank , int ):
3102
- cache_file_name = base_name + suffix_template .format (rank = rank , num_proc = num_proc ) + extension
3103
- logger .info (f"Process #{ rank } will write at { cache_file_name } " )
3104
- else :
3105
- cache_file_name = (
3106
- base_name
3107
- + suffix_template .replace ("{rank:05d}" , "{rank}" ).format (rank = rank , num_proc = num_proc )
3108
- + extension
3109
- )
3125
+ existing_cache_files = existing_cache_file_map .get (num_shards , [])
3126
+
3127
+ def format_cache_file_name (
3128
+ cache_file_name : Optional [str ],
3129
+ rank : Union [int , Literal ["*" ]], # noqa: F722
3130
+ ) -> Optional [str ]:
3131
+ if not cache_file_name :
3110
3132
return cache_file_name
3111
3133
3112
- def format_new_fingerprint (new_fingerprint : str , rank : int ) -> str :
3113
- new_fingerprint = new_fingerprint + suffix_template .format (rank = rank , num_proc = num_proc )
3114
- validate_fingerprint (new_fingerprint )
3115
- return new_fingerprint
3134
+ cache_file_prefix , cache_file_ext = os .path .splitext (cache_file_name )
3135
+ if not cache_file_ext :
3136
+ raise ValueError (f"Expected cache_file_name to have an extension, but got: { cache_file_name } " )
3137
+
3138
+ if isinstance (rank , int ):
3139
+ cache_file_name = (
3140
+ cache_file_prefix + suffix_template .format (rank = rank , num_proc = num_shards ) + cache_file_ext
3141
+ )
3142
+ logger .info (f"Process #{ rank } will write at { cache_file_name } " )
3143
+ else :
3144
+ # TODO: this assumes the format_spec of rank in suffix_template
3145
+ cache_file_name = (
3146
+ cache_file_prefix
3147
+ + suffix_template .replace ("{rank:05d}" , "{rank}" ).format (rank = rank , num_proc = num_shards )
3148
+ + cache_file_ext
3149
+ )
3150
+ return cache_file_name
3151
+
3152
+ def format_new_fingerprint (new_fingerprint : str , rank : int ) -> str :
3153
+ new_fingerprint = new_fingerprint + suffix_template .format (rank = rank , num_proc = num_shards )
3154
+ validate_fingerprint (new_fingerprint )
3155
+ return new_fingerprint
3116
3156
3157
+ if num_proc is not None and num_proc > 1 :
3117
3158
prev_env = deepcopy (os .environ )
3118
3159
# check if parallelism if off
3119
3160
# from https://github.com/huggingface/tokenizers/blob/bb668bc439dc34389b71dbb8ce0c597f15707b53/tokenizers/src/utils/parallelism.rs#L22
@@ -3128,9 +3169,17 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
3128
3169
):
3129
3170
logger .warning ("Setting TOKENIZERS_PARALLELISM=false for forked processes." )
3130
3171
os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
3172
+ else :
3173
+ prev_env = os .environ
3174
+
3175
+ kwargs_per_job : list [Optional [dict [str , Any ]]]
3176
+ if num_shards == 1 :
3177
+ shards = [self ]
3178
+ kwargs_per_job = [dataset_kwargs ]
3179
+ else :
3131
3180
shards = [
3132
- self .shard (num_shards = num_proc , index = rank , contiguous = True , keep_in_memory = keep_in_memory )
3133
- for rank in range (num_proc )
3181
+ self .shard (num_shards = num_shards , index = rank , contiguous = True , keep_in_memory = keep_in_memory )
3182
+ for rank in range (num_shards )
3134
3183
]
3135
3184
kwargs_per_job = [
3136
3185
{
@@ -3144,60 +3193,89 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
3144
3193
for rank in range (num_shards )
3145
3194
]
3146
3195
3147
- transformed_shards = [None ] * num_shards
3148
- for rank in range (num_shards ):
3149
- try :
3150
- transformed_shards [rank ] = load_processed_shard_from_cache (kwargs_per_job [rank ])
3151
- kwargs_per_job [rank ] = None
3152
- except NonExistentDatasetError :
3153
- pass
3154
-
3155
- kwargs_per_job = [kwargs for kwargs in kwargs_per_job if kwargs is not None ]
3156
-
3157
- # We try to create a pool with as many workers as dataset not yet cached.
3158
- if kwargs_per_job :
3159
- if len (kwargs_per_job ) < num_shards :
3160
- logger .info (
3161
- f"Reprocessing { len (kwargs_per_job )} /{ num_shards } shards because some of them were missing from the cache."
3162
- )
3163
- with Pool (len (kwargs_per_job )) as pool :
3164
- os .environ = prev_env
3165
- logger .info (f"Spawning { num_proc } processes" )
3166
- with hf_tqdm (
3167
- unit = " examples" ,
3168
- total = pbar_total ,
3169
- desc = (desc or "Map" ) + f" (num_proc={ num_proc } )" ,
3170
- ) as pbar :
3196
+ transformed_shards : list [Optional [Dataset ]] = [None ] * num_shards
3197
+ for rank in range (num_shards ):
3198
+ try :
3199
+ job_kwargs = kwargs_per_job [rank ]
3200
+ assert job_kwargs is not None
3201
+ transformed_shards [rank ] = load_processed_shard_from_cache (job_kwargs )
3202
+ kwargs_per_job [rank ] = None
3203
+ except NonExistentDatasetError :
3204
+ pass
3205
+
3206
+ if unprocessed_kwargs_per_job := [kwargs for kwargs in kwargs_per_job if kwargs is not None ]:
3207
+ if len (unprocessed_kwargs_per_job ) < num_shards :
3208
+ logger .info (
3209
+ f"Reprocessing { len (unprocessed_kwargs_per_job )} /{ num_shards } shards because some of them were "
3210
+ " missing from the cache."
3211
+ )
3212
+
3213
+ with hf_tqdm (
3214
+ unit = " examples" ,
3215
+ total = pbar_total (num_shards , batch_size ),
3216
+ desc = (desc or "Map" ) + (f" (num_proc={ num_proc } )" if num_proc is not None and num_proc > 1 else "" ),
3217
+ ) as pbar :
3218
+ shards_done = 0
3219
+
3220
+ def check_if_shard_done (rank : Optional [int ], done : bool , content : Union [Dataset , int ]) -> None :
3221
+ nonlocal shards_done
3222
+ if done :
3223
+ shards_done += 1
3224
+ logger .debug (f"Finished processing shard number { rank } of { num_shards } ." )
3225
+ assert isinstance (content , Dataset )
3226
+ transformed_shards [rank or 0 ] = content
3227
+ else :
3228
+ assert isinstance (content , int )
3229
+ pbar .update (content )
3230
+
3231
+ if num_proc is not None and num_proc > 1 :
3232
+ with Pool (num_proc ) as pool :
3233
+ os .environ = prev_env
3234
+ logger .info (f"Spawning { num_proc } processes" )
3235
+
3171
3236
for rank , done , content in iflatmap_unordered (
3172
- pool , Dataset ._map_single , kwargs_iterable = kwargs_per_job
3237
+ pool , Dataset ._map_single , kwargs_iterable = unprocessed_kwargs_per_job
3173
3238
):
3174
- if done :
3175
- shards_done += 1
3176
- logger .debug (f"Finished processing shard number { rank } of { num_shards } ." )
3177
- transformed_shards [rank ] = content
3178
- else :
3179
- pbar .update (content )
3180
- pool .close ()
3181
- pool .join ()
3182
- # Avoids PermissionError on Windows (the error: https://github.com/huggingface/datasets/actions/runs/4026734820/jobs/6921621805)
3183
- for kwargs in kwargs_per_job :
3184
- del kwargs ["shard" ]
3185
- else :
3186
- logger .info (f"Loading cached processed dataset at { format_cache_file_name (cache_file_name , '*' )} " )
3187
- assert None not in transformed_shards , (
3188
- f"Failed to retrieve results from map: result list { transformed_shards } still contains None - at least one worker failed to return its results"
3239
+ check_if_shard_done (rank , done , content )
3240
+
3241
+ pool .close ()
3242
+ pool .join ()
3243
+ else :
3244
+ for unprocessed_kwargs in unprocessed_kwargs_per_job :
3245
+ for rank , done , content in Dataset ._map_single (** unprocessed_kwargs ):
3246
+ check_if_shard_done (rank , done , content )
3247
+
3248
+ # Avoids PermissionError on Windows (the error: https://github.com/huggingface/datasets/actions/runs/4026734820/jobs/6921621805)
3249
+ for job_kwargs in unprocessed_kwargs_per_job :
3250
+ if "shard" in job_kwargs :
3251
+ del job_kwargs ["shard" ]
3252
+ else :
3253
+ logger .info (f"Loading cached processed dataset at { format_cache_file_name (cache_file_name , '*' )} " )
3254
+
3255
+ all_transformed_shards = [shard for shard in transformed_shards if shard is not None ]
3256
+ if len (transformed_shards ) != len (all_transformed_shards ):
3257
+ raise ValueError (
3258
+ f"Failed to retrieve results from map: result list { transformed_shards } still contains None - "
3259
+ "at least one worker failed to return its results"
3189
3260
)
3190
- logger .info (f"Concatenating { num_proc } shards" )
3191
- result = _concatenate_map_style_datasets (transformed_shards )
3192
- # update fingerprint if the dataset changed
3261
+
3262
+ if num_shards == 1 :
3263
+ result = all_transformed_shards [0 ]
3264
+ else :
3265
+ logger .info (f"Concatenating { num_shards } shards" )
3266
+ result = _concatenate_map_style_datasets (all_transformed_shards )
3267
+
3268
+ # update fingerprint if the dataset changed
3269
+ result ._fingerprint = (
3270
+ new_fingerprint
3193
3271
if any (
3194
3272
transformed_shard ._fingerprint != shard ._fingerprint
3195
- for transformed_shard , shard in zip (transformed_shards , shards )
3196
- ):
3197
- result ._fingerprint = new_fingerprint
3198
- else :
3199
- result . _fingerprint = self . _fingerprint
3200
- return result
3273
+ for transformed_shard , shard in zip (all_transformed_shards , shards )
3274
+ )
3275
+ else self ._fingerprint
3276
+ )
3277
+
3278
+ return result
3201
3279
3202
3280
@staticmethod
3203
3281
def _map_single (
@@ -3219,7 +3297,7 @@ def _map_single(
3219
3297
new_fingerprint : Optional [str ] = None ,
3220
3298
rank : Optional [int ] = None ,
3221
3299
offset : int = 0 ,
3222
- ) -> Iterable [Tuple [int , bool , Union [int , "Dataset" ]]]:
3300
+ ) -> Iterable [Tuple [Optional [ int ] , bool , Union [int , "Dataset" ]]]:
3223
3301
"""Apply a function to all the elements in the table (individually or in batches)
3224
3302
and update the table (if function does update examples).
3225
3303
0 commit comments