Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optionally disable logging in the data sampler to support predict_step #10127

Merged
merged 11 commits into from
Aug 21, 2024
Merged
19 changes: 18 additions & 1 deletion nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,24 @@ def forward(

pipeline = self.pipeline

use_global_batch_sampler = self.trainer.datamodule.data_sampler.dataloader_type == 'batch'
# FIXME: cleanup the following code block which is here for backwards compatibility with nemo1. The "batch"
# sampler is a nemo1 sampler. It requires some custom code here to use (if use_global_batch_sampler).
# by default we shouldn't use this "batch" sampler probably.
if getattr(self.trainer, "datamodule", None) is not None:
use_global_batch_sampler = self.trainer.datamodule.data_sampler.dataloader_type == 'batch'
elif getattr(self.trainer, "predict_dataloaders", None) is not None:
from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import ( # noqa: I001
MegatronPretrainingBatchSampler,
)

# The batch_sampler gets injected into the dataloader by the data_sampler. When doing predict without a
# datamodule we can look inside the dataloader's batch_sampler to see if it is the nemo1 style sampler
# that we need to handle specially below.
use_global_batch_sampler = isinstance(
self.trainer.predict_dataloaders.batch_sampler, MegatronPretrainingBatchSampler
)
else:
raise ValueError("Unsure how to check for nemo1 global_batch_sampler status. TODO maybe default to False?")
if use_global_batch_sampler:
from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split

Expand Down
30 changes: 18 additions & 12 deletions nemo/lightning/pytorch/plugins/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ def __init__(
dataloader_type: Literal["single", "cyclic", "batch"] = "single",
init_consumed_samples: int = 0,
init_global_step: int = 0,
output_log: bool = True,
):
self.seq_len = seq_len
self.output_log = output_log
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.rampup_batch_size = rampup_batch_size
Expand Down Expand Up @@ -95,25 +97,29 @@ def on_megatron_step_end(self, trainer: pl.Trainer, pl_module: pl.LightningModul
self.prev_global_batch_size = self.current_global_batch_size

consumed_samples = self.compute_consumed_samples(trainer.global_step + 1 - self.init_global_step)
pl_module.log(
'consumed_samples',
consumed_samples,
prog_bar=True,
batch_size=1,
)
if self.output_log:
# You may need to turn off logging, for example when doing trainer.predict(model, data)
pl_module.log(
'consumed_samples',
consumed_samples,
prog_bar=True,
batch_size=1,
)

self.prev_consumed_samples = consumed_samples

update_num_microbatches(
consumed_samples=consumed_samples,
consistency_check=False,
)
pl_module.log(
"global_batch_size",
self.current_global_batch_size,
prog_bar=True,
batch_size=1,
)
if self.output_log:
# You may need to turn off logging, for example when doing trainer.predict(model, data)
pl_module.log(
"global_batch_size",
self.current_global_batch_size,
prog_bar=True,
batch_size=1,
)
self.if_first_step = 1

@property
Expand Down
Loading
Loading