From 6d94980bd846354d9fe7a2523171f052d3a3949e Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 11 Jul 2024 17:39:11 -0700 Subject: [PATCH] match nemo 1's default behavior for drop_last and pad_samples_to_global_batch_size Signed-off-by: ashors1 --- nemo/lightning/data.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/nemo/lightning/data.py b/nemo/lightning/data.py index d83f5ba3b728..d28a2d7dfe04 100644 --- a/nemo/lightning/data.py +++ b/nemo/lightning/data.py @@ -85,10 +85,13 @@ def add_megatron_sampler( rampup_batch_size: Optional[List[int]] = None, consumed_samples: int = 0, dataloader_type: Literal["single", "cyclic"] = "single", + drop_last: bool = True, + pad_samples_to_global_batch_size: bool = False, # data_sharding: bool = False ) -> DataLoader: from megatron.core import parallel_state + ## TODO: expose drop_last and pad_samples_to_global_batch_size args if dataloader_type == 'single': batch_sampler = MegatronPretrainingSampler( total_samples=len(dataloader.dataset), @@ -98,8 +101,8 @@ def add_megatron_sampler( rampup_batch_size=rampup_batch_size, data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size(), - drop_last=getattr(dataloader, "_drop_last", False), - pad_samples_to_global_batch_size=getattr(dataloader, "_pad_samples_to_global_batch_size", False), + drop_last=drop_last, + pad_samples_to_global_batch_size=pad_samples_to_global_batch_size, ) elif dataloader_type == 'cyclic': batch_sampler = MegatronPretrainingRandomSampler( @@ -108,7 +111,7 @@ def add_megatron_sampler( micro_batch_size=micro_batch_size, data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size(), - pad_samples_to_global_batch_size=getattr(dataloader, "_pad_samples_to_global_batch_size", False), + drop_last=drop_last, # data_sharding=data_sharding ) else: