From 8c86034b6918636681235cd924b0f9efb3031e76 Mon Sep 17 00:00:00 2001 From: Jared Casper Date: Fri, 19 May 2023 14:05:03 -0700 Subject: [PATCH 01/15] Add option to specify a data cache path separate from data directory. Switches the cache to using md5 hashes of a text description instead of crafted filenames to determine a "cache hit". Changes the default location of these files to be an "index-cache" directory inside the data root. Should leave the data directories a bit cleaner, especially with these filenames being a bit "uglier". For GPT the code will first look in this default location before building a new index and caching it the specified data cache path (or this default if none is given). For Blendable dataset it will only look for and save the indices if a data cache path is provided, otherwise it will just rebuild every time. --- megatron/arguments.py | 2 + megatron/data/blendable_dataset.py | 77 ++++++-- megatron/data/gpt_dataset.py | 285 +++++++++++++++++------------ pretrain_gpt.py | 3 +- 4 files changed, 235 insertions(+), 132 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index a6e81b3e0aa..d755fe3e5d8 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -1030,6 +1030,8 @@ def _add_data_args(parser): '1) a single data path, 2) multiple datasets in the' 'form: dataset1-weight dataset1-path dataset2-weight ' 'dataset2-path ...') + group.add_argument('--data-cache-path', default=None, + help='Path to a directory to hold cached index files.') group.add_argument('--vocab-size', type=int, default=None, help='Size of vocab before EOD or padding.') diff --git a/megatron/data/blendable_dataset.py b/megatron/data/blendable_dataset.py index 453b362f3ec..61a00039bbb 100644 --- a/megatron/data/blendable_dataset.py +++ b/megatron/data/blendable_dataset.py @@ -2,17 +2,21 @@ """Blendable dataset.""" +import hashlib +import os import time import numpy as np import torch from megatron import print_rank_0 +from megatron.core import mpu class BlendableDataset(torch.utils.data.Dataset): - def __init__(self, datasets, weights, size): + def __init__(self, datasets, weights, size, *, + data_cache_path=None): self.datasets = datasets num_datasets = len(datasets) @@ -27,18 +31,65 @@ def __init__(self, datasets, weights, size): weights /= sum_weights # Build indicies. - start_time = time.time() - assert num_datasets < 255 - self.dataset_index = np.zeros(self.size, dtype=np.uint8) - self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) - - from megatron.data import helpers - helpers.build_blending_indices(self.dataset_index, - self.dataset_sample_index, - weights, num_datasets, self.size, - torch.distributed.get_rank() == 0) - print_rank_0('> elapsed time for building blendable dataset indices: ' - '{:.2f} (sec)'.format(time.time() - start_time)) + def _build_indices(): + start_time = time.time() + assert num_datasets < 255 + dataset_index = np.zeros(self.size, dtype=np.uint8) + dataset_sample_index = np.zeros(self.size, dtype=np.int64) + + from megatron.data import helpers + helpers.build_blending_indices(dataset_index, dataset_sample_index, + weights, num_datasets, self.size, + torch.distributed.get_rank() == 0) + print_rank_0('> elapsed time for building blendable dataset indices: ' + '{:.2f} (sec)'.format(time.time() - start_time)) + return dataset_index, dataset_sample_index + + desc = "Blendable dataset\n\n" + desc += "Datasets:\n" + for dataset in datasets: + desc += dataset.desc + "\n\n" + desc += f"Weights: {weights}\n" + desc += f"Size: {size}\n" + self.desc = desc + + if data_cache_path: + desc_hash = hashlib.md5(desc.encode('utf-8')).hexdigest() + desc_path = os.path.join(data_cache_path, desc_hash + ".dsc") + index_path = os.path.join(data_cache_path, desc_hash + "_index.npy") + sample_index_path = os.path.join(data_cache_path, desc_hash + "_sample_index.npy") + cache_hit = os.path.isfile(index_path) and os.path.isfile(sample_index_path) + if torch.distributed.get_rank() == 0 and not cache_hit: + print(' > WARNING: could not find index map files for blendable' + ' dataset, building indices on rank 0 ...', flush=True) + dataset_index, dataset_sample_index = _build_indices() + os.makedirs(os.path.dirname(index_path), exist_ok=True) + with open(desc_path, 'wt') as fd: + fd.write(desc) + np.save(index_path, dataset_index, allow_pickle=True) + np.save(sample_index_path, dataset_sample_index, + allow_pickle=True) + + # This should be a barrier but nccl barrier assumes device_index=rank which is not the + # case for model parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) + assert counts[0].item() == ( + torch.distributed.get_world_size() // + torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) + + # Load on all ranks. + print_rank_0(f'> loading blendable dataset index: {index_path}') + self.dataset_index = np.load(index_path, allow_pickle=True, mmap_mode='r') + assert self.dataset_index.size == self.size + + print_rank_0(f'> loading blendable dataset sample index: {sample_index_path}') + self.dataset_sample_index = np.load(sample_index_path, allow_pickle=True, mmap_mode='r') + assert self.dataset_sample_index.size == self.size + else: + self.dataset_index, self.dataset_sample_index = _build_indices() + # Check size _ = self.__getitem__(self.size - 1) diff --git a/megatron/data/gpt_dataset.py b/megatron/data/gpt_dataset.py index 3e4651c8837..cda6060b160 100644 --- a/megatron/data/gpt_dataset.py +++ b/megatron/data/gpt_dataset.py @@ -2,6 +2,7 @@ """GPT style dataset.""" +import hashlib import os import time @@ -22,7 +23,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, train_data_prefix=None, valid_data_prefix=None, test_data_prefix=None, - return_doc_ids=False): + return_doc_ids=False, *, + data_cache_path=None): """Build train, valid, and test datasets.""" if data_prefix: @@ -33,7 +35,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, return _build_train_valid_test_datasets(data_prefix[0], data_impl, splits_string, train_valid_test_num_samples, - seq_length, seed, skip_warmup) + seq_length, seed, skip_warmup, + data_cache_path=data_cache_path) # Blending dataset. # Parse the values. @@ -54,7 +57,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, prefixes[i], data_impl, splits_string, datasets_train_valid_test_num_samples[i], seq_length, seed, skip_warmup, - return_doc_ids) + return_doc_ids, + data_cache_path=data_cache_path) if train_ds: train_datasets.append(train_ds) if valid_ds: @@ -65,13 +69,16 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, # Blend. blending_train_dataset = None if train_datasets: - blending_train_dataset = BlendableDataset(train_datasets, weights, train_num_samples) + blending_train_dataset = BlendableDataset(train_datasets, weights, train_num_samples, + data_cache_path=data_cache_path) blending_valid_dataset = None if valid_datasets: - blending_valid_dataset = BlendableDataset(valid_datasets, weights, valid_num_samples) + blending_valid_dataset = BlendableDataset(valid_datasets, weights, valid_num_samples, + data_cache_path=data_cache_path) blending_test_dataset = None if test_datasets: - blending_test_dataset = BlendableDataset(test_datasets, weights, test_num_samples) + blending_test_dataset = BlendableDataset(test_datasets, weights, test_num_samples, + data_cache_path=data_cache_path) return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) @@ -84,17 +91,21 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, if train_data_prefix is not None: train_dataset = build_dataset("train", train_data_prefix, data_impl, train_valid_test_num_samples[0], - seq_length, seed, skip_warmup) + seq_length, seed, skip_warmup, + data_cache_path=data_cache_path) if valid_data_prefix is not None: valid_dataset = build_dataset("valid", valid_data_prefix, data_impl, train_valid_test_num_samples[1], - seq_length, seed, False) + seq_length, seed, False, + data_cache_path=data_cache_path) + if test_data_prefix is not None: test_dataset = build_dataset("test", test_data_prefix, data_impl, train_valid_test_num_samples[2], - seq_length, seed, False) + seq_length, seed, False, + data_cache_path=data_cache_path) return (train_dataset, valid_dataset, test_dataset) @@ -102,7 +113,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, train_valid_test_num_samples, seq_length, seed, skip_warmup, - return_doc_ids=False): + return_doc_ids=False, *, + data_cache_path=None): """Build train, valid, and test datasets.""" # Indexed dataset. @@ -134,7 +146,8 @@ def build_dataset(index, name): documents, indexed_dataset, train_valid_test_num_samples[index], seq_length, seed, - return_doc_ids) + return_doc_ids, + data_cache_path=data_cache_path) return dataset train_dataset = build_dataset(0, 'train') @@ -145,13 +158,15 @@ def build_dataset(index, name): def build_dataset(dataset_name, data_prefix, data_impl, num_samples, - seq_length, seed, skip_warmup): + seq_length, seed, skip_warmup, *, + data_cache_path=None): dataset = None if len(data_prefix) == 1: dataset = _build_dataset(dataset_name, - data_prefix[0], data_impl, - num_samples, seq_length, - seed, skip_warmup) + data_prefix[0], data_impl, + num_samples, seq_length, + seed, skip_warmup, + data_cache_path=data_cache_path) else: # Blending dataset. # Parse the values. @@ -163,19 +178,22 @@ def build_dataset(dataset_name, data_prefix, data_impl, num_samples, datasets = [] for i in range(len(prefixes)): ds = _build_dataset(dataset_name, prefixes[i], - data_impl, dataset_num_samples[i], - seq_length, seed, skip_warmup) + data_impl, dataset_num_samples[i], + seq_length, seed, skip_warmup, + data_cache_path=data_cache_path) if ds: datasets.append(ds) if datasets: - dataset = BlendableDataset(datasets, weights, num_samples) + dataset = BlendableDataset(datasets, weights, num_samples, + data_cache_path=data_cache_path) return dataset def _build_dataset(dataset_name, data_prefix, data_impl, - num_samples, seq_length, seed, skip_warmup): + num_samples, seq_length, seed, skip_warmup, *, + data_cache_path=None): """ Build dataset. This method is called when individual train, valid, test datasets are provided @@ -196,8 +214,9 @@ def _build_dataset(dataset_name, data_prefix, data_impl, step=1, dtype=np.int32) dataset = GPTDataset(dataset_name, data_prefix, - documents, indexed_dataset, - num_samples, seq_length, seed) + documents, indexed_dataset, + num_samples, seq_length, seed, + data_cache_path=data_cache_path) return dataset @@ -220,9 +239,10 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): class GPTDataset(torch.utils.data.Dataset): - def __init__(self, name, data_prefix, documents, indexed_dataset, - num_samples, seq_length, seed, - return_doc_ids=False): + def __init__(self, name, data_prefix, documents, + indexed_dataset, num_samples, seq_length, seed, + return_doc_ids=False, *, + data_cache_path=None): self.name = name self.indexed_dataset = indexed_dataset @@ -233,10 +253,11 @@ def __init__(self, name, data_prefix, documents, indexed_dataset, assert np.max(documents) < indexed_dataset.sizes.shape[0] # Build index mappings. - self.doc_idx, self.sample_idx, self.shuffle_idx, self.index_prefix = \ + self.doc_idx, self.sample_idx, self.shuffle_idx, self.desc = \ _build_index_mappings(self.name, data_prefix, documents, self.indexed_dataset.sizes, - num_samples, seq_length, seed) + num_samples, seq_length, seed, + data_cache_path=data_cache_path) def __len__(self): @@ -283,7 +304,8 @@ def __getitem__(self, idx): def _build_index_mappings(name, data_prefix, documents, sizes, - num_samples, seq_length, seed): + num_samples, seq_length, seed, *, + data_cache_path): """Build doc-idx, sample-idx, and shuffle-idx. doc-idx: is an array (ordered) of documents to be used in training. sample-idx: is the start document index and document offset for each @@ -298,94 +320,121 @@ def _build_index_mappings(name, data_prefix, documents, sizes, np_rng = np.random.RandomState(seed=seed) # Filename of the index mappings. - index_prefix = '{}_indexmap'.format(name) - index_prefix += '_{}ns'.format(num_samples) - index_prefix += '_{}sl'.format(seq_length) - index_prefix += '_{}s'.format(seed) - _filename = data_prefix + '_' + index_prefix - doc_idx_filename = _filename + '_doc_idx.npy' - sample_idx_filename = _filename + '_sample_idx.npy' - shuffle_idx_filename = _filename + '_shuffle_idx.npy' + desc = "GPT Dataset\n\n" + desc += f"Data prefix {data_prefix}\n" + desc += f"Dataset name {name}\n" + desc += f"Number of samples {num_samples}\n" + desc += f"Sequence length {seq_length}\n" + desc += f"Random seed {seed}\n" + desc_hash = hashlib.md5(desc.encode('utf-8')).hexdigest() + desc_filename = desc_hash + ".dsc" + doc_idx_filename = desc_hash + '_doc_idx.npy' + sample_idx_filename = desc_hash + '_sample_idx.npy' + shuffle_idx_filename = desc_hash + '_shuffle_idx.npy' + + # Look for cache in main data dir first to avoid unnecessary + # duplication, then look in data-cache-path if specified, + # If nothing is found, use the last path looked in + build_indices = True + prefixes = [os.path.join(os.path.dirname(data_prefix), 'index-cache')] + if data_cache_path is not None: + prefixes.append(data_cache_path) + for prefix in prefixes: + idx_path = { + 'desc': os.path.join(prefix, desc_filename), + 'doc': os.path.join(prefix, doc_idx_filename), + 'sample': os.path.join(prefix, sample_idx_filename), + 'shuffle': os.path.join(prefix, shuffle_idx_filename) + } + for f in idx_path.values(): + if not os.path.isfile(f): + break + else: + # Found our files! + build_indices = False + break # Build the indexed mapping if not exist. - if torch.distributed.get_rank() == 0: - if (not os.path.isfile(doc_idx_filename)) or \ - (not os.path.isfile(sample_idx_filename)) or \ - (not os.path.isfile(shuffle_idx_filename)): - - print_rank_0(' > WARNING: could not find index map files, building ' - 'the indices on rank 0 ...') + if build_indices and torch.distributed.get_rank() == 0: + print_rank_0(' > WARNING: could not find index map files, building ' + 'the indices on rank 0 ...') - # For the last epoch, decide whether include the entire epoch - # in the global shuffle or not. + # For the last epoch, decide whether include the entire epoch + # in the global shuffle or not. - # If we need only one epoch, then separating last epoch does - # not mean anything. - if num_epochs == 1: - separate_last_epoch = False - print(' > only one epoch required, setting ' - 'separate_last_epoch to False', flush=True) + # If we need only one epoch, then separating last epoch does + # not mean anything. + if num_epochs == 1: + separate_last_epoch = False + print(' > only one epoch required, setting ' + 'separate_last_epoch to False', flush=True) - else: - # Get the number of samples for the last epoch - num_samples_from_epochs_minus_one = ( - (num_epochs - 1) * tokens_per_epoch - 1) // seq_length - last_epoch_num_samples = num_samples - \ - num_samples_from_epochs_minus_one - assert last_epoch_num_samples >= 0, \ - 'last epoch number of samples should be non-negative.' - num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length - assert last_epoch_num_samples < (num_samples_per_epoch + 1), \ - 'last epoch number of samples exceeded max value.' - # If we have less than 80% of the samples for the last epoch, - # seperate out the epoch and treat it differently. - # Note: the 80% number is just based on common sense and can - # be adjusted if needed. - separate_last_epoch = (last_epoch_num_samples < - int(0.80 * num_samples_per_epoch)) - if separate_last_epoch: - string = ' > last epoch number of samples ({}) is smaller '\ - 'than 80% of number of samples per epoch ({}), '\ - 'setting separate_last_epoch to True' - else: - string = ' > last epoch number of samples ({}) is larger '\ - 'than 80% of number of samples per epoch ({}), '\ - 'setting separate_last_epoch to False' - print(string.format(last_epoch_num_samples, - num_samples_per_epoch), flush=True) - - # doc-idx. - start_time = time.time() - doc_idx = _build_doc_idx(documents, num_epochs, np_rng, - separate_last_epoch) - np.save(doc_idx_filename, doc_idx, allow_pickle=True) - print_rank_0(' > elasped time to build and save doc-idx mapping ' - '(seconds): {:4f}'.format(time.time() - start_time)) - # sample-idx. - start_time = time.time() - # Use C++ implementation for speed. - # First compile and then import. - from megatron.data import helpers - assert doc_idx.dtype == np.int32 - assert sizes.dtype == np.int32 - sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, - num_epochs, tokens_per_epoch) - np.save(sample_idx_filename, sample_idx, allow_pickle=True) - print_rank_0(' > elasped time to build and save sample-idx mapping ' - '(seconds): {:4f}'.format(time.time() - start_time)) - # shuffle-idx. - start_time = time.time() - # -1 is due to data structure used to retieve the index: - # sample i --> [sample_idx[i], sample_idx[i+1]) + else: + # Get the number of samples for the last epoch + num_samples_from_epochs_minus_one = ( + (num_epochs - 1) * tokens_per_epoch - 1) // seq_length + last_epoch_num_samples = num_samples - \ + num_samples_from_epochs_minus_one + assert last_epoch_num_samples >= 0, \ + 'last epoch number of samples should be non-negative.' + num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length + assert last_epoch_num_samples < (num_samples_per_epoch + 1), \ + 'last epoch number of samples exceeded max value.' + # If we have less than 80% of the samples for the last epoch, + # seperate out the epoch and treat it differently. + # Note: the 80% number is just based on common sense and can + # be adjusted if needed. + separate_last_epoch = (last_epoch_num_samples < + int(0.80 * num_samples_per_epoch)) if separate_last_epoch: - num_samples_ = num_samples_from_epochs_minus_one + string = ' > last epoch number of samples ({}) is smaller '\ + 'than 80% of number of samples per epoch ({}), '\ + 'setting separate_last_epoch to True' else: - num_samples_ = sample_idx.shape[0] - 1 - shuffle_idx = _build_shuffle_idx(num_samples_, - sample_idx.shape[0] - 1, np_rng) - np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) - print_rank_0(' > elasped time to build and save shuffle-idx mapping' - ' (seconds): {:4f}'.format(time.time() - start_time)) + string = ' > last epoch number of samples ({}) is larger '\ + 'than 80% of number of samples per epoch ({}), '\ + 'setting separate_last_epoch to False' + print(string.format(last_epoch_num_samples, + num_samples_per_epoch), flush=True) + + os.makedirs(os.path.dirname(idx_path['desc']), exist_ok=True) + + # description + with open(idx_path['desc'], 'wt') as fd: + fd.write(desc) + + # doc-idx. + start_time = time.time() + doc_idx = _build_doc_idx(documents, num_epochs, np_rng, + separate_last_epoch) + np.save(idx_path['doc'], doc_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save doc-idx mapping ' + '(seconds): {:4f}'.format(time.time() - start_time)) + # sample-idx. + start_time = time.time() + # Use C++ implementation for speed. + # First compile and then import. + from megatron.data import helpers + assert doc_idx.dtype == np.int32 + assert sizes.dtype == np.int32 + sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, + num_epochs, tokens_per_epoch) + np.save(idx_path['sample'], sample_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save sample-idx mapping ' + '(seconds): {:4f}'.format(time.time() - start_time)) + # shuffle-idx. + start_time = time.time() + # -1 is due to data structure used to retieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + if separate_last_epoch: + num_samples_ = num_samples_from_epochs_minus_one + else: + num_samples_ = sample_idx.shape[0] - 1 + shuffle_idx = _build_shuffle_idx(num_samples_, + sample_idx.shape[0] - 1, np_rng) + np.save(idx_path['shuffle'], shuffle_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save shuffle-idx mapping' + ' (seconds): {:4f}'.format(time.time() - start_time)) # This should be a barrier but nccl barrier assumes # device_index=rank which is not the case for model @@ -399,22 +448,22 @@ def _build_index_mappings(name, data_prefix, documents, sizes, # Load mappings. start_time = time.time() - print_rank_0(' > loading doc-idx mapping from {}'.format( - doc_idx_filename)) - doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode='r') - print_rank_0(' > loading sample-idx mapping from {}'.format( - sample_idx_filename)) - sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r') - print_rank_0(' > loading shuffle-idx mapping from {}'.format( - shuffle_idx_filename)) - shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r') + print_rank_0(f" > loading doc-idx mapping from {idx_path['doc']}") + doc_idx = np.load(idx_path['doc'], allow_pickle=True, mmap_mode='r') + + print_rank_0(f" > loading sample-idx mapping from {idx_path['sample']}") + sample_idx = np.load(idx_path['sample'], allow_pickle=True, mmap_mode='r') + + print_rank_0(f" > loading shuffle-idx mapping from {idx_path['shuffle']}") + shuffle_idx = np.load(idx_path['shuffle'], allow_pickle=True, mmap_mode='r') + print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( time.time() - start_time)) print_rank_0(' total number of samples: {}'.format( sample_idx.shape[0])) print_rank_0(' total number of epochs: {}'.format(num_epochs)) - return doc_idx, sample_idx, shuffle_idx, index_prefix + return doc_idx, sample_idx, shuffle_idx, desc def _num_tokens(documents, sizes): diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 16339677e1c..18c763f44bb 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -104,7 +104,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): skip_warmup=(not args.mmap_warmup), train_data_prefix=args.train_data_path, valid_data_prefix=args.valid_data_path, - test_data_prefix=args.test_data_path) + test_data_prefix=args.test_data_path, + data_cache_path=args.data_cache_path) print_rank_0("> finished creating GPT datasets ...") return train_ds, valid_ds, test_ds From ae37924084545be3a92c8c4295a82002a1fe15bb Mon Sep 17 00:00:00 2001 From: Jared Casper Date: Mon, 22 May 2023 22:06:02 -0700 Subject: [PATCH 02/15] Check for write failure of index cache and print error message. --- megatron/data/blendable_dataset.py | 33 ++++++---- megatron/data/gpt_dataset.py | 100 ++++++++++++++++------------- 2 files changed, 76 insertions(+), 57 deletions(-) diff --git a/megatron/data/blendable_dataset.py b/megatron/data/blendable_dataset.py index 61a00039bbb..8ff5ce3da87 100644 --- a/megatron/data/blendable_dataset.py +++ b/megatron/data/blendable_dataset.py @@ -59,25 +59,34 @@ def _build_indices(): index_path = os.path.join(data_cache_path, desc_hash + "_index.npy") sample_index_path = os.path.join(data_cache_path, desc_hash + "_sample_index.npy") cache_hit = os.path.isfile(index_path) and os.path.isfile(sample_index_path) + cache_success = True if torch.distributed.get_rank() == 0 and not cache_hit: print(' > WARNING: could not find index map files for blendable' ' dataset, building indices on rank 0 ...', flush=True) dataset_index, dataset_sample_index = _build_indices() - os.makedirs(os.path.dirname(index_path), exist_ok=True) - with open(desc_path, 'wt') as fd: - fd.write(desc) - np.save(index_path, dataset_index, allow_pickle=True) - np.save(sample_index_path, dataset_sample_index, - allow_pickle=True) - - # This should be a barrier but nccl barrier assumes device_index=rank which is not the - # case for model parallel case - counts = torch.cuda.LongTensor([1]) + try: + os.makedirs(os.path.dirname(index_path), exist_ok=True) + with open(desc_path, 'wt') as fd: + fd.write(desc) + np.save(index_path, dataset_index, allow_pickle=True) + np.save(sample_index_path, dataset_sample_index, + allow_pickle=True) + except OSError: + print(f'There was an error trying to create the data cache directory ({data_cache_path})') + print('or a file in it. This is set with the --data-cache-path argument. Please') + print('ensure you have write access to this directory or specify one that you do have') + print('write access to.') + cache_success = False + + + counts = torch.cuda.LongTensor([cache_success]) torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) - assert counts[0].item() == ( + if counts[0].item() != ( torch.distributed.get_world_size() // - torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) + torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())): + print_rank_0("Data index creation unsuccessful, exiting.") + exit() # Load on all ranks. print_rank_0(f'> loading blendable dataset index: {index_path}') diff --git a/megatron/data/gpt_dataset.py b/megatron/data/gpt_dataset.py index cda6060b160..0962ce326b6 100644 --- a/megatron/data/gpt_dataset.py +++ b/megatron/data/gpt_dataset.py @@ -353,6 +353,8 @@ def _build_index_mappings(name, data_prefix, documents, sizes, # Found our files! build_indices = False break + data_cache_dir = os.path.dirname(idx_path['desc']) + data_cache_success = True # Build the indexed mapping if not exist. if build_indices and torch.distributed.get_rank() == 0: @@ -397,54 +399,62 @@ def _build_index_mappings(name, data_prefix, documents, sizes, print(string.format(last_epoch_num_samples, num_samples_per_epoch), flush=True) - os.makedirs(os.path.dirname(idx_path['desc']), exist_ok=True) - - # description - with open(idx_path['desc'], 'wt') as fd: - fd.write(desc) - - # doc-idx. - start_time = time.time() - doc_idx = _build_doc_idx(documents, num_epochs, np_rng, - separate_last_epoch) - np.save(idx_path['doc'], doc_idx, allow_pickle=True) - print_rank_0(' > elasped time to build and save doc-idx mapping ' - '(seconds): {:4f}'.format(time.time() - start_time)) - # sample-idx. - start_time = time.time() - # Use C++ implementation for speed. - # First compile and then import. - from megatron.data import helpers - assert doc_idx.dtype == np.int32 - assert sizes.dtype == np.int32 - sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, - num_epochs, tokens_per_epoch) - np.save(idx_path['sample'], sample_idx, allow_pickle=True) - print_rank_0(' > elasped time to build and save sample-idx mapping ' - '(seconds): {:4f}'.format(time.time() - start_time)) - # shuffle-idx. - start_time = time.time() - # -1 is due to data structure used to retieve the index: - # sample i --> [sample_idx[i], sample_idx[i+1]) - if separate_last_epoch: - num_samples_ = num_samples_from_epochs_minus_one - else: - num_samples_ = sample_idx.shape[0] - 1 - shuffle_idx = _build_shuffle_idx(num_samples_, - sample_idx.shape[0] - 1, np_rng) - np.save(idx_path['shuffle'], shuffle_idx, allow_pickle=True) - print_rank_0(' > elasped time to build and save shuffle-idx mapping' - ' (seconds): {:4f}'.format(time.time() - start_time)) - - # This should be a barrier but nccl barrier assumes - # device_index=rank which is not the case for model - # parallel case - counts = torch.cuda.LongTensor([1]) + + try: + os.makedirs(data_cache_dir, exist_ok=True) + + # description + with open(idx_path['desc'], 'wt') as fd: + fd.write(desc) + + # doc-idx. + start_time = time.time() + doc_idx = _build_doc_idx(documents, num_epochs, np_rng, + separate_last_epoch) + np.save(idx_path['doc'], doc_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save doc-idx mapping ' + '(seconds): {:4f}'.format(time.time() - start_time)) + # sample-idx. + start_time = time.time() + # Use C++ implementation for speed. + # First compile and then import. + from megatron.data import helpers + assert doc_idx.dtype == np.int32 + assert sizes.dtype == np.int32 + sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, + num_epochs, tokens_per_epoch) + np.save(idx_path['sample'], sample_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save sample-idx mapping ' + '(seconds): {:4f}'.format(time.time() - start_time)) + # shuffle-idx. + start_time = time.time() + # -1 is due to data structure used to retieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + if separate_last_epoch: + num_samples_ = num_samples_from_epochs_minus_one + else: + num_samples_ = sample_idx.shape[0] - 1 + shuffle_idx = _build_shuffle_idx(num_samples_, + sample_idx.shape[0] - 1, np_rng) + np.save(idx_path['shuffle'], shuffle_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save shuffle-idx mapping' + ' (seconds): {:4f}'.format(time.time() - start_time)) + except OSError: + print(f'There was an error trying to create the data cache directory ({data_cache_dir})') + print('or a file in it. This defaults to a directory "index-cache" within the directory') + print('the data files are in and can be set with the --data-cache-path argument. Please') + print('ensure you have write access to this directory or specify one that you do have') + print('write access to.') + data_cache_success = False + + counts = torch.cuda.LongTensor([data_cache_success]) torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) - assert counts[0].item() == ( + if counts[0].item() != ( torch.distributed.get_world_size() // - torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) + torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())): + print_rank_0("Data index creation unsuccessful, exiting.") + exit() # Load mappings. start_time = time.time() From f9283c5a8a1dc61d97d5873807c6614d0ec5e631 Mon Sep 17 00:00:00 2001 From: Jared Casper Date: Wed, 31 May 2023 15:27:34 -0700 Subject: [PATCH 03/15] Add option to overlap p2p communication. --- megatron/arguments.py | 4 + .../pipeline_parallel/p2p_communication.py | 229 ++++++++++--- megatron/core/pipeline_parallel/schedules.py | 314 ++++++++++++++---- megatron/training.py | 2 + 4 files changed, 435 insertions(+), 114 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 84a007c0262..78a01ea964c 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -935,6 +935,10 @@ def _add_distributed_args(parser): '--tensor-model-parallel-size instead.') group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None, help='Number of layers per virtual pipeline stage') + group.add_argument('--overlap-p2p-communication', + action='store_true', + help='overlap pipeline parallel communication with forward and backward chunks', + dest='overlap_p2p_comm') group.add_argument('--distributed-backend', default='nccl', choices=['nccl', 'gloo'], help='Which backend to use for distributed training.') diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py index 301583132a6..6a461ad8d47 100644 --- a/megatron/core/pipeline_parallel/p2p_communication.py +++ b/megatron/core/pipeline_parallel/p2p_communication.py @@ -9,6 +9,7 @@ from megatron import core from megatron.core.parallel_state import ( get_pipeline_model_parallel_group, + get_pipeline_model_parallel_rank, get_pipeline_model_parallel_prev_rank, get_pipeline_model_parallel_next_rank, ) @@ -63,28 +64,28 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev, tensor_recv_prev=recv_prev_shape_tensor, tensor_send_next=send_next_shape_tensor, tensor_recv_next=recv_next_shape_tensor, - group=mpu.get_pipeline_model_parallel_group()) + group=get_pipeline_model_parallel_group()) else: ops = [] if send_prev_shape_tensor is not None: send_prev_op = torch.distributed.P2POp( torch.distributed.isend, send_prev_shape_tensor, - mpu.get_pipeline_model_parallel_prev_rank()) + get_pipeline_model_parallel_prev_rank()) ops.append(send_prev_op) if recv_prev_shape_tensor is not None: recv_prev_op = torch.distributed.P2POp( torch.distributed.irecv, recv_prev_shape_tensor, - mpu.get_pipeline_model_parallel_prev_rank()) + get_pipeline_model_parallel_prev_rank()) ops.append(recv_prev_op) if send_next_shape_tensor is not None: send_next_op = torch.distributed.P2POp( torch.distributed.isend, send_next_shape_tensor, - mpu.get_pipeline_model_parallel_next_rank()) + get_pipeline_model_parallel_next_rank()) ops.append(send_next_op) if recv_next_shape_tensor is not None: recv_next_op = torch.distributed.P2POp( torch.distributed.irecv, recv_next_shape_tensor, - mpu.get_pipeline_model_parallel_next_rank()) + get_pipeline_model_parallel_next_rank()) ops.append(recv_next_op) if len(ops) > 0: reqs = torch.distributed.batch_isend_irecv(ops) @@ -105,12 +106,125 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev, return recv_prev_shape, recv_next_shape +def _batched_p2p_ops(*, + tensor_send_prev: Optional[torch.Tensor], + tensor_recv_prev: Optional[torch.Tensor], + tensor_send_next: Optional[torch.Tensor], + tensor_recv_next: Optional[torch.Tensor], + group: torch.distributed.ProcessGroup): + ops = [] + if tensor_send_prev is not None: + send_prev_op = torch.distributed.P2POp( + torch.distributed.isend, tensor_send_prev, + get_pipeline_model_parallel_prev_rank(), + group) + ops.append(send_prev_op) + if tensor_recv_prev is not None: + recv_prev_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor_recv_prev, + get_pipeline_model_parallel_prev_rank(), + group) + ops.append(recv_prev_op) + if tensor_send_next is not None: + send_next_op = torch.distributed.P2POp( + torch.distributed.isend, tensor_send_next, + get_pipeline_model_parallel_next_rank(), + group) + ops.append(send_next_op) + if tensor_recv_next is not None: + recv_next_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor_recv_next, + get_pipeline_model_parallel_next_rank(), + group) + ops.append(recv_next_op) + if len(ops) > 0: + reqs = torch.distributed.batch_isend_irecv(ops) + else: + reqs = [] + return reqs + +def _p2p_ops(*, + tensor_send_prev: Optional[torch.Tensor], + tensor_recv_prev: Optional[torch.Tensor], + tensor_send_next: Optional[torch.Tensor], + tensor_recv_next: Optional[torch.Tensor], + group: torch.distributed.ProcessGroup): + reqs = [] + rank = get_pipeline_model_parallel_rank() + if get_pipeline_model_parallel_rank() % 2 == 0: + if tensor_send_next is not None: + send_next_req = torch.distributed.isend( + tensor=tensor_send_next, + dst=get_pipeline_model_parallel_next_rank(), + group=group, + ) + reqs.append(send_next_req) + + if tensor_recv_prev is not None: + recv_prev_req = torch.distributed.irecv( + tensor=tensor_recv_prev, + src=get_pipeline_model_parallel_prev_rank(), + group=group, + ) + reqs.append(recv_prev_req) + + if tensor_send_prev is not None: + send_prev_req = torch.distributed.isend( + tensor=tensor_send_prev, + dst=get_pipeline_model_parallel_prev_rank(), + group=group, + ) + reqs.append(send_prev_req) + + if tensor_recv_next is not None: + recv_next_req = torch.distributed.irecv( + tensor=tensor_recv_next, + src=get_pipeline_model_parallel_next_rank(), + group=group, + ) + reqs.append(recv_next_req) + + else: + if tensor_recv_prev is not None: + recv_prev_req = torch.distributed.irecv( + tensor=tensor_recv_prev, + src=get_pipeline_model_parallel_prev_rank(), + group=group, + ) + reqs.append(recv_prev_req) + + if tensor_send_next is not None: + send_next_req = torch.distributed.isend( + tensor=tensor_send_next, + dst=get_pipeline_model_parallel_next_rank(), + group=group, + ) + reqs.append(send_next_req) + + if tensor_recv_next is not None: + recv_next_req = torch.distributed.irecv( + tensor=tensor_recv_next, + src=get_pipeline_model_parallel_next_rank(), + group=group, + ) + reqs.append(recv_next_req) + + if tensor_send_prev is not None: + send_prev_req = torch.distributed.isend( + tensor=tensor_send_prev, + dst=get_pipeline_model_parallel_prev_rank(), + group=group, + ) + reqs.append(send_prev_req) + return reqs def _communicate(*, tensor_send_next: Optional[torch.Tensor], tensor_send_prev: Optional[torch.Tensor], recv_prev: bool, recv_next: bool, tensor_shape: Shape, + batch_p2p_comm: bool = True, + wait_on_reqs: bool = True, dtype: Optional[torch.dtype], variable_seq_lengths: bool = False, use_ring_exchange_p2p: bool = False, @@ -136,6 +250,14 @@ def _communicate(*, tensor_send_next: Optional[torch.Tensor], tensors sent and received in a single function call are the same shape). + batch_p2p_comm (boolean, required): + If true use batch_isend_irecv, otherwise use individual + isend and irecv calls. + + wait_on_reqs (boolean, optional, default=False): + For non-batched p2p communication, wait on each request + before returning. + dtype (torch.dtype, required if either recv_{prev,next} is True): this must be the type of the tensors that will be received, will typically be params_dtype, but in the case @@ -167,6 +289,10 @@ def _communicate(*, tensor_send_next: Optional[torch.Tensor], tensor_recv_prev = None tensor_recv_next = None + # This will come from config in the next version, for now hard + # code it here to match existing functionality. + batch_p2p_sync = True + if not variable_seq_lengths: recv_prev_shape = tensor_shape recv_next_shape = tensor_shape @@ -204,46 +330,38 @@ def _communicate(*, tensor_send_next: Optional[torch.Tensor], # Send tensors in both the forward and backward directions as appropriate. if use_ring_exchange_p2p: - torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev, - tensor_recv_prev=tensor_recv_prev, - tensor_send_next=tensor_send_next, - tensor_recv_next=tensor_recv_next, - group=get_pipeline_model_parallel_group()) + def _ring_exchange_wrapper(**kwargs): + torch.distributed.ring_exchange(**kwargs) + return [] + p2p_func = _ring_exchange_wrapper + elif batch_p2p_comm: + assert wait_on_reqs + p2p_func = _batched_p2p_ops else: - ops = [] - if tensor_send_prev is not None: - send_prev_op = torch.distributed.P2POp( - torch.distributed.isend, tensor_send_prev, - get_pipeline_model_parallel_prev_rank()) - ops.append(send_prev_op) - if tensor_recv_prev is not None: - recv_prev_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor_recv_prev, - get_pipeline_model_parallel_prev_rank()) - ops.append(recv_prev_op) - if tensor_send_next is not None: - send_next_op = torch.distributed.P2POp( - torch.distributed.isend, tensor_send_next, - get_pipeline_model_parallel_next_rank()) - ops.append(send_next_op) - if tensor_recv_next is not None: - recv_next_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor_recv_next, - get_pipeline_model_parallel_next_rank()) - ops.append(recv_next_op) - if len(ops) > 0: - reqs = torch.distributed.batch_isend_irecv(ops) - for req in reqs: - req.wait() + p2p_func = _p2p_ops + + reqs = p2p_func(tensor_send_prev=tensor_send_prev, + tensor_recv_prev=tensor_recv_prev, + tensor_send_next=tensor_send_next, + tensor_recv_next=tensor_recv_next, + group=get_pipeline_model_parallel_group()) + + if wait_on_reqs and len(reqs) > 0: + for req in reqs: + req.wait() + reqs = None + + if batch_p2p_comm and batch_p2p_sync: # To protect against race condition when using batch_isend_irecv(). # User should assert that we have a modern enough PyTorch to not need this torch.cuda.synchronize() - return tensor_recv_prev, tensor_recv_next + return tensor_recv_prev, tensor_recv_next, reqs def recv_forward(tensor_shape: Shape, dtype: torch.dtype, + batch_p2p_comm: bool = True, timers: Callable = None) -> torch.Tensor: """ Receive tensor from previous rank in pipeline (forward receive). @@ -256,12 +374,13 @@ def recv_forward(tensor_shape: Shape, else: if timers is not None: timers('forward-recv', log_level=2).start() - input_tensor, _ = _communicate( + input_tensor, _, _ = _communicate( tensor_send_next=None, tensor_send_prev=None, recv_prev=True, recv_next=False, tensor_shape=tensor_shape, + batch_p2p_comm=batch_p2p_comm, dtype=dtype) if timers is not None: timers('forward-recv').stop() @@ -270,6 +389,7 @@ def recv_forward(tensor_shape: Shape, def recv_backward(tensor_shape: Shape, dtype: torch.dtype, + batch_p2p_comm: bool = True, timers: Callable = None) -> torch.Tensor: """Receive tensor from next rank in pipeline (backward receive). @@ -280,12 +400,13 @@ def recv_backward(tensor_shape: Shape, else: if timers is not None: timers('backward-recv', log_level=2).start() - _, output_tensor_grad = _communicate( + _, output_tensor_grad, _ = _communicate( tensor_send_next=None, tensor_send_prev=None, recv_prev=False, recv_next=True, tensor_shape=tensor_shape, + batch_p2p_comm=batch_p2p_comm, dtype=dtype) if timers is not None: timers('backward-recv').stop() @@ -293,6 +414,7 @@ def recv_backward(tensor_shape: Shape, def send_forward(output_tensor: torch.Tensor, + batch_p2p_comm: bool = True, timers: Callable = None) -> None: """Send tensor to next rank in pipeline (forward send). @@ -308,12 +430,14 @@ def send_forward(output_tensor: torch.Tensor, recv_prev=False, recv_next=False, tensor_shape=None, + batch_p2p_comm=batch_p2p_comm, dtype=None) if timers is not None: timers('forward-send').stop() def send_backward(input_tensor_grad: torch.Tensor, + batch_p2p_comm: bool = True, timers: Callable = None) -> None: """Send tensor to previous rank in pipeline (backward send). @@ -328,6 +452,7 @@ def send_backward(input_tensor_grad: torch.Tensor, recv_prev=False, recv_next=False, tensor_shape=None, + batch_p2p_comm=batch_p2p_comm, dtype=None) if timers is not None: timers('backward-send').stop() @@ -336,6 +461,7 @@ def send_backward(input_tensor_grad: torch.Tensor, def send_forward_recv_backward(output_tensor: torch.Tensor, tensor_shape: Shape, dtype: torch.dtype, + batch_p2p_comm: bool = True, timers: Callable = None) -> torch.Tensor: """Batched send and recv with next rank in pipeline. @@ -346,12 +472,13 @@ def send_forward_recv_backward(output_tensor: torch.Tensor, else: if timers is not None: timers('forward-send-backward-recv', log_level=2).start() - _, output_tensor_grad = _communicate( + _, output_tensor_grad,_ = _communicate( tensor_send_next=output_tensor, tensor_send_prev=None, recv_prev=False, recv_next=True, tensor_shape=tensor_shape, + batch_p2p_comm=batch_p2p_comm, dtype=dtype) if timers is not None: timers('forward-send-backward-recv').stop() @@ -361,6 +488,7 @@ def send_forward_recv_backward(output_tensor: torch.Tensor, def send_backward_recv_forward(input_tensor_grad: torch.Tensor, tensor_shape: Shape, dtype: torch.dtype, + batch_p2p_comm: bool = True, timers: Callable = None) -> torch.Tensor: """Batched send and recv with previous rank in pipeline. @@ -371,12 +499,13 @@ def send_backward_recv_forward(input_tensor_grad: torch.Tensor, else: if timers is not None: timers('backward-send-forward-recv', log_level=2).start() - input_tensor, _ = _communicate( + input_tensor, _, _ = _communicate( tensor_send_next=None, tensor_send_prev=input_tensor_grad, recv_prev=True, recv_next=False, tensor_shape=tensor_shape, + batch_p2p_comm=batch_p2p_comm, dtype=dtype) if timers is not None: timers('backward-send-forward-recv').stop() @@ -387,6 +516,8 @@ def send_forward_recv_forward(output_tensor: torch.Tensor, recv_prev: bool, tensor_shape: Shape, dtype: torch.dtype, + batch_p2p_comm: bool = True, + overlap_p2p_comm: bool = False, timers: Callable = None) -> torch.Tensor: """Batched recv from previous rank and send to next rank in pipeline. @@ -394,15 +525,19 @@ def send_forward_recv_forward(output_tensor: torch.Tensor, """ if timers is not None: timers('forward-send-forward-recv', log_level=2).start() - input_tensor, _ = _communicate( + input_tensor, _, wait_handles = _communicate( tensor_send_next=output_tensor, tensor_send_prev=None, recv_prev=recv_prev, recv_next=False, tensor_shape=tensor_shape, + batch_p2p_comm=batch_p2p_comm, + wait_on_reqs=(not overlap_p2p_comm), dtype=dtype) if timers is not None: timers('forward-send-forward-recv').stop() + if overlap_p2p_comm: + return input_tensor, wait_handles return input_tensor @@ -410,6 +545,8 @@ def send_backward_recv_backward(input_tensor_grad: torch.Tensor, recv_next: bool, tensor_shape: Shape, dtype: torch.dtype, + batch_p2p_comm: bool = True, + overlap_p2p_comm: bool = False, timers: Callable = None) -> torch.Tensor: """Batched recv from next rank and send to previous rank in pipeline. @@ -417,15 +554,19 @@ def send_backward_recv_backward(input_tensor_grad: torch.Tensor, """ if timers is not None: timers('backward-send-backward-recv', log_level=2).start() - _, output_tensor_grad = _communicate( + _, output_tensor_grad, wait_handles = _communicate( tensor_send_next=None, tensor_send_prev=input_tensor_grad, recv_prev=False, recv_next=recv_next, tensor_shape=tensor_shape, + batch_p2p_comm=batch_p2p_comm, + wait_on_reqs=(not overlap_p2p_comm), dtype=dtype) if timers is not None: timers('backward-send-backward-recv').stop() + if overlap_p2p_comm: + return output_tensor_grad, wait_handles return output_tensor_grad @@ -436,6 +577,7 @@ def send_forward_backward_recv_forward_backward( recv_next: bool, tensor_shape: Shape, dtype: torch.dtype, + batch_p2p_comm: bool = True, timers: Callable = None) -> Tuple[torch.Tensor, torch.Tensor]: """Batched send and recv with previous and next ranks in pipeline. @@ -444,12 +586,13 @@ def send_forward_backward_recv_forward_backward( if timers is not None: timers('forward-backward-send-forward-backward-recv', log_level=2).start() - input_tensor, output_tensor_grad = _communicate( + input_tensor, output_tensor_grad, _ = _communicate( tensor_send_next=output_tensor, tensor_send_prev=input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, tensor_shape=tensor_shape, + batch_p2p_comm=batch_p2p_comm, dtype=dtype) if timers is not None: timers('forward-backward-send-forward-backward-recv').stop() diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 5007a44cd2c..174b8a5ea69 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -85,6 +85,15 @@ def forward_step(data_iterator, model): tensor\_model\_parallel\_world\_size`. TODO: Do we need this? Just roll into tensor_shape arg? + overlap_p2p_communication (optional, default=False): When True + some of the peer to peer communication for pipeline + parallelism will overlap with compuation. Must be False if + batch_p2p_communication is true. + + batch_p2p_communication (optional, default=True): When true use + batch_isend_irecv, otherwise use individual isend and irecv + calls. Must be false if overlap_p2p_communication is True. + forward_only (optional, default=False): Perform only the forward step timers (optional, default=None): TODO @@ -94,11 +103,11 @@ def forward_step(data_iterator, model): enable_autocast (optional, default=False): If True, runs the forward_step_func call inside torch.autocast context - deallocate_pipeline_outputs (optional, default=False): If True, output data + deallocate_pipeline_outputs (optional, default=False): If True, output data is deallocated after the tensor is sent to the next pipeline stage. - Helps with saving memory, does nothing when pipeline parallel is + Helps with saving memory, does nothing when pipeline parallel is not used. - + no_sync_func (optional): Function that creates a context that suppresses asynchronous data-parallel communication. If the model is an instance of torch.nn.DistributedDataParallel, the @@ -277,7 +286,7 @@ def backward_step(grad_scaler, input_tensor, output_tensor, # Backward pass. if output_tensor_grad[0] is None and grad_scaler is not None: output_tensor = grad_scaler(output_tensor[0]) - + if deallocate_pipeline_outputs: custom_backward(output_tensor[0], output_tensor_grad[0]) else: @@ -319,6 +328,8 @@ def forward_backward_no_pipelining(*, decoder_seq_length: Optional[int] = None, # unused grad_scaler: Callable = None, sequence_parallel: bool = False, # unused + overlap_p2p_communication: bool = False, # unused + batch_p2p_communication: bool = True, # unused forward_only: bool = False, timers: Callable = None, collect_non_loss_data: bool = False, @@ -387,6 +398,8 @@ def forward_backward_pipelining_with_interleaving(*, decoder_seq_length: Optional[int] = None, grad_scaler: Callable = None, sequence_parallel: bool = False, + overlap_p2p_communication: bool = False, + batch_p2p_communication: bool = True, forward_only: bool = False, timers: Callable = None, collect_non_loss_data: bool = False, @@ -407,6 +420,9 @@ def forward_backward_pipelining_with_interleaving(*, assert isinstance(data_iterator, list), \ "interleaved pipeline parallelism expected each model chunk to have a data iterator" + if overlap_p2p_communication and batch_p2p_communication: + raise ValueError("Can not use both overlap_p2p_communication and batch_p2p_communication") + # Disable async grad reductions if no_sync_func is None and all(isinstance(chunk, torchDDP) for chunk in model): def multi_no_sync(): @@ -617,8 +633,20 @@ def backward_step_helper(microbatch_id): # Run warmup forward passes. parallel_state.set_virtual_pipeline_model_parallel_rank(0) input_tensors[0].append( - p2p_communication.recv_forward(tensor_shape, dtype, timers=timers)) + p2p_communication.recv_forward(tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_communication, + timers=timers)) + + fwd_wait_handles = None + bwd_wait_handles = None + for k in range(num_warmup_microbatches): + + if fwd_wait_handles is not None: + for req in fwd_wait_handles: + req.wait() + output_tensor = forward_step_helper(k) # Determine if tensor should be received from previous stage. @@ -636,91 +664,216 @@ def backward_step_helper(microbatch_id): # Send and receive tensors as appropriate (send tensors computed # in this iteration; receive tensors for next iteration). - if k == (num_warmup_microbatches - 1) and not forward_only and \ - not all_warmup_microbatches: - input_tensor_grad = None - recv_next = True - if parallel_state.is_pipeline_last_stage(ignore_virtual=True): - recv_next = False - input_tensor, output_tensor_grad = \ - p2p_communication.send_forward_backward_recv_forward_backward( + if not overlap_p2p_communication: + if k == (num_warmup_microbatches - 1) and not forward_only and \ + not all_warmup_microbatches: + input_tensor_grad = None + recv_next = True + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + recv_next = False + input_tensor, output_tensor_grad = \ + p2p_communication.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, - tensor_shape=tensor_shape, dtype=dtype, + tensor_shape=tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_communication, timers=timers) - output_tensor_grads[num_model_chunks-1].append(output_tensor_grad) + output_tensor_grads[num_model_chunks-1].append(output_tensor_grad) + else: + input_tensor = \ + p2p_communication.send_forward_recv_forward( + output_tensor, recv_prev=recv_prev, + tensor_shape=tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_communication, + timers=timers) + input_tensors[next_forward_model_chunk_id].append(input_tensor) else: - input_tensor = \ + input_tensor, fwd_wait_handles = \ p2p_communication.send_forward_recv_forward( output_tensor, recv_prev=recv_prev, - tensor_shape=tensor_shape, dtype=dtype, - timers=timers) - input_tensors[next_forward_model_chunk_id].append(input_tensor) + tensor_shape=tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_communication, + timers=timers, + overlap_p2p_comm=True) + + if k == (num_warmup_microbatches - 1) and not forward_only and \ + not all_warmup_microbatches: + input_tensor_grad = None + recv_next = True + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + recv_next = False + + output_tensor_grad, bwd_wait_handles = p2p_communication.send_backward_recv_backward( + input_tensor_grad, recv_next=recv_next, + tensor_shape=tensor_shape, + batch_p2p_comm=batch_p2p_communication, + dtype=dtype, + timers=timers, + overlap_p2p_comm=True) + + output_tensor_grads[num_model_chunks-1].append(output_tensor_grad) + input_tensors[next_forward_model_chunk_id].append(input_tensor) + deallocate_output_tensor(output_tensor, deallocate_pipeline_outputs) # Run 1F1B in steady state. for k in range(num_microbatches_remaining): # Forward pass. forward_k = k + num_warmup_microbatches - output_tensor = forward_step_helper(forward_k) - # Backward pass. - backward_k = k - input_tensor_grad = backward_step_helper(backward_k) + if overlap_p2p_communication: + if fwd_wait_handles is not None: + for req in fwd_wait_handles: + req.wait() + + deallocate_output_tensor(output_tensor, deallocate_pipeline_outputs) + + output_tensor = forward_step_helper(forward_k) + + # Determine if current stage has anything to send in either direction, + # otherwise set tensor to None. + forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) + parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) + + # Last virtual stage no activation tensor to send + if parallel_state.is_pipeline_last_stage(): + output_tensor = None + + # Determine if peers are sending, and where in data structure to put + # received tensors. + recv_prev = True + if parallel_state.is_pipeline_first_stage(ignore_virtual=True): + # First stage is ahead of last stage by (pipeline_parallel_size - 1). + next_forward_model_chunk_id = get_model_chunk_id( + forward_k - (pipeline_parallel_size - 1), forward=True) + if next_forward_model_chunk_id == (num_model_chunks - 1): + recv_prev = False + next_forward_model_chunk_id += 1 + else: + next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, + forward=True) - # Send output_tensor and input_tensor_grad, receive input_tensor - # and output_tensor_grad. + # If last iteration, don't receive; we already received one extra + # before the start of the for loop. + if k == (num_microbatches_remaining - 1): + recv_prev = False - # Determine if current stage has anything to send in either direction, - # otherwise set tensor to None. - forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) - parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) - if parallel_state.is_pipeline_last_stage(): - output_tensor = None + # Send activation tensor to the next stage and receive activation tensor from the + # previous stage + input_tensor, fwd_wait_handles = \ + p2p_communication.send_forward_recv_forward( + output_tensor, recv_prev=recv_prev, + tensor_shape=tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_communication, + timers=timers, + overlap_p2p_comm=True) + # assert fwd_wait_handles is not None - backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) - parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) - if parallel_state.is_pipeline_first_stage(): - input_tensor_grad = None + if bwd_wait_handles is not None: + for req in bwd_wait_handles: + req.wait() - # Determine if peers are sending, and where in data structure to put - # received tensors. - recv_prev = True - if parallel_state.is_pipeline_first_stage(ignore_virtual=True): - # First stage is ahead of last stage by (pipeline_parallel_size - 1). - next_forward_model_chunk_id = get_model_chunk_id( - forward_k - (pipeline_parallel_size - 1), forward=True) - if next_forward_model_chunk_id == (num_model_chunks - 1): - recv_prev = False - next_forward_model_chunk_id += 1 - else: - next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, - forward=True) - - recv_next = True - if parallel_state.is_pipeline_last_stage(ignore_virtual=True): - # Last stage is ahead of first stage by (pipeline_parallel_size - 1). - next_backward_model_chunk_id = get_model_chunk_id( - backward_k - (pipeline_parallel_size - 1), forward=False) - if next_backward_model_chunk_id == 0: - recv_next = False - next_backward_model_chunk_id -= 1 - else: - next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, - forward=False) + # Backward pass. + backward_k = k + input_tensor_grad = backward_step_helper(backward_k) - # If last iteration, don't receive; we already received one extra - # before the start of the for loop. - if k == (num_microbatches_remaining - 1): - recv_prev = False + backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) + parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) + + # First virtual stage no activation gradient tensor to send + if parallel_state.is_pipeline_first_stage(): + input_tensor_grad = None + + # Determine if the current virtual stage has an activation gradient tensor to receive + recv_next = True + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + # Last stage is ahead of first stage by (pipeline_parallel_size - 1). + next_backward_model_chunk_id = get_model_chunk_id( + backward_k - (pipeline_parallel_size - 1), forward=False + ) + if next_backward_model_chunk_id == 0: + recv_next = False + next_backward_model_chunk_id -= 1 + else: + next_backward_model_chunk_id = get_model_chunk_id( + backward_k + 1, forward=False + ) + + output_tensor_grad, bwd_wait_handles = p2p_communication.send_backward_recv_backward( + input_tensor_grad, recv_next=recv_next, + tensor_shape=tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_communication, + timers=timers, + overlap_p2p_comm=True) + + else: # no p2p overlap + output_tensor = forward_step_helper(forward_k) + + # Backward pass. + backward_k = k + input_tensor_grad = backward_step_helper(backward_k) + + # Send output_tensor and input_tensor_grad, receive input_tensor + # and output_tensor_grad. + + # Determine if current stage has anything to send in either direction, + # otherwise set tensor to None. + forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) + parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) + if parallel_state.is_pipeline_last_stage(): + output_tensor = None + + backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) + parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) + if parallel_state.is_pipeline_first_stage(): + input_tensor_grad = None + + # Determine if peers are sending, and where in data structure to put + # received tensors. + recv_prev = True + if parallel_state.is_pipeline_first_stage(ignore_virtual=True): + # First stage is ahead of last stage by (pipeline_parallel_size - 1). + next_forward_model_chunk_id = get_model_chunk_id( + forward_k - (pipeline_parallel_size - 1), forward=True) + if next_forward_model_chunk_id == (num_model_chunks - 1): + recv_prev = False + next_forward_model_chunk_id += 1 + else: + next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, + forward=True) - # Communicate tensors. - input_tensor, output_tensor_grad = \ - p2p_communication.send_forward_backward_recv_forward_backward( + recv_next = True + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + # Last stage is ahead of first stage by (pipeline_parallel_size - 1). + next_backward_model_chunk_id = get_model_chunk_id( + backward_k - (pipeline_parallel_size - 1), forward=False) + if next_backward_model_chunk_id == 0: + recv_next = False + next_backward_model_chunk_id -= 1 + else: + next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, + forward=False) + + # If last iteration, don't receive; we already received one extra + # before the start of the for loop. + if k == (num_microbatches_remaining - 1): + recv_prev = False + + # Communicate tensors. + input_tensor, output_tensor_grad = \ + p2p_communication.send_forward_backward_recv_forward_backward( output_tensor, input_tensor_grad, recv_prev=recv_prev, recv_next=recv_next, - tensor_shape=tensor_shape, dtype=dtype, timers=timers) - deallocate_output_tensor(output_tensor, deallocate_pipeline_outputs) + tensor_shape=tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_communication, + timers=timers) + deallocate_output_tensor(output_tensor, deallocate_pipeline_outputs) # Put input_tensor and output_tensor_grad in data structures in the # right location. @@ -730,11 +883,20 @@ def backward_step_helper(microbatch_id): output_tensor_grads[next_backward_model_chunk_id].append( output_tensor_grad) + deallocate_output_tensor(output_tensor, deallocate_pipeline_outputs) + # Run cooldown backward passes (flush out pipeline). if not forward_only: + if overlap_p2p_communication and bwd_wait_handles is not None: + for wait_handle in bwd_wait_handles: + wait_handle.wait() + if all_warmup_microbatches: output_tensor_grads[num_model_chunks-1].append( - p2p_communication.recv_backward(tensor_shape, dtype=dtype, timers=timers)) + p2p_communication.recv_backward(tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_communication, + timers=timers)) for k in range(num_microbatches_remaining, total_num_microbatches): input_tensor_grad = backward_step_helper(k) next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False) @@ -747,7 +909,9 @@ def backward_step_helper(microbatch_id): output_tensor_grads[next_backward_model_chunk_id].append( p2p_communication.send_backward_recv_backward( input_tensor_grad, recv_next=recv_next, - tensor_shape=tensor_shape, dtype=dtype, + tensor_shape=tensor_shape, + dtype=dtype, + batch_p2p_comm=batch_p2p_communication, timers=timers)) # Launch any remaining grad reductions @@ -881,6 +1045,8 @@ def forward_backward_pipelining_without_interleaving(*, decoder_seq_length: Optional[int] = None, grad_scaler: Callable = None, sequence_parallel: bool = False, + overlap_p2p_communication: bool = False, + batch_p2p_communication: bool = True, forward_only: bool = False, timers: Callable = None, collect_non_loss_data: bool = False, @@ -904,6 +1070,12 @@ def forward_backward_pipelining_without_interleaving(*, "non-pipeline-parallel schedule does not support model chunking" data_iterator = data_iterator[0] + if overlap_p2p_communication: + raise ValueError("Non-interleaved pipeline parallelism does not support overlapping p2p communication") + + if not batch_p2p_communication: + raise ValueError("Non-interleaved pipeline parallelism only supports using batched p2p communication") + # Disable async grad reductions if no_sync_func is None and isinstance(model, torchDDP): no_sync_func = model.no_sync diff --git a/megatron/training.py b/megatron/training.py index 14bca152f0a..88b7d6256bf 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -427,6 +427,8 @@ def train_step(forward_step_func, data_iterator, tensor_shape=(args.seq_length, args.micro_batch_size, args.hidden_size), grad_scaler=optimizer.scale_loss, sequence_parallel=args.sequence_parallel, + overlap_p2p_communication=args.overlap_p2p_comm, + batch_p2p_communication=not args.overlap_p2p_comm, forward_only=False, timers=fwd_bwd_timers) timers('forward-backward').stop() From 621c9de29b37d0211ef7f4b91058e25e6e9a5d57 Mon Sep 17 00:00:00 2001 From: Jared Casper Date: Wed, 31 May 2023 15:57:39 -0700 Subject: [PATCH 04/15] typo --- megatron/core/pipeline_parallel/schedules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 174b8a5ea69..f5c921c7d73 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -87,7 +87,7 @@ def forward_step(data_iterator, model): overlap_p2p_communication (optional, default=False): When True some of the peer to peer communication for pipeline - parallelism will overlap with compuation. Must be False if + parallelism will overlap with computation. Must be False if batch_p2p_communication is true. batch_p2p_communication (optional, default=True): When true use From 2c13d1f95b9d20f6ab4b6fa7d4d571ba052c122c Mon Sep 17 00:00:00 2001 From: Jared Casper Date: Wed, 31 May 2023 16:20:01 -0700 Subject: [PATCH 05/15] Consistent arg names. --- megatron/core/pipeline_parallel/schedules.py | 54 ++++++++++---------- megatron/training.py | 4 +- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index f5c921c7d73..375acef1afd 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -85,14 +85,14 @@ def forward_step(data_iterator, model): tensor\_model\_parallel\_world\_size`. TODO: Do we need this? Just roll into tensor_shape arg? - overlap_p2p_communication (optional, default=False): When True + overlap_p2p_comm (optional, default=False): When True some of the peer to peer communication for pipeline parallelism will overlap with computation. Must be False if - batch_p2p_communication is true. + batch_p2p_comm is true. - batch_p2p_communication (optional, default=True): When true use + batch_p2p_comm (optional, default=True): When true use batch_isend_irecv, otherwise use individual isend and irecv - calls. Must be false if overlap_p2p_communication is True. + calls. Must be false if overlap_p2p_comm is True. forward_only (optional, default=False): Perform only the forward step @@ -328,8 +328,8 @@ def forward_backward_no_pipelining(*, decoder_seq_length: Optional[int] = None, # unused grad_scaler: Callable = None, sequence_parallel: bool = False, # unused - overlap_p2p_communication: bool = False, # unused - batch_p2p_communication: bool = True, # unused + overlap_p2p_comm: bool = False, # unused + batch_p2p_comm: bool = True, # unused forward_only: bool = False, timers: Callable = None, collect_non_loss_data: bool = False, @@ -398,8 +398,8 @@ def forward_backward_pipelining_with_interleaving(*, decoder_seq_length: Optional[int] = None, grad_scaler: Callable = None, sequence_parallel: bool = False, - overlap_p2p_communication: bool = False, - batch_p2p_communication: bool = True, + overlap_p2p_comm: bool = False, + batch_p2p_comm: bool = True, forward_only: bool = False, timers: Callable = None, collect_non_loss_data: bool = False, @@ -420,8 +420,8 @@ def forward_backward_pipelining_with_interleaving(*, assert isinstance(data_iterator, list), \ "interleaved pipeline parallelism expected each model chunk to have a data iterator" - if overlap_p2p_communication and batch_p2p_communication: - raise ValueError("Can not use both overlap_p2p_communication and batch_p2p_communication") + if overlap_p2p_comm and batch_p2p_comm: + raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm") # Disable async grad reductions if no_sync_func is None and all(isinstance(chunk, torchDDP) for chunk in model): @@ -635,7 +635,7 @@ def backward_step_helper(microbatch_id): input_tensors[0].append( p2p_communication.recv_forward(tensor_shape, dtype=dtype, - batch_p2p_comm=batch_p2p_communication, + batch_p2p_comm=batch_p2p_comm, timers=timers)) fwd_wait_handles = None @@ -664,7 +664,7 @@ def backward_step_helper(microbatch_id): # Send and receive tensors as appropriate (send tensors computed # in this iteration; receive tensors for next iteration). - if not overlap_p2p_communication: + if not overlap_p2p_comm: if k == (num_warmup_microbatches - 1) and not forward_only and \ not all_warmup_microbatches: input_tensor_grad = None @@ -677,7 +677,7 @@ def backward_step_helper(microbatch_id): recv_prev=recv_prev, recv_next=recv_next, tensor_shape=tensor_shape, dtype=dtype, - batch_p2p_comm=batch_p2p_communication, + batch_p2p_comm=batch_p2p_comm, timers=timers) output_tensor_grads[num_model_chunks-1].append(output_tensor_grad) else: @@ -686,7 +686,7 @@ def backward_step_helper(microbatch_id): output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, dtype=dtype, - batch_p2p_comm=batch_p2p_communication, + batch_p2p_comm=batch_p2p_comm, timers=timers) input_tensors[next_forward_model_chunk_id].append(input_tensor) else: @@ -695,7 +695,7 @@ def backward_step_helper(microbatch_id): output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, dtype=dtype, - batch_p2p_comm=batch_p2p_communication, + batch_p2p_comm=batch_p2p_comm, timers=timers, overlap_p2p_comm=True) @@ -709,7 +709,7 @@ def backward_step_helper(microbatch_id): output_tensor_grad, bwd_wait_handles = p2p_communication.send_backward_recv_backward( input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, - batch_p2p_comm=batch_p2p_communication, + batch_p2p_comm=batch_p2p_comm, dtype=dtype, timers=timers, overlap_p2p_comm=True) @@ -724,7 +724,7 @@ def backward_step_helper(microbatch_id): # Forward pass. forward_k = k + num_warmup_microbatches - if overlap_p2p_communication: + if overlap_p2p_comm: if fwd_wait_handles is not None: for req in fwd_wait_handles: req.wait() @@ -768,7 +768,7 @@ def backward_step_helper(microbatch_id): output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, dtype=dtype, - batch_p2p_comm=batch_p2p_communication, + batch_p2p_comm=batch_p2p_comm, timers=timers, overlap_p2p_comm=True) # assert fwd_wait_handles is not None @@ -807,7 +807,7 @@ def backward_step_helper(microbatch_id): input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, dtype=dtype, - batch_p2p_comm=batch_p2p_communication, + batch_p2p_comm=batch_p2p_comm, timers=timers, overlap_p2p_comm=True) @@ -871,7 +871,7 @@ def backward_step_helper(microbatch_id): recv_prev=recv_prev, recv_next=recv_next, tensor_shape=tensor_shape, dtype=dtype, - batch_p2p_comm=batch_p2p_communication, + batch_p2p_comm=batch_p2p_comm, timers=timers) deallocate_output_tensor(output_tensor, deallocate_pipeline_outputs) @@ -887,7 +887,7 @@ def backward_step_helper(microbatch_id): # Run cooldown backward passes (flush out pipeline). if not forward_only: - if overlap_p2p_communication and bwd_wait_handles is not None: + if overlap_p2p_comm and bwd_wait_handles is not None: for wait_handle in bwd_wait_handles: wait_handle.wait() @@ -895,7 +895,7 @@ def backward_step_helper(microbatch_id): output_tensor_grads[num_model_chunks-1].append( p2p_communication.recv_backward(tensor_shape, dtype=dtype, - batch_p2p_comm=batch_p2p_communication, + batch_p2p_comm=batch_p2p_comm, timers=timers)) for k in range(num_microbatches_remaining, total_num_microbatches): input_tensor_grad = backward_step_helper(k) @@ -911,7 +911,7 @@ def backward_step_helper(microbatch_id): input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, dtype=dtype, - batch_p2p_comm=batch_p2p_communication, + batch_p2p_comm=batch_p2p_comm, timers=timers)) # Launch any remaining grad reductions @@ -1045,8 +1045,8 @@ def forward_backward_pipelining_without_interleaving(*, decoder_seq_length: Optional[int] = None, grad_scaler: Callable = None, sequence_parallel: bool = False, - overlap_p2p_communication: bool = False, - batch_p2p_communication: bool = True, + overlap_p2p_comm: bool = False, + batch_p2p_comm: bool = True, forward_only: bool = False, timers: Callable = None, collect_non_loss_data: bool = False, @@ -1070,10 +1070,10 @@ def forward_backward_pipelining_without_interleaving(*, "non-pipeline-parallel schedule does not support model chunking" data_iterator = data_iterator[0] - if overlap_p2p_communication: + if overlap_p2p_comm: raise ValueError("Non-interleaved pipeline parallelism does not support overlapping p2p communication") - if not batch_p2p_communication: + if not batch_p2p_comm: raise ValueError("Non-interleaved pipeline parallelism only supports using batched p2p communication") # Disable async grad reductions diff --git a/megatron/training.py b/megatron/training.py index 88b7d6256bf..9a5190b4a71 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -427,8 +427,8 @@ def train_step(forward_step_func, data_iterator, tensor_shape=(args.seq_length, args.micro_batch_size, args.hidden_size), grad_scaler=optimizer.scale_loss, sequence_parallel=args.sequence_parallel, - overlap_p2p_communication=args.overlap_p2p_comm, - batch_p2p_communication=not args.overlap_p2p_comm, + overlap_p2p_comm=args.overlap_p2p_comm, + batch_p2p_comm=not args.overlap_p2p_comm, forward_only=False, timers=fwd_bwd_timers) timers('forward-backward').stop() From 4d564cf2e48d7b48dad3b0083cf3ac6580486d2c Mon Sep 17 00:00:00 2001 From: Jared Casper Date: Fri, 2 Jun 2023 12:21:38 -0700 Subject: [PATCH 06/15] Supporting loading checkpoints without add_position_embedding arg. --- megatron/checkpointing.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 41b0535704b..3ab0e5ba3e4 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -37,11 +37,15 @@ def check_checkpoint_args(checkpoint_args): arguments and the one retrieved from checkpoint.""" args = get_args() - def _compare(arg_name, old_arg_name=None): + def _compare(arg_name, old_arg_name=None, default=None): if old_arg_name is not None: - checkpoint_value = getattr(checkpoint_args, old_arg_name) + ckpt_arg_name = old_arg_name else: - checkpoint_value = getattr(checkpoint_args, arg_name) + ckpt_arg_name = arg_name + if default is not None: + checkpoint_value = getattr(checkpoint_args, ckpt_arg_name, default) + else: + checkpoint_value = getattr(checkpoint_args, ckpt_arg_name) args_value = getattr(args, arg_name) error_message = '{} value from checkpoint ({}) is not equal to the ' \ 'input argument value ({}).'.format( @@ -51,7 +55,7 @@ def _compare(arg_name, old_arg_name=None): _compare('num_layers') _compare('hidden_size') _compare('num_attention_heads') - _compare('add_position_embedding') + _compare('add_position_embedding', default=True) if args.vocab_file: _compare('max_position_embeddings') _compare('make_vocab_size_divisible_by') From 1997e94986810e768257dcaa3f7ccc76a3dc6584 Mon Sep 17 00:00:00 2001 From: Jared Casper Date: Fri, 2 Jun 2023 13:08:36 -0700 Subject: [PATCH 07/15] Fix GPTDataset assert. --- megatron/data/gpt_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/data/gpt_dataset.py b/megatron/data/gpt_dataset.py index 602e5116781..aa397a3a815 100644 --- a/megatron/data/gpt_dataset.py +++ b/megatron/data/gpt_dataset.py @@ -335,7 +335,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes, assert last_epoch_num_samples >= 0, \ 'last epoch number of samples should be non-negative.' num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length - assert last_epoch_num_samples < (num_samples_per_epoch + 1), \ + assert last_epoch_num_samples <= (num_samples_per_epoch + 1), \ 'last epoch number of samples exceeded max value.' # If we have less than 80% of the samples for the last epoch, # seperate out the epoch and treat it differently. From a6c574d4fb72f4d1877d489ef2ffa094d4258d95 Mon Sep 17 00:00:00 2001 From: Lawrence McAfee Date: Mon, 5 Jun 2023 13:01:40 -0700 Subject: [PATCH 08/15] Fixed rotary_pos_emb's position in layer's forward args. --- megatron/model/transformer.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 4d744e7a259..9ed2d6ffd77 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -1335,6 +1335,8 @@ def __init__(self, init_method, output_layer_init_method, # Transformer layers. if args.retro_add_retriever: + assert self.recompute_granularity != 'full', \ + "Full recompute not supported for Retro." assert args.transformer_impl == 'local', \ "Transformer engine does not support Retro layers." def build_layer(layer_number): @@ -1485,8 +1487,9 @@ def custom_forward(*args, **kwargs): hidden_states = tensor_parallel.checkpoint( custom(l, l + self.recompute_num_layers), self.distribute_saved_activations, - hidden_states, attention_mask, encoder_output, - enc_dec_attn_mask, rotary_pos_emb) + hidden_states, attention_mask, + encoder_output, enc_dec_attn_mask, + None, None, None, None, rotary_pos_emb) l += self.recompute_num_layers @@ -1508,8 +1511,9 @@ def custom_forward(*args, **kwargs): hidden_states = tensor_parallel.checkpoint( custom(l, l + 1), self.distribute_saved_activations, - hidden_states, attention_mask, encoder_output, - enc_dec_attn_mask, rotary_pos_emb) + hidden_states, attention_mask, + encoder_output, enc_dec_attn_mask, + None, None, None, None, rotary_pos_emb) else: if self.transformer_impl == 'transformer_engine': hidden_states = custom(l, l + 1)( @@ -1517,8 +1521,9 @@ def custom_forward(*args, **kwargs): enc_dec_attn_mask, **te_forward_kwargs) else: hidden_states = custom(l, l + 1)( - hidden_states, attention_mask, encoder_output, - enc_dec_attn_mask, rotary_pos_emb) + hidden_states, attention_mask, + encoder_output, enc_dec_attn_mask, + None, None, None, None, rotary_pos_emb) else: raise ValueError("Invalid activation recompute method.") @@ -1596,8 +1601,6 @@ def forward(self, hidden_states, attention_mask, # Forward pass. if self.recompute_granularity == 'full': - assert not self.retro_add_retriever, \ - "full recompute not supported for retro." hidden_states = self._checkpointed_forward(hidden_states, attention_mask, encoder_output, From 41221b879d576decb884c72ba918f29f5aa3a2b9 Mon Sep 17 00:00:00 2001 From: Abhinav Khattar Date: Mon, 5 Jun 2023 13:09:35 -0700 Subject: [PATCH 09/15] fix indexation for output tensor after gradscaler call Signed-off-by: Abhinav Khattar --- megatron/core/pipeline_parallel/schedules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 375acef1afd..20ae496ee8d 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -285,7 +285,7 @@ def backward_step(grad_scaler, input_tensor, output_tensor, # Backward pass. if output_tensor_grad[0] is None and grad_scaler is not None: - output_tensor = grad_scaler(output_tensor[0]) + output_tensor[0] = grad_scaler(output_tensor[0]) if deallocate_pipeline_outputs: custom_backward(output_tensor[0], output_tensor_grad[0]) From ea76ecde2e5d559df4374d5d0ca19a34c8e80235 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 5 Jun 2023 17:45:01 -0700 Subject: [PATCH 10/15] Perform grad sync at correct place in interleaved pipeline parallelism --- megatron/core/pipeline_parallel/schedules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 20ae496ee8d..484d398fd82 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -523,7 +523,7 @@ def get_model_chunk_id(microbatch_id, forward): def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool: """Check if an iteration is the first for a model chunk.""" microbatch_group_size = pipeline_parallel_size * num_model_chunks - num_microbatch_groups = num_microbatches // microbatch_group_size + num_microbatch_groups = total_num_microbatches // microbatch_group_size microbatch_group_id = microbatch_id // microbatch_group_size microbatch_id_in_group = microbatch_id % microbatch_group_size if microbatch_group_id == 0: @@ -534,7 +534,7 @@ def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool: def is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool: """Check if an iteration is the last for a model chunk.""" microbatch_group_size = pipeline_parallel_size * num_model_chunks - num_microbatch_groups = num_microbatches // microbatch_group_size + num_microbatch_groups = total_num_microbatches // microbatch_group_size microbatch_group_id = microbatch_id // microbatch_group_size microbatch_id_in_group = microbatch_id % microbatch_group_size if microbatch_group_id == num_microbatch_groups - 1: From 28802670f928e9b77f6454c9348487fd616d6297 Mon Sep 17 00:00:00 2001 From: Jon Barker Date: Thu, 8 Jun 2023 14:33:44 -0700 Subject: [PATCH 11/15] Add workarounds for non-determinism in Megatron training --- README.md | 13 ++++++++++++- megatron/arguments.py | 8 +++++--- megatron/model/language_model.py | 21 +++++++++++++++++---- 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 6bb334e8e15..cdb5bd3f074 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ The following table shows both model (MFU) and hardware (HFU) FLOPs utilization * [Datasets](#datasets) * [Collecting Wikipedia Training Data](#collecting-wikipedia-training-data) * [Collecting GPT Webtext Data](#collecting-gpt-webtext-data) + * [Reproducibility](#reproducibility) # Setup We strongly recommend using the latest release of [NGC's PyTorch container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) with DGX nodes. If you can't use this for some reason, use the latest pytorch, cuda, nccl, and NVIDIA [APEX](https://github.com/NVIDIA/apex#quick-start) releases. Data preprocessing requires [NLTK](https://www.nltk.org/install.html), though this is not required for training, evaluation, or downstream tasks. @@ -365,7 +366,7 @@ See [megatron/text_generation_server.py](megatron/text_generation_server.py) for ### Detoxify GPT via Self-generation We include an example in `examples/detxoify_lm/` to detoxify language models by leveraging the generative power of language models. -See [examples/detxoify_lm/README.md](examples/detxoify_lm/README.md) for step-by-step tutorials on how to perform domain-adaptive training and detoxify LM using self-generated corpus. +See [examples/detxoify_lm/README.md](examples/detxoify_lm/README.md) for step-by-step tutorials on how to perform domain-adaptive training and detoxify LM using self-generated corpus. ## GPT Evaluation @@ -513,3 +514,13 @@ We recommend using the `--json` argument when using WikiExtractor, which will du ## Collecting GPT Webtext Data We utilize the publicly available [OpenWebText](https://github.com/eukaryote31/openwebtext) library from [jcpeterson](https://github.com/jcpeterson/openwebtext) and [eukaryote31's](https://github.com/eukaryote31/openwebtext) work to download urls. We then filtered, cleaned, and deduplicated all downloaded content according to the procedure described in our [openwebtext](./tools/openwebtext) directory. For reddit URLs corresponding to content up to October 2018 we arrived at approximately 37GB of content. + +# Reproducibility +Megatron training is intended to be bitwise reproducible. This means that the same training config run twice in the same HW and SW environment should produce identical model checkpoints, losses and accuracy metric values (iteration time metrics may vary). + +There are currently three known Megatron optimizations that break reproducibility whilst still producing almost identical training runs. They are only applicable when using NGC containers >=22.05. The following workarounds should be applied in cases where reproducibility is required: +1. When training using the `--bf16` option the backward pass of `torch.nn.functional.embedding` is non-deterministic. If reproducibility is required you should also use the option `--embedding-weights-in-fp32`. The speed and memory impact of this change is negligible. +2. Also when training using `--bf16`, reproducbility is only obtained when the checkpointing and resume schedule of training is identical. If the checkpointing schedule will change, i.e. checkpointing and resume will occur at different iterations, the option `--no-bias-gelu-fusion` should be used. +3. Flash attention is non-deterministic. If reproducibility is required do not use `--use-flash-attn`. + +These sources of non-determinism are under active investigation. If you observe non-determinism in Megatron training under other circumstances please open an issue. diff --git a/megatron/arguments.py b/megatron/arguments.py index e6cc4a60194..9eda475ca67 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -49,7 +49,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): # Args from environment args.rank = int(os.getenv('RANK', '0')) args.world_size = int(os.getenv("WORLD_SIZE", '1')) - + return args def validate_args(args, defaults={}): @@ -553,6 +553,8 @@ def _add_network_size_args(parser): help='Number of Experts in Switch Transformer (None means no Switch)') group.add_argument('--untie-embeddings-and-output-weights', action='store_true', help='Untie embeddings and output weights.'), + group.add_argument('--embedding-weights-in-fp32', action='store_true', + help='Cast word embedding weights to fp32 before embedding fwd.'), return parser @@ -1193,14 +1195,14 @@ def _add_vision_args(parser): group.add_argument('--swin-backbone-type', type=str, default='tiny', choices=['tiny', 'base', 'h3'], help='pretraining objectives') - + # inpainting arguments group.add_argument('--mask-type', type=str, default='random', choices=['random', 'row'], help='mask types') group.add_argument('--mask-factor', type=float, default=1.0, help='mask size scaling parameter') - + # dino arguments group.add_argument('--iter-per-epoch', type=int, default=1250, help='iterations per epoch') diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index 61f2501bcb2..353f6e00206 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -131,6 +131,10 @@ class Embedding(MegatronModule): init_method: weight initialization method num_tokentypes: size of the token-type embeddings. 0 value will ignore this embedding + embedding_weights_in_fp32: casts word embedding weights to + fp32 before sampling. Required to + maintain reproducibility when + training in bf16. """ def __init__(self, @@ -139,7 +143,8 @@ def __init__(self, max_sequence_length, embedding_dropout_prob, init_method, - num_tokentypes=0): + num_tokentypes=0, + embedding_weights_in_fp32=False): super(Embedding, self).__init__() self.hidden_size = hidden_size @@ -149,12 +154,14 @@ def __init__(self, args = get_args() # Word embeddings (parallel). + self.embedding_weights_in_fp32 = embedding_weights_in_fp32 + self.params_dtype = args.params_dtype self.word_embeddings = tensor_parallel.VocabParallelEmbedding( vocab_size, self.hidden_size, init_method=self.init_method, params_dtype=args.params_dtype, use_cpu_initialization=args.use_cpu_initialization, - perform_initialization=args.perform_initialization + perform_initialization=args.perform_initialization, ) self._word_embeddings_key = 'word_embeddings' @@ -182,7 +189,7 @@ def __init__(self, else: self.tokentype_embeddings = None - self.fp32_residual_connection = args.fp32_residual_connection + self.fp32_residual_connection = args.fp32_residual_connection self.sequence_parallel = args.sequence_parallel # Embeddings dropout self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) @@ -217,7 +224,12 @@ def add_tokentype_embeddings(self, num_tokentypes): def forward(self, input_ids, position_ids, tokentype_ids=None): # Embeddings. + if self.embedding_weights_in_fp32: + self.word_embeddings = self.word_embeddings.to(torch.float32) words_embeddings = self.word_embeddings(input_ids) + if self.embedding_weights_in_fp32: + words_embeddings = words_embeddings.to(self.params_dtype) + self.word_embeddings = self.word_embeddings.to(self.params_dtype) if self.add_position_embedding: position_embeddings = self.position_embeddings(position_ids) embeddings = words_embeddings + position_embeddings @@ -362,7 +374,8 @@ def __init__(self, args.max_position_embeddings, args.hidden_dropout, self.init_method, - self.num_tokentypes) + self.num_tokentypes, + args.embedding_weights_in_fp32) self._embedding_key = 'embedding' # Rotary positional embeddings From 1af380d7b7726910782cca1adc708ed962ae881b Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy Date: Fri, 9 Jun 2023 17:21:15 -0700 Subject: [PATCH 12/15] Update gitlab to catch pytest errors --- .gitlab-ci.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 3cd1c2f2e69..0c0bc711f0a 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -69,7 +69,8 @@ unit_tests: - echo "Slurm job state $SLURM_STATE" - if [[ "$SLURM_STATE" != "COMPLETED" ]]; then echo "Slurm job did not complete. See ${SELENE_ADLR_CI_PATH}/${CI_PIPELINE_ID}/${RUN_NAME}/results directory for result logs. Skipping pytest."; exit 1; fi - source $PYTHON_VIRTUAL_ENV - - pytest $BUILD_DIR/tests/functional_tests/python_test_utils/test_resume_checkpoint_pipeline.py || echo "Pytest failed. See ${SELENE_ADLR_CI_PATH}/${CI_PIPELINE_ID}/${RUN_NAME}/results directory for result logs." + - cmd='pytest $BUILD_DIR/tests/functional_tests/python_test_utils/test_resume_checkpoint_pipeline.py' + - if $cmd; then echo "Pytest succeded"; else echo "Pytest failed. See ${SELENE_ADLR_CI_PATH}/${CI_PIPELINE_ID}/${RUN_NAME}/results directory for result logs"; fi - echo "Completed the job" rules: - if: $TEST_LEVEL =~ $TESTS_TO_RUN_ON_THIS_COMMIT || $CI_JOB_NAME =~ $TESTS_TO_RUN_ON_THIS_COMMIT || $CI_JOB_NAME =~ $TEST_REGEX_ON_THIS_COMMIT @@ -134,7 +135,8 @@ unit_tests: if [[ $USE_TE -ne 1 ]]; then echo "Checking against ground truth file" export EXPECTED_METRICS_FILE=$BUILD_DIR/tests/functional_tests/test_results/$RUN_MODEL/$RUN_NAME.json - pytest $BUILD_DIR/tests/functional_tests/python_test_utils/test_ci_pipeline.py || echo "Pytest failed. See ${SELENE_ADLR_CI_PATH}/${CI_PIPELINE_ID}/${RUN_NAME}/results directory for result logs." + cmd='pytest $BUILD_DIR/tests/functional_tests/python_test_utils/test_ci_pipeline.py' + if $cmd; then echo "Pytest succeded"; else echo "Pytest failed. See ${SELENE_ADLR_CI_PATH}/${CI_PIPELINE_ID}/${RUN_NAME}/results directory for result logs"; fi fi - echo "Completed the job" rules: From bf5206e06608d4457bf2d0d111ac7910aa22b774 Mon Sep 17 00:00:00 2001 From: Jon Barker Date: Mon, 12 Jun 2023 11:08:30 -0700 Subject: [PATCH 13/15] Remove use of deprecated np.float in indexed_dataset.py --- megatron/data/indexed_dataset.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 4286e69b45c..ebe3fab81aa 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -95,9 +95,9 @@ def write_longs(f, a): 3: np.int16, 4: np.int32, 5: np.int64, - 6: np.float32, - 7: np.double, - 8: np.uint16 + 6: np.float64, + 7: np.float32, + 8: np.uint16, } @@ -268,8 +268,8 @@ class IndexedDatasetBuilder(object): np.int16: 2, np.int32: 4, np.int64: 8, - np.float: 4, - np.double: 8 + np.float32: 4, + np.float64: 8, } def __init__(self, out_file, dtype=np.int32): From f479999f56b6a5bdd5ff8783ae1ba22d0dcfda6a Mon Sep 17 00:00:00 2001 From: Lawrence McAfee Date: Tue, 13 Jun 2023 10:19:41 -0700 Subject: [PATCH 14/15] Retro fix for tensor parallelism. --- megatron/data/gpt_dataset.py | 46 +++++++++++++++------------ megatron/training.py | 51 ++++++++++++++++++------------ pretrain_gpt.py | 2 +- tools/retro/main.py | 29 +++++++++++++++-- tools/retro/query/chunk_dataset.py | 50 ++++++++++++++--------------- tools/retro/query/retro_dataset.py | 8 ++--- tools/retro/query/utils.py | 7 ++++ 7 files changed, 119 insertions(+), 74 deletions(-) diff --git a/megatron/data/gpt_dataset.py b/megatron/data/gpt_dataset.py index b0cf4df57e3..2662b5f80ac 100644 --- a/megatron/data/gpt_dataset.py +++ b/megatron/data/gpt_dataset.py @@ -90,12 +90,14 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, # Single dataset. if train_data_prefix is not None: train_dataset = build_dataset("train", train_data_prefix, data_impl, + splits_string, train_valid_test_num_samples[0], seq_length, seed, skip_warmup, data_cache_path=data_cache_path) if valid_data_prefix is not None: valid_dataset = build_dataset("valid", valid_data_prefix, data_impl, + splits_string, train_valid_test_num_samples[1], seq_length, seed, False, data_cache_path=data_cache_path) @@ -103,6 +105,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, if test_data_prefix is not None: test_dataset = build_dataset("test", test_data_prefix, data_impl, + splits_string, train_valid_test_num_samples[2], seq_length, seed, False, data_cache_path=data_cache_path) @@ -142,8 +145,8 @@ def build_dataset(index, name): if splits[index + 1] > splits[index]: documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32) - dataset = GPTDataset(name, data_prefix, - documents, indexed_dataset, + dataset = GPTDataset(name, data_prefix, documents, indexed_dataset, + splits_string, train_valid_test_num_samples[index], seq_length, seed, return_doc_ids, @@ -157,14 +160,15 @@ def build_dataset(index, name): return (train_dataset, valid_dataset, test_dataset) -def build_dataset(dataset_name, data_prefix, data_impl, num_samples, - seq_length, seed, skip_warmup, *, +def build_dataset(dataset_name, data_prefix, data_impl, + splits_string, num_samples, + seq_length, seed, skip_warmup, + *, data_cache_path=None): dataset = None if len(data_prefix) == 1: - dataset = _build_dataset(dataset_name, - data_prefix[0], data_impl, - num_samples, seq_length, + dataset = _build_dataset(dataset_name, data_prefix[0], data_impl, + splits_string, num_samples, seq_length, seed, skip_warmup, data_cache_path=data_cache_path) else: @@ -177,8 +181,8 @@ def build_dataset(dataset_name, data_prefix, data_impl, num_samples, # Build individual datasets. datasets = [] for i in range(len(prefixes)): - ds = _build_dataset(dataset_name, prefixes[i], - data_impl, dataset_num_samples[i], + ds = _build_dataset(dataset_name, prefixes[i], data_impl, + splits_string, dataset_num_samples[i], seq_length, seed, skip_warmup, data_cache_path=data_cache_path) if ds: @@ -191,8 +195,9 @@ def build_dataset(dataset_name, data_prefix, data_impl, num_samples, return dataset -def _build_dataset(dataset_name, data_prefix, data_impl, - num_samples, seq_length, seed, skip_warmup, *, +def _build_dataset(dataset_name, data_prefix, data_impl, splits_string, + num_samples, seq_length, seed, skip_warmup, + *, data_cache_path=None): """ Build dataset. This method is called when individual @@ -213,9 +218,8 @@ def _build_dataset(dataset_name, data_prefix, data_impl, documents = np.arange(start=0, stop=total_num_of_documents, step=1, dtype=np.int32) - dataset = GPTDataset(dataset_name, data_prefix, - documents, indexed_dataset, - num_samples, seq_length, seed, + dataset = GPTDataset(dataset_name, data_prefix, documents, indexed_dataset, + splits_string, num_samples, seq_length, seed, data_cache_path=data_cache_path) return dataset @@ -239,8 +243,8 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): class GPTDataset(torch.utils.data.Dataset): - def __init__(self, name, data_prefix, documents, - indexed_dataset, num_samples, seq_length, seed, + def __init__(self, name, data_prefix, documents, indexed_dataset, + splits_string, num_samples, seq_length, seed, return_doc_ids=False, *, data_cache_path=None): @@ -253,10 +257,10 @@ def __init__(self, name, data_prefix, documents, assert np.max(documents) < indexed_dataset.sizes.shape[0] # Build index mappings. - self.doc_idx, self.sample_idx, self.shuffle_idx, self.desc = \ + self.doc_idx, self.sample_idx, self.shuffle_idx, self.desc, self.desc_hash = \ _build_index_mappings(self.name, data_prefix, documents, self.indexed_dataset.sizes, - num_samples, seq_length, seed, + splits_string, num_samples, seq_length, seed, data_cache_path=data_cache_path) @@ -304,7 +308,8 @@ def __getitem__(self, idx): def _build_index_mappings(name, data_prefix, documents, sizes, - num_samples, seq_length, seed, *, + splits_string, num_samples, seq_length, seed, + *, data_cache_path): """Build doc-idx, sample-idx, and shuffle-idx. doc-idx: is an array (ordered) of documents to be used in training. @@ -326,6 +331,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes, desc += f"Number of samples {num_samples}\n" desc += f"Sequence length {seq_length}\n" desc += f"Random seed {seed}\n" + desc += f"Split {splits_string}\n" desc_hash = hashlib.md5(desc.encode('utf-8')).hexdigest() desc_filename = desc_hash + ".dsc" doc_idx_filename = desc_hash + '_doc_idx.npy' @@ -473,7 +479,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes, sample_idx.shape[0])) print_rank_0(' total number of epochs: {}'.format(num_epochs)) - return doc_idx, sample_idx, shuffle_idx, desc + return doc_idx, sample_idx, shuffle_idx, desc, desc_hash def _num_tokens(documents, sizes): diff --git a/megatron/training.py b/megatron/training.py index 9a5190b4a71..1fdb668cee5 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -888,9 +888,35 @@ def cyclic_iter(iter): yield x +def build_train_valid_test_datasets(build_train_valid_test_datasets_provider): + """Build pretraining datasets.""" + + args = get_args() + + # Number of train/valid/test samples. + if args.train_samples: + train_samples = args.train_samples + else: + train_samples = args.train_iters * args.global_batch_size + eval_iters = (args.train_iters // args.eval_interval + 1) * \ + args.eval_iters + test_iters = args.eval_iters + train_val_test_num_samples = [train_samples, + eval_iters * args.global_batch_size, + test_iters * args.global_batch_size] + print_rank_0(' > datasets target sizes (minimum size):') + print_rank_0(' train: {}'.format(train_val_test_num_samples[0])) + print_rank_0(' validation: {}'.format(train_val_test_num_samples[1])) + print_rank_0(' test: {}'.format(train_val_test_num_samples[2])) + + # Build the datasets. + return build_train_valid_test_datasets_provider(train_val_test_num_samples) + + def build_train_valid_test_data_loaders( build_train_valid_test_datasets_provider): - """XXX""" + """Build pretraining data loaders.""" + args = get_args() (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) @@ -910,25 +936,9 @@ def build_train_valid_test_data_loaders( # Data loader only on rank 0 of each model parallel group. if mpu.get_tensor_model_parallel_rank() == 0: - # Number of train/valid/test samples. - if args.train_samples: - train_samples = args.train_samples - else: - train_samples = args.train_iters * args.global_batch_size - eval_iters = (args.train_iters // args.eval_interval + 1) * \ - args.eval_iters - test_iters = args.eval_iters - train_val_test_num_samples = [train_samples, - eval_iters * args.global_batch_size, - test_iters * args.global_batch_size] - print_rank_0(' > datasets target sizes (minimum size):') - print_rank_0(' train: {}'.format(train_val_test_num_samples[0])) - print_rank_0(' validation: {}'.format(train_val_test_num_samples[1])) - print_rank_0(' test: {}'.format(train_val_test_num_samples[2])) - - # Build the datasets. - train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider( - train_val_test_num_samples) + # Build datasets. + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + build_train_valid_test_datasets_provider) # Build dataloders. train_dataloader = build_pretraining_data_loader( @@ -960,6 +970,7 @@ def build_train_valid_test_data_loaders( def build_train_valid_test_data_iterators( build_train_valid_test_datasets_provider): + """Build pretraining data iterators.""" args = get_args() diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 48cd7eedaf9..9792009da14 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. """Pretrain GPT""" diff --git a/tools/retro/main.py b/tools/retro/main.py index 3cebdc8ab72..f7850087c81 100644 --- a/tools/retro/main.py +++ b/tools/retro/main.py @@ -55,15 +55,40 @@ def add_retro_args(parser): "a separate file.") # GPT args. + group.add_argument('--retro-gpt-seed', type=int, default=1234, + help='Random seed used for python, numpy, ' + 'pytorch, and cuda.') + group.add_argument('--retro-gpt-data-impl', type=str, default='infer', + choices=['lazy', 'cached', 'mmap', 'infer'], + help='Implementation of indexed datasets.') + group.add_argument('--retro-gpt-data-path', nargs='*', required=True, + help='Path to the training dataset. Accepted format:' + '1) a single data path, 2) multiple datasets in the' + 'form: dataset1-weight dataset1-path dataset2-weight ' + 'dataset2-path ... It is used with --split when a ' + 'single dataset used for all three: train, valid ' + 'and test. It is exclusive to the other ' + '--*-data-path args') + group.add_argument('--retro-gpt-split', type=str, default='969,30,1', + help='Comma-separated list of proportions for training,' + ' validation, and test split. For example the split ' + '`90,5,5` will use 90%% of data for training, 5%% for ' + 'validation and 5%% for test.') + group.add_argument('--retro-gpt-mmap-warmup', action='store_true', + help='Warm up mmap files.') + group.add_argument("--retro-gpt-eval-interval", type=int, required=True, + help="GPT evaluation interval.") + group.add_argument("--retro-gpt-eval-iters", type=int, required=True, + help="GPT evaluation iterations.") group.add_argument("--retro-gpt-tokenizer-type", required=True, help="GPT tokenizer type.") group.add_argument("--retro-gpt-vocab-file", help="GPT vocab file.") group.add_argument("--retro-gpt-merge-file", help="GPT merge file.") group.add_argument("--retro-gpt-tokenizer-model", help="GPT tokenizer model file.") - group.add_argument("--retro-gpt-seq-length", type=int, default=2048, + group.add_argument("--retro-gpt-seq-length", type=int, required=True, help="GPT sequence length.") - group.add_argument("--retro-gpt-global-batch-size", type=int, default=2048, + group.add_argument("--retro-gpt-global-batch-size", type=int, required=True, help="GPT global batch size.") group.add_argument("--retro-gpt-chunk-length", type=int, default=64, help="GPT chunk length.") diff --git a/tools/retro/query/chunk_dataset.py b/tools/retro/query/chunk_dataset.py index f9cc4d51205..841788fe804 100644 --- a/tools/retro/query/chunk_dataset.py +++ b/tools/retro/query/chunk_dataset.py @@ -4,15 +4,16 @@ import torch from megatron import get_retro_args, print_rank_0 -from megatron.data.gpt_dataset import build_train_valid_test_datasets +from megatron.data.gpt_dataset import build_train_valid_test_datasets \ + as build_gpt_train_valid_test_datasets from megatron.training import ( - build_train_valid_test_data_loaders, + build_train_valid_test_datasets as build_pretraining_train_valid_test_datasets, update_train_iters, ) from tools.retro.db.utils import get_indexed_dataset_infos from tools.retro.utils import get_num_chunks_per_sample -from .utils import get_query_workdir +from .utils import get_neighbor_dirname, get_query_workdir class ChunkDataset(torch.utils.data.Dataset): @@ -86,14 +87,14 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): print_rank_0('> building train, validation, and test datasets ' 'for GPT ...') - train_ds, valid_ds, test_ds = build_train_valid_test_datasets( - data_prefix=args.data_path, - data_impl=args.data_impl, - splits_string=args.split, + train_ds, valid_ds, test_ds = build_gpt_train_valid_test_datasets( + data_prefix=args.retro_gpt_data_path, + data_impl=args.retro_gpt_data_impl, + splits_string=args.retro_gpt_split, train_valid_test_num_samples=train_val_test_num_samples, seq_length=args.retro_gpt_seq_length, - seed=args.seed, - skip_warmup=(not args.mmap_warmup), + seed=args.retro_gpt_seed, + skip_warmup=(not args.retro_gpt_mmap_warmup), return_doc_ids=args.retro_return_doc_ids) print_rank_0("> finished creating pretrained GPT datasets ...") @@ -115,28 +116,23 @@ def get_chunk_dataset_map(): verify_indexed_dataset_order() # Datasets. - print_rank_0(" > data loader.") - train_data_loader, valid_data_loader, test_data_loader \ - = build_train_valid_test_data_loaders( - train_valid_test_datasets_provider) - - data_loader_map = { - "train" : train_data_loader, - "valid" : valid_data_loader, - "test" : test_data_loader, + print_rank_0(" > datasets.") + train_ds, valid_ds, test_ds = build_pretraining_train_valid_test_datasets( + train_valid_test_datasets_provider) + + sample_dataset_map = { + "train" : train_ds, + "valid" : valid_ds, + "test" : test_ds, } # Info dict. - workdir = get_query_workdir() - dataset_map = { + chunk_dataset_map = { key : { - "neighbor_dir" : os.path.join( - workdir, - os.path.basename(loader.dataset.datasets[0].index_prefix), - ), - "data" : ChunkDataset(loader.dataset, args.retro_gpt_chunk_length), + "neighbor_dir" : get_neighbor_dirname(key, sample_ds), + "data" : ChunkDataset(sample_ds, args.retro_gpt_chunk_length), } - for key, loader in data_loader_map.items() if loader + for key, sample_ds in sample_dataset_map.items() if sample_ds } - return dataset_map + return chunk_dataset_map diff --git a/tools/retro/query/retro_dataset.py b/tools/retro/query/retro_dataset.py index e89a47007a4..0879d5d5fca 100644 --- a/tools/retro/query/retro_dataset.py +++ b/tools/retro/query/retro_dataset.py @@ -10,6 +10,7 @@ from tools.retro.external_libs import h5py from .chunk_dataset import get_chunk_dataset_map +from .utils import get_neighbor_dirname class RetroDataset(torch.utils.data.Dataset): @@ -120,11 +121,10 @@ def get_retro_datasets(verify_sizes=True): retro_args.retro_block_size) # Verify dataset prefixes. - sample_prefix = chunk_dataset.sample_dataset.datasets[0].index_prefix - neighbor_prefix = os.path.basename(neighbor_dir) - assert sample_prefix == neighbor_prefix, \ + expected_dir = get_neighbor_dirname(data_key, chunk_dataset.sample_dataset) + assert expected_dir == neighbor_dir, \ "inconsistent dataset source; '%s' vs. '%s'." % \ - (sample_prefix, neighbor_prefix) + (expected_dir, neighbor_dir) # Verify num chunks. n_sample_chunks = len(chunk_dataset) diff --git a/tools/retro/query/utils.py b/tools/retro/query/utils.py index a4ea2a5ca17..f6557abf1fa 100644 --- a/tools/retro/query/utils.py +++ b/tools/retro/query/utils.py @@ -1,5 +1,6 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +import hashlib import os from megatron import get_retro_args @@ -8,3 +9,9 @@ def get_query_workdir(): args = get_retro_args() return os.path.join(args.retro_workdir, "query") + + +def get_neighbor_dirname(key, dataset): + hashes = ",".join([ d.desc_hash for d in dataset.datasets ]) + hash = hashlib.md5(hashes.encode()).hexdigest() + return os.path.join(get_query_workdir(), os.path.basename(f"{key}_{hash}")) From 8b6ceeb55b8acf27ccbf6bd8643d747928a5bebd Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Tue, 27 Jun 2023 17:42:31 +0000 Subject: [PATCH 15/15] pass on data_cache_pass in build_dataset_group --- megatron/data/gpt_dataset.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/megatron/data/gpt_dataset.py b/megatron/data/gpt_dataset.py index 75535208676..4a572a95518 100644 --- a/megatron/data/gpt_dataset.py +++ b/megatron/data/gpt_dataset.py @@ -126,7 +126,8 @@ def build_dataset_group(dataset_group_name, paths, weights, splits, data_impl, data_impl, train_valid_test_num_samples, seq_length, seed, skip_warmup, - dataset_group_name, train_valid_test) + dataset_group_name, train_valid_test, + data_cache_path=data_cache_path) return dataset # Blending dataset. else: @@ -150,7 +151,8 @@ def build_dataset_group(dataset_group_name, paths, weights, splits, data_impl, datasets_train_valid_test_num_samples[i], seq_length, seed, skip_warmup, - dataset_group_name, train_valid_test) + dataset_group_name, train_valid_test, + data_cache_path=data_cache_path) # ds can be none if the dataset is so small that not a single document # is present in the split.