Skip to content

Commit b0b7a78

Browse files
authored
[tutorial] add synthetic dataset for opt (#1924)
1 parent 0486048 commit b0b7a78

File tree

3 files changed

+169
-100
lines changed

3 files changed

+169
-100
lines changed

examples/tutorial/opt/opt/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@ bash ./run_clm.sh <batch-size-per-gpu> <mem-cap> <model> <gpu-num>
3939
the pretrained weights from [OPT weight downloading page](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT).
4040
- gpu-num: the number of GPUs to use, default is 1.
4141

42+
It uses `wikitext` dataset.
43+
44+
To use synthetic dataset:
45+
46+
```bash
47+
bash ./run_clm_synthetic.sh <batch-size-per-gpu> <mem-cap> <model> <gpu-num>
48+
```
49+
4250
## Remarkable Performance
4351
On a single GPU, Colossal-AI’s automatic strategy provides remarkable performance gains from the ZeRO Offloading strategy by Microsoft DeepSpeed.
4452
Users can experience up to a 40% speedup, at a variety of model scales. However, when using a traditional deep learning training framework like PyTorch, a single GPU can no longer support the training of models at such a scale.

examples/tutorial/opt/opt/run_clm.py

Lines changed: 140 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def get_time_stamp():
7474

7575
def parse_args():
7676
parser = colossalai.get_default_parser()
77+
parser.add_argument("-s", "--synthetic", action="store_true")
7778
parser.add_argument(
7879
"--dataset_name",
7980
type=str,
@@ -231,15 +232,16 @@ def parse_args():
231232
args = parser.parse_args()
232233

233234
# Sanity checks
234-
if args.dataset_name is None and args.train_file is None and args.validation_file is None:
235-
raise ValueError("Need either a dataset name or a training/validation file.")
236-
else:
237-
if args.train_file is not None:
238-
extension = args.train_file.split(".")[-1]
239-
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file."
240-
if args.validation_file is not None:
241-
extension = args.validation_file.split(".")[-1]
242-
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
235+
if not args.synthetic:
236+
if args.dataset_name is None and args.train_file is None and args.validation_file is None:
237+
raise ValueError("Need either a dataset name or a training/validation file.")
238+
else:
239+
if args.train_file is not None:
240+
extension = args.train_file.split(".")[-1]
241+
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file."
242+
if args.validation_file is not None:
243+
extension = args.validation_file.split(".")[-1]
244+
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
243245

244246
if args.push_to_hub:
245247
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
@@ -255,6 +257,34 @@ def colo_memory_cap(size_in_GB):
255257
print("Using {} GB of GPU memory".format(size_in_GB))
256258

257259

260+
class DummyDataloader:
261+
262+
def __init__(self, length, batch_size, seq_len, vocab_size):
263+
self.length = length
264+
self.batch_size = batch_size
265+
self.seq_len = seq_len
266+
self.vocab_size = vocab_size
267+
268+
def generate(self):
269+
input_ids = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len), device=get_current_device())
270+
attention_mask = torch.ones_like(input_ids)
271+
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids}
272+
273+
def __iter__(self):
274+
self.step = 0
275+
return self
276+
277+
def __next__(self):
278+
if self.step < self.length:
279+
self.step += 1
280+
return self.generate()
281+
else:
282+
raise StopIteration
283+
284+
def __len__(self):
285+
return self.length
286+
287+
258288
def main():
259289
args = parse_args()
260290
disable_existing_loggers()
@@ -292,46 +322,47 @@ def main():
292322
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
293323
# download the dataset.
294324
logger.info("Start preparing dataset", ranks=[0])
295-
if args.dataset_name is not None:
296-
# Downloading and loading a dataset from the hub.
297-
raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
298-
if "validation" not in raw_datasets.keys():
299-
raw_datasets["validation"] = load_dataset(
300-
args.dataset_name,
301-
args.dataset_config_name,
302-
split=f"train[:{args.validation_split_percentage}%]",
303-
)
304-
raw_datasets["train"] = load_dataset(
305-
args.dataset_name,
306-
args.dataset_config_name,
307-
split=f"train[{args.validation_split_percentage}%:]",
308-
)
309-
else:
310-
data_files = {}
311-
dataset_args = {}
312-
if args.train_file is not None:
313-
data_files["train"] = args.train_file
314-
if args.validation_file is not None:
315-
data_files["validation"] = args.validation_file
316-
extension = args.train_file.split(".")[-1]
317-
if extension == "txt":
318-
extension = "text"
319-
dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks
320-
raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args)
321-
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
322-
if "validation" not in raw_datasets.keys():
323-
raw_datasets["validation"] = load_dataset(
324-
extension,
325-
data_files=data_files,
326-
split=f"train[:{args.validation_split_percentage}%]",
327-
**dataset_args,
328-
)
329-
raw_datasets["train"] = load_dataset(
330-
extension,
331-
data_files=data_files,
332-
split=f"train[{args.validation_split_percentage}%:]",
333-
**dataset_args,
334-
)
325+
if not args.synthetic:
326+
if args.dataset_name is not None:
327+
# Downloading and loading a dataset from the hub.
328+
raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
329+
if "validation" not in raw_datasets.keys():
330+
raw_datasets["validation"] = load_dataset(
331+
args.dataset_name,
332+
args.dataset_config_name,
333+
split=f"train[:{args.validation_split_percentage}%]",
334+
)
335+
raw_datasets["train"] = load_dataset(
336+
args.dataset_name,
337+
args.dataset_config_name,
338+
split=f"train[{args.validation_split_percentage}%:]",
339+
)
340+
else:
341+
data_files = {}
342+
dataset_args = {}
343+
if args.train_file is not None:
344+
data_files["train"] = args.train_file
345+
if args.validation_file is not None:
346+
data_files["validation"] = args.validation_file
347+
extension = args.train_file.split(".")[-1]
348+
if extension == "txt":
349+
extension = "text"
350+
dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks
351+
raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args)
352+
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
353+
if "validation" not in raw_datasets.keys():
354+
raw_datasets["validation"] = load_dataset(
355+
extension,
356+
data_files=data_files,
357+
split=f"train[:{args.validation_split_percentage}%]",
358+
**dataset_args,
359+
)
360+
raw_datasets["train"] = load_dataset(
361+
extension,
362+
data_files=data_files,
363+
split=f"train[{args.validation_split_percentage}%:]",
364+
**dataset_args,
365+
)
335366
logger.info("Dataset is prepared", ranks=[0])
336367

