Skip to content

Commit

Permalink
Handle the new check for batch samplers to enable predict_step
Browse files Browse the repository at this point in the history
Signed-off-by: John St John <jstjohn@nvidia.com>
  • Loading branch information
jstjohn committed Aug 16, 2024
1 parent 650157b commit 88d9df9
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,24 @@ def forward(

pipeline = self.pipeline

use_global_batch_sampler = self.trainer.datamodule.data_sampler.dataloader_type == 'batch'
# FIXME: cleanup the following code block which is here for backwards compatibility with nemo1. The "batch"
# sampler is a nemo1 sampler. It requires some custom code here to use (if use_global_batch_sampler).
# by default we shouldn't use this "batch" sampler probably.
if getattr(self.trainer, "datamodule", None) is not None:
use_global_batch_sampler = self.trainer.datamodule.data_sampler.dataloader_type == 'batch'
elif getattr(self.trainer, "predict_dataloaders", None) is not None:
from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import ( # noqa: I001
MegatronPretrainingBatchSampler,
)

# The batch_sampler gets injected into the dataloader by the data_sampler. When doing predict without a
# datamodule we can look inside the dataloader's batch_sampler to see if it is the nemo1 style sampler
# that we need to handle specially below.
use_global_batch_sampler = isinstance(
self.trainer.predict_dataloaders.batch_sampler, MegatronPretrainingBatchSampler
)
else:
raise ValueError("Unsure how to check for nemo1 global_batch_sampler status. TODO maybe default to False?")
if use_global_batch_sampler:
from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split

Expand Down

0 comments on commit 88d9df9

Please sign in to comment.