From f20a2c413350d7f8f741c8be51af90d00375b991 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 17 Jul 2024 09:30:05 -0700 Subject: [PATCH] match nemo 1's default behavior for drop_last and pad_samples_to_global_batch_size (#9707) (#9753) Signed-off-by: ashors1 Co-authored-by: Anna Shors <71393111+ashors1@users.noreply.github.com> Co-authored-by: Marc Romeyn Signed-off-by: Alexandros Koumparoulis --- nemo/lightning/data.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/nemo/lightning/data.py b/nemo/lightning/data.py index 809885e75c79..58ba81a4ddac 100644 --- a/nemo/lightning/data.py +++ b/nemo/lightning/data.py @@ -85,10 +85,13 @@ def add_megatron_sampler( rampup_batch_size: Optional[List[int]] = None, consumed_samples: int = 0, dataloader_type: Literal["single", "cyclic"] = "single", + drop_last: bool = True, + pad_samples_to_global_batch_size: bool = False, # data_sharding: bool = False ) -> DataLoader: from megatron.core import parallel_state + ## TODO: expose drop_last and pad_samples_to_global_batch_size args if dataloader_type == 'single': batch_sampler = MegatronPretrainingSampler( total_samples=len(dataloader.dataset), @@ -98,8 +101,8 @@ def add_megatron_sampler( rampup_batch_size=rampup_batch_size, data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size(), - drop_last=getattr(dataloader, "_drop_last", False), - pad_samples_to_global_batch_size=getattr(dataloader, "_pad_samples_to_global_batch_size", False), + drop_last=drop_last, + pad_samples_to_global_batch_size=pad_samples_to_global_batch_size, ) elif dataloader_type == 'cyclic': batch_sampler = MegatronPretrainingRandomSampler( @@ -108,7 +111,7 @@ def add_megatron_sampler( micro_batch_size=micro_batch_size, data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size(), - pad_samples_to_global_batch_size=getattr(dataloader, "_pad_samples_to_global_batch_size", False), + drop_last=drop_last, # data_sharding=data_sharding ) else: