From e8304d6ee76dc09e916dee756423020e51063b4c Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Date: Wed, 25 Sep 2024 10:15:30 -0700 Subject: [PATCH] Romeyn/sampler (#10525) * Introducing MegatronStep to make things more extensible Signed-off-by: Alexandros Koumparoulis * Improve megatron callbacks Signed-off-by: Marc Romeijn Signed-off-by: Alexandros Koumparoulis * Some small fixes Signed-off-by: Marc Romeijn Signed-off-by: Alexandros Koumparoulis * Apply isort and black reformatting Signed-off-by: akoumpa Signed-off-by: Alexandros Koumparoulis * remove debg code Signed-off-by: Alexandros Koumparoulis * add forward_only to forward backward func Signed-off-by: Alexandros Koumparoulis * add global-batch-sampler support to MegatronStep Signed-off-by: Alexandros Koumparoulis * Apply isort and black reformatting Signed-off-by: akoumpa --------- Signed-off-by: Alexandros Koumparoulis Signed-off-by: Marc Romeijn Signed-off-by: akoumpa Co-authored-by: Marc Romeijn Co-authored-by: akoumpa Co-authored-by: sichu --- nemo/lightning/megatron_parallel.py | 715 +++++++++++++----- .../lightning/pytorch/plugins/data_sampler.py | 31 +- .../pytorch/strategies/megatron_strategy.py | 29 +- 3 files changed, 550 insertions(+), 225 deletions(-) diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 60f090d6318f..096c7728d4a1 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -18,6 +18,7 @@ import inspect import queue from collections import defaultdict +from dataclasses import dataclass from typing import ( Any, Callable, @@ -36,6 +37,7 @@ runtime_checkable, ) +import pytorch_lightning as pl import torch import torch.distributed from megatron.core import parallel_state @@ -48,6 +50,7 @@ DataT = TypeVar("DataT", Tensor, Dict[str, Tensor], Sequence[Tensor]) ModelT = TypeVar("ModelT", bound=nn.Module) +T = TypeVar('T') @runtime_checkable @@ -207,7 +210,7 @@ def forward( data: Union[DataT, Iterator[DataT], List[Iterator[DataT]]], forward_only: bool = True, data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None, - forward_step: Optional[Callable[[nn.Module, DataT], Tensor]] = None, + forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None, loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None, seq_length: Optional[int] = None, micro_batch_size: Optional[int] = None, @@ -238,97 +241,61 @@ def forward( """ _forward_step = forward_step or self.forward_step _loss_reduction = loss_reduction or self.loss_reduction - _micro_batch_size: int = micro_batch_size or self.infer_micro_batch_size(data) - _seq_length: int = seq_length or self.infer_seq_length(data) - _num_microbatches: int = num_microbatches or self.infer_num_microbatches(data) - - pipeline = self.pipeline - - # FIXME: cleanup the following code block which is here for backwards compatibility with nemo1. The "batch" - # sampler is a nemo1 sampler. It requires some custom code here to use (if use_global_batch_sampler). - # by default we shouldn't use this "batch" sampler probably. - if getattr(self.trainer, "datamodule", None) is not None: - use_global_batch_sampler = self.trainer.datamodule.data_sampler.dataloader_type == 'batch' - elif getattr(self.trainer, "predict_dataloaders", None) is not None: - from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import ( # noqa: I001 - MegatronPretrainingBatchSampler, - ) - - # The batch_sampler gets injected into the dataloader by the data_sampler. When doing predict without a - # datamodule we can look inside the dataloader's batch_sampler to see if it is the nemo1 style sampler - # that we need to handle specially below. - use_global_batch_sampler = isinstance( - self.trainer.predict_dataloaders.batch_sampler, MegatronPretrainingBatchSampler - ) - else: - raise ValueError("Unsure how to check for nemo1 global_batch_sampler status. TODO maybe default to False?") - if use_global_batch_sampler: - from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split - - # The current way of using a batch sampler + split to micro iterator results in - # extraneous padding, and is only implemented to ensure bit-exactness with NeMo 1. - # This part in NeMo 1 was written when megatron fwd_bwd_function did not support unequal - # sequence lengths, but it does now. Hence this part should be revisited in the future. - batch = next(data) - if isinstance(batch, tuple) and len(batch) == 3: - batch = batch[0] - data = get_iterator_k_split(batch, _num_microbatches, True) - - data_iterator: List[Iterator[DataT]] = self.to_data_iterator_list(data) - context = self._build_context({**locals()}) + _forward_context = {} if wrap_forward_step: _data_step = data_step or self.data_step forward_step_func = self.wrapped_forward_step( - _forward_step, + forward_step=_forward_step, data_step=_data_step, loss_reduction=_loss_reduction, - context=context, + context=_forward_context, ) else: forward_step_func = _forward_step - self.callbacks.event("on_megatron_step_start", **context) - self.callbacks.event("on_megatron_microbatches_start", **context) - - microbatch_outputs = self.forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=data_iterator, - model=pipeline, + step = MegatronStep.infer( + self, + data, + forward_step_func, forward_only=forward_only, - micro_batch_size=_micro_batch_size, - seq_length=_seq_length, - num_microbatches=_num_microbatches, + micro_batch_size=micro_batch_size, + num_microbatches=num_microbatches, + seq_length=seq_length, ) + _forward_context["step"] = step + step = self.callbacks.transform_event("on_megatron_step_start", step) - context["microbatch_outputs"] = microbatch_outputs - - self.callbacks.event("on_megatron_microbatches_end", **context) + self.callbacks.event("on_megatron_microbatches_start", step=step) + microbatch_outputs = step() + self.callbacks.event("on_megatron_microbatches_end", step=step, microbatch_outputs=microbatch_outputs) if microbatch_outputs: - self.callbacks.event("on_megatron_reduce_microbatches_start", **context) + self.callbacks.event( + "on_megatron_reduce_microbatches_start", step=step, microbatch_outputs=microbatch_outputs + ) if isinstance(_loss_reduction, _ModuleStepFunction): _loss_reduction = _loss_reduction(self[0]) - loss_mean = _loss_reduction.reduce(microbatch_outputs) - context["loss_mean"] = loss_mean - self.callbacks.event("on_megatron_reduce_microbatches_end", **context) + reduced = _loss_reduction.reduce(microbatch_outputs) + self.callbacks.event( + "on_megatron_reduce_microbatches_end", + step=step, + loss_reduction=_loss_reduction, + microbatch_outputs=microbatch_outputs, + reduced=reduced, + ) else: # we're not on the last pipeline stage so no losses - loss_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + reduced = torch.tensor(0.0, device=torch.cuda.current_device()) - self.callbacks.event("on_megatron_log_step_end", **context) - self.callbacks.event("on_megatron_step_end", **context) + self.callbacks.event("on_megatron_step_end", step=step, microbatch_outputs=microbatch_outputs, reduced=reduced) - return loss_mean + return reduced def wrapped_forward_step( - self, - forward_step, - loss_reduction, - context, - data_step, + self, forward_step, loss_reduction, data_step, context ) -> Callable[[nn.Module, DataT], Tuple[torch.Tensor, "MegatronCallbackProtocol"]]: """The method wraps the forward step function and returns a callable. @@ -355,6 +322,7 @@ def wrapped_forward_step_func(dataloader_iter, model): _data_step = data_step batch = _data_step(dataloader_iter) + step = context["step"] if isinstance(loss_reduction, _ModuleStepFunction): forward_callback = loss_reduction(model) @@ -366,10 +334,12 @@ def wrapped_forward_step_func(dataloader_iter, model): else: _forward_step = forward_step - _context = {**context, "batch": batch} - _context["forward_callback"] = forward_callback - - self.callbacks.event("on_megatron_microbatch_start", **_context) + self.callbacks.event( + "on_megatron_microbatch_start", + step=step, + batch=batch, + forward_callback=forward_callback, + ) if self.precision_plugin and parallel_state.is_pipeline_first_stage(): batch = self.precision_plugin.convert_input(batch) @@ -388,106 +358,18 @@ def wrapped_forward_step_func(dataloader_iter, model): if self.precision_plugin and parallel_state.is_pipeline_last_stage(): output_tensor = self.precision_plugin.convert_output(output_tensor) + self.callbacks.event( + "on_megatron_microbatch_end", + step=step, + batch=batch, + output=output_tensor, + forward_callback=forward_callback, + ) + return output_tensor, forward_callback return wrapped_forward_step_func - def to_data_iterator_list( - self, data: Union[DataT, Iterator[DataT], List[Iterator[DataT]]] - ) -> List[Iterator[DataT]]: - """ - Converts the provided data into a list of iterators. - - This method is used to convert the input data into a list of iterators that can be used - for data parallelism in the Megatron model. The input data can be a single data item, - an iterator, or a list of iterators. - - Args: - data (Union[DataT, Iterator[DataT], List[Iterator[DataT]]]): The input data to be - converted into a list of iterators. This can be a single data item, an iterator, - or a list of iterators. - - Returns - ------- - List[Iterator[DataT]]: A list of iterators created from the input data. - """ - if isinstance(data, Iterator): - return _make_data_iterator_list(self.pipeline, data) - elif isinstance(data, list) and all(isinstance(item, Iterator) for item in data): - # If data is already a list of iterators, return it as is - return cast(List[Iterator[DataT]], data) - - # For a single data item or any other type, wrap it in an iterator and return as a list - return cast(List[Iterator[DataT]], [iter([data])]) - - def infer_micro_batch_size(self, data: Union[DataT, Iterator[DataT], List[Iterator[DataT]]]) -> int: - """ - Infers the micro batch size from the provided data. - - This method attempts to infer the micro batch size by checking for specific attributes - in the data object. If the data object has a `micro_batch_size` attribute, it is returned. - If the data object has a `data_config` attribute with a `micro_batch_size` attribute, - it is returned. Otherwise, the method attempts to infer the micro batch size from the - first dimension of the data tensor, if the data is a tensor. If the data is a dictionary, - the method is called recursively on the first value of the dictionary. If the data is a - list or tuple with at least one element, the method is called recursively on the first - element. If none of these conditions are met, a ValueError is raised. - - Args: - data (Union[DataT, Iterator[DataT], List[Iterator[DataT]]]): The data to infer the - micro batch size from. - - Returns - ------- - int: The inferred micro batch size. - - Raises - ------ - ValueError: If the micro batch size cannot be inferred from the data. - """ - if hasattr(data, "micro_batch_size"): - return data.micro_batch_size - if hasattr(data, "data_config"): - return data.data_config.micro_batch_size - - if isinstance(data, Tensor): - return data.size(0) - elif isinstance(data, dict): - return self.infer_micro_batch_size(next(iter(data.values()))) - elif isinstance(data, (list, tuple)) and len(data) > 0: - _tensor: Tensor = data[0] - return self.infer_micro_batch_size(_tensor) - - raise ValueError("Cannot infer `micro_batch_size` from data, please specify it manually") - - def infer_seq_length(self, data: Union[DataT, Iterator[DataT], List[Iterator[DataT]]]) -> int: - if hasattr(data, "seq_length"): - return data.seq_length - if hasattr(data, "data_config"): - return data.data_config.seq_length - - if isinstance(data, Tensor): - # TODO: Check if at least 2 dims - return data.size(1) - elif isinstance(data, dict): - return self.infer_seq_length(next(iter(data.values()))) - elif isinstance(data, (list, tuple)) and len(data) > 0: - _tensor: Tensor = data[0] - return self.infer_seq_length(_tensor) - - raise ValueError("Cannot infer `seq_length` from data, please specify it manually") - - def infer_num_microbatches(self, data: Union[DataT, Iterator[DataT], List[Iterator[DataT]]]) -> int: - if hasattr(data, "num_microbatches"): - return data.num_microbatches - if hasattr(data, "data_config"): - return data.data_config.num_microbatches - - if isinstance(data, (dict, tuple, list, Tensor)): - return 1 - - raise ValueError("Cannot infer `num_microbatches` from data, please specify it manually") - def init_model_parallel(self): from megatron.core import parallel_state from megatron.core.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes @@ -564,27 +446,6 @@ def init_ddp(self): module.config.no_sync_func = no_sync_func module.config.grad_sync_func = grad_sync_func - def _build_context(self, context: Dict[str, Any]) -> Dict[str, Any]: - if "self" in context: - del context["self"] - context["pl_module"] = self - if hasattr(self, "trainer"): - context["trainer"] = self.trainer - - for val in [ - "data_step", - "forward_step", - "loss_reduction", - "micro_batch_size", - "seq_length", - "num_microbatches", - ]: - if "_" + val in context: - context[val] = context["_" + val] - del context["_" + val] - - return context - def _setup_module(self, function, **kwargs) -> None: if hasattr(function, "setup"): setup_args = inspect.getfullargspec(function.setup).args @@ -646,12 +507,6 @@ def pipeline(self) -> Union[ModelT, List[ModelT]]: def module(self) -> ModelT: return self[0] - @property - def forward_backward_func(self) -> "MegatronStepProtocol": - from megatron.core.pipeline_parallel.schedules import get_forward_backward_func - - return get_forward_backward_func() - @override def __getattr__(self, item: Any) -> Any: try: @@ -860,6 +715,39 @@ def event(self, name: str, *args, **kwargs) -> None: filtered_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters} callback_method(*filtered_args, **filtered_kwargs) + def transform_event(self, name: str, obj: T, **kwargs) -> T: + """ + Triggers an event that allows callbacks to transform and return an object. + + This method applies a series of potential transformations to the input object + by calling registered callbacks. Each callback has the opportunity to modify + and return a new version of the object. + + Parameters + ---------- + name : str + The name of the event to trigger. + obj : T + The object to be potentially transformed by callbacks. + **kwargs : Any + Additional keyword arguments to pass to the callbacks. + + Returns + ------- + T + The potentially transformed object. + """ + for callback in self.callbacks.get(name, []): + callback_method = getattr(callback, name, None) + if callable(callback_method): + result = callback_method(obj, **kwargs) + + # Update obj if the callback returned a value of the same type + if result is not None and isinstance(result, type(obj)): + obj = result + + return obj + def __add__(self, other) -> "CallbackConnector": """ Adds another CallbackConnector's callbacks to this one. @@ -945,22 +833,445 @@ def __contains__(self, callback_object) -> bool: return False +@dataclass +class MegatronStep(Generic[ModelT, DataT]): + """ + Represents a single step in the Megatron model's training or inference process. + + This class encapsulates all the necessary information and logic for executing + a single step (forward pass, and optionally backward pass) in the Megatron model. + It handles data preparation, model execution, and provides utilities for inferring + batch sizes and sequence lengths. + + Attributes: + pipeline (MegatronParallel[ModelT]): The Megatron parallel model pipeline. + data (Union[DataT, Iterator[DataT], List[Iterator[DataT]]]): Input data for the step. + forward_step_func (Callable): Function to perform the forward step. + forward_only (bool): If True, only perform forward pass (no backward pass). + micro_batch_size (Optional[int]): Size of each micro-batch. + seq_length (Optional[int]): Sequence length for the current step. + num_microbatches (Optional[int]): Number of micro-batches in this step. + + Type Parameters: + ModelT: The type of the model being used. + DataT: The type of the input data. + """ + + pipeline: MegatronParallel[ModelT] + data: Union[DataT, Iterator[DataT], List[Iterator[DataT]]] + forward_step_func: Callable + forward_only: bool + micro_batch_size: Optional[int] = None + seq_length: Optional[int] = None + num_microbatches: Optional[int] = None + + @classmethod + def infer( + cls, + pipeline: MegatronParallel[ModelT], + data: DataT, + forward_step_func: Callable, + forward_only: bool, + micro_batch_size: Optional[int] = None, + seq_length: Optional[int] = None, + num_microbatches: Optional[int] = None, + ) -> "MegatronStep[ModelT, DataT]": + """ + Creates a MegatronStep instance, inferring missing parameters if possible. + + This method attempts to infer the micro_batch_size, seq_length, and num_microbatches + from the provided data if they are not explicitly specified. + + Args: + pipeline (MegatronParallel[ModelT]): The Megatron parallel model pipeline. + data (DataT): Input data for the step. + forward_step_func (Callable): Function to perform the forward step. + forward_only (bool): If True, only perform forward pass (no backward pass). + micro_batch_size (Optional[int]): Size of each micro-batch. + seq_length (Optional[int]): Sequence length for the current step. + num_microbatches (Optional[int]): Number of micro-batches in this step. + + Returns: + MegatronStep[ModelT, DataT]: An instance of MegatronStep with inferred parameters. + """ + return cls( + pipeline=pipeline, + data=data, + forward_step_func=forward_step_func, + forward_only=forward_only, + micro_batch_size=micro_batch_size or cls.infer_micro_batch_size(data), + seq_length=seq_length or cls.infer_seq_length(data), + num_microbatches=num_microbatches or cls.infer_num_microbatches(data), + ) + + def __call__(self) -> List[Any]: + """ + Executes the Megatron step. + + This method performs the forward (and optionally backward) pass using the + configured forward_backward_func. It ensures all necessary parameters are set + before execution. + + Returns: + List[Any]: The output of the forward_backward_func, typically containing + loss values and other relevant information. + + Raises: + ValueError: If any of num_microbatches, seq_length, or micro_batch_size is not set. + """ + if self.num_microbatches is None: + raise ValueError("num_microbatches is not set") + + if self.seq_length is None: + raise ValueError("seq_length is not set") + + if self.micro_batch_size is None: + raise ValueError("micro_batch_size is not set") + + return self.forward_backward_func( + forward_step_func=self.forward_step_func, + data_iterator=self.data_iterator, + model=self.model, + num_microbatches=self.num_microbatches, + seq_length=self.seq_length, + micro_batch_size=self.micro_batch_size, + forward_only=self.forward_only, + ) + + def to_data_iterator_list( + self, data: Union[DataT, Iterator[DataT], List[Iterator[DataT]]] + ) -> List[Iterator[DataT]]: + """ + Converts the provided data into a list of iterators. + + This method is used to convert the input data into a list of iterators that can be used + for data parallelism in the Megatron model. The input data can be a single data item, + an iterator, or a list of iterators. + + Args: + data (Union[DataT, Iterator[DataT], List[Iterator[DataT]]]): The input data to be + converted into a list of iterators. + + Returns: + List[Iterator[DataT]]: A list of iterators created from the input data. + """ + if isinstance(data, Iterator): + return _make_data_iterator_list(self.pipeline, data) + elif isinstance(data, list) and all(isinstance(item, Iterator) for item in data): + # If data is already a list of iterators, return it as is + return cast(List[Iterator[DataT]], data) + + # For a single data item or any other type, wrap it in an iterator and return as a list + return cast(List[Iterator[DataT]], [iter([data])]) + + @classmethod + def infer_micro_batch_size(cls, data: DataT) -> Optional[int]: + """ + Infers the micro-batch size from the input data. + + This method attempts to determine the micro-batch size by examining the first + dimension of the input data. It handles various data types including Tensors, + dictionaries, lists, and tuples. + + Args: + data (DataT): The input data from which to infer the micro-batch size. + + Returns: + Optional[int]: The inferred micro-batch size, or None if it cannot be determined. + """ + if isinstance(data, Tensor): + return data.size(0) + elif isinstance(data, dict): + return cls.infer_micro_batch_size(next(iter(data.values()))) + elif isinstance(data, (list, tuple)) and len(data) > 0: + _tensor: Tensor = data[0] + return cls.infer_micro_batch_size(_tensor) + + return None + + @classmethod + def infer_seq_length(cls, data: DataT) -> Optional[int]: + """ + Infers the sequence length from the input data. + + This method attempts to determine the sequence length by examining the second + dimension of the input data. It handles various data types including Tensors, + dictionaries, lists, and tuples. + + Args: + data (DataT): The input data from which to infer the sequence length. + + Returns: + Optional[int]: The inferred sequence length, or None if it cannot be determined. + """ + if isinstance(data, Tensor): + # TODO: Check if at least 2 dims + return data.size(1) + elif isinstance(data, dict): + return cls.infer_seq_length(next(iter(data.values()))) + elif isinstance(data, (list, tuple)) and len(data) > 0: + _tensor: Tensor = data[0] + return cls.infer_seq_length(_tensor) + + return None + + @classmethod + def infer_num_microbatches(cls, data: DataT) -> Optional[int]: + """ + Infers the number of micro-batches from the input data. + + Currently, this method assumes a single micro-batch for common data types. + It may need to be extended for more complex data structures or use cases. + + Args: + data (DataT): The input data from which to infer the number of micro-batches. + + Returns: + Optional[int]: The inferred number of micro-batches, or None if it cannot be determined. + """ + if isinstance(data, (dict, tuple, list, Tensor)): + return 1 + + return None + + @property + def model(self) -> Union[ModelT, List[ModelT]]: + """ + Retrieves the model or list of models from the pipeline. + + Returns: + Union[ModelT, List[ModelT]]: The model or list of models in the pipeline. + """ + return self.pipeline.pipeline + + @property + def pl_module(self) -> pl.LightningModule: + """ + Retrieves the PyTorch Lightning module from the pipeline. + + Returns: + pl.LightningModule: The PyTorch Lightning module. + """ + return self.pipeline.module + + @property + def trainer(self) -> pl.Trainer: + """ + Retrieves the PyTorch Lightning trainer from the pipeline. + + Returns: + pl.Trainer: The PyTorch Lightning trainer. + """ + return self.pipeline.trainer + + @functools.cached_property + def forward_backward_func(self) -> "MegatronStepProtocol": + """ + Retrieves the forward-backward function for the Megatron model. + + This property uses Megatron's scheduling to get the appropriate + forward-backward function based on the current configuration. + + Returns: + MegatronStepProtocol: The function to perform forward and backward passes. + """ + from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + + return get_forward_backward_func() + + @functools.cached_property + def data_iterator(self) -> List[Iterator[DataT]]: + """ + Cached property that converts the provided data into a list of iterators. + + This property ensures that the data is converted to the required format + only once and then cached for subsequent uses. + + Returns: + List[Iterator[DataT]]: A list of iterators created from the input data. + """ + if self.has_global_batch_sampler: + batch = next(self.data) + if isinstance(batch, tuple) and len(batch) == 3: + batch = batch[0] + from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split + + data = get_iterator_k_split(batch, self.num_microbatches, True) + else: + data = self.data + return self.to_data_iterator_list(data) + + @functools.cached_property + def has_global_batch_sampler(self) -> bool: + # FIXME: cleanup the following code is here for backwards compatibility with nemo1. + # The "batch" sampler is a nemo1 sampler. It requires some custom code here to use + # (if use_global_batch_sampler), by default we shouldn't use this "batch" sampler probably. + if getattr(self.trainer, "datamodule", None) is not None: + use_global_batch_sampler = self.trainer.datamodule.data_sampler.dataloader_type == 'batch' + elif getattr(self.trainer, "predict_dataloaders", None) is not None: + from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import ( # noqa: I001 + MegatronPretrainingBatchSampler, + ) + + # The batch_sampler gets injected into the dataloader by the data_sampler. When doing + # predict without a datamodule we can look inside the dataloader's batch_sampler to see + # if it is the nemo1 style sampler that we need to handle specially below. + use_global_batch_sampler = isinstance( + self.trainer.predict_dataloaders.batch_sampler, MegatronPretrainingBatchSampler + ) + else: + use_global_batch_sampler = False + return use_global_batch_sampler + + class CallbackMethods: - def on_megatron_step_start(self, *args, **kwargs) -> None: ... + """ + Defines callback methods for various stages of the Megatron model's execution. + + This class outlines the structure for callbacks that can be implemented to hook into + different phases of the Megatron model's training or inference process. Each method + represents a specific point in the execution where custom logic can be inserted. + """ + + def on_megatron_step_start(self, step: MegatronStep) -> MegatronStep: + """ + Called at the beginning of each Megatron step. - def on_megatron_microbatch_start(self, *args, **kwargs) -> None: ... + This method is invoked before any processing of the step begins. It allows for + any necessary setup or initialization for the step. - def on_megatron_microbatch_callback(self, *args, **kwargs) -> None: ... + Args: + step (MegatronStep): The MegatronStep object representing the current step. - def on_megatron_microbatch_end(self, *args, **kwargs) -> None: ... + Returns: + MegatronStep: The potentially modified MegatronStep object. + """ + ... - def on_megatron_reduce_microbatches_start(self, *args, **kwargs) -> None: ... + def on_megatron_microbatches_start(self, step: MegatronStep) -> None: + """ + Called before processing of microbatches begins. - def on_megatron_reduce_microbatches_end(self, *args, **kwargs) -> None: ... + This method is invoked just before the model starts processing the microbatches + within a step. It can be used for any preparations needed before microbatch processing. - def on_megatron_log_step_end(self, *args, **kwargs) -> None: ... + Args: + step (MegatronStep): The MegatronStep object representing the current step. + """ + ... + + def on_megatron_microbatch_start( + self, + step: MegatronStep, + batch: DataT, + forward_callback: "MegatronLossReduction", + ) -> None: + """ + Called at the start of processing each microbatch. + + This method is invoked before the forward pass of each microbatch. It provides + access to the current batch data and the loss reduction callback. + + Args: + step (MegatronStep): The MegatronStep object representing the current step. + batch (DataT): The current microbatch of data being processed. + forward_callback (MegatronLossReduction): The callback for loss reduction. + """ + ... - def on_megatron_step_end(self, *args, **kwargs) -> None: ... + def on_megatron_microbatch_end( + self, + step: MegatronStep, + batch: DataT, + forward_callback: "MegatronLossReduction", + output: Any, + ) -> None: + """ + Called at the end of processing each microbatch. + + This method is invoked after the forward pass of each microbatch. It provides + access to the processed batch, the loss reduction callback, and the output of the forward pass. + + Args: + step (MegatronStep): The MegatronStep object representing the current step. + batch (DataT): The microbatch of data that was processed. + forward_callback (MegatronLossReduction): The callback for loss reduction. + output (Any): The output from the forward pass for this microbatch. + """ + ... + + def on_megatron_microbatches_end(self, step: MegatronStep, microbatch_outputs: List[Any]) -> None: + """ + Called after all microbatches in a step have been processed. + + This method is invoked once all microbatches within a step have been processed. + It provides access to the outputs from all microbatches. + + Args: + step (MegatronStep): The MegatronStep object representing the current step. + microbatch_outputs (List[Any]): A list of outputs from all processed microbatches. + """ + ... + + def on_megatron_reduce_microbatches_start( + self, + step: MegatronStep, + microbatch_outputs: List[Any], + ) -> None: + """ + Called before the reduction of microbatch outputs begins. + + This method is invoked just before the model starts reducing (e.g., averaging) + the outputs from all microbatches. It can be used for any preparations needed + before the reduction process. + + Args: + step (MegatronStep): The MegatronStep object representing the current step. + microbatch_outputs (List[Any]): A list of outputs from all processed microbatches. + """ + ... + + def on_megatron_reduce_microbatches_end( + self, + step: MegatronStep, + microbatch_outputs: List[Any], + loss_reduction: "MegatronLossReduction", + reduced: Union[torch.Tensor, Dict[str, torch.Tensor]], + ) -> None: + """ + Called after the reduction of microbatch outputs is complete. + + This method is invoked after the model has finished reducing the outputs from + all microbatches. It provides access to the original microbatch outputs, + the loss reduction object, and the final reduced output. + + Args: + step (MegatronStep): The MegatronStep object representing the current step. + microbatch_outputs (List[Any]): A list of outputs from all processed microbatches. + loss_reduction (MegatronLossReduction): The object used for loss reduction. + reduced (Union[torch.Tensor, Dict[str, torch.Tensor]]): The final reduced output. + """ + ... + + def on_megatron_step_end( + self, + step: MegatronStep, + microbatch_outputs: List[Any], + reduced: Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]] = None, + ) -> None: + """ + Called at the end of each Megatron step. + + This method is invoked after all processing for a step is complete. It provides + access to the outputs from all microbatches and the final reduced output (if available). + + Args: + step (MegatronStep): The MegatronStep object representing the current step. + microbatch_outputs (List[Any]): A list of outputs from all processed microbatches. + reduced (Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]]): The final reduced + output, if available. This may be None for certain configurations or pipeline stages. + """ + ... ReductionT = TypeVar("ReductionT") diff --git a/nemo/lightning/pytorch/plugins/data_sampler.py b/nemo/lightning/pytorch/plugins/data_sampler.py index 060ec7915ec0..4fadae8dc722 100644 --- a/nemo/lightning/pytorch/plugins/data_sampler.py +++ b/nemo/lightning/pytorch/plugins/data_sampler.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses import logging -from typing import Any, Dict, List, Literal, Optional +from typing import List, Literal, Optional import pytorch_lightning as pl from torch.utils.data import DataLoader +from nemo.lightning.megatron_parallel import MegatronStep + class DataSampler: def connect(self, trainer: pl.Trainer): @@ -91,16 +94,28 @@ def compute_consumed_samples(self, steps_since_resume=0) -> int: return int(consumed_samples) # Megatron callbacks - def on_megatron_step_start(self, trainer: pl.Trainer) -> None: + + def on_megatron_step_start(self, step: MegatronStep) -> MegatronStep: + return dataclasses.replace( + step, + seq_length=self.seq_len, + micro_batch_size=self.micro_batch_size, + num_microbatches=self.num_microbatches, + ) + + def on_megatron_microbatches_start(self, step: MegatronStep) -> None: # do validation and save the checkpoint when gbs is changed if ( self.rampup_batch_size is not None and self.prev_global_batch_size != self.current_global_batch_size and self.prev_global_batch_size ): - trainer.should_stop = True + step.trainer.should_stop = True + + def on_megatron_step_end(self, step: MegatronStep) -> None: + trainer = step.trainer + pl_module = step.pl_module - def on_megatron_step_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: try: from megatron.core.num_microbatches_calculator import update_num_microbatches @@ -136,14 +151,6 @@ def on_megatron_step_end(self, trainer: pl.Trainer, pl_module: pl.LightningModul ) self.if_first_step = 1 - @property - def megatron_data_kwargs(self) -> Dict[str, Any]: - return { - "seq_length": self.seq_len, - "micro_batch_size": self.micro_batch_size, - "num_microbatches": self.num_microbatches, - } - @property def num_microbatches(self) -> int: try: diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index 841bec6ab731..6c0d7c8f6b04 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -411,7 +411,9 @@ def setup_megatron_parallel(self, trainer: pl.Trainer) -> None: self.setup_optimizers(trainer) self.model = self.megatron_parallel - self.model.callbacks.add(getattr(trainer, "callbacks")) + trainer_callbacks = getattr(trainer, "callbacks", None) + if trainer_callbacks: + self.model.callbacks.add(*trainer_callbacks) if self.data_sampler: self.model.callbacks.add(self.data_sampler) @@ -480,6 +482,17 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP out = self.model(dataloader_iter, forward_only=False, *args, **kwargs) + if torch.is_tensor(out): + reduced_train_loss = out + else: + if not isinstance(out, dict): + raise ValueError(f"Expected dict or tensor for reduced_train_loss, got {type(out)}") + + if "loss" not in out: + raise ValueError(f"Expected 'loss' in output dict, got {out.keys()}") + + reduced_train_loss = out["loss"] + self.lightning_module.log( "global_step", self.trainer.global_step, @@ -511,8 +524,10 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP if self.log_train_loss: # p2p now, broadcast later at ckpt. only with pp, some ranks will log 0.0 # WHICH IS OK because we broadcast later at checkpoint time - _strategy_lib._sync_from_last_pipeline_stage(out, broadcast=False) - self.lightning_module.log("reduced_train_loss", out, prog_bar=True, batch_size=1, sync_dist=False) + _strategy_lib._sync_from_last_pipeline_stage(reduced_train_loss, broadcast=False) + self.lightning_module.log( + "reduced_train_loss", reduced_train_loss, prog_bar=True, batch_size=1, sync_dist=False + ) return out @@ -601,7 +616,6 @@ def _update_step_kwargs(self, dataloader_iter, kwargs, step_name: str): kwargs["forward_step"] = self._get_forward_step(step_name) if "loss_reduction" not in kwargs: kwargs["loss_reduction"] = self._get_loss_reduction(step_name) - kwargs.update(self._data_config_kwargs(dataloader_iter)) return kwargs @@ -781,13 +795,6 @@ def _get_loss_reduction(self, step_type: str) -> Optional[_ModuleStepFunction]: return None - def _data_config_kwargs(self, dataloader_iter) -> Dict[str, Any]: - if not hasattr(dataloader_iter, "data_config") and self.data_sampler: - if hasattr(self.data_sampler, "megatron_data_kwargs"): - return self.data_sampler.megatron_data_kwargs - - return {} - @property def distributed_sampler_kwargs(self) -> Dict[str, Any]: from nemo.utils import AppState