Skip to content

Commit

Permalink
fix legacy ds padding bug (NVIDIA#9716)
Browse files Browse the repository at this point in the history
* fix legacy ds padding bug

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* Apply isort and black reformatting

Signed-off-by: dimapihtar <dimapihtar@users.noreply.github.com>

* avoid code repetition

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* fix typo

Signed-off-by: dimapihtar <dpihtar@gmail.com>

---------

Signed-off-by: dimapihtar <dpihtar@gmail.com>
Signed-off-by: dimapihtar <dimapihtar@users.noreply.github.com>
Co-authored-by: dimapihtar <dimapihtar@users.noreply.github.com>
Signed-off-by: Malay Nagda <malayn@malayn-mlt.client.nvidia.com>
  • Loading branch information
2 people authored and Malay Nagda committed Jul 26, 2024
1 parent 199ea51 commit d158d37
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,16 @@ def get_start_end_idx(self):
end_idx = start_idx + self.micro_batch_size
return start_idx, end_idx

def _get_padding_indices(self, pad_samples_num):
return range(-1, -pad_samples_num - 1, -1)

def __iter__(self):
batch = []
# Last batch will be dropped if drop_last is not set False
indices = range(self.consumed_samples, self.total_samples)
if (not self.drop_last) and self.pad_samples_to_global_batch_size:
pad_samples_num = -len(indices) % self.global_batch_size
pad_indices = [None] * pad_samples_num
pad_indices = self._get_padding_indices(pad_samples_num)
indices = chain(indices, pad_indices)

for idx in indices:
Expand All @@ -125,6 +128,11 @@ def __iter__(self):
yield batch[start_idx:end_idx]


class MegatronCorePretrainingSampler(MegatronPretrainingSampler):
def _get_padding_indices(self, pad_samples_num):
return [None] * pad_samples_num


class MegatronPretrainingRandomSampler(BaseMegatronSampler):
def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from nemo.collections.common.parts.utils import extend_instance
from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import (
MegatronCorePretrainingSampler,
MegatronPretrainingRandomSampler,
MegatronPretrainingSampler,
)
Expand Down Expand Up @@ -1605,8 +1606,13 @@ def build_pretraining_data_loader(
logging.info(f'Building dataloader with consumed samples: {consumed_samples}')
# Megatron sampler
if hasattr(self.cfg.data, 'dataloader_type') and self.cfg.data.dataloader_type is not None:
data_sampler = (
MegatronPretrainingSampler
if self.cfg.data.get('legacy_dataset', False)
else MegatronCorePretrainingSampler
)
if self.cfg.data.dataloader_type == 'single':
batch_sampler = MegatronPretrainingSampler(
batch_sampler = data_sampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=self.cfg.micro_batch_size,
Expand Down

0 comments on commit d158d37

Please sign in to comment.