337368
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
@@ -399,23 +430,24 @@ def main():
399430

400431
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])
401432

402-
# Preprocessing the datasets.
403-
# First we tokenize all the texts.
404-
column_names = raw_datasets["train"].column_names
405-
text_column_name = "text" if "text" in column_names else column_names[0]
406-
407-
def tokenize_function(examples):
408-
return tokenizer(examples[text_column_name])
409-
410-
with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA):
411-
tokenized_datasets = raw_datasets.map(
412-
tokenize_function,
413-
batched=True,
414-
num_proc=args.preprocessing_num_workers,
415-
remove_columns=column_names,
416-
load_from_cache_file=not args.overwrite_cache,
417-
desc="Running tokenizer on dataset",
418-
)
433+
if not args.synthetic:
434+
# Preprocessing the datasets.
435+
# First we tokenize all the texts.
436+
column_names = raw_datasets["train"].column_names
437+
text_column_name = "text" if "text" in column_names else column_names[0]
438+
439+
def tokenize_function(examples):
440+
return tokenizer(examples[text_column_name])
441+
442+
with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA):
443+
tokenized_datasets = raw_datasets.map(
444+
tokenize_function,
445+
batched=True,
446+
num_proc=args.preprocessing_num_workers,
447+
remove_columns=column_names,
448+
load_from_cache_file=not args.overwrite_cache,
449+
desc="Running tokenizer on dataset",
450+
)
419451

420452
if args.block_size is None:
421453
block_size = tokenizer.model_max_length
@@ -447,38 +479,44 @@ def group_texts(examples):
447479
result["labels"] = result["input_ids"].copy()
448480
return result
449481

