diff --git a/nemo/lightning/data.py b/nemo/lightning/data.py index 809885e75c79..58ba81a4ddac 100644 --- a/nemo/lightning/data.py +++ b/nemo/lightning/data.py @@ -85,10 +85,13 @@ def add_megatron_sampler( rampup_batch_size: Optional[List[int]] = None, consumed_samples: int = 0, dataloader_type: Literal["single", "cyclic"] = "single", + drop_last: bool = True, + pad_samples_to_global_batch_size: bool = False, # data_sharding: bool = False ) -> DataLoader: from megatron.core import parallel_state + ## TODO: expose drop_last and pad_samples_to_global_batch_size args if dataloader_type == 'single': batch_sampler = MegatronPretrainingSampler( total_samples=len(dataloader.dataset), @@ -98,8 +101,8 @@ def add_megatron_sampler( rampup_batch_size=rampup_batch_size, data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size(), - drop_last=getattr(dataloader, "_drop_last", False), - pad_samples_to_global_batch_size=getattr(dataloader, "_pad_samples_to_global_batch_size", False), + drop_last=drop_last, + pad_samples_to_global_batch_size=pad_samples_to_global_batch_size, ) elif dataloader_type == 'cyclic': batch_sampler = MegatronPretrainingRandomSampler( @@ -108,7 +111,7 @@ def add_megatron_sampler( micro_batch_size=micro_batch_size, data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size(), - pad_samples_to_global_batch_size=getattr(dataloader, "_pad_samples_to_global_batch_size", False), + drop_last=drop_last, # data_sharding=data_sharding ) else: