Skip to content

Commit

Permalink
match nemo 1's default behavior for drop_last and pad_samples_to_glob…
Browse files Browse the repository at this point in the history
…al_batch_size (NVIDIA#9707) (NVIDIA#9753)

Signed-off-by: ashors1 <ashors@nvidia.com>
Co-authored-by: Anna Shors <71393111+ashors1@users.noreply.github.com>
Co-authored-by: Marc Romeyn <mromeijn@nvidia.com>
Signed-off-by: Malay Nagda <malayn@malayn-mlt.client.nvidia.com>
  • Loading branch information
3 people authored and Malay Nagda committed Jul 26, 2024
1 parent ab52c8d commit c67e077
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions nemo/lightning/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,13 @@ def add_megatron_sampler(
rampup_batch_size: Optional[List[int]] = None,
consumed_samples: int = 0,
dataloader_type: Literal["single", "cyclic"] = "single",
drop_last: bool = True,
pad_samples_to_global_batch_size: bool = False,
# data_sharding: bool = False
) -> DataLoader:
from megatron.core import parallel_state

## TODO: expose drop_last and pad_samples_to_global_batch_size args
if dataloader_type == 'single':
batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataloader.dataset),
Expand All @@ -98,8 +101,8 @@ def add_megatron_sampler(
rampup_batch_size=rampup_batch_size,
data_parallel_rank=parallel_state.get_data_parallel_rank(),
data_parallel_size=parallel_state.get_data_parallel_world_size(),
drop_last=getattr(dataloader, "_drop_last", False),
pad_samples_to_global_batch_size=getattr(dataloader, "_pad_samples_to_global_batch_size", False),
drop_last=drop_last,
pad_samples_to_global_batch_size=pad_samples_to_global_batch_size,
)
elif dataloader_type == 'cyclic':
batch_sampler = MegatronPretrainingRandomSampler(
Expand All @@ -108,7 +111,7 @@ def add_megatron_sampler(
micro_batch_size=micro_batch_size,
data_parallel_rank=parallel_state.get_data_parallel_rank(),
data_parallel_size=parallel_state.get_data_parallel_world_size(),
pad_samples_to_global_batch_size=getattr(dataloader, "_pad_samples_to_global_batch_size", False),
drop_last=drop_last,
# data_sharding=data_sharding
)
else:
Expand Down

0 comments on commit c67e077

Please sign in to comment.