-
Notifications
You must be signed in to change notification settings - Fork 441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve Wandb experience #660
Changes from 79 commits
a3df205
3f41234
914a901
54a5e2a
de155dc
97b994d
08fae5e
ae600b2
290beb5
dc6e54d
a804c23
0217bfa
98ae830
f60ebb2
ba93269
ee2f82b
83660cd
97381a7
86c6ee4
ec3d93e
4c6460f
72dd372
d876889
0770781
f085a77
34accd9
07d3813
7fab51f
96ecf28
77eb695
76c21b7
cba0560
98f82e5
e97720a
1162295
99283ae
6e1bfcc
2ff4db7
f73b4d7
6a888ac
90ebe53
c0f81ae
c80c7ab
fedfd9c
6880cba
1ddbe7a
ad42a35
2af532a
c8c771e
8518e26
b2ab64c
b45edeb
c251b59
a9f43fe
7b66488
4985758
3ed657c
5cc577a
b64300f
3d18e69
0b7c13a
80bd373
29a5252
c040735
7069498
f319fbf
8202bc9
09f59e8
ce42cc8
cf8a948
bfb8e98
820bef1
8fee701
50cc6a8
f3fe9e5
fffba10
52c2207
1256087
97e8aa3
55513f7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
.. _wandb_logging: | ||
|
||
=========================== | ||
Logging to Weights & Biases | ||
=========================== | ||
|
||
.. customcarditem:: | ||
:header: Logging to Weights & Biases | ||
:card_description: Log metrics and model checkpoints to W&B | ||
:image: _static/img/torchtune_workspace.png | ||
:link: examples/wandb_logging.html | ||
:tags: logging,wandb | ||
|
||
|
||
Torchtune supports logging your training runs to [Weights & Biases](https://wandb.ai). | ||
|
||
.. note:: | ||
|
||
You will need to install the `wandb`` package to use this feature. | ||
You can install it via pip: | ||
|
||
.. code-block:: bash | ||
|
||
pip install wandb | ||
|
||
|
||
Metric Logger | ||
------------- | ||
|
||
The only change you need to make is to add the metric logger to your config. Weights & Biases will log the metrics and model checkpoints for you. | ||
|
||
.. code-block:: yaml | ||
|
||
# enable logging to the built-in WandBLogger | ||
metric_logger: | ||
_component_: torchtune.utils.metric_logging.WandBLogger | ||
# the W&B project to log to | ||
project: torchtune | ||
|
||
|
||
We automatically grab the config from the recipe you are running and log it to W&B. You can find it in the W&B overview tab and the actual file in the `Files` tab. | ||
|
||
.. note:: | ||
|
||
Click on this sample [project to see the W&B workspace](https://wandb.ai/capecape/torchtune) | ||
The config used to train the models can be found [here](https://wandb.ai/capecape/torchtune/runs/6053ofw0/files/torchtune_config_j67sb73v.yaml) | ||
tcapelle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Logging Model Checkpoints to W&B | ||
-------------------------------- | ||
|
||
You can also log the model checkpoints to W&B by modifying the desired script `save_checkpoint` method. | ||
|
||
A suggested approach would be something like this: | ||
|
||
.. code-block:: python | ||
|
||
def save_checkpoint(self, epoch: int) -> None: | ||
... | ||
## Let's save the checkpoint to W&B | ||
## depending on the Checkpointer Class the file will be named differently | ||
## Here is an example for the full_finetune case | ||
checkpoint_file = Path.joinpath( | ||
self._checkpointer._output_dir, f"torchtune_model_{epoch}" | ||
).with_suffix(".pt") | ||
wandb_at = wandb.Artifact( | ||
name=f"torchtune_model_{epoch}", | ||
type="model", | ||
# description of the model checkpoint | ||
description="Model checkpoint", | ||
# you can add whatever metadata you want as a dict | ||
metadata={ | ||
utils.SEED_KEY: self.seed, | ||
utils.EPOCHS_KEY: self.epochs_run, | ||
utils.TOTAL_EPOCHS_KEY: self.total_epochs, | ||
utils.MAX_STEPS_KEY: self.max_steps_per_epoch, | ||
} | ||
) | ||
wandb_at.add_file(checkpoint_file) | ||
wandb.log_artifact(wandb_at) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -168,6 +168,9 @@ def setup(self, cfg: DictConfig) -> None: | |
if self._is_rank_zero: | ||
self._metric_logger = config.instantiate(cfg.metric_logger) | ||
|
||
# log config with parameter override | ||
self._metric_logger.log_config(cfg) | ||
|
||
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) | ||
|
||
self._model = self._setup_model( | ||
|
@@ -323,12 +326,9 @@ def _setup_model( | |
utils.set_activation_checkpointing( | ||
model, auto_wrap_policy={modules.TransformerDecoderLayer} | ||
) | ||
if self._is_rank_zero: | ||
log.info( | ||
utils.memory_stats_log( | ||
"Memory Stats after model init:", device=self._device | ||
) | ||
) | ||
if self._is_rank_zero and self._device == torch.device("cuda"): | ||
memory_stats = utils.memory_stats_log(device=self._device) | ||
log.info(f"Memory Stats after model init:\n{memory_stats}") | ||
|
||
# synchronize before training begins | ||
torch.distributed.barrier() | ||
|
@@ -541,9 +541,12 @@ def train(self) -> None: | |
if ( | ||
self.total_training_steps % self._log_peak_memory_every_n_steps == 0 | ||
and self._is_rank_zero | ||
and self._device == torch.device("cuda") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we really need this one too? For distributed tests they should only run on GPU There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mm yeah I can remove this check for distributed recipes |
||
): | ||
log.info( | ||
utils.memory_stats_log("Memory Stats:", device=self._device) | ||
# Log peak memory for iteration | ||
memory_stats = utils.memory_stats_log(device=self._device) | ||
self._metric_logger.log_dict( | ||
memory_stats, step=self.total_training_steps | ||
) | ||
|
||
self.epochs_run += 1 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -146,6 +146,10 @@ def setup(self, cfg: DictConfig) -> None: | |
model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. | ||
""" | ||
self._metric_logger = config.instantiate(cfg.metric_logger) | ||
|
||
# log config with parameter override | ||
self._metric_logger.log_config(cfg) | ||
|
||
self._model_compile = cfg.compile | ||
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) | ||
|
||
|
@@ -263,11 +267,9 @@ def _setup_model( | |
if compile_model: | ||
log.info("Compiling model with torch.compile...") | ||
model = utils.wrap_compile(model) | ||
log.info( | ||
utils.memory_stats_log( | ||
"Memory Stats after model init:", device=self._device | ||
) | ||
) | ||
if self._device == torch.device("cuda"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is for the CPU recipe tests? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, this actually wont throw the log just doesn't print anything |
||
memory_stats = utils.memory_stats_log(device=self._device) | ||
log.info(f"Memory Stats after model init:\n{memory_stats}") | ||
return model | ||
|
||
def _setup_optimizer( | ||
|
@@ -446,9 +448,12 @@ def train(self) -> None: | |
if ( | ||
self.total_training_steps % self._log_peak_memory_every_n_steps | ||
== 0 | ||
and self._device == torch.device("cuda") | ||
): | ||
log.info( | ||
utils.memory_stats_log("Memory Stats:", device=self._device) | ||
# Log peak memory for iteration | ||
memory_stats = utils.memory_stats_log(device=self._device) | ||
self._metric_logger.log_dict( | ||
memory_stats, step=self.total_training_steps | ||
) | ||
self.epochs_run += 1 | ||
self.save_checkpoint(epoch=curr_epoch) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a "tip" to run
wandb login
before running?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch