Skip to content

Commit

Permalink
Optionally disable logging in the data sampler to support predict_step (
Browse files Browse the repository at this point in the history
NVIDIA#10127)

* Resolve merge conflicts with consumed sample logging

Signed-off-by: John St John <jstjohn@nvidia.com>

* Add test file that captures the predict step error

Signed-off-by: John St John <jstjohn@nvidia.com>

* Add fixme comment around proper checkpoint nemo2 handling

Signed-off-by: John St John <jstjohn@nvidia.com>

* Skip megatron training test on CPU nodes

Signed-off-by: John St John <jstjohn@nvidia.com>

* Move output_log to last arg for compatibility

Signed-off-by: John St John <jstjohn@nvidia.com>

* try setting the default root dir in predict to avoid writing artifacts to cwd

Signed-off-by: John St John <jstjohn@nvidia.com>

* Handle the new check for batch samplers to enable predict_step

Signed-off-by: John St John <jstjohn@nvidia.com>

* Only reset the global microbatch, not entire parallel state

Signed-off-by: John St John <jstjohn@nvidia.com>

* Destroy the right sets of state in test of lightning trainer

Signed-off-by: John St John <jstjohn@nvidia.com>

* Fix typo and rename state resetting functions

Signed-off-by: John St John <jstjohn@nvidia.com>

* Run test in a subprocess to avoid contaminating global state

Signed-off-by: John St John <jstjohn@nvidia.com>

---------

Signed-off-by: John St John <jstjohn@nvidia.com>
Signed-off-by: adityavavre <aditya.vavre@gmail.com>
  • Loading branch information
jstjohn authored and adityavavre committed Sep 15, 2024
1 parent 78b38ef commit b1f5c17
Show file tree
Hide file tree
Showing 3 changed files with 634 additions and 13 deletions.
19 changes: 18 additions & 1 deletion nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,24 @@ def forward(

pipeline = self.pipeline

use_global_batch_sampler = self.trainer.datamodule.data_sampler.dataloader_type == 'batch'
# FIXME: cleanup the following code block which is here for backwards compatibility with nemo1. The "batch"
# sampler is a nemo1 sampler. It requires some custom code here to use (if use_global_batch_sampler).
# by default we shouldn't use this "batch" sampler probably.
if getattr(self.trainer, "datamodule", None) is not None:
use_global_batch_sampler = self.trainer.datamodule.data_sampler.dataloader_type == 'batch'
elif getattr(self.trainer, "predict_dataloaders", None) is not None:
from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import ( # noqa: I001
MegatronPretrainingBatchSampler,
)

# The batch_sampler gets injected into the dataloader by the data_sampler. When doing predict without a
# datamodule we can look inside the dataloader's batch_sampler to see if it is the nemo1 style sampler
# that we need to handle specially below.
use_global_batch_sampler = isinstance(
self.trainer.predict_dataloaders.batch_sampler, MegatronPretrainingBatchSampler
)
else:
raise ValueError("Unsure how to check for nemo1 global_batch_sampler status. TODO maybe default to False?")
if use_global_batch_sampler:
from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split

Expand Down
30 changes: 18 additions & 12 deletions nemo/lightning/pytorch/plugins/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ def __init__(
dataloader_type: Literal["single", "cyclic", "batch"] = "single",
init_consumed_samples: int = 0,
init_global_step: int = 0,
output_log: bool = True,
):
self.seq_len = seq_len
self.output_log = output_log
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.rampup_batch_size = rampup_batch_size
Expand Down Expand Up @@ -95,25 +97,29 @@ def on_megatron_step_end(self, trainer: pl.Trainer, pl_module: pl.LightningModul
self.prev_global_batch_size = self.current_global_batch_size

consumed_samples = self.compute_consumed_samples(trainer.global_step + 1 - self.init_global_step)
pl_module.log(
'consumed_samples',
consumed_samples,
prog_bar=True,
batch_size=1,
)
if self.output_log:
# You may need to turn off logging, for example when doing trainer.predict(model, data)
pl_module.log(
'consumed_samples',
consumed_samples,
prog_bar=True,
batch_size=1,
)

self.prev_consumed_samples = consumed_samples

update_num_microbatches(
consumed_samples=consumed_samples,
consistency_check=False,
)
pl_module.log(
"global_batch_size",
self.current_global_batch_size,
prog_bar=True,
batch_size=1,
)
if self.output_log:
# You may need to turn off logging, for example when doing trainer.predict(model, data)
pl_module.log(
"global_batch_size",
self.current_global_batch_size,
prog_bar=True,
batch_size=1,
)
self.if_first_step = 1

@property
Expand Down
Loading

0 comments on commit b1f5c17

Please sign in to comment.