diff --git a/docs/source/_static/img/torchtune_workspace.png b/docs/source/_static/img/torchtune_workspace.png new file mode 100644 index 0000000000..4a94a53b72 Binary files /dev/null and b/docs/source/_static/img/torchtune_workspace.png differ diff --git a/docs/source/examples/wandb_logging.rst b/docs/source/examples/wandb_logging.rst new file mode 100644 index 0000000000..1faa9e1c26 --- /dev/null +++ b/docs/source/examples/wandb_logging.rst @@ -0,0 +1,85 @@ +.. _wandb_logging: + +=========================== +Logging to Weights & Biases +=========================== + +.. customcarditem:: + :header: Logging to Weights & Biases + :card_description: Log metrics and model checkpoints to W&B + :image: _static/img/torchtune_workspace.png + :link: examples/wandb_logging.html + :tags: logging,wandb + + +Torchtune supports logging your training runs to [Weights & Biases](https://wandb.ai). + +.. note:: + + You will need to install the `wandb` package to use this feature. + You can install it via pip: + + .. code-block:: bash + + pip install wandb + + Then you need to login with your API key using the W&B CLI: + + .. code-block:: bash + + wandb login + + +Metric Logger +------------- + +The only change you need to make is to add the metric logger to your config. Weights & Biases will log the metrics and model checkpoints for you. + +.. code-block:: yaml + + # enable logging to the built-in WandBLogger + metric_logger: + _component_: torchtune.utils.metric_logging.WandBLogger + # the W&B project to log to + project: torchtune + + +We automatically grab the config from the recipe you are running and log it to W&B. You can find it in the W&B overview tab and the actual file in the `Files` tab. + +.. note:: + + Click on this sample [project to see the W&B workspace](https://wandb.ai/capecape/torchtune) + The config used to train the models can be found [here](https://wandb.ai/capecape/torchtune/runs/6053ofw0/files/torchtune_config_j67sb73v.yaml) + +Logging Model Checkpoints to W&B +-------------------------------- + +You can also log the model checkpoints to W&B by modifying the desired script `save_checkpoint` method. + +A suggested approach would be something like this: + +.. code-block:: python + + def save_checkpoint(self, epoch: int) -> None: + ... + ## Let's save the checkpoint to W&B + ## depending on the Checkpointer Class the file will be named differently + ## Here is an example for the full_finetune case + checkpoint_file = Path.joinpath( + self._checkpointer._output_dir, f"torchtune_model_{epoch}" + ).with_suffix(".pt") + wandb_at = wandb.Artifact( + name=f"torchtune_model_{epoch}", + type="model", + # description of the model checkpoint + description="Model checkpoint", + # you can add whatever metadata you want as a dict + metadata={ + utils.SEED_KEY: self.seed, + utils.EPOCHS_KEY: self.epochs_run, + utils.TOTAL_EPOCHS_KEY: self.total_epochs, + utils.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + wandb_at.add_file(checkpoint_file) + wandb.log_artifact(wandb_at) diff --git a/docs/source/index.rst b/docs/source/index.rst index b9fdef3c6b..9e16acd5d0 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -107,6 +107,7 @@ TorchTune tutorials. examples/checkpointer examples/configs examples/recipe_deepdive + examples/wandb_logging .. toctree:: :glob: diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 75e513defd..28b882ea85 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -147,7 +147,11 @@ def setup(self, cfg: DictConfig) -> None: Sets up the recipe state correctly. This includes setting recipe attributes based on the ``resume_from_checkpoint`` flag. """ - self._metric_logger = config.instantiate(cfg.metric_logger) + if self._is_rank_zero: + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) ckpt_dict = self.load_checkpoint(cfg.checkpointer) @@ -267,11 +271,8 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerDecoderLayer} ) if self._is_rank_zero: - log.info( - utils.memory_stats_log( - "Memory Stats after model init", device=self._device - ) - ) + memory_stats = utils.memory_stats_log(device=self._device) + log.info(f"Memory Stats after model init:\n{memory_stats}") # synchronize before training begins torch.distributed.barrier() @@ -451,15 +452,18 @@ def train(self) -> None: self.total_training_steps % self._log_peak_memory_every_n_steps == 0 and self._is_rank_zero ): - log.info( - utils.memory_stats_log("Memory Stats", device=self._device) + # Log peak memory for iteration + memory_stats = utils.memory_stats_log(device=self._device) + self._metric_logger.log_dict( + memory_stats, step=self.total_training_steps ) self.epochs_run += 1 self.save_checkpoint(epoch=curr_epoch) def cleanup(self) -> None: - self._metric_logger.close() + if self._is_rank_zero: + self._metric_logger.close() torch.distributed.destroy_process_group() diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index 3658fa8205..ed530694e5 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -150,6 +150,9 @@ def setup(self, cfg: DictConfig) -> None: """ self._metric_logger = config.instantiate(cfg.metric_logger) + # log config with parameter override + self._metric_logger.log_config(cfg) + ckpt_dict = self.load_checkpoint(cfg.checkpointer) # ``_setup_model`` handles initialization and loading the state dict. This method @@ -231,11 +234,9 @@ def _setup_model( if compile_model: log.info("Compiling model with torch.compile...") model = utils.wrap_compile(model) - log.info( - utils.memory_stats_log( - "Memory Stats after model init:", device=self._device - ) - ) + if self._device == torch.device("cuda"): + memory_stats = utils.memory_stats_log(device=self._device) + log.info(f"Memory Stats after model init:\n{memory_stats}") return model def _setup_optimizer( @@ -414,9 +415,13 @@ def train(self) -> None: self.total_training_steps += 1 # Log peak memory for iteration - if self.total_training_steps % self._log_peak_memory_every_n_steps == 0: - log.info( - utils.memory_stats_log("Memory Stats:", device=self._device) + if ( + self.total_training_steps % self._log_peak_memory_every_n_steps == 0 + and self._device == torch.device("cuda") + ): + memory_stats = utils.memory_stats_log(device=self._device) + self._metric_logger.log_dict( + memory_stats, step=self.total_training_steps ) self.epochs_run += 1 self.save_checkpoint(epoch=curr_epoch) diff --git a/recipes/gemma_full_finetune_distributed.py b/recipes/gemma_full_finetune_distributed.py index cd918a9ebc..5114fefb9c 100644 --- a/recipes/gemma_full_finetune_distributed.py +++ b/recipes/gemma_full_finetune_distributed.py @@ -146,7 +146,11 @@ def setup(self, cfg: DictConfig) -> None: Sets up the recipe state correctly. This includes setting recipe attributes based on the ``resume_from_checkpoint`` flag. """ - self._metric_logger = config.instantiate(cfg.metric_logger) + if self._is_rank_zero: + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) ckpt_dict = self.load_checkpoint(cfg.checkpointer) @@ -263,12 +267,8 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerDecoderLayer} ) if self._is_rank_zero: - log.info( - utils.memory_stats_log( - "Memory Stats after model init", device=self._device - ) - ) - + memory_stats = utils.memory_stats_log(device=self._device) + log.info(f"Memory Stats after model init:\n{memory_stats}") # synchronize before training begins torch.distributed.barrier() @@ -458,15 +458,18 @@ def train(self) -> None: self.total_training_steps % self._log_peak_memory_every_n_steps == 0 and self._is_rank_zero ): - log.info( - utils.memory_stats_log("Memory Stats", device=self._device) + # Log peak memory for iteration + memory_stats = utils.memory_stats_log(device=self._device) + self._metric_logger.log_dict( + memory_stats, step=self.total_training_steps ) self.epochs_run += 1 self.save_checkpoint(epoch=curr_epoch) def cleanup(self) -> None: - self._metric_logger.close() + if self._is_rank_zero: + self._metric_logger.close() torch.distributed.destroy_process_group() diff --git a/recipes/lora_dpo_single_device.py b/recipes/lora_dpo_single_device.py index 2c501f069b..d89fabb1ce 100644 --- a/recipes/lora_dpo_single_device.py +++ b/recipes/lora_dpo_single_device.py @@ -148,6 +148,9 @@ def setup(self, cfg: DictConfig) -> None: """ self._metric_logger = config.instantiate(cfg.metric_logger) + # log config with parameter override + self._metric_logger.log_config(cfg) + checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) self._model = self._setup_model( @@ -252,11 +255,9 @@ def _setup_model( ) log.info(f"Model is initialized with precision {self._dtype}.") - log.info( - utils.memory_stats_log( - "Memory Stats after model init:", device=self._device - ) - ) + if self._device == torch.device("cuda"): + memory_stats = utils.memory_stats_log(device=self._device) + log.info(f"Memory Stats after model init:\n{memory_stats}") return model def _setup_optimizer( @@ -490,9 +491,14 @@ def train(self) -> None: # Update the number of steps when the weights are updated self.total_training_steps += 1 # Log peak memory for iteration - if self.total_training_steps % self._log_peak_memory_every_n_steps == 0: - log.info( - utils.memory_stats_log("Memory Stats:", device=self._device) + if ( + self.total_training_steps % self._log_peak_memory_every_n_steps == 0 + and self._device == torch.device("cuda") + ): + # Log peak memory for iteration + memory_stats = utils.memory_stats_log(device=self._device) + self._metric_logger.log_dict( + memory_stats, step=self.total_training_steps ) self.epochs_run += 1 self.save_checkpoint(epoch=curr_epoch) diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index ec89ee424a..0b6c76e825 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -168,6 +168,9 @@ def setup(self, cfg: DictConfig) -> None: if self._is_rank_zero: self._metric_logger = config.instantiate(cfg.metric_logger) + # log config with parameter override + self._metric_logger.log_config(cfg) + checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) self._model = self._setup_model( @@ -324,11 +327,8 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerDecoderLayer} ) if self._is_rank_zero: - log.info( - utils.memory_stats_log( - "Memory Stats after model init:", device=self._device - ) - ) + memory_stats = utils.memory_stats_log(device=self._device) + log.info(f"Memory Stats after model init:\n{memory_stats}") # synchronize before training begins torch.distributed.barrier() @@ -542,8 +542,10 @@ def train(self) -> None: self.total_training_steps % self._log_peak_memory_every_n_steps == 0 and self._is_rank_zero ): - log.info( - utils.memory_stats_log("Memory Stats:", device=self._device) + # Log peak memory for iteration + memory_stats = utils.memory_stats_log(device=self._device) + self._metric_logger.log_dict( + memory_stats, step=self.total_training_steps ) self.epochs_run += 1 diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 197b166655..cb41293d5b 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -146,6 +146,10 @@ def setup(self, cfg: DictConfig) -> None: model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. """ self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + self._model_compile = cfg.compile checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) @@ -263,11 +267,9 @@ def _setup_model( if compile_model: log.info("Compiling model with torch.compile...") model = utils.wrap_compile(model) - log.info( - utils.memory_stats_log( - "Memory Stats after model init:", device=self._device - ) - ) + if self._device == torch.device("cuda"): + memory_stats = utils.memory_stats_log(device=self._device) + log.info(f"Memory Stats after model init:\n{memory_stats}") return model def _setup_optimizer( @@ -446,9 +448,12 @@ def train(self) -> None: if ( self.total_training_steps % self._log_peak_memory_every_n_steps == 0 + and self._device == torch.device("cuda") ): - log.info( - utils.memory_stats_log("Memory Stats:", device=self._device) + # Log peak memory for iteration + memory_stats = utils.memory_stats_log(device=self._device) + self._metric_logger.log_dict( + memory_stats, step=self.total_training_steps ) self.epochs_run += 1 self.save_checkpoint(epoch=curr_epoch) diff --git a/torchtune/utils/memory.py b/torchtune/utils/memory.py index 5713c73e98..a767ea51c6 100644 --- a/torchtune/utils/memory.py +++ b/torchtune/utils/memory.py @@ -160,37 +160,39 @@ def optim_step(param) -> None: p.register_post_accumulate_grad_hook(optim_step) -def memory_stats_log( - prefix: str, device: torch.device, reset_stats: bool = True -) -> None: +def memory_stats_log(device: torch.device, reset_stats: bool = True) -> dict: """ - Print a memory summary for the passed in device. If ``reset_stats`` is ``True``, this will + Computes a memory summary for the passed in device. If ``reset_stats`` is ``True``, this will also reset CUDA's peak memory tracking. This is useful to get data around relative use of peak memory (i.e. peak memory during model init, during forward, etc) and optimize memory for individual sections of training. Args: - prefix (str): Prefix to prepend to the printed summary. device (torch.device): Device to get memory summary for. Only CUDA devices are supported. reset_stats (bool): Whether to reset CUDA's peak memory tracking. Returns: - None + Dict[str, float]: A dictionary containing the peak memory active, peak memory allocated, + and peak memory reserved. This dict is useful for logging memory stats. + + Raises: + ValueError: If the passed in device is not CUDA. """ if device.type != "cuda": - return + raise ValueError( + f"Logging memory stats is only supported on CUDA devices, got {device}" + ) + peak_memory_active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1e9 peak_mem_alloc = torch.cuda.max_memory_allocated(device) / 1e9 peak_mem_reserved = torch.cuda.max_memory_reserved(device) / 1e9 - ret = f""" - {prefix}: - GPU peak memory allocation: {peak_mem_alloc:.2f} GB - GPU peak memory reserved: {peak_mem_reserved:.2f} GB - GPU peak memory active: {peak_memory_active:.2f} GB - """ - if reset_stats: torch.cuda.reset_peak_memory_stats(device) - return ret + memory_stats = { + "peak_memory_active": peak_memory_active, + "peak_memory_alloc": peak_mem_alloc, + "peak_memory_reserved": peak_mem_reserved, + } + return memory_stats diff --git a/torchtune/utils/metric_logging.py b/torchtune/utils/metric_logging.py index 4366a54a7a..59b862303f 100644 --- a/torchtune/utils/metric_logging.py +++ b/torchtune/utils/metric_logging.py @@ -11,13 +11,17 @@ from typing import Mapping, Optional, Union from numpy import ndarray +from omegaconf import DictConfig, OmegaConf from torch import Tensor +from torchtune.utils import get_logger from torchtune.utils._distributed import get_world_size_and_rank from typing_extensions import Protocol Scalar = Union[Tensor, ndarray, int, float] +log = get_logger("DEBUG") + class MetricLoggerInterface(Protocol): """Abstract metric logger.""" @@ -37,6 +41,14 @@ def log( """ pass + def log_config(self, config: DictConfig) -> None: + """Logs the config + + Args: + config (DictConfig): config to log + """ + pass + def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: """Log multiple scalar values. @@ -147,7 +159,7 @@ class WandBLogger(MetricLoggerInterface): def __init__( self, - project: str, + project: str = "torchtune", entity: Optional[str] = None, group: Optional[str] = None, **kwargs, @@ -160,26 +172,63 @@ def __init__( "Alternatively, use the ``StdoutLogger``, which can be specified by setting metric_logger_type='stdout'." ) from e self._wandb = wandb - self._wandb.init( - project=project, - entity=entity, - group=group, - reinit=True, - resume="allow", - config=kwargs, - ) + + _, self.rank = get_world_size_and_rank() + + if self.rank == 0: + self._wandb.init( + project=project, + entity=entity, + group=group, + reinit=True, + resume="allow", + **kwargs, + ) + + def log_config(self, config: DictConfig) -> None: + """Saves the config locally and also logs the config to W&B. The config is + stored in the same directory as the checkpoint. You can + see an example of the logged config to W&B in the following link: + https://wandb.ai/capecape/torchtune/runs/6053ofw0/files/torchtune_config_j67sb73v.yaml + + Args: + config (DictConfig): config to log + """ + if self._wandb.run: + resolved = OmegaConf.to_container(config, resolve=True) + self._wandb.config.update(resolved) + + output_config_fname = Path( + os.path.join( + config.checkpointer.checkpoint_dir, + f"torchtune_config_{self._wandb.run.id}.yaml", + ) + ) + OmegaConf.save(config, output_config_fname) + try: + log.info(f"Logging {output_config_fname} to W&B under Files") + self._wandb.save( + output_config_fname, base_path=output_config_fname.parent + ) + + except Exception as e: + log.warning(f"Error saving {output_config_fname} to W&B.\nError: \n{e}") def log(self, name: str, data: Scalar, step: int) -> None: - self._wandb.log({name: data}, step=step) + if self._wandb.run: + self._wandb.log({name: data}, step=step) def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: - self._wandb.log(payload, step=step) + if self._wandb.run: + self._wandb.log(payload, step=step) def __del__(self) -> None: - self._wandb.finish() + if self._wandb.run: + self._wandb.finish() def close(self) -> None: - self._wandb.finish() + if self._wandb.run: + self._wandb.finish() class TensorBoardLogger(MetricLoggerInterface):