Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Checkpointer Tutorial + Docs #674

Merged
merged 12 commits into from
Apr 11, 2024
Merged

Add Checkpointer Tutorial + Docs #674

merged 12 commits into from
Apr 11, 2024

Conversation

kartikayk
Copy link
Contributor

Context

We've gotten several round of questions on our checkpointer. In this PR I add a tutorial and render the APIs on our docs

Changelog

  • Checkpointer Tutorial
  • Checkpointer Docs
  • Make HFCheckpointer the default in all llama2 configs

Test plan

  • Ran all of the modified configs

Tutorial rendered correctly

Screenshot 2024-04-09 at 3 50 18 PM

Docs rendered correctly

Screenshot 2024-04-09 at 3 50 10 PM

Copy link

pytorch-bot bot commented Apr 9, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/674

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 4 Unrelated Failures

As of commit 7d9d200 with merge base 6e9ea22 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 9, 2024
Comment on lines +17 to +18
FullModelHFCheckpointer
FullModelMetaCheckpointer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait I thought we were doing away with FullModel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, decided to punt the change to a separate PR since that will be a fairly widespread change.

.. _understand_checkpointer:

==============================
Understanding the Checkpointer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we wanna add one of those "What you will learn"/"Prerequisites" headers (similar to e.g. our "Finetune your First LLM" tutorial)? Might be helpful to (a) standardize our format a bit, and (b) make sure readers have the appropriate context going in

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also maybe I am nitpicking too much, but we could also just call this tutorial "Checkpoints in TorchTune". To me "Understanding X" indicates that X is inherently difficult to understand

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both good points. I had the "What you will learn" and then removed it. Will bring it back. Good point on the name

TorchTune is designed to be "state-dict invariant".

- At the input, TorchTune accepts checkpoints from multiple sources in multiple formats.
For Llama2 this includes both the HF Hub and the Meta Llama website. Model users don't
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: prefer to write to the reader (i.e. "Model users don't have to worry about" -> "You don't have to worry about")


TorchTune is designed to be "state-dict invariant".

- At the input, TorchTune accepts checkpoints from multiple sources in multiple formats.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So one possible suggestion: do we wanna explicitly state the problem that this is solving up front? Like give an example of the usual .load_state_dict() flow and how it doesn't work when the model and weights are coming from different places, then show how the checkpointer allows us to just call load_checkpoint() without having to worry about it. (Btw I think the way you've laid it out is good too, this'd just be a slightly different exposition.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea! Let me try to put an example of this

into any recipe - training, evaluation or generation. Each checkpointer supports a
set of models and scenarios making these easy to understand, debug and extend.

TorchTune is designed to be "state-dict invariant".
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idk if it's too elementary, but could just give a sentence or two on state dicts and how they can differ (different keys, reshaped tensors, etc)

The model weights at the end of a completed training
run are written out to file. The checkpointer ensures that the output checkpoint
files have the same keys as the input checkpoint file used to begin training. The
checkpointer also ensures that the keys are paritioned across the same number of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
checkpointer also ensures that the keys are paritioned across the same number of
checkpointer also ensures that the keys are partitioned across the same number of

model_type: LLAMA2

# set to True if restarting training
resume_from_checkpoint: True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this ends kinda abruptly. Would either add some bit about offramps or even just a generic conclusion to wrap things up. Given that you're ending with LoRA, an example of generation with the merged checkpoint in e.g. HF format could be a nice way to tie everything together.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point!

@@ -26,7 +26,7 @@
# Tokenizer
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/llama2/tokenizer.model
path: /tmp/Llama-2-7b-hf/tokenizer.model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😍

read directly from the "config.json" file. This helps ensure we either load the weights
correctly or error out in case of discrepancy between the HF checkpoint file and TorchTune's
model implementations.
- HF checkpoint names usually oredered by ID (eg: 0001_of_0003, 0002_of_0003, etc.) To ensure
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- HF checkpoint names usually oredered by ID (eg: 0001_of_0003, 0002_of_0003, etc.) To ensure
- HF checkpoint names usually ordered by ID (eg: 0001_of_0003, 0002_of_0003, etc.) To ensure

@@ -600,25 +576,11 @@ def save_checkpoint(
checkpoint file ``recipe_state.pt`` is created in ``_output_dir`` which contains the recipe
state. The output state dicts have the following formats:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prob should remove this last bit too

Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this! Couple of comments, but LG overall

:nosignatures:

FullModelHFCheckpointer
FullModelMetaCheckpointer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tutorial discusses 3 checkpointers, but the torchtune one is left out here. Is the thinking that users don't really have to use the torchtune one? If so, then shall we just remove it from tutorial, as we're presenting user with info that they're not really going to use?


# model_type which specifies how to convert the state_dict
# into a format which TorchTune understands
model_type: LLAMA2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_type is pretty confusing to me and not sure if its going to be super clear for users. For example, what's the supported list of model_type? If I want to work with mistral models, how do I find the appropriate model_type? Do I ever need to define my own?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does model type being LLAMA2 or MISTRAL change how the state duct is converted, I thought that is more defined by the check pointer class itself? (Meta, HF, torchtune)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed overall with the point about usage in the current code base. ModelType is more of a forward facing concept for now (discussed in the original checkpointer PRs in case you're interested). It was relatively useful for the Gemma case where we were able to point the user to just add that branch without intruding into core logic. Will become more imp as we add more models.

