@@ -74,6 +74,7 @@ def get_time_stamp():
7474
7575def parse_args ():
7676 parser = colossalai .get_default_parser ()
77+ parser .add_argument ("-s" , "--synthetic" , action = "store_true" )
7778 parser .add_argument (
7879 "--dataset_name" ,
7980 type = str ,
@@ -231,15 +232,16 @@ def parse_args():
231232 args = parser .parse_args ()
232233
233234 # Sanity checks
234- if args .dataset_name is None and args .train_file is None and args .validation_file is None :
235- raise ValueError ("Need either a dataset name or a training/validation file." )
236- else :
237- if args .train_file is not None :
238- extension = args .train_file .split ("." )[- 1 ]
239- assert extension in ["csv" , "json" , "txt" ], "`train_file` should be a csv, json or txt file."
240- if args .validation_file is not None :
241- extension = args .validation_file .split ("." )[- 1 ]
242- assert extension in ["csv" , "json" , "txt" ], "`validation_file` should be a csv, json or txt file."
235+ if not args .synthetic :
236+ if args .dataset_name is None and args .train_file is None and args .validation_file is None :
237+ raise ValueError ("Need either a dataset name or a training/validation file." )
238+ else :
239+ if args .train_file is not None :
240+ extension = args .train_file .split ("." )[- 1 ]
241+ assert extension in ["csv" , "json" , "txt" ], "`train_file` should be a csv, json or txt file."
242+ if args .validation_file is not None :
243+ extension = args .validation_file .split ("." )[- 1 ]
244+ assert extension in ["csv" , "json" , "txt" ], "`validation_file` should be a csv, json or txt file."
243245
244246 if args .push_to_hub :
245247 assert args .output_dir is not None , "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
@@ -255,6 +257,34 @@ def colo_memory_cap(size_in_GB):
255257 print ("Using {} GB of GPU memory" .format (size_in_GB ))
256258
257259
260+ class DummyDataloader :
261+
262+ def __init__ (self , length , batch_size , seq_len , vocab_size ):
263+ self .length = length
264+ self .batch_size = batch_size
265+ self .seq_len = seq_len
266+ self .vocab_size = vocab_size
267+
268+ def generate (self ):
269+ input_ids = torch .randint (0 , self .vocab_size , (self .batch_size , self .seq_len ), device = get_current_device ())
270+ attention_mask = torch .ones_like (input_ids )
271+ return {"input_ids" : input_ids , "attention_mask" : attention_mask , "labels" : input_ids }
272+
273+ def __iter__ (self ):
274+ self .step = 0
275+ return self
276+
277+ def __next__ (self ):
278+ if self .step < self .length :
279+ self .step += 1
280+ return self .generate ()
281+ else :
282+ raise StopIteration
283+
284+ def __len__ (self ):
285+ return self .length
286+
287+
258288def main ():
259289 args = parse_args ()
260290 disable_existing_loggers ()
@@ -292,46 +322,47 @@ def main():
292322 # In distributed training, the load_dataset function guarantee that only one local process can concurrently
293323 # download the dataset.
294324 logger .info ("Start preparing dataset" , ranks = [0 ])
295- if args .dataset_name is not None :
296- # Downloading and loading a dataset from the hub.
297- raw_datasets = load_dataset (args .dataset_name , args .dataset_config_name )
298- if "validation" not in raw_datasets .keys ():
299- raw_datasets ["validation" ] = load_dataset (
300- args .dataset_name ,
301- args .dataset_config_name ,
302- split = f"train[:{ args .validation_split_percentage } %]" ,
303- )
304- raw_datasets ["train" ] = load_dataset (
305- args .dataset_name ,
306- args .dataset_config_name ,
307- split = f"train[{ args .validation_split_percentage } %:]" ,
308- )
309- else :
310- data_files = {}
311- dataset_args = {}
312- if args .train_file is not None :
313- data_files ["train" ] = args .train_file
314- if args .validation_file is not None :
315- data_files ["validation" ] = args .validation_file
316- extension = args .train_file .split ("." )[- 1 ]
317- if extension == "txt" :
318- extension = "text"
319- dataset_args ["keep_linebreaks" ] = not args .no_keep_linebreaks
320- raw_datasets = load_dataset (extension , data_files = data_files , ** dataset_args )
321- # If no validation data is there, validation_split_percentage will be used to divide the dataset.
322- if "validation" not in raw_datasets .keys ():
323- raw_datasets ["validation" ] = load_dataset (
324- extension ,
325- data_files = data_files ,
326- split = f"train[:{ args .validation_split_percentage } %]" ,
327- ** dataset_args ,
328- )
329- raw_datasets ["train" ] = load_dataset (
330- extension ,
331- data_files = data_files ,
332- split = f"train[{ args .validation_split_percentage } %:]" ,
333- ** dataset_args ,
334- )
325+ if not args .synthetic :
326+ if args .dataset_name is not None :
327+ # Downloading and loading a dataset from the hub.
328+ raw_datasets = load_dataset (args .dataset_name , args .dataset_config_name )
329+ if "validation" not in raw_datasets .keys ():
330+ raw_datasets ["validation" ] = load_dataset (
331+ args .dataset_name ,
332+ args .dataset_config_name ,
333+ split = f"train[:{ args .validation_split_percentage } %]" ,
334+ )
335+ raw_datasets ["train" ] = load_dataset (
336+ args .dataset_name ,
337+ args .dataset_config_name ,
338+ split = f"train[{ args .validation_split_percentage } %:]" ,
339+ )
340+ else :
341+ data_files = {}
342+ dataset_args = {}
343+ if args .train_file is not None :
344+ data_files ["train" ] = args .train_file
345+ if args .validation_file is not None :
346+ data_files ["validation" ] = args .validation_file
347+ extension = args .train_file .split ("." )[- 1 ]
348+ if extension == "txt" :
349+ extension = "text"
350+ dataset_args ["keep_linebreaks" ] = not args .no_keep_linebreaks
351+ raw_datasets = load_dataset (extension , data_files = data_files , ** dataset_args )
352+ # If no validation data is there, validation_split_percentage will be used to divide the dataset.
353+ if "validation" not in raw_datasets .keys ():
354+ raw_datasets ["validation" ] = load_dataset (
355+ extension ,
356+ data_files = data_files ,
357+ split = f"train[:{ args .validation_split_percentage } %]" ,
358+ ** dataset_args ,
359+ )
360+ raw_datasets ["train" ] = load_dataset (
361+ extension ,
362+ data_files = data_files ,
363+ split = f"train[{ args .validation_split_percentage } %:]" ,
364+ ** dataset_args ,
365+ )
335366 logger .info ("Dataset is prepared" , ranks = [0 ])
336367
337368 # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
@@ -399,23 +430,24 @@ def main():
399430
400431 logger .info (f'{ model .__class__ .__name__ } has been created' , ranks = [0 ])
401432
402- # Preprocessing the datasets.
403- # First we tokenize all the texts.
404- column_names = raw_datasets ["train" ].column_names
405- text_column_name = "text" if "text" in column_names else column_names [0 ]
406-
407- def tokenize_function (examples ):
408- return tokenizer (examples [text_column_name ])
409-
410- with barrier_context (executor_rank = 0 , parallel_mode = ParallelMode .DATA ):
411- tokenized_datasets = raw_datasets .map (
412- tokenize_function ,
413- batched = True ,
414- num_proc = args .preprocessing_num_workers ,
415- remove_columns = column_names ,
416- load_from_cache_file = not args .overwrite_cache ,
417- desc = "Running tokenizer on dataset" ,
418- )
433+ if not args .synthetic :
434+ # Preprocessing the datasets.
435+ # First we tokenize all the texts.
436+ column_names = raw_datasets ["train" ].column_names
437+ text_column_name = "text" if "text" in column_names else column_names [0 ]
438+
439+ def tokenize_function (examples ):
440+ return tokenizer (examples [text_column_name ])
441+
442+ with barrier_context (executor_rank = 0 , parallel_mode = ParallelMode .DATA ):
443+ tokenized_datasets = raw_datasets .map (
444+ tokenize_function ,
445+ batched = True ,
446+ num_proc = args .preprocessing_num_workers ,
447+ remove_columns = column_names ,
448+ load_from_cache_file = not args .overwrite_cache ,
449+ desc = "Running tokenizer on dataset" ,
450+ )
419451
420452 if args .block_size is None :
421453 block_size = tokenizer .model_max_length
@@ -447,38 +479,44 @@ def group_texts(examples):
447479 result ["labels" ] = result ["input_ids" ].copy ()
448480 return result
449481
450- # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
451- # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
452- # to preprocess.
453- #
454- # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
455- # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
456-
457- with barrier_context (executor_rank = 0 , parallel_mode = ParallelMode .DATA ):
458- lm_datasets = tokenized_datasets .map (
459- group_texts ,
460- batched = True ,
461- num_proc = args .preprocessing_num_workers ,
462- load_from_cache_file = not args .overwrite_cache ,
463- desc = f"Grouping texts in chunks of { block_size } " ,
464- )
465-
466- train_dataset = lm_datasets ["train" ]
467- eval_dataset = lm_datasets ["validation" ]
468-
469- # Log a few random samples from the training set:
470- # for index in random.sample(range(len(train_dataset)), 3):
471- # logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
472-
473- # DataLoaders creation:
474- train_dataloader = get_dataloader (train_dataset ,
475- shuffle = True ,
476- add_sampler = True ,
477- collate_fn = default_data_collator ,
478- batch_size = args .per_device_train_batch_size )
479- eval_dataloader = DataLoader (eval_dataset ,
480- collate_fn = default_data_collator ,
481- batch_size = args .per_device_eval_batch_size )
482+ if not args .synthetic :
483+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
484+ # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
485+ # to preprocess.
486+ #
487+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
488+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
489+
490+ with barrier_context (executor_rank = 0 , parallel_mode = ParallelMode .DATA ):
491+ lm_datasets = tokenized_datasets .map (
492+ group_texts ,
493+ batched = True ,
494+ num_proc = args .preprocessing_num_workers ,
495+ load_from_cache_file = not args .overwrite_cache ,
496+ desc = f"Grouping texts in chunks of { block_size } " ,
497+ )
498+
499+ train_dataset = lm_datasets ["train" ]
500+ eval_dataset = lm_datasets ["validation" ]
501+
502+ # Log a few random samples from the training set:
503+ # for index in random.sample(range(len(train_dataset)), 3):
504+ # logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
505+
506+ # DataLoaders creation:
507+ train_dataloader = get_dataloader (train_dataset ,
508+ shuffle = True ,
509+ add_sampler = True ,
510+ collate_fn = default_data_collator ,
511+ batch_size = args .per_device_train_batch_size )
512+ eval_dataloader = DataLoader (eval_dataset ,
513+ collate_fn = default_data_collator ,
514+ batch_size = args .per_device_eval_batch_size )
515+ else :
516+ train_dataloader = DummyDataloader (30 , args .per_device_train_batch_size , config .max_position_embeddings ,
517+ config .vocab_size )
518+ eval_dataloader = DummyDataloader (10 , args .per_device_train_batch_size , config .max_position_embeddings ,
519+ config .vocab_size )
482520 logger .info ("Dataloaders have been created" , ranks = [0 ])
483521
484522 # Optimizer
@@ -521,9 +559,11 @@ def group_texts(examples):
521559
522560 # Train!
523561 total_batch_size = args .per_device_train_batch_size * gpc .get_world_size (ParallelMode .DATA )
562+ num_train_samples = len (train_dataset ) if not args .synthetic else 30 * total_batch_size
563+ num_eval_samples = len (eval_dataset ) if not args .synthetic else 10 * total_batch_size
524564
525565 logger .info ("***** Running training *****" , ranks = [0 ])
526- logger .info (f" Num examples = { len ( train_dataset ) } " , ranks = [0 ])
566+ logger .info (f" Num examples = { num_train_samples } " , ranks = [0 ])
527567 logger .info (f" Num Epochs = { args .num_train_epochs } " , ranks = [0 ])
528568 logger .info (f" Instantaneous batch size per device = { args .per_device_train_batch_size } " , ranks = [0 ])
529569 logger .info (f" Total train batch size (w. parallel, distributed & accumulation) = { total_batch_size } " , ranks = [0 ])
@@ -572,7 +612,7 @@ def group_texts(examples):
572612 losses .append (loss )
573613
574614 losses = torch .cat (losses )
575- losses = losses [:len ( eval_dataset ) ]
615+ losses = losses [:num_eval_samples ]
576616 try :
577617 eval_loss = torch .mean (losses )
578618 perplexity = math .exp (eval_loss )
0 commit comments