Skip to content

Commit

Permalink
Improve Wandb experience (#660)
Browse files Browse the repository at this point in the history
Co-authored-by: Thomas Capelle <thomas.capelle@steady-sun.com>
Co-authored-by: Kartikay Khandelwal <47255723+kartikayk@users.noreply.github.com>
Co-authored-by: yechenzhi <136920488@qq.com>
Co-authored-by: Rafi Ayub <33648637+RdoubleA@users.noreply.github.com>
Co-authored-by: Rohan Varma <rvarm1@fb.com>
Co-authored-by: Joe Cummings <jrcummings27@gmail.com>
Co-authored-by: Botao Chen <markchen1015@meta.com>
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Co-authored-by: ebsmothers <ebs@meta.com>
Co-authored-by: Mengtao Yuan <mengtaoyuan1@gmail.com>
Co-authored-by: solitude-alive <44771751+solitude-alive@users.noreply.github.com>
Co-authored-by: Jerry Zhang <jerryzh168@gmail.com>
Co-authored-by: RdoubleA <rafiayub@fb.com>
  • Loading branch information
14 people authored Apr 15, 2024
1 parent 053d0ae commit 5402b29
Show file tree
Hide file tree
Showing 11 changed files with 239 additions and 77 deletions.
Binary file added docs/source/_static/img/torchtune_workspace.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
85 changes: 85 additions & 0 deletions docs/source/examples/wandb_logging.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
.. _wandb_logging:

===========================
Logging to Weights & Biases
===========================

.. customcarditem::
:header: Logging to Weights & Biases
:card_description: Log metrics and model checkpoints to W&B
:image: _static/img/torchtune_workspace.png
:link: examples/wandb_logging.html
:tags: logging,wandb


Torchtune supports logging your training runs to [Weights & Biases](https://wandb.ai).

.. note::

You will need to install the `wandb` package to use this feature.
You can install it via pip:

.. code-block:: bash
pip install wandb
Then you need to login with your API key using the W&B CLI:

.. code-block:: bash
wandb login
Metric Logger
-------------

The only change you need to make is to add the metric logger to your config. Weights & Biases will log the metrics and model checkpoints for you.

.. code-block:: yaml
# enable logging to the built-in WandBLogger
metric_logger:
_component_: torchtune.utils.metric_logging.WandBLogger
# the W&B project to log to
project: torchtune
We automatically grab the config from the recipe you are running and log it to W&B. You can find it in the W&B overview tab and the actual file in the `Files` tab.

.. note::

Click on this sample [project to see the W&B workspace](https://wandb.ai/capecape/torchtune)
The config used to train the models can be found [here](https://wandb.ai/capecape/torchtune/runs/6053ofw0/files/torchtune_config_j67sb73v.yaml)

Logging Model Checkpoints to W&B
--------------------------------

You can also log the model checkpoints to W&B by modifying the desired script `save_checkpoint` method.

A suggested approach would be something like this:

.. code-block:: python
def save_checkpoint(self, epoch: int) -> None:
...
## Let's save the checkpoint to W&B
## depending on the Checkpointer Class the file will be named differently
## Here is an example for the full_finetune case
checkpoint_file = Path.joinpath(
self._checkpointer._output_dir, f"torchtune_model_{epoch}"
).with_suffix(".pt")
wandb_at = wandb.Artifact(
name=f"torchtune_model_{epoch}",
type="model",
# description of the model checkpoint
description="Model checkpoint",
# you can add whatever metadata you want as a dict
metadata={
utils.SEED_KEY: self.seed,
utils.EPOCHS_KEY: self.epochs_run,
utils.TOTAL_EPOCHS_KEY: self.total_epochs,
utils.MAX_STEPS_KEY: self.max_steps_per_epoch,
}
)
wandb_at.add_file(checkpoint_file)
wandb.log_artifact(wandb_at)
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ torchtune tutorials.
examples/checkpointer
examples/configs
examples/recipe_deepdive
examples/wandb_logging

.. toctree::
:glob:
Expand Down
22 changes: 13 additions & 9 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,11 @@ def setup(self, cfg: DictConfig) -> None:
Sets up the recipe state correctly. This includes setting recipe attributes based
on the ``resume_from_checkpoint`` flag.
"""
self._metric_logger = config.instantiate(cfg.metric_logger)
if self._is_rank_zero:
self._metric_logger = config.instantiate(cfg.metric_logger)

# log config with parameter override
self._metric_logger.log_config(cfg)

ckpt_dict = self.load_checkpoint(cfg.checkpointer)

Expand Down Expand Up @@ -291,11 +295,8 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerDecoderLayer}
)
if self._is_rank_zero:
log.info(
utils.memory_stats_log(
"Memory Stats after model init", device=self._device
)
)
memory_stats = utils.memory_stats_log(device=self._device)
log.info(f"Memory Stats after model init:\n{memory_stats}")

# synchronize before training begins
torch.distributed.barrier()
Expand Down Expand Up @@ -475,15 +476,18 @@ def train(self) -> None:
self.total_training_steps % self._log_peak_memory_every_n_steps == 0
and self._is_rank_zero
):
log.info(
utils.memory_stats_log("Memory Stats", device=self._device)
# Log peak memory for iteration
memory_stats = utils.memory_stats_log(device=self._device)
self._metric_logger.log_dict(
memory_stats, step=self.total_training_steps
)

self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch)

def cleanup(self) -> None:
self._metric_logger.close()
if self._is_rank_zero:
self._metric_logger.close()
torch.distributed.destroy_process_group()


Expand Down
21 changes: 13 additions & 8 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ def setup(self, cfg: DictConfig) -> None:
"""
self._metric_logger = config.instantiate(cfg.metric_logger)

# log config with parameter override
self._metric_logger.log_config(cfg)

ckpt_dict = self.load_checkpoint(cfg.checkpointer)

# ``_setup_model`` handles initialization and loading the state dict. This method
Expand Down Expand Up @@ -257,11 +260,9 @@ def _setup_model(
if compile_model:
log.info("Compiling model with torch.compile...")
model = utils.wrap_compile(model)
log.info(
utils.memory_stats_log(
"Memory Stats after model init:", device=self._device
)
)
if self._device == torch.device("cuda"):
memory_stats = utils.memory_stats_log(device=self._device)
log.info(f"Memory Stats after model init:\n{memory_stats}")
return model

def _setup_optimizer(
Expand Down Expand Up @@ -440,9 +441,13 @@ def train(self) -> None:
self.total_training_steps += 1

# Log peak memory for iteration
if self.total_training_steps % self._log_peak_memory_every_n_steps == 0:
log.info(
utils.memory_stats_log("Memory Stats:", device=self._device)
if (
self.total_training_steps % self._log_peak_memory_every_n_steps == 0
and self._device == torch.device("cuda")
):
memory_stats = utils.memory_stats_log(device=self._device)
self._metric_logger.log_dict(
memory_stats, step=self.total_training_steps
)
self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch)
Expand Down
23 changes: 13 additions & 10 deletions recipes/gemma_full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,11 @@ def setup(self, cfg: DictConfig) -> None:
Sets up the recipe state correctly. This includes setting recipe attributes based
on the ``resume_from_checkpoint`` flag.
"""
self._metric_logger = config.instantiate(cfg.metric_logger)
if self._is_rank_zero:
self._metric_logger = config.instantiate(cfg.metric_logger)

# log config with parameter override
self._metric_logger.log_config(cfg)

ckpt_dict = self.load_checkpoint(cfg.checkpointer)

Expand Down Expand Up @@ -263,12 +267,8 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerDecoderLayer}
)
if self._is_rank_zero:
log.info(
utils.memory_stats_log(
"Memory Stats after model init", device=self._device
)
)

memory_stats = utils.memory_stats_log(device=self._device)
log.info(f"Memory Stats after model init:\n{memory_stats}")
# synchronize before training begins
torch.distributed.barrier()

Expand Down Expand Up @@ -458,15 +458,18 @@ def train(self) -> None:
self.total_training_steps % self._log_peak_memory_every_n_steps == 0
and self._is_rank_zero
):
log.info(
utils.memory_stats_log("Memory Stats", device=self._device)
# Log peak memory for iteration
memory_stats = utils.memory_stats_log(device=self._device)
self._metric_logger.log_dict(
memory_stats, step=self.total_training_steps
)

self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch)

def cleanup(self) -> None:
self._metric_logger.close()
if self._is_rank_zero:
self._metric_logger.close()
torch.distributed.destroy_process_group()


Expand Down
22 changes: 14 additions & 8 deletions recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def setup(self, cfg: DictConfig) -> None:
"""
self._metric_logger = config.instantiate(cfg.metric_logger)

# log config with parameter override
self._metric_logger.log_config(cfg)

checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)

self._model = self._setup_model(
Expand Down Expand Up @@ -252,11 +255,9 @@ def _setup_model(
)

log.info(f"Model is initialized with precision {self._dtype}.")
log.info(
utils.memory_stats_log(
"Memory Stats after model init:", device=self._device
)
)
if self._device == torch.device("cuda"):
memory_stats = utils.memory_stats_log(device=self._device)
log.info(f"Memory Stats after model init:\n{memory_stats}")
return model

def _setup_optimizer(
Expand Down Expand Up @@ -490,9 +491,14 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.total_training_steps += 1
# Log peak memory for iteration
if self.total_training_steps % self._log_peak_memory_every_n_steps == 0:
log.info(
utils.memory_stats_log("Memory Stats:", device=self._device)
if (
self.total_training_steps % self._log_peak_memory_every_n_steps == 0
and self._device == torch.device("cuda")
):
# Log peak memory for iteration
memory_stats = utils.memory_stats_log(device=self._device)
self._metric_logger.log_dict(
memory_stats, step=self.total_training_steps
)
self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch)
Expand Down
16 changes: 9 additions & 7 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ def setup(self, cfg: DictConfig) -> None:
if self._is_rank_zero:
self._metric_logger = config.instantiate(cfg.metric_logger)

# log config with parameter override
self._metric_logger.log_config(cfg)

checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)

self._model = self._setup_model(
Expand Down Expand Up @@ -353,11 +356,8 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerDecoderLayer}
)
if self._is_rank_zero:
log.info(
utils.memory_stats_log(
"Memory Stats after model init:", device=self._device
)
)
memory_stats = utils.memory_stats_log(device=self._device)
log.info(f"Memory Stats after model init:\n{memory_stats}")

# synchronize before training begins
torch.distributed.barrier()
Expand Down Expand Up @@ -571,8 +571,10 @@ def train(self) -> None:
self.total_training_steps % self._log_peak_memory_every_n_steps == 0
and self._is_rank_zero
):
log.info(
utils.memory_stats_log("Memory Stats:", device=self._device)
# Log peak memory for iteration
memory_stats = utils.memory_stats_log(device=self._device)
self._metric_logger.log_dict(
memory_stats, step=self.total_training_steps
)

self.epochs_run += 1
Expand Down
19 changes: 12 additions & 7 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ def setup(self, cfg: DictConfig) -> None:
model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader.
"""
self._metric_logger = config.instantiate(cfg.metric_logger)

# log config with parameter override
self._metric_logger.log_config(cfg)

self._model_compile = cfg.compile
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)

Expand Down Expand Up @@ -291,11 +295,9 @@ def _setup_model(
if compile_model:
log.info("Compiling model with torch.compile...")
model = utils.wrap_compile(model)
log.info(
utils.memory_stats_log(
"Memory Stats after model init:", device=self._device
)
)
if self._device == torch.device("cuda"):
memory_stats = utils.memory_stats_log(device=self._device)
log.info(f"Memory Stats after model init:\n{memory_stats}")
return model

def _setup_optimizer(
Expand Down Expand Up @@ -474,9 +476,12 @@ def train(self) -> None:
if (
self.total_training_steps % self._log_peak_memory_every_n_steps
== 0
and self._device == torch.device("cuda")
):
log.info(
utils.memory_stats_log("Memory Stats:", device=self._device)
# Log peak memory for iteration
memory_stats = utils.memory_stats_log(device=self._device)
self._metric_logger.log_dict(
memory_stats, step=self.total_training_steps
)
self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch)
Expand Down
Loading

0 comments on commit 5402b29

Please sign in to comment.