For the case of Llama2 vs Mistral, we had this discussion on one of the PRs as well. Personally, its more confusing to have a mistral config with model type being Llama2. I expect many more users to look at configs than to dive into the code itself. So going to leave it as is. But I agree that currently this doesn't impact code. This will change once we add more mistral models.

Do I ever need to define my own?

This will go into a "how do I add a new model" tutorial. its on my list, but likely wont make it for the launch. More than happy for either of you to take this tutorial on.

_component_: torchtune.utils.FullModelHFCheckpointer

# directory with the checkpoint files
# this should match the output_dir above
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: there is no output_dir above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!

# this should match the output_dir above
checkpoint_dir: <checkpoint_dir>

# checkpoint files. These refer to intermediate checkpoints
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted that these are the intermediate checkpoint files generated, though might be clearer to explicitly mention what users need to change when loading in from a mid-training checkpoint. i.e. explicitly mention "replace the original checkpoint files with the ones saved during mid training". Might be obvious but feel like changing config stuff is pretty easy to forget.


.. code-block:: yaml

checkpointer:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't I need adapter_checkpoint somewhere? where does adapter_0 get read? Or is it a default setting? If so, would be vaulable to specify that since its not clear where the actual adapter weights learnt during training are getting read here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we dont use adapter weights during generation since we always use merged checkpoints for post-training ops. Probably good to clarify this somewhere, not sure if I'll do this here though. Maybe in the post-training tutorial which I'll work on enxt.

Let's take a close look at these different formats.

Very simply put, the format of a checkpoint is dictated by the state_dict and how this is stored
in files on disk. If the string identifer of the keys in the stored checkpoints don't match up
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
in files on disk. If the string identifer of the keys in the stored checkpoints don't match up
in files on disk. Each weight is associated with a string key that identifies it in the state dict. If the string identifer of the keys in the stored checkpoints don't match up


This is the format supported by the official Llama2 implementation. When you download the llama2 7B model
from the `meta-llama website <https://llama.meta.com/llama-downloads>`_, you'll get access to a single
``.pt`` checkpoint file. You can inspect the contents of this checkpoint easily with ``torch.load``
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
``.pt`` checkpoint file. You can inspect the contents of this checkpoint easily with ``torch.load``
``.pth`` checkpoint file. You can inspect the contents of this checkpoint easily with ``torch.load``

>>> print(len(state_dict.keys()))
292

The state_dict contains 292 keys, including an input embedding table called ``tok_embeddings``. the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The state_dict contains 292 keys, including an input embedding table called ``tok_embeddings``. the
The state_dict contains 292 keys, including an input embedding table called ``tok_embeddings``. The

292

The state_dict contains 292 keys, including an input embedding table called ``tok_embeddings``. the
model definition for this state_dict expects and embedding layer with 32000 items each having a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model definition for this state_dict expects and embedding layer with 32000 items each having a
model definition for this state_dict expects an embedding layer with 32000 vectors, each having a


**Meta Format**

This is the format supported by the official Llama2 implementation. When you download the llama2 7B model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does format specifically mean here? Does it just mean state dict keys?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the section above on checkpoint formats help answer this question?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, yes it does


# model_type which specifies how to convert the state_dict
# into a format which TorchTune understands
model_type: LLAMA2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does model type being LLAMA2 or MISTRAL change how the state duct is converted, I thought that is more defined by the check pointer class itself? (Meta, HF, torchtune)


|

**MetaCheckpointer**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is there any new information here that's not already conveyed by the HF checkpointer example above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly self-contained info for this checkpointer?


If checkpointing in the middle of training, the output checkpoint needs to store additional
information to ensure that subsequent training runs can be correctly restarted. In addition to
the model checkpoint files, we output a ``recipe_state.pt`` file for intermediate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we always save recipe state? Too much storage space? What if I finish training but later decide that I want to continue again from there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its a good use case, but seems like an edge case for now? Happy to revisit this as well based on user info. The problem is that recipe_state can be quite a big cpt since we capture opt state as well. So we end up with ~13.5GB checkpoint/epoch for weights and then ~27GB checkpoint/epoch for opt state. This can be quite a lot esp when users dont have much disk space


# if we're restarting a previous run, we need to specify
# the file with the checkpoint state
recipe_checkpoint: recipe_state.pt
Copy link
Contributor

@RdoubleA RdoubleA Apr 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just look for this file in the checkpoint for and if it's present then we know we need to resume training and update recipe state? Making the user specify it leaves room for forgetting to update the config and is more error prone, ideally we just handle it for them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a huge fan of changing behavior based on files tbh. Locks us into specific file names or formats. I think it's fine if the user forgets and we just have meaningful errors?

