From 4843b3009921e238688edb88e6ba6616ca141eb1 Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Fri, 12 Jul 2024 06:16:00 -0700 Subject: [PATCH 1/4] fix legacy ds padding bug Signed-off-by: dimapihtar --- .../megatron/data_samplers.py | 26 +++++++++++++++++++ .../language_modeling/megatron_gpt_model.py | 4 ++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py b/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py index 4a8b989a7b6d..02f2fd8059af 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py @@ -100,6 +100,32 @@ def get_start_end_idx(self): end_idx = start_idx + self.micro_batch_size return start_idx, end_idx + def __iter__(self): + batch = [] + # Last batch will be dropped if drop_last is not set False + indices = range(self.consumed_samples, self.total_samples) + if (not self.drop_last) and self.pad_samples_to_global_batch_size: + pad_samples_num = -len(indices) % self.global_batch_size + pad_indices = range(-1, -pad_samples_num - 1, -1) + indices = chain(indices, pad_indices) + + for idx in indices: + batch.append(idx) + if len(batch) == self.micro_batch_times_data_parallel_size: + start_idx, end_idx = self.get_start_end_idx() + yield batch[start_idx:end_idx] + batch = [] + + # Check the last partial batch and see drop_last is set + if len(batch) > 0 and not self.drop_last: + assert ( + not self.pad_samples_to_global_batch_size + ), 'with pad_samples_to_global_batch_size all batches should be complete' + start_idx, end_idx = self.get_start_end_idx() + yield batch[start_idx:end_idx] + + +class MegatronCorePretrainingSampler(MegatronPretrainingSampler): def __iter__(self): batch = [] # Last batch will be dropped if drop_last is not set False diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 69cd06021f50..992885d376e9 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -33,6 +33,7 @@ from nemo.collections.common.parts.utils import extend_instance from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( MegatronPretrainingRandomSampler, + MegatronCorePretrainingSampler, MegatronPretrainingSampler, ) from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import build_train_valid_test_datasets @@ -1605,8 +1606,9 @@ def build_pretraining_data_loader( logging.info(f'Building dataloader with consumed samples: {consumed_samples}') # Megatron sampler if hasattr(self.cfg.data, 'dataloader_type') and self.cfg.data.dataloader_type is not None: + data_sampler = MegatronPretrainingSampler if self.cfg.data.get('legacy_dataset', False) else MegatronCorePretrainingSampler if self.cfg.data.dataloader_type == 'single': - batch_sampler = MegatronPretrainingSampler( + batch_sampler = data_sampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=self.cfg.micro_batch_size, From 278fd4cd5621aaec9b0da139d236816a7b13e331 Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Fri, 12 Jul 2024 13:17:19 +0000 Subject: [PATCH 2/4] Apply isort and black reformatting Signed-off-by: dimapihtar --- .../nlp/models/language_modeling/megatron_gpt_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 992885d376e9..e4cab6cec26f 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -32,8 +32,8 @@ from nemo.collections.common.parts.utils import extend_instance from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( - MegatronPretrainingRandomSampler, MegatronCorePretrainingSampler, + MegatronPretrainingRandomSampler, MegatronPretrainingSampler, ) from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import build_train_valid_test_datasets @@ -1606,7 +1606,11 @@ def build_pretraining_data_loader( logging.info(f'Building dataloader with consumed samples: {consumed_samples}') # Megatron sampler if hasattr(self.cfg.data, 'dataloader_type') and self.cfg.data.dataloader_type is not None: - data_sampler = MegatronPretrainingSampler if self.cfg.data.get('legacy_dataset', False) else MegatronCorePretrainingSampler + data_sampler = ( + MegatronPretrainingSampler + if self.cfg.data.get('legacy_dataset', False) + else MegatronCorePretrainingSampler + ) if self.cfg.data.dataloader_type == 'single': batch_sampler = data_sampler( total_samples=len(dataset), From ffac48bc42f9414ef1cdd2b0b7ca9f53600e9cbe Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Fri, 12 Jul 2024 07:26:02 -0700 Subject: [PATCH 3/4] avoid code repetition Signed-off-by: dimapihtar --- .../megatron/data_samplers.py | 30 ++++--------------- 1 file changed, 6 insertions(+), 24 deletions(-) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py b/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py index 02f2fd8059af..4d89fa2a1243 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py @@ -100,13 +100,16 @@ def get_start_end_idx(self): end_idx = start_idx + self.micro_batch_size return start_idx, end_idx + def _get_padding_indices(self, pad_samples_num): + return range(-1, -pad_samples_num - 1, -1) + def __iter__(self): batch = [] # Last batch will be dropped if drop_last is not set False indices = range(self.consumed_samples, self.total_samples) if (not self.drop_last) and self.pad_samples_to_global_batch_size: pad_samples_num = -len(indices) % self.global_batch_size - pad_indices = range(-1, -pad_samples_num - 1, -1) + pad_indices = self._get_padding_indices(self, pad_samples_num) indices = chain(indices, pad_indices) for idx in indices: @@ -126,29 +129,8 @@ def __iter__(self): class MegatronCorePretrainingSampler(MegatronPretrainingSampler): - def __iter__(self): - batch = [] - # Last batch will be dropped if drop_last is not set False - indices = range(self.consumed_samples, self.total_samples) - if (not self.drop_last) and self.pad_samples_to_global_batch_size: - pad_samples_num = -len(indices) % self.global_batch_size - pad_indices = [None] * pad_samples_num - indices = chain(indices, pad_indices) - - for idx in indices: - batch.append(idx) - if len(batch) == self.micro_batch_times_data_parallel_size: - start_idx, end_idx = self.get_start_end_idx() - yield batch[start_idx:end_idx] - batch = [] - - # Check the last partial batch and see drop_last is set - if len(batch) > 0 and not self.drop_last: - assert ( - not self.pad_samples_to_global_batch_size - ), 'with pad_samples_to_global_batch_size all batches should be complete' - start_idx, end_idx = self.get_start_end_idx() - yield batch[start_idx:end_idx] + def _get_padding_indices(self, pad_samples_num): + return [None] * pad_samples_num class MegatronPretrainingRandomSampler(BaseMegatronSampler): From b4c54f6405eae287a75d7a37832eed50d5bdafbd Mon Sep 17 00:00:00 2001 From: dimapihtar Date: Fri, 12 Jul 2024 07:32:15 -0700 Subject: [PATCH 4/4] fix typo Signed-off-by: dimapihtar --- .../nlp/data/language_modeling/megatron/data_samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py b/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py index 4d89fa2a1243..622e2d759266 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py @@ -109,7 +109,7 @@ def __iter__(self): indices = range(self.consumed_samples, self.total_samples) if (not self.drop_last) and self.pad_samples_to_global_batch_size: pad_samples_num = -len(indices) % self.global_batch_size - pad_indices = self._get_padding_indices(self, pad_samples_num) + pad_indices = self._get_padding_indices(pad_samples_num) indices = chain(indices, pad_indices) for idx in indices: