From ccf78e0db4c4ede877498f49242af9147a4ae538 Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Mon, 6 Mar 2023 11:32:19 -0500 Subject: [PATCH 01/22] add santacoder example script --- examples/pretrain_gpt_1B_santacoder.sh | 61 ++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 examples/pretrain_gpt_1B_santacoder.sh diff --git a/examples/pretrain_gpt_1B_santacoder.sh b/examples/pretrain_gpt_1B_santacoder.sh new file mode 100644 index 0000000000..dfb754429c --- /dev/null +++ b/examples/pretrain_gpt_1B_santacoder.sh @@ -0,0 +1,61 @@ +#! /bin/bash + +set -u # stop on unset variables + +# Runs the SantaCoder 1B model + +GPUS_PER_NODE=8 +MASTER_ADDR=${MASTER_NODE} # Adjust +MASTER_PORT=6000 +NNODES=12 # Adjust +# NODE_RANK=0 # Adjust +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +CHECKPOINT_PATH=/my/experiment/path # Adjust: Directory to store the checkpoints +DATA_PATH=/preprocessed/data/path # Adjust: Prefix of the preprocessed dataset. +TOKENIZER_FILE=/tokenizer/path # Adjust + +GPT_ARGS="\ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --recompute-granularity full \ + --recompute-method uniform \ +--num-layers 24 \ +--hidden-size 2048 \ +--num-attention-heads 16 \ +--attention-head-type multiquery \ +--init-method-std 0.022 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ +--attention-dropout 0.1 \ +--hidden-dropout 0.1 \ + --micro-batch-size 2 \ + --global-batch-size 192 \ +--lr 0.0002 \ +--train-iters 3000 \ +--lr-decay-iters 600000 \ +--lr-decay-style cosine \ +--lr-warmup-iters 175 \ +--weight-decay .1 \ +--adam-beta2 .95 \ +--clip-grad 1.0 \ +--fp16 \ + --log-interval 10 \ + --save-interval 4000 \ + --eval-interval 200 \ + --eval-iters 10 \ +" + +TENSORBOARD_ARGS="--tensorboard-dir ${CHECKPOINT_PATH}/tensorboard" + +torchrun $DISTRIBUTED_ARGS \ + pretrain_gpt.py \ + $GPT_ARGS \ + --tokenizer-type TokenizerFromFileWithFIM \ + --tokenizer-file $TOKENIZER_FILE \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + $TENSORBOARD_ARGS \ No newline at end of file From d661d045b4d2bc9594365e04f7ad0a819e5a984f Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Fri, 14 Apr 2023 12:36:38 +0200 Subject: [PATCH 02/22] Add MTF (alike bloomz/mt0) --- finetune_mtf.py | 211 ++++++++ megatron/arguments.py | 5 +- megatron/data/decoder_packed_mtf_dataset.py | 532 ++++++++++++++++++++ megatron/data/mtf_dataset.py | 90 ++++ megatron/utils.py | 84 ++++ 5 files changed, 921 insertions(+), 1 deletion(-) create mode 100644 finetune_mtf.py create mode 100644 megatron/data/decoder_packed_mtf_dataset.py create mode 100644 megatron/data/mtf_dataset.py diff --git a/finetune_mtf.py b/finetune_mtf.py new file mode 100644 index 0000000000..6492185d8d --- /dev/null +++ b/finetune_mtf.py @@ -0,0 +1,211 @@ +"""Multitask Finetuning""" + +import torch +from functools import partial +from megatron import get_args +from megatron import print_rank_0 +from megatron import get_timers +from megatron import get_tokenizer + +from megatron import get_args, get_tokenizer, print_rank_0, mpu +from megatron.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets, build_dataset_group +from megatron.model.enums import PositionEmbeddingType +#from megatron.model import GPTModelPipe +from megatron.model import GPTModel, ModelType +from megatron.training import pretrain +from megatron.utils import get_ltor_masks_and_position_ids, get_packed_attention_mask +from megatron.utils import average_losses_across_data_parallel_group + +#import deepspeed +#from deepspeed.runtime.utils import see_memory_usage + + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + print_rank_0('building GPT model ...') + model = GPTModel( + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process + ) + return model + +def fast_normalize(loss_mask: torch.Tensor): + """ + Turn loss_mask from [0,0,0,1,1,0,0,1,0,0,1,1,1] > [0,0,0,0.5,0.5,0,0,1,0,0,0.3,0.3,0.3] + """ + _, inverse_indices, counts = torch.unique_consecutive(loss_mask, return_inverse=True, return_counts=True) + counts = torch.gather(dim=0, index=inverse_indices, input=counts) + return loss_mask / counts + +def get_batch(data): + """ + Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator` & in packed fashion + + data: + decoder_tokens = [[6, 7, 8, 3, 4, 5, 0]] + decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]] + decoder_is_inputs = [[1, 1, 0, 1, 1, 0, 0]] + """ + args = get_args() + tokenizer = get_tokenizer() + + # Broadcast data. + if data is not None: + data = next(data) + else: + data = None + + data_b = mpu.broadcast_data(["decoder_token_ids", "decoder_segment_ids"], data, torch.int64) + data_c = mpu.broadcast_data(["decoder_is_inputs"], data, torch.bool) + + # Unpack. + tokens_ = data_b["decoder_token_ids"].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + segment_ids = data_b["decoder_segment_ids"].long()[:, :-1] + decoder_is_inputs = data_c["decoder_is_inputs"][:, :-1] + + # Get the masks and position ids. + causal_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss, + prefix_indices=None, + loss_on_targets_only=False # This is done below + ) + # Only compute loss over causal target tokens, i.e. ignore input_tokens & padding + loss_on_targets_only = ~data_c["decoder_is_inputs"][:, 1:] + loss_on_non_pad_only = (labels != tokenizer.pad) + loss_mask *= loss_on_targets_only * loss_on_non_pad_only + + attention_mask = get_packed_attention_mask( + # Run non-causal decoder + is_causal=not(args.prefixlm), + causal_mask=~(causal_mask.bool()), # Turn back into tril being ones + decoder_is_inputs=decoder_is_inputs.bool(), + segment_ids=segment_ids.long(), + ) + + if args.norm_target_loss: + loss_mask = loss_mask.view(-1) + loss_mask = fast_normalize(loss_mask) + + if args.position_embedding_type == PositionEmbeddingType.absolute: + # Create position ids from segment_ids + # segment_ids = torch.tensor([[1, 1, 1, 2, 2, 2, 2, 0]]) (Shape: (batch_size, seq_len)) + # position_ids = torch.tensor([[0, 1, 2, 0, 1, 2, 3, 0]]) (Shape: (batch_size, seq_len)) + # I.e. they should restart for each new segment from 0 + position_ids = [] + for b in segment_ids: + counts = torch.unique_consecutive(b, return_counts=True, dim=-1)[1] + p = torch.cat([torch.arange(c) for c in counts]) + position_ids.append(p) + position_ids = torch.stack(position_ids) + + + #if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]: + # raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.") + + return tokens, labels, loss_mask, attention_mask, position_ids + #return (tokens, position_ids, attention_mask), (labels, loss_mask) + +def loss_func(loss_mask, output_tensor): + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + + return loss, {'lm loss': averaged_loss[0]} + +def forward_step(data_iterator, model): + """Forward step.""" + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch-generator').start() + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator) + timers('batch-generator').stop() + + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels) + + return output_tensor, partial(loss_func, loss_mask) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + args = get_args() + train_ds, valid_ds, test_ds = None, None, None + + tokenizer = get_tokenizer() + + print_rank_0("> building train, validation, and test datasets for MTF ...") + # Option 1 of data loading using --data-path + if args.data_path: + # TODO: Not yet compatible with dataset weights (Will break at prefixes, weights = analyze_data_prefix(args.data_path)) + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + data_prefix=args.data_path, + data_impl=args.data_impl, + splits_string=args.split, + seq_length=args.seq_length + 1, + pad_token=tokenizer.pad, + eos_token=tokenizer.eos, + train_valid_test_num_samples=train_val_test_num_samples, + seed=args.seed, + skip_warmup=(not args.mmap_warmup) + ) + # Option 2 of data loading using --(train|valid|test)-weighted-split-paths + elif args.train_weighted_split_paths: + assigned_train_valid_test = [] + if args.train_weighted_split_paths is not None: + train_ds = [] + assigned_train_valid_test.append("train") + if args.valid_weighted_split_paths is not None: + valid_ds = [] + assigned_train_valid_test.append("valid") + if args.test_weighted_split_paths is not None: + test_ds = [] + assigned_train_valid_test.append("test") + + for s in assigned_train_valid_test: + data_groups = zip(eval(f"args.{s}_weighted_split_paths"), + eval(f"args.{s}_weighted_split_weights"), + eval(f"args.{s}_weighted_split_splits"), + eval(f"args.{s}_weighted_split_names")) + for paths, weights, splits, name in data_groups: + d = build_dataset_group( + dataset_group_name=name, + paths=paths, + weights=weights, + splits=splits, + data_impl=args.data_impl, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length + 1, + pad_token=tokenizer.pad, + eos_token=tokenizer.eos, + seed=args.seed, + skip_warmup=(not args.mmap_warmup), + train_valid_test=s + ) + eval(f"{s}_ds").append(d) + else: + raise NotImplementedError("No dataloading argument passed") + + print_rank_0("> finished creating MTF datasets ...") + return train_ds, valid_ds, test_ds + +if __name__ == "__main__": + + pretrain(train_valid_test_datasets_provider, model_provider, + ModelType.encoder_or_decoder, + forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) diff --git a/megatron/arguments.py b/megatron/arguments.py index 73e33f51cb..4e73057509 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -1055,7 +1055,10 @@ def __call__(self, parser, args, values, option_string=None): group.add_argument('--fim-spm-rate', type=float, default=0.5, help='Probability that the a FIM sample uses the SPM format over the PSM format. ' 'At 1, exclusively train with SPM. At 0, exclusively train with PSM') - + group.add_argument('--loss-on-targets-only', action='store_true', + help='Mask loss on input sequence.') + group.add_argument('--norm-target-loss', action='store_true', + help='Normalize the loss per target. Used for multi-task finetuning with packing.') return parser diff --git a/megatron/data/decoder_packed_mtf_dataset.py b/megatron/data/decoder_packed_mtf_dataset.py new file mode 100644 index 0000000000..0ef812544b --- /dev/null +++ b/megatron/data/decoder_packed_mtf_dataset.py @@ -0,0 +1,532 @@ +import os +import time + +import numpy as np +import torch + +from megatron import print_rank_0, mpu, logging +from megatron.data.blendable_dataset import BlendableDataset +from megatron.data.dataset_utils import get_datasets_weights_and_num_samples, get_split_by_range_, \ + get_train_valid_test_split_ +from megatron.data.mtf_dataset import MTFDataset +from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset + +logger = logging.get_logger(__name__) + +def build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + seq_length: int, + pad_token: int, + eos_token: int, + train_valid_test_num_samples, + seed, + skip_warmup +): + """Build train, valid, and test datasets.""" + + # Single dataset. + if len(data_prefix) == 1: + all_train_datasets, all_valid_datasets, all_test_datasets = _build_train_valid_test_datasets( + data_prefix=data_prefix[0], + data_impl=data_impl, + splits_string=splits_string, + seq_length=seq_length, + pad_token=pad_token, + eos_token=eos_token, + train_valid_test_num_samples=train_valid_test_num_samples, + seed=seed, + skip_warmup=skip_warmup + ) + # Blending dataset. + else: + + output = get_datasets_weights_and_num_samples(data_prefix=data_prefix, train_valid_test_num_samples=train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + + # Build individual datasets. + train_datasets = [] + valid_datasets = [] + test_datasets = [] + for i in range(len(prefixes)): + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + data_prefix=prefixes[i], + data_impl=data_impl, + splits_string=splits_string, + seq_length=seq_length, + pad_token=pad_token, + eos_token=eos_token, + train_valid_test_num_samples=datasets_train_valid_test_num_samples[i], + seed=seed, + skip_warmup=skip_warmup + ) + if train_ds: + train_datasets.append(train_ds) + if valid_ds: + valid_datasets.append(valid_ds) + if test_ds: + test_datasets.append(test_ds) + + all_train_datasets = BlendableDataset(train_datasets, weights) \ + if train_datasets else None + all_valid_datasets = BlendableDataset(valid_datasets, weights) \ + if valid_datasets else None + all_test_datasets = BlendableDataset(test_datasets, weights) \ + if test_datasets else None + + return all_train_datasets, all_valid_datasets, all_test_datasets + + +def build_dataset_group( + dataset_group_name, + paths, + weights, + splits, + data_impl, + seq_length: int, + pad_token: int, + eos_token: int, + train_valid_test_num_samples, + seed, + skip_warmup, + train_valid_test +): + ''' + Build a single dataset group corresponding to Option 2 of data loading see arguments.py + a dataset group is passed in the following form + GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2, WEIGHT2 START:END PATH2 + or alternatively + GIVEN_NAME PATH1 # for a single dataset to be used fully + ''' + + assert train_valid_test in ["train","valid","test"] + + # Single dataset. + if len(paths) == 1: + dataset = _build_single_datasets( + data_prefix=paths[0], + range_string=splits[0], + data_impl=data_impl, + seq_length=seq_length, + pad_token=pad_token, + eos_token=eos_token, + train_valid_test_num_samples=train_valid_test_num_samples, + seed=seed, + skip_warmup=skip_warmup, + dataset_group_name=dataset_group_name, + train_valid_test=train_valid_test + ) + return dataset + # Blending dataset. + else: + + data_prefix = [] + # data_prefix is of the shape: + # ["WEIGHT1", "PATH1", "WEIGHT2", "PATH2", "WEIGHT3", "PATH3"] + for w,p in zip(weights, paths): + data_prefix += [w,p] + + output = get_datasets_weights_and_num_samples(data_prefix, + train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + + # Build individual datasets. + datasets = [] + for i in range(len(prefixes)): + ds = _build_single_datasets( + data_prefix=prefixes[i], + range_string=splits[i], + data_impl=data_impl, + seq_length=seq_length, + pad_token=pad_token, + eos_token=eos_token, + train_valid_test_num_samples=datasets_train_valid_test_num_samples[i], + seed=seed, + skip_warmup=skip_warmup, + dataset_group_name=dataset_group_name, + train_valid_test=train_valid_test + ) + + datasets.append(ds) + all_datasets = BlendableDataset(datasets, weights) + + return all_datasets + +def _build_single_datasets( + data_prefix, + range_string, + data_impl, + seq_length: int, + pad_token: int, + eos_token: int, + train_valid_test_num_samples, + seed, + skip_warmup, + dataset_group_name, + train_valid_test +): + """Build a single dataset""" + + assert train_valid_test in ["train","valid","test"] + index = ["train","valid","test"].index(train_valid_test) + + # Target indexed dataset. + target_indexed_dataset = get_indexed_dataset( + data_prefix=data_prefix, + is_input=False, + data_impl=data_impl, + skip_warmup=skip_warmup + ) + + total_num_of_documents = target_indexed_dataset.sizes.shape[0] + # this corresponds to option2 for data loading on the form + # WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2, WEIGHT3 START:END PATH3 + # splits here is an array of size 2 [start_index, end_index] + splits = get_split_by_range_(range_string=range_string, size=total_num_of_documents) + + # Print stats about the splits. + print_rank_0(' > dataset split:') + + print_rank_0(' {}:'.format(dataset_group_name)) + print_rank_0(' document indices in [{}, {}) total of {} ' + 'documents'.format(splits[0], splits[1], + splits[1] - splits[0])) + + def build_dataset(name): + dataset = None + if splits[1] > splits[0]: + documents = np.arange(start=splits[0], stop=splits[1], + step=1, dtype=np.int32) + dataset = DecoderPackedMTFDataset( + name=name, + data_prefix=data_prefix, + data_impl=data_impl, + skip_warmup=skip_warmup, + documents=documents, + seq_length=seq_length, + pad_token=pad_token, + eos_token=eos_token, + num_samples=train_valid_test_num_samples[index], + seed=seed + ) + return dataset + + dataset = build_dataset(dataset_group_name) + + return dataset + + +def _build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + seq_length: int, + pad_token: int, + eos_token: int, + train_valid_test_num_samples, + seed, + skip_warmup +): + """Build train, valid, and test datasets.""" + + # Target indexed dataset. + target_indexed_dataset = get_indexed_dataset(data_prefix, is_input=False, data_impl=data_impl, skip_warmup=skip_warmup) + + total_num_of_documents = target_indexed_dataset.sizes.shape[0] + # splits here is an array of size 4 [train_start_index, valid_start_index, test_start_index, test_end_index] + splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + # Print stats about the splits. + print_rank_0(' > dataset split:') + + def print_split_stats(name, index): + print_rank_0(' {}:'.format(name)) + print_rank_0(' document indices in [{}, {}) total of {} ' + 'documents'.format(splits[index], splits[index + 1], + splits[index + 1] - splits[index])) + print_split_stats('train', 0) + print_split_stats('validation', 1) + print_split_stats('test', 2) + + def build_dataset(index, name): + dataset = None + if splits[index + 1] > splits[index]: + documents = np.arange(start=splits[index], stop=splits[index + 1], + step=1, dtype=np.int32) + dataset = DecoderPackedMTFDataset( + name=name, + data_prefix=data_prefix, + data_impl=data_impl, + skip_warmup=skip_warmup, + documents=documents, + seq_length=seq_length, + pad_token=pad_token, + eos_token=eos_token, + num_samples=train_valid_test_num_samples[index], + seed=seed + ) + return dataset + + train_dataset = build_dataset(0, 'train') + valid_dataset = build_dataset(1, 'valid') + test_dataset = build_dataset(2, 'test') + + return (train_dataset, valid_dataset, test_dataset) + + +class DecoderPackedMTFDataset(torch.utils.data.Dataset): + + def __init__( + self, + name, + data_prefix, + data_impl, + skip_warmup, + documents, + num_samples, + seq_length: int, + pad_token: int, + eos_token: int, + seed, + ): + self.mtf_dataset = MTFDataset(name=name, data_prefix=data_prefix, data_impl=data_impl, skip_warmup=skip_warmup, documents=documents) + + self.pad_token = pad_token + self.seq_length = seq_length + + self.sample_index, self.shuffle_index = _build_index_mappings(name=name, data_prefix=data_prefix, nb_documents=len(documents), mtf_dataset=self.mtf_dataset, num_samples=num_samples, seq_length=seq_length, seed=seed) + + def __len__(self): + return len(self.sample_index) + + def __getitem__(self, idx): + # Get the shuffled index. + start, end = self.sample_index[idx] + mtf_samples_indices = self.shuffle_index[start: end] + # TODO @thomasw21 build a dataset that generates an entire batch instead of a row (allows for more optimization) + items = [self.mtf_dataset[sample_id] for sample_id in mtf_samples_indices] + + return self.pack_samples(items) + + def pack_samples(self, items): + """ + Greedily packs samples. + + Items: + [ + { + 'input_tokens': array([6, 7]), + 'target_tokens': array([8]) + }, + { + 'input_tokens': array([3, 4]), + 'target_tokens': array([5]) + } + ] + + Output: + decoder_tokens = [[6, 7, 8, 3, 4, 5, ]]: Concatenation of tokens followed with padding tokens. + decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]: Segment ids determine original documents. + decoder_is_inputs = [[1, 1, 0, 1, 1, 0, 0]]: `1` depicts inputs, `0` depicts target. + """ + + decoder_tokens = np.full((self.seq_length,), self.pad_token, dtype=np.int64) + decoder_segment_ids = np.zeros((self.seq_length,), dtype=np.int64) + decoder_is_inputs = np.full((self.seq_length,), False, dtype=bool) + + # `0` is reserved for padding + item_num = 1 + cur_len = 0 + + assert len(items) > 0 + + for token_dict in items: + input_token_len = len(token_dict["input_tokens"]) + target_token_len = len(token_dict["target_tokens"]) + + total_len = input_token_len + target_token_len + + if cur_len + total_len > self.seq_length: + # This should not happen at the indexing should only allow the correct number of items + raise ValueError(f"""Items to be packed do not fit inside a single sample. + current length: {cur_len} + input tokens length: {input_token_len} + target token length: {target_token_len} + expected sequence length: {self.seq_length} + """) + + decoder_tokens[cur_len: cur_len + input_token_len] = token_dict["input_tokens"] + decoder_tokens[cur_len + input_token_len: cur_len + total_len] = token_dict["target_tokens"] + decoder_segment_ids[cur_len: cur_len + total_len] = item_num + decoder_is_inputs[cur_len: cur_len + input_token_len] = True # inputs + # targets are already 0 at init, no need to update `decoder_is_inputs` + + item_num += 1 + cur_len += total_len + assert cur_len <= self.seq_length + + return { + "decoder_token_ids": decoder_tokens, + "decoder_segment_ids": decoder_segment_ids, + "decoder_is_inputs": decoder_is_inputs, + } + + +def _build_index_mappings( + name, + data_prefix, + nb_documents, + mtf_dataset, + num_samples: int, + seq_length: int, + seed, +): + """ + - `shuffle_index` is [num_epoch * len(self.mtf)] + - `sample_index` is [num_sample, 2] (storing the start and end of the sample). We query the sample via `self.shuffle_index[start:end]` + + TODO @thomas21 Instead of loading individually samples, we save the packing one and for all + """ + # rng state + np_rng = np.random.RandomState(seed=seed) + + # Filename of the index mappings. + _filename = data_prefix + _filename += '_{}_indexmap'.format(name) + _filename += '_{}ns'.format(num_samples) + _filename += '_{}s'.format(seed) + sample_idx_filename = _filename + '_decoder_packed_batch_idx.npy' + shuffle_idx_filename = _filename + '_decoder_packed_shuffle_idx.npy' + + # Build the indexed mapping if not exist. + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + if (not os.path.isfile(sample_idx_filename)) or \ + (not os.path.isfile(shuffle_idx_filename)): + + print_rank_0(' > WARNING: could not find index map files, building ' + 'the indices on rank 0 ...') + + # iteratively add the entire dataset for every epoch and see if it's enough given current packing strategy + start_time = time.time() + row_offset = 0 + old_sample_start = 0 + epoch = 0 + shuffle_idx = [] + sample_idx = [] + while len(sample_idx) <= num_samples: + new_document_ids = _build_shuffle_idx(nb_documents=nb_documents, np_rng=np_rng) + # Generate a shuffling of the entire dataset + shuffle_idx.append(new_document_ids) + # Packs them into a single sample + new_samples, row_offset, old_sample_start = _build_sample_idx( + mtf_dataset=mtf_dataset, + document_ids=new_document_ids, + seq_length=seq_length, + row_offset=row_offset, + old_sample_start=old_sample_start, + epoch=epoch + ) + sample_idx.extend(new_samples) + epoch += 1 + + shuffle_idx = np.concatenate(shuffle_idx, axis=0) + sample_idx = np.stack(sample_idx, axis=0) + + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save shuffle-idx and sample-idx mapping' + ' (seconds): {:4f}'.format(time.time() - start_time)) + + if torch.distributed.is_initialized(): + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) + assert counts[0].item() == ( + torch.distributed.get_world_size() // + torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) + + # Load mappings. + start_time = time.time() + print_rank_0(' > loading doc-idx mapping from {}'.format( + sample_idx_filename)) + sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r') + print_rank_0(' > loading shuffle-idx mapping from {}'.format( + shuffle_idx_filename)) + shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r') + print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( + time.time() - start_time)) + + return sample_idx, shuffle_idx + +def _build_sample_idx(mtf_dataset, document_ids, seq_length, row_offset, old_sample_start, epoch): + """Build start and off index of each `full` batch, return that list of batch + start of the unfinished batch""" + row_length = row_offset + + full_samples = [] + current_sample_start = old_sample_start + epoch_offset = epoch * len(document_ids) + + assert epoch_offset >= current_sample_start + for current_sample_end, document_id in enumerate(document_ids): + current_sample_end = epoch_offset + current_sample_end + sample_sizes = mtf_dataset.size(document_id) + + # TODO @thomasw21 figure out if we add tokens + tok_len = sample_sizes["input_tokens"] + sample_sizes["target_tokens"] + + row_length = row_length + tok_len + if row_length > seq_length: + # current sample can't be added and requires to be added in the next one + if current_sample_end > current_sample_start: + full_samples.append(np.asarray([current_sample_start, current_sample_end])) + current_sample_start = current_sample_end + row_length = tok_len + + if tok_len > seq_length: + # TODO @thomasw21 handle the case where a single sample cannot fit inside a row. We can + # - silently skip that value [currently implemented] + # - truncate to `seq_length`, and keep the right part + logger.warning(f"Skipping sample id={document_id}. Maximum sequence length: {seq_length}, sample length: {tok_len}") + current_sample_start = current_sample_end + 1 # skipping + row_length = 0 + continue + + return full_samples, row_length, current_sample_start + +def _build_shuffle_idx(nb_documents: int, np_rng): + """Build the range [0, dataset_size) and shuffle.""" + dtype_ = np.int64 + + result = np.arange(start=0, stop=nb_documents, step=1, dtype=dtype_) + + # in-place shuffling + np_rng.shuffle(result) + + return result + + +def get_indexed_dataset(data_prefix: str, is_input: bool, data_impl: str, skip_warmup: bool): + if is_input: + field = "inputs" + else: + field = "targets" + + return get_indexed_dataset_(f"{data_prefix}_{field}_document", data_impl, skip_warmup) + + +def get_indexed_dataset_(path, data_impl, skip_warmup): + """Build indexed dataset.""" + print_rank_0(' > building dataset index ...') + start_time = time.time() + indexed_dataset = make_indexed_dataset(path, + data_impl, + skip_warmup) + print_rank_0(' > finished creating indexed dataset in {:4f} ' + 'seconds'.format(time.time() - start_time)) + print_rank_0(' number of documents: {}'.format( + indexed_dataset.sizes.shape[0])) + + return indexed_dataset diff --git a/megatron/data/mtf_dataset.py b/megatron/data/mtf_dataset.py new file mode 100644 index 0000000000..57f3a779b0 --- /dev/null +++ b/megatron/data/mtf_dataset.py @@ -0,0 +1,90 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multitask Finetune style dataset.""" + +import time + +import numpy as np +import torch + +from megatron import print_rank_0 +from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset + +class MTFDataset(torch.utils.data.Dataset): + + def __init__( + self, + name, + data_prefix, + data_impl, + skip_warmup, + documents, + ): + # Params to store. + self.name = name + + # Dataset. + self.input_indexed_dataset = get_indexed_dataset(data_prefix, is_input=True, data_impl=data_impl, skip_warmup=skip_warmup) + self.target_indexed_dataset = get_indexed_dataset(data_prefix, is_input=False, data_impl=data_impl, skip_warmup=skip_warmup) + + # Checks + assert np.min(documents) >= 0 + assert np.max(documents) < self.input_indexed_dataset.sizes.shape[0] + assert np.max(documents) < self.target_indexed_dataset.sizes.shape[0] + assert self.input_indexed_dataset.sizes.shape[0] == self.target_indexed_dataset.sizes.shape[0] + + def __len__(self): + return len(self.input_indexed_dataset) + + def __getitem__(self, idx): + input_tokens = self.input_indexed_dataset.get(idx) + target_tokens = self.target_indexed_dataset.get(idx) + + assert len(input_tokens) > 0 + assert len(target_tokens) > 0 + + return { + 'input_tokens': input_tokens, + 'target_tokens': target_tokens, + } + + def size(self, index): + return { + 'input_tokens': self.input_indexed_dataset.size(index), + 'target_tokens': self.target_indexed_dataset.size(index), + } + +def get_indexed_dataset(data_prefix: str, is_input: bool, data_impl: str, skip_warmup: bool): + if is_input: + field = "inputs" + else: + field = "targets" + + return get_indexed_dataset_(f"{data_prefix}_{field}_document", data_impl, skip_warmup) + +def get_indexed_dataset_(path, data_impl, skip_warmup): + """Build indexed dataset.""" + print_rank_0(' > building dataset index ...') + start_time = time.time() + indexed_dataset = make_indexed_dataset(path, + data_impl, + skip_warmup) + print_rank_0(' > finished creating indexed dataset in {:4f} ' + 'seconds'.format(time.time() - start_time)) + print_rank_0(' number of documents: {}'.format( + indexed_dataset.sizes.shape[0])) + + return indexed_dataset diff --git a/megatron/utils.py b/megatron/utils.py index d115f815a4..6c9d3452c5 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -269,3 +269,87 @@ def get_tflops(batch_size, elapsed_time_per_iteration): tflops = flops_per_iteration / (elapsed_time_per_iteration * args.world_size * (10**12)) return tflops + + +def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decoder_is_inputs: torch.Tensor, segment_ids: torch.Tensor): + """ + Inspired by https://github.com/google-research/t5x/blob/7193407f98a8b18100b71a04ff777238be1682ca/t5x/examples/decoder_only/layers.py#L978 + Arguments: + - is_causal: determines if the masking should be causal in the `inputs` part + - causal_mask: torch.BoolTensor [batch_size, sequence_length, sequence_length] + - decoder_is_inputs: torch.BoolTensor [batch_size, sequence_length] + - segment_ids: torch.IntTensor [batch_size, sequence_length] + Returns: + - attention_mask: torch.BoolTensor [batch_size, 1, sequence_length, sequence_length] + Input example for the mask examples: + att_mask_batch = 1 + seq_length = 7 + decoder_is_inputs = torch.tensor([[1, 1, 0, 1, 1, 0, 0]]) + segment_ids = torch.tensor([[1, 1, 1, 2, 2, 2, 0]]) + causal_mask = torch.tril(torch.ones(att_mask_batch, seq_length, seq_length)).view(att_mask_batch, 1, seq_length, seq_length) + """ + + """Causal Inputs Mask: + mask = [[[[1, 1, 0, 1, 1, 0, 0], + [1, 1, 0, 1, 1, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 1]]]] + """ + assert causal_mask.dtype == torch.bool + assert segment_ids.dtype == torch.long + if is_causal: + causal_inputs_mask = causal_mask + else: + assert decoder_is_inputs.dtype == torch.bool + inputs_mask = decoder_is_inputs[:, None, :, None] * decoder_is_inputs[:, None, None, :] + causal_inputs_mask = causal_mask + inputs_mask + + """Padding Mask: + mask = [[[[1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0]]]] + """ + padding_mask = (segment_ids != 0)[:, None, :, None] * (segment_ids != 0)[:, None, None, :] + + """Segment Mask: + mask = [[[[1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 1]]]] + """ + segment_mask = segment_ids[:, None, :, None] == segment_ids[:, None, None, :] + + """Final Mask: + mask = [[[[1, 1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0]]]] + + If is_causal=True: + mask = [[[[1, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0]]]] + + """ + + attention_mask = causal_inputs_mask * padding_mask * segment_mask + + # True for places we do not want to attend to + return ~attention_mask \ No newline at end of file From 181593c69e1ff86ab1d63c45ef013af16693ae07 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Sat, 15 Apr 2023 10:44:21 +0200 Subject: [PATCH 03/22] Fix logging --- megatron/data/decoder_packed_mtf_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/data/decoder_packed_mtf_dataset.py b/megatron/data/decoder_packed_mtf_dataset.py index 0ef812544b..81ea9bedb4 100644 --- a/megatron/data/decoder_packed_mtf_dataset.py +++ b/megatron/data/decoder_packed_mtf_dataset.py @@ -4,7 +4,7 @@ import numpy as np import torch -from megatron import print_rank_0, mpu, logging +from megatron import print_rank_0, mpu from megatron.data.blendable_dataset import BlendableDataset from megatron.data.dataset_utils import get_datasets_weights_and_num_samples, get_split_by_range_, \ get_train_valid_test_split_ From 005dbb0e7824addbdb0123fe33fc9fc1a1c5488a Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Sat, 15 Apr 2023 10:45:31 +0200 Subject: [PATCH 04/22] Add --- megatron/data/decoder_packed_mtf_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/megatron/data/decoder_packed_mtf_dataset.py b/megatron/data/decoder_packed_mtf_dataset.py index 81ea9bedb4..f12b9e45d9 100644 --- a/megatron/data/decoder_packed_mtf_dataset.py +++ b/megatron/data/decoder_packed_mtf_dataset.py @@ -11,6 +11,7 @@ from megatron.data.mtf_dataset import MTFDataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset +import logging logger = logging.get_logger(__name__) def build_train_valid_test_datasets( From b64ade1a578a4f26793811a22a53e005a68ff798 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Sat, 15 Apr 2023 10:47:08 +0200 Subject: [PATCH 05/22] Fix warn --- megatron/data/decoder_packed_mtf_dataset.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/megatron/data/decoder_packed_mtf_dataset.py b/megatron/data/decoder_packed_mtf_dataset.py index f12b9e45d9..b13837d719 100644 --- a/megatron/data/decoder_packed_mtf_dataset.py +++ b/megatron/data/decoder_packed_mtf_dataset.py @@ -12,7 +12,6 @@ from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset import logging -logger = logging.get_logger(__name__) def build_train_valid_test_datasets( data_prefix, @@ -490,7 +489,7 @@ def _build_sample_idx(mtf_dataset, document_ids, seq_length, row_offset, old_sam # TODO @thomasw21 handle the case where a single sample cannot fit inside a row. We can # - silently skip that value [currently implemented] # - truncate to `seq_length`, and keep the right part - logger.warning(f"Skipping sample id={document_id}. Maximum sequence length: {seq_length}, sample length: {tok_len}") + logging.warning(f"Skipping sample id={document_id}. Maximum sequence length: {seq_length}, sample length: {tok_len}") current_sample_start = current_sample_end + 1 # skipping row_length = 0 continue From 9c66466ed9826f0970960a807331bfb9e86a411e Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Sat, 15 Apr 2023 13:41:20 +0200 Subject: [PATCH 06/22] Add PAD --- megatron/tokenizer/tokenizer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/megatron/tokenizer/tokenizer.py b/megatron/tokenizer/tokenizer.py index 87aeb512f4..3dc17719ea 100644 --- a/megatron/tokenizer/tokenizer.py +++ b/megatron/tokenizer/tokenizer.py @@ -326,6 +326,8 @@ def __init__(self, tokenizer_file, special_tokens=None): self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file, errors='replace', max_len=None) self.tokenizer.add_special_tokens({'additional_special_tokens': special_tokens}) self.eod_id = self.tokenizer.vocab[EOD] + if FIM_PAD in self.tokenizer.vocab: + self.pad_id = self.tokenizer.vocab[FIM_PAD] # Token->id mapping for additional special-tokens self.special_tokens = { tok: self.tokenizer.vocab[tok] for tok in special_tokens @@ -353,3 +355,10 @@ def detokenize(self, token_ids): @property def eod(self): return self.eod_id + + @property + def pad(self): + if hasattr(self, 'pad_id'): + return self.pad_id + else: + raise ValueError('PAD token not found in the vocabulary') From 65c81aecbe06587154a6cf321005b229a766a483 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Sat, 15 Apr 2023 13:47:49 +0200 Subject: [PATCH 07/22] eos > eod --- finetune_mtf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/finetune_mtf.py b/finetune_mtf.py index 6492185d8d..b015288653 100644 --- a/finetune_mtf.py +++ b/finetune_mtf.py @@ -159,7 +159,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): splits_string=args.split, seq_length=args.seq_length + 1, pad_token=tokenizer.pad, - eos_token=tokenizer.eos, + eos_token=tokenizer.eod, train_valid_test_num_samples=train_val_test_num_samples, seed=args.seed, skip_warmup=(not args.mmap_warmup) @@ -192,7 +192,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): train_valid_test_num_samples=train_val_test_num_samples, seq_length=args.seq_length + 1, pad_token=tokenizer.pad, - eos_token=tokenizer.eos, + eos_token=tokenizer.eod, seed=args.seed, skip_warmup=(not args.mmap_warmup), train_valid_test=s From fb234a3c03292b27223d0638d56b344c5ba37c9d Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Sat, 15 Apr 2023 13:52:11 +0200 Subject: [PATCH 08/22] Add size --- megatron/data/indexed_dataset.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 2f6e1b845c..4e73896810 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -520,6 +520,9 @@ def get(self, idx, offset=0, length=None): def sizes(self): return self._index.sizes + def size(self, index): + return self._index.sizes[index] + @property def doc_idx(self): return self._index.doc_idx From 9dde9db5a8272b6451db8d9795db924273125a51 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Sat, 15 Apr 2023 14:00:04 +0200 Subject: [PATCH 09/22] Fix args --- finetune_mtf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/finetune_mtf.py b/finetune_mtf.py index b015288653..d0f56ab14f 100644 --- a/finetune_mtf.py +++ b/finetune_mtf.py @@ -76,8 +76,8 @@ def get_batch(data): args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss, - prefix_indices=None, - loss_on_targets_only=False # This is done below + #prefix_indices=None, + #loss_on_targets_only=False # This is done below ) # Only compute loss over causal target tokens, i.e. ignore input_tokens & padding loss_on_targets_only = ~data_c["decoder_is_inputs"][:, 1:] From cbd0beb48c25296c99f7c7ed3364c39bcb103b93 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Sat, 15 Apr 2023 14:07:04 +0200 Subject: [PATCH 10/22] Make causal --- finetune_mtf.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/finetune_mtf.py b/finetune_mtf.py index d0f56ab14f..347d06c3bd 100644 --- a/finetune_mtf.py +++ b/finetune_mtf.py @@ -85,8 +85,7 @@ def get_batch(data): loss_mask *= loss_on_targets_only * loss_on_non_pad_only attention_mask = get_packed_attention_mask( - # Run non-causal decoder - is_causal=not(args.prefixlm), + is_causal=True, # Always make it causal for now; Could ablate this causal_mask=~(causal_mask.bool()), # Turn back into tril being ones decoder_is_inputs=decoder_is_inputs.bool(), segment_ids=segment_ids.long(), From 2d8e41460d4d8b569c79bf09cd271f448c1a4417 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Sat, 15 Apr 2023 14:13:14 +0200 Subject: [PATCH 11/22] Mov pos to dev --- finetune_mtf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finetune_mtf.py b/finetune_mtf.py index 347d06c3bd..89600e780b 100644 --- a/finetune_mtf.py +++ b/finetune_mtf.py @@ -105,7 +105,7 @@ def get_batch(data): counts = torch.unique_consecutive(b, return_counts=True, dim=-1)[1] p = torch.cat([torch.arange(c) for c in counts]) position_ids.append(p) - position_ids = torch.stack(position_ids) + position_ids = torch.stack(position_ids).to(tokens.device) #if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]: From bd2205731b3a148eb0aa139959a2420719644790 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Sat, 15 Apr 2023 14:49:33 +0200 Subject: [PATCH 12/22] Packed MTF mask --- finetune_mtf.py | 5 +-- .../fused_kernels/scaled_masked_softmax.h | 36 +++++++++++++++---- megatron/model/gpt_model.py | 15 ++++---- megatron/model/utils.py | 2 +- 4 files changed, 43 insertions(+), 15 deletions(-) diff --git a/finetune_mtf.py b/finetune_mtf.py index 89600e780b..921e89affb 100644 --- a/finetune_mtf.py +++ b/finetune_mtf.py @@ -9,7 +9,7 @@ from megatron import get_args, get_tokenizer, print_rank_0, mpu from megatron.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets, build_dataset_group -from megatron.model.enums import PositionEmbeddingType +from megatron.model.enums import PositionEmbeddingType, AttnMaskType #from megatron.model import GPTModelPipe from megatron.model import GPTModel, ModelType from megatron.training import pretrain @@ -28,7 +28,8 @@ def model_provider(pre_process=True, post_process=True): num_tokentypes=0, parallel_output=True, pre_process=pre_process, - post_process=post_process + post_process=post_process, + attn_mask_type=AttnMaskType.custom, ) return model diff --git a/megatron/fused_kernels/scaled_masked_softmax.h b/megatron/fused_kernels/scaled_masked_softmax.h index f9ca0bbc7e..8abca7e90d 100644 --- a/megatron/fused_kernels/scaled_masked_softmax.h +++ b/megatron/fused_kernels/scaled_masked_softmax.h @@ -47,6 +47,22 @@ __device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t * template <> __device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } +template +__device__ __inline__ void copy_zero_vector(Datatype *dst); + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + + int log2_ceil(int value) { int log2_value = 0; while ((1 << log2_value) < value) ++log2_value; @@ -269,7 +285,7 @@ __global__ void scaled_masked_softmax_warp_forward( if (temp_mask[element] != 1) { elements[i][it + element] = (acc_t)temp_data[element] * scale; } else { - elements[i][it + element] = -10000.0; + elements[i][it + element] = -std::numeric_limits::infinity(); } } } else { @@ -298,7 +314,11 @@ __global__ void scaled_masked_softmax_warp_forward( for (int i = 0; i < WARP_BATCH; ++i) { #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); + if (elements[i][it] <= -std::numeric_limits::infinity()) { + elements[i][it] = 0.0f; + } else { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + } sum[i] += elements[i][it]; } } @@ -314,11 +334,15 @@ __global__ void scaled_masked_softmax_warp_forward( for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < element_count) { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = elements[i][it + element] / sum[i]; + if (sum[i] == 0.0f) { + copy_zero_vector(dst + i * element_count + it * WARP_SIZE); + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] / sum[i]; + } + copy_vector(dst + i * element_count + it * WARP_SIZE, out); } - copy_vector(dst + i * element_count + it * WARP_SIZE, out); } else { break; } diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index b6a1d7b5e9..86899da7b4 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -58,11 +58,14 @@ def post_language_model_processing(lm_output, labels, logit_weights, class GPTModel(MegatronModule): """GPT-2 Language model.""" - def __init__(self, - num_tokentypes=0, - parallel_output=True, - pre_process=True, - post_process=True): + def __init__( + self, + num_tokentypes=0, + parallel_output=True, + pre_process=True, + post_process=True, + attn_mask_type: AttnMaskType = AttnMaskType.causal, + ): super(GPTModel, self).__init__() args = get_args() @@ -74,7 +77,7 @@ def __init__(self, self.language_model, self._language_model_key = get_language_model( num_tokentypes=num_tokentypes, add_pooler=False, - encoder_attn_mask_type=AttnMaskType.causal, + encoder_attn_mask_type=attn_mask_type, init_method=init_method_normal(args.init_method_std), scaled_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers), diff --git a/megatron/model/utils.py b/megatron/model/utils.py index f26b068534..1b85d12833 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -40,7 +40,7 @@ def init_(tensor): def attention_mask_func(attention_scores, attention_mask): - attention_scores.masked_fill_(attention_mask, -10000.0) + attention_scores.masked_fill_(attention_mask, torch.finfo(attention_scores.dtype).min) return attention_scores From 168e314283d5bb82715ae541b86b526416d169e0 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Sat, 15 Apr 2023 20:59:16 +0200 Subject: [PATCH 13/22] Add reset progress --- megatron/arguments.py | 3 ++- megatron/checkpointing.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 4e73057509..ceef583a05 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -774,7 +774,8 @@ def _add_checkpointing_args(parser): group.add_argument('--finetune-from', type=str, default=None, help='Directory containing a model checkpoint for finetuning.' 'Will be loaded if the `--load` directory contains no checkpoint') - + group.add_argument('--reset-progress', action='store_true', default=None, + help='Reset iteration to 0 & do not load args.') return parser diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index c3359ed18c..f57e3970b8 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -598,7 +598,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri # Check arguments. assert args.consumed_train_samples == 0 assert args.consumed_valid_samples == 0 - if 'args' in model_state_dict: + if 'args' in model_state_dict and not args.reset_progress: checkpoint_args = model_state_dict['args'] check_checkpoint_args(checkpoint_args) if not args.finetune: From 39915200ca567b3197a16cb99fc7d2dc3247de8b Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Sun, 16 Apr 2023 22:54:22 +0200 Subject: [PATCH 14/22] Copy tensors after loading ckpt --- megatron/checkpointing.py | 2 +- megatron/optimizer/optimizer.py | 3 +++ megatron/training.py | 4 ++++ 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index f57e3970b8..59fd715979 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -581,7 +581,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri set_checkpoint_version(model_state_dict.get('checkpoint_version', 0)) # Set iteration. - if args.finetune or release: + if args.finetune or release or args.reset_progress: iteration = 0 else: try: diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py index efa1bd36f8..f075ae2cb5 100644 --- a/megatron/optimizer/optimizer.py +++ b/megatron/optimizer/optimizer.py @@ -538,6 +538,9 @@ def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, params_have_main_grad, use_contiguous_buffers_in_local_ddp, fp16, bf16, grad_scaler, models) + + def init_param_groups(self): + # ====================== # main parameter stuff # ====================== diff --git a/megatron/training.py b/megatron/training.py index 468a600291..5e621e7916 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -394,6 +394,10 @@ def setup_model_and_optimizer(model_provider_func, else: args.iteration = 0 + # Init param groups + if hasattr(optimizer, 'init_param_groups'): + optimizer.init_param_groups() + # We only support local DDP with multiple micro-batches. if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1: assert args.DDP_impl == 'local' From ad603d615447994d99179e798e590e66509fd47a Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Mon, 17 Apr 2023 10:48:18 +0200 Subject: [PATCH 15/22] Fix spec toks; Opt ckpt loading; newline --- megatron/optimizer/optimizer.py | 2 ++ megatron/tokenizer/tokenizer.py | 8 ++++---- megatron/utils.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py index f075ae2cb5..2ad3c12f76 100644 --- a/megatron/optimizer/optimizer.py +++ b/megatron/optimizer/optimizer.py @@ -699,6 +699,8 @@ def state_dict(self): def load_state_dict(self, state_dict): + if not hasattr(self, 'fp32_from_float16_groups'): + self.init_param_groups() # Optimizer. optimizer_key = 'optimizer' if optimizer_key not in state_dict: diff --git a/megatron/tokenizer/tokenizer.py b/megatron/tokenizer/tokenizer.py index 3dc17719ea..cd3343a7fc 100644 --- a/megatron/tokenizer/tokenizer.py +++ b/megatron/tokenizer/tokenizer.py @@ -23,10 +23,10 @@ from .gpt2_tokenization import GPT2Tokenizer -FIM_PREFIX = "" -FIM_MIDDLE = "" -FIM_SUFFIX = "" -FIM_PAD = "" +FIM_PREFIX = "" +FIM_MIDDLE = "" +FIM_SUFFIX = "" +FIM_PAD = "" EOD = "<|endoftext|>" diff --git a/megatron/utils.py b/megatron/utils.py index 6c9d3452c5..773a275bb7 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -352,4 +352,4 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode attention_mask = causal_inputs_mask * padding_mask * segment_mask # True for places we do not want to attend to - return ~attention_mask \ No newline at end of file + return ~attention_mask From 14bc3e5f1ff488a4b0d8c9bd4efefdf53bac0121 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Wed, 19 Apr 2023 11:24:51 +0200 Subject: [PATCH 16/22] Fix tok --- finetune_mtf.py | 29 +++++++++++++++++++++++++++-- megatron/tokenizer/tokenizer.py | 10 +++++++--- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/finetune_mtf.py b/finetune_mtf.py index 921e89affb..43370fc390 100644 --- a/finetune_mtf.py +++ b/finetune_mtf.py @@ -19,6 +19,25 @@ #import deepspeed #from deepspeed.runtime.utils import see_memory_usage +### Debugging Helpers ### + +def visualize_model_inputs(tokens, attention_mask, labels, loss_mask, position_ids): + tok = get_tokenizer() + print("TOKENS:", ",".join([tok.detokenize(tokens[0, i]) for i in range(100)])) + print("ATTN:", attention_mask[0, :, :100, :100]) + print("LABS:", labels[0, :100]) + print("LOSSMSK:", loss_mask[:100]) + print("POSIDS:", position_ids[0, :100]) + +def save_model_inputs(tokens, attention_mask, labels, loss_mask, position_ids, segment_ids): + """Save as tensors for debugging""" + torch.save(tokens, "tokens.pt") + torch.save(attention_mask, "attention_mask.pt") + torch.save(labels, "labels.pt") + torch.save(loss_mask, "loss_mask.pt") + torch.save(position_ids, "position_ids.pt") + torch.save(segment_ids, "segment_ids.pt") + exit() def model_provider(pre_process=True, post_process=True): """Build the model.""" @@ -81,9 +100,12 @@ def get_batch(data): #loss_on_targets_only=False # This is done below ) # Only compute loss over causal target tokens, i.e. ignore input_tokens & padding - loss_on_targets_only = ~data_c["decoder_is_inputs"][:, 1:] loss_on_non_pad_only = (labels != tokenizer.pad) - loss_mask *= loss_on_targets_only * loss_on_non_pad_only + if args.loss_on_targets_only: + loss_on_targets_only = ~data_c["decoder_is_inputs"][:, 1:] + loss_mask *= loss_on_targets_only * loss_on_non_pad_only + else: + loss_mask *= loss_on_non_pad_only attention_mask = get_packed_attention_mask( is_causal=True, # Always make it causal for now; Could ablate this @@ -112,6 +134,9 @@ def get_batch(data): #if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]: # raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.") + # visualize_model_inputs(tokens, attention_mask, labels, loss_mask, position_ids) + # save_model_inputs(tokens, attention_mask, labels, loss_mask, position_ids, segment_ids) + return tokens, labels, loss_mask, attention_mask, position_ids #return (tokens, position_ids, attention_mask), (labels, loss_mask) diff --git a/megatron/tokenizer/tokenizer.py b/megatron/tokenizer/tokenizer.py index cd3343a7fc..4e92596ab3 100644 --- a/megatron/tokenizer/tokenizer.py +++ b/megatron/tokenizer/tokenizer.py @@ -26,7 +26,9 @@ FIM_PREFIX = "" FIM_MIDDLE = "" FIM_SUFFIX = "" -FIM_PAD = "" +# SantaCoder & BigCode discrepancy +FIM_PAD_SC = "" +FIM_PAD_BC = "" EOD = "<|endoftext|>" @@ -326,8 +328,10 @@ def __init__(self, tokenizer_file, special_tokens=None): self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file, errors='replace', max_len=None) self.tokenizer.add_special_tokens({'additional_special_tokens': special_tokens}) self.eod_id = self.tokenizer.vocab[EOD] - if FIM_PAD in self.tokenizer.vocab: - self.pad_id = self.tokenizer.vocab[FIM_PAD] + if FIM_PAD_SC in self.tokenizer.vocab: + self.pad_id = self.tokenizer.vocab[FIM_PAD_SC] + elif FIM_PAD_BC in self.tokenizer.vocab: + self.pad_id = self.tokenizer.vocab[FIM_PAD_BC] # Token->id mapping for additional special-tokens self.special_tokens = { tok: self.tokenizer.vocab[tok] for tok in special_tokens From 5dc080fd12db1f6c91b88cf6b29ef312ff4b456e Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Thu, 20 Apr 2023 08:06:45 +0200 Subject: [PATCH 17/22] Fix var --- megatron/tokenizer/tokenizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/tokenizer/tokenizer.py b/megatron/tokenizer/tokenizer.py index 4e92596ab3..1ff3a1b065 100644 --- a/megatron/tokenizer/tokenizer.py +++ b/megatron/tokenizer/tokenizer.py @@ -56,13 +56,13 @@ def build_tokenizer(args): tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file) elif args.tokenizer_type == 'GPT2BPETokenizerWithFIM': assert args.merge_file is not None - tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file, special_tokens=[FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD]) + tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file, special_tokens=[FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD_BC]) elif args.tokenizer_type == "TokenizerFromFile": assert args.tokenizer_file is not None tokenizer = _HFTokenizer(args.tokenizer_file, special_tokens=[EOD]) elif args.tokenizer_type == "TokenizerFromFileWithFIM": assert args.tokenizer_file is not None - tokenizer = _HFTokenizer(args.tokenizer_file, special_tokens=[EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD]) + tokenizer = _HFTokenizer(args.tokenizer_file, special_tokens=[EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD_BC]) else: raise NotImplementedError('{} tokenizer is not ' 'implemented.'.format(args.tokenizer_type)) From b4609a79f6532f30194862923eb0f3ed33739160 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Thu, 20 Apr 2023 18:01:58 +0200 Subject: [PATCH 18/22] Clean init of optimizer when not laoaded --- megatron/optimizer/optimizer.py | 5 ----- megatron/training.py | 40 ++++++++++++++++++++++++--------- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py index 2ad3c12f76..efa1bd36f8 100644 --- a/megatron/optimizer/optimizer.py +++ b/megatron/optimizer/optimizer.py @@ -538,9 +538,6 @@ def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, params_have_main_grad, use_contiguous_buffers_in_local_ddp, fp16, bf16, grad_scaler, models) - - def init_param_groups(self): - # ====================== # main parameter stuff # ====================== @@ -699,8 +696,6 @@ def state_dict(self): def load_state_dict(self, state_dict): - if not hasattr(self, 'fp32_from_float16_groups'): - self.init_param_groups() # Optimizer. optimizer_key = 'optimizer' if optimizer_key not in state_dict: diff --git a/megatron/training.py b/megatron/training.py index 5e621e7916..1692e0b27f 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -153,6 +153,11 @@ def pretrain(train_valid_test_dataset_provider, print_rank_0('done with setup ...') timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup']) print_rank_0('training ...') + + print(f"{len(model)}") + print(f"{len(model[0])}") + print(f"{model[0].keys()}") + print(f"{model[0]['language_model']['embedding']['word_embeddings']['weight'][0,:10]}") iteration = 0 if args.do_train and args.train_iters > 0: @@ -377,26 +382,39 @@ def setup_model_and_optimizer(model_provider_func, unwrapped_model = unwrap_model(model, (torchDDP, LocalDDP, Float16Module)) - optimizer = get_megatron_optimizer(model, no_wd_decay_cond, - scale_lr_cond, lr_mult) - opt_param_scheduler = get_optimizer_param_scheduler(optimizer) - - if args.load is not None: + if (args.no_load_optim) and (args.load is not None): + # Load checkpoint first to copy over correct model params in the init of mix. prec. optimizers timers = get_timers() # Extra barrier is added to make sure all ranks report the # max time. torch.distributed.barrier() timers('load-checkpoint').start() - args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler) + # Optimizer is not loaded hence not needed + args.iteration = load_checkpoint(unwrapped_model, None, None) torch.distributed.barrier() timers('load-checkpoint').stop() timers.log(['load-checkpoint']) + optimizer = get_megatron_optimizer(model, no_wd_decay_cond, + scale_lr_cond, lr_mult) + opt_param_scheduler = get_optimizer_param_scheduler(optimizer) else: - args.iteration = 0 - - # Init param groups - if hasattr(optimizer, 'init_param_groups'): - optimizer.init_param_groups() + optimizer = get_megatron_optimizer(model, no_wd_decay_cond, + scale_lr_cond, lr_mult) + opt_param_scheduler = get_optimizer_param_scheduler(optimizer) + + if args.load is not None: + # In these cases the Optimizer is not loaded + timers = get_timers() + # Extra barrier is added to make sure all ranks report the + # max time. + torch.distributed.barrier() + timers('load-checkpoint').start() + args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler) + torch.distributed.barrier() + timers('load-checkpoint').stop() + timers.log(['load-checkpoint']) + else: + args.iteration = 0 # We only support local DDP with multiple micro-batches. if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1: From a828682e89d957f28c114c2f9e5ae03832205422 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Thu, 20 Apr 2023 18:13:02 +0200 Subject: [PATCH 19/22] Rmv debug --- megatron/training.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/megatron/training.py b/megatron/training.py index 1692e0b27f..1525aa43d6 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -153,12 +153,6 @@ def pretrain(train_valid_test_dataset_provider, print_rank_0('done with setup ...') timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup']) print_rank_0('training ...') - - print(f"{len(model)}") - print(f"{len(model[0])}") - print(f"{model[0].keys()}") - print(f"{model[0]['language_model']['embedding']['word_embeddings']['weight'][0,:10]}") - iteration = 0 if args.do_train and args.train_iters > 0: iteration = train(forward_step_func, From 57184fce1cc57430c143b57638d739d8f2daf720 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Thu, 20 Apr 2023 22:06:15 +0200 Subject: [PATCH 20/22] Debug --- finetune_mtf.py | 27 +++++++++++++++------------ megatron/training.py | 2 +- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/finetune_mtf.py b/finetune_mtf.py index 43370fc390..bcb4efb029 100644 --- a/finetune_mtf.py +++ b/finetune_mtf.py @@ -29,15 +29,15 @@ def visualize_model_inputs(tokens, attention_mask, labels, loss_mask, position_i print("LOSSMSK:", loss_mask[:100]) print("POSIDS:", position_ids[0, :100]) -def save_model_inputs(tokens, attention_mask, labels, loss_mask, position_ids, segment_ids): +def save_model_inputs(tokens, attention_mask, labels, loss_mask, position_ids, segment_ids, iteration): """Save as tensors for debugging""" - torch.save(tokens, "tokens.pt") - torch.save(attention_mask, "attention_mask.pt") - torch.save(labels, "labels.pt") - torch.save(loss_mask, "loss_mask.pt") - torch.save(position_ids, "position_ids.pt") - torch.save(segment_ids, "segment_ids.pt") - exit() + torch.save(tokens, f"tokens_{iteration}.pt") + torch.save(attention_mask, f"attention_mask_{iteration}.pt") + torch.save(labels, f"labels_{iteration}.pt") + torch.save(loss_mask, f"loss_mask_{iteration}.pt") + torch.save(position_ids, f"position_ids_{iteration}.pt") + torch.save(segment_ids, f"segment_ids_{iteration}.pt") + #exit() def model_provider(pre_process=True, post_process=True): """Build the model.""" @@ -48,7 +48,8 @@ def model_provider(pre_process=True, post_process=True): parallel_output=True, pre_process=pre_process, post_process=post_process, - attn_mask_type=AttnMaskType.custom, + #attn_mask_type=AttnMaskType.custom, + attn_mask_type=AttnMaskType.causal, ) return model @@ -133,9 +134,11 @@ def get_batch(data): #if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]: # raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.") - - # visualize_model_inputs(tokens, attention_mask, labels, loss_mask, position_ids) - # save_model_inputs(tokens, attention_mask, labels, loss_mask, position_ids, segment_ids) + + #if (7140 < args.curr_iteration < 7150) or (6420 < args.curr_iteration < 6430): + #visualize_model_inputs(tokens, attention_mask, labels, loss_mask, position_ids) + #if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: + # save_model_inputs(tokens, attention_mask, labels, loss_mask, position_ids, segment_ids, args.curr_iteration) return tokens, labels, loss_mask, attention_mask, position_ids #return (tokens, position_ids, attention_mask), (labels, loss_mask) diff --git a/megatron/training.py b/megatron/training.py index 1525aa43d6..64468498fb 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -384,7 +384,7 @@ def setup_model_and_optimizer(model_provider_func, torch.distributed.barrier() timers('load-checkpoint').start() # Optimizer is not loaded hence not needed - args.iteration = load_checkpoint(unwrapped_model, None, None) + args.iteration = load_checkpoint(model, None, None) torch.distributed.barrier() timers('load-checkpoint').stop() timers.log(['load-checkpoint']) From 6c80f7979944ee34534c93b1cb24edfb0230d86c Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Thu, 20 Apr 2023 22:12:42 +0200 Subject: [PATCH 21/22] Clean comments --- finetune_mtf.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/finetune_mtf.py b/finetune_mtf.py index bcb4efb029..e33bb684e9 100644 --- a/finetune_mtf.py +++ b/finetune_mtf.py @@ -10,15 +10,11 @@ from megatron import get_args, get_tokenizer, print_rank_0, mpu from megatron.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets, build_dataset_group from megatron.model.enums import PositionEmbeddingType, AttnMaskType -#from megatron.model import GPTModelPipe from megatron.model import GPTModel, ModelType from megatron.training import pretrain from megatron.utils import get_ltor_masks_and_position_ids, get_packed_attention_mask from megatron.utils import average_losses_across_data_parallel_group -#import deepspeed -#from deepspeed.runtime.utils import see_memory_usage - ### Debugging Helpers ### def visualize_model_inputs(tokens, attention_mask, labels, loss_mask, position_ids): @@ -37,7 +33,7 @@ def save_model_inputs(tokens, attention_mask, labels, loss_mask, position_ids, s torch.save(loss_mask, f"loss_mask_{iteration}.pt") torch.save(position_ids, f"position_ids_{iteration}.pt") torch.save(segment_ids, f"segment_ids_{iteration}.pt") - #exit() + # exit() # Optionaly exit right after def model_provider(pre_process=True, post_process=True): """Build the model.""" @@ -97,8 +93,6 @@ def get_batch(data): args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss, - #prefix_indices=None, - #loss_on_targets_only=False # This is done below ) # Only compute loss over causal target tokens, i.e. ignore input_tokens & padding loss_on_non_pad_only = (labels != tokenizer.pad) @@ -119,6 +113,7 @@ def get_batch(data): loss_mask = loss_mask.view(-1) loss_mask = fast_normalize(loss_mask) + # For Alibi / Rotary, positions ids are not used so it does not matter if args.position_embedding_type == PositionEmbeddingType.absolute: # Create position ids from segment_ids # segment_ids = torch.tensor([[1, 1, 1, 2, 2, 2, 2, 0]]) (Shape: (batch_size, seq_len)) @@ -129,13 +124,8 @@ def get_batch(data): counts = torch.unique_consecutive(b, return_counts=True, dim=-1)[1] p = torch.cat([torch.arange(c) for c in counts]) position_ids.append(p) - position_ids = torch.stack(position_ids).to(tokens.device) - + position_ids = torch.stack(position_ids).to(tokens.device) - #if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]: - # raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.") - - #if (7140 < args.curr_iteration < 7150) or (6420 < args.curr_iteration < 6430): #visualize_model_inputs(tokens, attention_mask, labels, loss_mask, position_ids) #if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: # save_model_inputs(tokens, attention_mask, labels, loss_mask, position_ids, segment_ids, args.curr_iteration) @@ -192,7 +182,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): seed=args.seed, skip_warmup=(not args.mmap_warmup) ) - # Option 2 of data loading using --(train|valid|test)-weighted-split-paths + # Option 2 of data loading using --(train|valid|test)-weighted-split-paths elif args.train_weighted_split_paths: assigned_train_valid_test = [] if args.train_weighted_split_paths is not None: From 9192641368a91efe8ac61b13d31665470f85c917 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Thu, 20 Apr 2023 22:54:04 +0200 Subject: [PATCH 22/22] Attn --- finetune_mtf.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/finetune_mtf.py b/finetune_mtf.py index e33bb684e9..cade412e19 100644 --- a/finetune_mtf.py +++ b/finetune_mtf.py @@ -44,8 +44,7 @@ def model_provider(pre_process=True, post_process=True): parallel_output=True, pre_process=pre_process, post_process=post_process, - #attn_mask_type=AttnMaskType.custom, - attn_mask_type=AttnMaskType.causal, + attn_mask_type=AttnMaskType.custom, ) return model