-
Notifications
You must be signed in to change notification settings - Fork 441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Checkpointer Tutorial + Docs #674
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/674
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 4 Unrelated FailuresAs of commit 7d9d200 with merge base 6e9ea22 (): NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
c66645f
to
22f0b46
Compare
FullModelHFCheckpointer | ||
FullModelMetaCheckpointer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait I thought we were doing away with FullModel
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh sorry, decided to punt the change to a separate PR since that will be a fairly widespread change.
.. _understand_checkpointer: | ||
|
||
============================== | ||
Understanding the Checkpointer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we wanna add one of those "What you will learn"/"Prerequisites" headers (similar to e.g. our "Finetune your First LLM" tutorial)? Might be helpful to (a) standardize our format a bit, and (b) make sure readers have the appropriate context going in
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also maybe I am nitpicking too much, but we could also just call this tutorial "Checkpoints in TorchTune". To me "Understanding X" indicates that X is inherently difficult to understand
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both good points. I had the "What you will learn" and then removed it. Will bring it back. Good point on the name
TorchTune is designed to be "state-dict invariant". | ||
|
||
- At the input, TorchTune accepts checkpoints from multiple sources in multiple formats. | ||
For Llama2 this includes both the HF Hub and the Meta Llama website. Model users don't |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: prefer to write to the reader (i.e. "Model users don't have to worry about" -> "You don't have to worry about")
|
||
TorchTune is designed to be "state-dict invariant". | ||
|
||
- At the input, TorchTune accepts checkpoints from multiple sources in multiple formats. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So one possible suggestion: do we wanna explicitly state the problem that this is solving up front? Like give an example of the usual .load_state_dict()
flow and how it doesn't work when the model and weights are coming from different places, then show how the checkpointer allows us to just call load_checkpoint()
without having to worry about it. (Btw I think the way you've laid it out is good too, this'd just be a slightly different exposition.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great idea! Let me try to put an example of this
into any recipe - training, evaluation or generation. Each checkpointer supports a | ||
set of models and scenarios making these easy to understand, debug and extend. | ||
|
||
TorchTune is designed to be "state-dict invariant". |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Idk if it's too elementary, but could just give a sentence or two on state dicts and how they can differ (different keys, reshaped tensors, etc)
The model weights at the end of a completed training | ||
run are written out to file. The checkpointer ensures that the output checkpoint | ||
files have the same keys as the input checkpoint file used to begin training. The | ||
checkpointer also ensures that the keys are paritioned across the same number of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checkpointer also ensures that the keys are paritioned across the same number of | |
checkpointer also ensures that the keys are partitioned across the same number of |
model_type: LLAMA2 | ||
|
||
# set to True if restarting training | ||
resume_from_checkpoint: True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like this ends kinda abruptly. Would either add some bit about offramps or even just a generic conclusion to wrap things up. Given that you're ending with LoRA, an example of generation with the merged checkpoint in e.g. HF format could be a nice way to tie everything together.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point!
@@ -26,7 +26,7 @@ | |||
# Tokenizer | |||
tokenizer: | |||
_component_: torchtune.models.llama2.llama2_tokenizer | |||
path: /tmp/llama2/tokenizer.model | |||
path: /tmp/Llama-2-7b-hf/tokenizer.model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😍
read directly from the "config.json" file. This helps ensure we either load the weights | ||
correctly or error out in case of discrepancy between the HF checkpoint file and TorchTune's | ||
model implementations. | ||
- HF checkpoint names usually oredered by ID (eg: 0001_of_0003, 0002_of_0003, etc.) To ensure |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- HF checkpoint names usually oredered by ID (eg: 0001_of_0003, 0002_of_0003, etc.) To ensure | |
- HF checkpoint names usually ordered by ID (eg: 0001_of_0003, 0002_of_0003, etc.) To ensure |
@@ -600,25 +576,11 @@ def save_checkpoint( | |||
checkpoint file ``recipe_state.pt`` is created in ``_output_dir`` which contains the recipe | |||
state. The output state dicts have the following formats: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prob should remove this last bit too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this! Couple of comments, but LG overall
:nosignatures: | ||
|
||
FullModelHFCheckpointer | ||
FullModelMetaCheckpointer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tutorial discusses 3 checkpointers, but the torchtune one is left out here. Is the thinking that users don't really have to use the torchtune one? If so, then shall we just remove it from tutorial, as we're presenting user with info that they're not really going to use?
|
||
# model_type which specifies how to convert the state_dict | ||
# into a format which TorchTune understands | ||
model_type: LLAMA2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model_type is pretty confusing to me and not sure if its going to be super clear for users. For example, what's the supported list of model_type
? If I want to work with mistral models, how do I find the appropriate model_type
? Do I ever need to define my own?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does model type being LLAMA2 or MISTRAL change how the state duct is converted, I thought that is more defined by the check pointer class itself? (Meta, HF, torchtune)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed overall with the point about usage in the current code base. ModelType is more of a forward facing concept for now (discussed in the original checkpointer PRs in case you're interested). It was relatively useful for the Gemma case where we were able to point the user to just add that branch without intruding into core logic. Will become more imp as we add more models.
For the case of Llama2 vs Mistral, we had this discussion on one of the PRs as well. Personally, its more confusing to have a mistral config with model type being Llama2. I expect many more users to look at configs than to dive into the code itself. So going to leave it as is. But I agree that currently this doesn't impact code. This will change once we add more mistral models.
Do I ever need to define my own?
This will go into a "how do I add a new model" tutorial. its on my list, but likely wont make it for the launch. More than happy for either of you to take this tutorial on.
_component_: torchtune.utils.FullModelHFCheckpointer | ||
|
||
# directory with the checkpoint files | ||
# this should match the output_dir above |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: there is no output_dir
above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch!
# this should match the output_dir above | ||
checkpoint_dir: <checkpoint_dir> | ||
|
||
# checkpoint files. These refer to intermediate checkpoints |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noted that these are the intermediate checkpoint files generated, though might be clearer to explicitly mention what users need to change when loading in from a mid-training checkpoint. i.e. explicitly mention "replace the original checkpoint files with the ones saved during mid training". Might be obvious but feel like changing config stuff is pretty easy to forget.
|
||
.. code-block:: yaml | ||
|
||
checkpointer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't I need adapter_checkpoint
somewhere? where does adapter_0
get read? Or is it a default setting? If so, would be vaulable to specify that since its not clear where the actual adapter weights learnt during training are getting read here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we dont use adapter weights during generation since we always use merged checkpoints for post-training ops. Probably good to clarify this somewhere, not sure if I'll do this here though. Maybe in the post-training tutorial which I'll work on enxt.
Let's take a close look at these different formats. | ||
|
||
Very simply put, the format of a checkpoint is dictated by the state_dict and how this is stored | ||
in files on disk. If the string identifer of the keys in the stored checkpoints don't match up |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in files on disk. If the string identifer of the keys in the stored checkpoints don't match up | |
in files on disk. Each weight is associated with a string key that identifies it in the state dict. If the string identifer of the keys in the stored checkpoints don't match up |
|
||
This is the format supported by the official Llama2 implementation. When you download the llama2 7B model | ||
from the `meta-llama website <https://llama.meta.com/llama-downloads>`_, you'll get access to a single | ||
``.pt`` checkpoint file. You can inspect the contents of this checkpoint easily with ``torch.load`` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
``.pt`` checkpoint file. You can inspect the contents of this checkpoint easily with ``torch.load`` | |
``.pth`` checkpoint file. You can inspect the contents of this checkpoint easily with ``torch.load`` |
>>> print(len(state_dict.keys())) | ||
292 | ||
|
||
The state_dict contains 292 keys, including an input embedding table called ``tok_embeddings``. the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The state_dict contains 292 keys, including an input embedding table called ``tok_embeddings``. the | |
The state_dict contains 292 keys, including an input embedding table called ``tok_embeddings``. The |
292 | ||
|
||
The state_dict contains 292 keys, including an input embedding table called ``tok_embeddings``. the | ||
model definition for this state_dict expects and embedding layer with 32000 items each having a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model definition for this state_dict expects and embedding layer with 32000 items each having a | |
model definition for this state_dict expects an embedding layer with 32000 vectors, each having a |
|
||
**Meta Format** | ||
|
||
This is the format supported by the official Llama2 implementation. When you download the llama2 7B model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does format specifically mean here? Does it just mean state dict keys?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the section above on checkpoint formats help answer this question?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, yes it does
|
||
# model_type which specifies how to convert the state_dict | ||
# into a format which TorchTune understands | ||
model_type: LLAMA2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does model type being LLAMA2 or MISTRAL change how the state duct is converted, I thought that is more defined by the check pointer class itself? (Meta, HF, torchtune)
|
||
| | ||
|
||
**MetaCheckpointer** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: is there any new information here that's not already conveyed by the HF checkpointer example above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly self-contained info for this checkpointer?
|
||
If checkpointing in the middle of training, the output checkpoint needs to store additional | ||
information to ensure that subsequent training runs can be correctly restarted. In addition to | ||
the model checkpoint files, we output a ``recipe_state.pt`` file for intermediate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we always save recipe state? Too much storage space? What if I finish training but later decide that I want to continue again from there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its a good use case, but seems like an edge case for now? Happy to revisit this as well based on user info. The problem is that recipe_state can be quite a big cpt since we capture opt state as well. So we end up with ~13.5GB checkpoint/epoch for weights and then ~27GB checkpoint/epoch for opt state. This can be quite a lot esp when users dont have much disk space
|
||
# if we're restarting a previous run, we need to specify | ||
# the file with the checkpoint state | ||
recipe_checkpoint: recipe_state.pt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just look for this file in the checkpoint for and if it's present then we know we need to resume training and update recipe state? Making the user specify it leaves room for forgetting to update the config and is more error prone, ideally we just handle it for them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a huge fan of changing behavior based on files tbh. Locks us into specific file names or formats. I think it's fine if the user forgets and we just have meaningful errors?
Checkpointing for LoRA | ||
---------------------- | ||
|
||
In TorchTune, we output both the adapter weights and the full model "merged" weights |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add link to lora tutorial so users understand what adapter checkpoints are
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple nits but overall this looks good to me!
|
||
Very simply put, the format of a checkpoint is dictated by the state_dict and how this is stored | ||
in files on disk. Each weight is associated with a string key that identifies it in the state dict. | ||
If the string identifer of the keys in the stored checkpoints don't match up |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the string identifer of the keys in the stored checkpoints don't match up | |
If the string identifier of the keys in the stored checkpoints don't match up |
checkpoint load and save.The TorchTune checkpointer makes this less error-prone by managing state dicts | ||
for you. TorchTune is designed to be "state-dict invariant". | ||
|
||
- When loading,, TorchTune accepts checkpoints from multiple sources in multiple formats. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- When loading,, TorchTune accepts checkpoints from multiple sources in multiple formats. | |
- When loading, TorchTune accepts checkpoints from multiple sources in multiple formats. |
tune run generate --config generate | ||
|
||
# output from the generation | ||
[generate.py:68] Model is initialized with precision torch.bfloat16. | ||
[generate.py:92] Welcome to the 'Alternative' Treatments and Therapies Forum |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: without understanding the generation recipe I have no idea what's going on here (even with understanding it I kinda have no idea what's going on here 😅)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh you mean in the logs - hmm, let me take a closer look
Hello! I am wondering on best practice to enable saving checkpoints to Weights & Biases using the checkpointers. I was thinking on subclassing the Checkpointer classes but it seems that I would need to subclass every one of them. My current solution seems very convoluted and brittle: # save_checkpoint method in full_finetune_single_device.py
def save_checkpoint(self, epoch: int) -> None:
ckpt_dict = {utils.MODEL_KEY: self._model.state_dict()}
# if training is in-progress, checkpoint the optimizer state as well
if epoch + 1 < self.total_epochs:
ckpt_dict.update(
{
utils.SEED_KEY: self.seed,
utils.EPOCHS_KEY: self.epochs_run,
utils.TOTAL_EPOCHS_KEY: self.total_epochs,
utils.MAX_STEPS_KEY: self.max_steps_per_epoch,
}
)
if not self._optimizer_in_bwd:
ckpt_dict[utils.OPT_KEY] = self._optimizer.state_dict()
else:
ckpt_dict[utils.OPT_KEY] = self._optim_ckpt_wrapper.state_dict()
self._checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
intermediate_checkpoint=(epoch + 1 < self.total_epochs),
)
## Let's save the checkpoint to W&B
## depending on the Checkpointer Class the file will be named differently
## An example for the full_finetune case
checkpoint_file = Path.joinpath(
self._checkpointer._output_dir, f"torchtune_model_{epoch}"
).with_suffix(".pt")
wandb_at = wandb.Artifact(
name=f"torchtune_model_{epoch}",
type="model",
description="Model checkpoint",
metadata={
utils.SEED_KEY: self.seed,
utils.EPOCHS_KEY: self.epochs_run,
utils.TOTAL_EPOCHS_KEY: self.total_epochs,
utils.MAX_STEPS_KEY: self.max_steps_per_epoch,
}
)
wandb_at.add_file(checkpoint_file)
wandb.log_artifact(wandb_at) |
@tcapelle thanks for taking a look! Agreed, subclassing these is not a great idea like you pointed out. We also try to minimize implementation inheritance in this code base so that would not work out from a design principles perspective. I'm trying to understand the use case a bit better. Do you want to add functionality to save the final checkpoint to WANDB? What would be the benefit of doing that? Do users get some visualization tools for the checkpoint? Or would the meta data be enough? If it's just the metadata, then we can add this to the recipe but will need to think a bit about how we expose this since not every user would have WANDB installed. If it's the checkpoint, then probably needs to happen after the conversion is done in which case this should be a utility which runs at the recipe and picks this up from the output folder. But yeh if you can describe the use case in more detail, that would be helpful. Maybe do this on a issue where we can have a directed discussion? |
b320d91
to
7d9d200
Compare
Thanks for the detailed answer! I am mostly interested on saving the output model checkpoint to W&B, mostly for our professional users that will want to keep everything integrated. This of course should be disabled by default. I imagine users would want to save to S3, GCP and other file storage systems when integrating torchtune on their pipelines. |
Context
We've gotten several round of questions on our checkpointer. In this PR I add a tutorial and render the APIs on our docs
Changelog
Test plan
Tutorial rendered correctly
Docs rendered correctly