450-
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
451-
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
452-
# to preprocess.
453-
#
454-
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
455-
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
456-
457-
with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA):
458-
lm_datasets = tokenized_datasets.map(
459-
group_texts,
460-
batched=True,
461-
num_proc=args.preprocessing_num_workers,
462-
load_from_cache_file=not args.overwrite_cache,
463-
desc=f"Grouping texts in chunks of {block_size}",
464-
)
465-
466-
train_dataset = lm_datasets["train"]
467-
eval_dataset = lm_datasets["validation"]
468-
469-
# Log a few random samples from the training set:
470-
# for index in random.sample(range(len(train_dataset)), 3):
471-
# logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
472-
473-
# DataLoaders creation:
474-
train_dataloader = get_dataloader(train_dataset,
475-
shuffle=True,
476-
add_sampler=True,
477-
collate_fn=default_data_collator,
478-
batch_size=args.per_device_train_batch_size)
479-
eval_dataloader = DataLoader(eval_dataset,
480-
collate_fn=default_data_collator,
481-
batch_size=args.per_device_eval_batch_size)
482+
if not args.synthetic:
483+
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
484+
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
485+
# to preprocess.
486+
#
487+
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
488+
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
489+
490+
with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA):
491+
lm_datasets = tokenized_datasets.map(
492+
group_texts,
493+
batched=True,
494+
num_proc=args.preprocessing_num_workers,
495+
load_from_cache_file=not args.overwrite_cache,
496+
desc=f"Grouping texts in chunks of {block_size}",
497+
)
498+
499+
train_dataset = lm_datasets["train"]
500+
eval_dataset = lm_datasets["validation"]
501+
502+
# Log a few random samples from the training set:
503+
# for index in random.sample(range(len(train_dataset)), 3):
504+
# logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
505+
506+
# DataLoaders creation:
507+
train_dataloader = get_dataloader(train_dataset,
508+
shuffle=True,
509+
add_sampler=True,
510+
collate_fn=default_data_collator,
511+
batch_size=args.per_device_train_batch_size)
512+
eval_dataloader = DataLoader(eval_dataset,
513+
collate_fn=default_data_collator,
514+
batch_size=args.per_device_eval_batch_size)
515+
else:
516+
train_dataloader = DummyDataloader(30, args.per_device_train_batch_size, config.max_position_embeddings,
517+
config.vocab_size)
518+
eval_dataloader = DummyDataloader(10, args.per_device_train_batch_size, config.max_position_embeddings,
519+
config.vocab_size)
482520
logger.info("Dataloaders have been created", ranks=[0])
483521

484522
# Optimizer
@@ -521,9 +559,11 @@ def group_texts(examples):
521559

522560
# Train!
523561
total_batch_size = args.per_device_train_batch_size * gpc.get_world_size(ParallelMode.DATA)
562+
num_train_samples = len(train_dataset) if not args.synthetic else 30 * total_batch_size
563+
num_eval_samples = len(eval_dataset) if not args.synthetic else 10 * total_batch_size
524564

525565
logger.info("***** Running training *****", ranks=[0])
526-
logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
566+
logger.info(f" Num examples = {num_train_samples}", ranks=[0])
527567
logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])
528568
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}", ranks=[0])
529569
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
@@ -572,7 +612,7 @@ def group_texts(examples):
572612
losses.append(loss)
573613

574614
losses = torch.cat(losses)
575-
losses = losses[:len(eval_dataset)]
615+
losses = losses[:num_eval_samples]
576616
try:
577617
eval_loss = torch.mean(losses)
578618
perplexity = math.exp(eval_loss)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
set -x
2+
export BS=${1:-16}
3+
export MEMCAP=${2:-0}
4+
export MODEL=${3:-"125m"}
5+
export GPUNUM=${4:-1}
6+
7+
# make directory for logs
8+
mkdir -p ./logs
9+
10+
export MODLE_PATH="facebook/opt-${MODEL}"
11+
12+
# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
13+
torchrun \
14+
--nproc_per_node ${GPUNUM} \
15+
--master_port 19198 \
16+
run_clm.py \
17+
-s \
18+
--output_dir $PWD \
19+
--mem_cap ${MEMCAP} \
20+
--model_name_or_path ${MODLE_PATH} \
21+
--per_device_train_batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log

0 commit comments

Comments
 (0)