Checkpointing for LoRA
----------------------

In TorchTune, we output both the adapter weights and the full model "merged" weights
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add link to lora tutorial so users understand what adapter checkpoints are

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple nits but overall this looks good to me!


Very simply put, the format of a checkpoint is dictated by the state_dict and how this is stored
in files on disk. Each weight is associated with a string key that identifies it in the state dict.
If the string identifer of the keys in the stored checkpoints don't match up
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If the string identifer of the keys in the stored checkpoints don't match up
If the string identifier of the keys in the stored checkpoints don't match up

checkpoint load and save.The TorchTune checkpointer makes this less error-prone by managing state dicts
for you. TorchTune is designed to be "state-dict invariant".

- When loading,, TorchTune accepts checkpoints from multiple sources in multiple formats.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- When loading,, TorchTune accepts checkpoints from multiple sources in multiple formats.
- When loading, TorchTune accepts checkpoints from multiple sources in multiple formats.

Comment on lines 446 to 450
tune run generate --config generate

# output from the generation
[generate.py:68] Model is initialized with precision torch.bfloat16.
[generate.py:92] Welcome to the 'Alternative' Treatments and Therapies Forum
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: without understanding the generation recipe I have no idea what's going on here (even with understanding it I kinda have no idea what's going on here 😅)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh you mean in the logs - hmm, let me take a closer look

@tcapelle
Copy link
Contributor

tcapelle commented Apr 11, 2024

Hello!

I am wondering on best practice to enable saving checkpoints to Weights & Biases using the checkpointers. I was thinking on subclassing the Checkpointer classes but it seems that I would need to subclass every one of them.
Can you suggest me a good way to approach this?

My current solution seems very convoluted and brittle:

# save_checkpoint method in full_finetune_single_device.py
def save_checkpoint(self, epoch: int) -> None:
    ckpt_dict = {utils.MODEL_KEY: self._model.state_dict()}
    # if training is in-progress, checkpoint the optimizer state as well
    if epoch + 1 < self.total_epochs:
        ckpt_dict.update(
            {
                utils.SEED_KEY: self.seed,
                utils.EPOCHS_KEY: self.epochs_run,
                utils.TOTAL_EPOCHS_KEY: self.total_epochs,
                utils.MAX_STEPS_KEY: self.max_steps_per_epoch,
            }
        )
        if not self._optimizer_in_bwd:
            ckpt_dict[utils.OPT_KEY] = self._optimizer.state_dict()
        else:
            ckpt_dict[utils.OPT_KEY] = self._optim_ckpt_wrapper.state_dict()
    self._checkpointer.save_checkpoint(
        ckpt_dict,
        epoch=epoch,
        intermediate_checkpoint=(epoch + 1 < self.total_epochs),
    )
    ## Let's save the checkpoint to W&B
    ## depending on the Checkpointer Class the file will be named differently
    ##  An example for the full_finetune case
    checkpoint_file = Path.joinpath(
        self._checkpointer._output_dir, f"torchtune_model_{epoch}"
    ).with_suffix(".pt")
    wandb_at = wandb.Artifact(
      name=f"torchtune_model_{epoch}",
      type="model",
      description="Model checkpoint",
      metadata={
        utils.SEED_KEY: self.seed,
        utils.EPOCHS_KEY: self.epochs_run,
        utils.TOTAL_EPOCHS_KEY: self.total_epochs,
        utils.MAX_STEPS_KEY: self.max_steps_per_epoch,
        }
    )
    wandb_at.add_file(checkpoint_file)
    wandb.log_artifact(wandb_at)

@kartikayk
Copy link
Contributor Author

@tcapelle thanks for taking a look!

Agreed, subclassing these is not a great idea like you pointed out. We also try to minimize implementation inheritance in this code base so that would not work out from a design principles perspective.

I'm trying to understand the use case a bit better. Do you want to add functionality to save the final checkpoint to WANDB? What would be the benefit of doing that? Do users get some visualization tools for the checkpoint? Or would the meta data be enough?

If it's just the metadata, then we can add this to the recipe but will need to think a bit about how we expose this since not every user would have WANDB installed. If it's the checkpoint, then probably needs to happen after the conversion is done in which case this should be a utility which runs at the recipe and picks this up from the output folder.

But yeh if you can describe the use case in more detail, that would be helpful. Maybe do this on a issue where we can have a directed discussion?

@kartikayk kartikayk merged commit 0539976 into main Apr 11, 2024
24 of 31 checks passed
@kartikayk kartikayk deleted the checkpointer_docs branch April 11, 2024 20:14
joecummings pushed a commit that referenced this pull request Apr 11, 2024
@tcapelle
Copy link
Contributor

Thanks for the detailed answer!

I am mostly interested on saving the output model checkpoint to W&B, mostly for our professional users that will want to keep everything integrated. This of course should be disabled by default.

I imagine users would want to save to S3, GCP and other file storage systems when integrating torchtune on their pipelines.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants