Skip to content

Commit

Permalink
Merge branch 'main' into add-feature-grad-norm
Browse files Browse the repository at this point in the history
  • Loading branch information
lindawangg committed Aug 30, 2024
2 parents 06c2792 + dfc69e2 commit 2d46255
Show file tree
Hide file tree
Showing 121 changed files with 657 additions and 579 deletions.
53 changes: 53 additions & 0 deletions docs/source/api_ref_training.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
==================
torchtune.training
==================

.. currentmodule:: torchtune.training

.. _checkpointing_label:

Checkpointing
-------------

torchtune offers checkpointers to allow seamless transitioning between checkpoint formats for training and interoperability with the rest of the ecosystem. For a comprehensive overview of
checkpointing, please see the :ref:`checkpointing deep-dive <understand_checkpointer>`.

.. autosummary::
:toctree: generated/
:nosignatures:

FullModelHFCheckpointer
FullModelMetaCheckpointer
FullModelTorchTuneCheckpointer
ModelType
update_state_dict_for_classifier

.. _mp_label:

Reduced Precision
------------------

Utilities for working in a reduced precision setting.

.. autosummary::
:toctree: generated/
:nosignatures:

get_dtype
set_default_dtype
validate_expected_param_dtype
get_quantizer_mode

.. _perf_profiling_label:

Performance and Profiling
-------------------------

torchtune provides utilities to profile and debug the memory and performance
of your finetuning job.

.. autosummary::
:toctree: generated/
:nosignatures:

setup_torch_profiler
42 changes: 2 additions & 40 deletions docs/source/api_ref_utilities.rst
Original file line number Diff line number Diff line change
@@ -1,28 +1,9 @@
=================
===============
torchtune.utils
=================
===============

.. currentmodule:: torchtune.utils


.. _checkpointing_label:

Checkpointing
-------------

torchtune offers checkpointers to allow seamless transitioning between checkpoint formats for training and interoperability with the rest of the ecosystem. For a comprehensive overview of
checkpointing, please see the :ref:`checkpointing deep-dive <understand_checkpointer>`.

.. autosummary::
:toctree: generated/
:nosignatures:

FullModelHFCheckpointer
FullModelMetaCheckpointer
FullModelTorchTuneCheckpointer
ModelType
update_state_dict_for_classifier

.. _dist_label:

Distributed
Expand All @@ -41,22 +22,6 @@ Utilities for enabling and working with distributed training.
get_full_finetune_fsdp_wrap_policy
lora_fsdp_wrap_policy

.. _mp_label:

Reduced Precision
------------------

Utilities for working in a reduced precision setting.

.. autosummary::
:toctree: generated/
:nosignatures:

get_dtype
set_default_dtype
validate_expected_param_dtype
get_quantizer_mode

.. _ac_label:

Memory Management
Expand All @@ -74,8 +39,6 @@ Utilities to reduce memory consumption during training.
register_optim_in_bwd_hooks


.. _perf_profiling_label:

Performance and Profiling
-------------------------

Expand All @@ -88,7 +51,6 @@ of your finetuning job.

get_memory_stats
log_memory_stats
setup_torch_profiler

.. _metric_logging_label:

Expand Down
22 changes: 11 additions & 11 deletions docs/source/deep_dives/checkpointer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ torchtune supports three different
each of which supports a different checkpoint format.


:class:`HFCheckpointer <torchtune.utils.FullModelHFCheckpointer>`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
:class:`HFCheckpointer <torchtune.training.FullModelHFCheckpointer>`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This checkpointer reads and writes checkpoints in a format which is compatible with the transformers
framework from Hugging Face. As mentioned above, this is the most popular format within the Hugging Face
Expand Down Expand Up @@ -167,7 +167,7 @@ The following snippet explains how the HFCheckpointer is setup in torchtune conf
checkpointer:
# checkpointer to use
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
# directory with the checkpoint files
# this should match the output_dir above
Expand Down Expand Up @@ -205,8 +205,8 @@ The following snippet explains how the HFCheckpointer is setup in torchtune conf

|
:class:`MetaCheckpointer <torchtune.utils.FullModelMetaCheckpointer>`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
:class:`MetaCheckpointer <torchtune.training.FullModelMetaCheckpointer>`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This checkpointer reads and writes checkpoints in a format which is compatible with the original meta-llama
github repository.
Expand Down Expand Up @@ -237,7 +237,7 @@ The following snippet explains how the MetaCheckpointer is setup in torchtune co
checkpointer:
# checkpointer to use
_component_: torchtune.utils.FullModelMetaCheckpointer
_component_: torchtune.training.FullModelMetaCheckpointer
# directory with the checkpoint files
# this should match the output_dir above
Expand Down Expand Up @@ -265,8 +265,8 @@ The following snippet explains how the MetaCheckpointer is setup in torchtune co
|
:class:`TorchTuneCheckpointer <torchtune.utils.FullModelTorchTuneCheckpointer>`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
:class:`TorchTuneCheckpointer <torchtune.training.FullModelTorchTuneCheckpointer>`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This checkpointer reads and writes checkpoints in a format that is compatible with torchtune's
model definition. This does not perform any state_dict conversions and is currently used either
Expand Down Expand Up @@ -335,7 +335,7 @@ to the config file
checkpointer:
# checkpointer to use
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: <checkpoint_dir>
Expand Down Expand Up @@ -381,7 +381,7 @@ looks something like this:
checkpointer:
# checkpointer to use
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
# directory with the checkpoint files
# this should match the output_dir above
Expand Down Expand Up @@ -427,7 +427,7 @@ For this section we'll use the Llama2 13B model in HF format.
.. code-block:: python
import torch
from torchtune.utils import FullModelHFCheckpointer, ModelType
from torchtune.training import FullModelHFCheckpointer, ModelType
from torchtune.models.llama2 import llama2_13b
# Set the right directory and files
Expand Down
2 changes: 1 addition & 1 deletion docs/source/deep_dives/recipe_deepdive.rst
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ Initialize recipe state including seed, device, dtype, metric loggers, relevant
def __init__(...):
self._device = utils.get_device(device=params.device)
self._dtype = utils.get_dtype(dtype=params.dtype, device=self._device)
self._dtype = training.get_dtype(dtype=params.dtype, device=self._device)
...
Load checkpoint, update recipe state from checkpoint, initialize components and load state dicts from checkpoint
Expand Down
8 changes: 4 additions & 4 deletions docs/source/deep_dives/wandb_logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ A suggested approach would be something like this:
description="Model checkpoint",
# you can add whatever metadata you want as a dict
metadata={
utils.SEED_KEY: self.seed,
utils.EPOCHS_KEY: self.epochs_run,
utils.TOTAL_EPOCHS_KEY: self.total_epochs,
utils.MAX_STEPS_KEY: self.max_steps_per_epoch,
training.SEED_KEY: self.seed,
training.EPOCHS_KEY: self.epochs_run,
training.TOTAL_EPOCHS_KEY: self.total_epochs,
training.MAX_STEPS_KEY: self.max_steps_per_epoch,
}
)
wandb_at.add_file(checkpoint_file)
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,5 @@ torchtune tutorials.
api_ref_datasets
api_ref_models
api_ref_modules
api_ref_training
api_ref_utilities
4 changes: 2 additions & 2 deletions docs/source/recipes/qat_distributed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ strategy. Generally, the pipeline for training, quantizing, and evaluating a mod
# QAT specific args
quantizer:
_component_: torchtune.utils.quantization.Int8DynActInt4WeightQATQuantizer
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256
#. :ref:`Evaluate<qat_eval_label>` or `run inference <https://github.com/pytorch/torchtune/blob/main/recipes/quantization.md#generate>`_
Expand All @@ -66,7 +66,7 @@ strategy. Generally, the pipeline for training, quantizing, and evaluating a mod
.. code-block:: yaml
quantizer:
_component_: torchtune.utils.quantization.Int8DynActInt4WeightQuantizer
_component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer
groupsize: 256
.. note::
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/e2e_flow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ First, we modify ``custom_eval_config.yaml`` to include the fine-tuned checkpoin
.. code-block:: yaml
checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
# directory with the checkpoint files
# this should match the output_dir specified during
Expand Down Expand Up @@ -262,7 +262,7 @@ Let's modify ``custom_generation_config.yaml`` to include the following changes.
.. code-block:: yaml
checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
# directory with the checkpoint files
# this should match the output_dir specified during
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/llama3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ Next, we modify ``custom_eval_config.yaml`` to include the fine-tuned checkpoint
_component_: torchtune.models.llama3.llama3_8b
checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
_component_: torchtune.training.FullModelMetaCheckpointer
# directory with the checkpoint files
# this should match the output_dir specified during
Expand Down Expand Up @@ -203,7 +203,7 @@ Now we modify ``custom_generation_config.yaml`` to point to our checkpoint and t
_component_: torchtune.models.llama3.llama3_8b
checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
_component_: torchtune.training.FullModelMetaCheckpointer
# directory with the checkpoint files
# this should match the output_dir specified during
Expand Down
10 changes: 5 additions & 5 deletions docs/source/tutorials/qat_finetune.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ We can easily apply the above QAT transformations to Llama3 in torchtune for fin

.. code-block:: python
from torchtune.utils.quantization import Int8DynActInt4WeightQATQuantizer
from torchtune.training.quantization import Int8DynActInt4WeightQATQuantizer
from torchtune.models.llama3 import llama3_8b
model = llama3_8b()
Expand Down Expand Up @@ -223,7 +223,7 @@ copy and make the following modifications to the quantization config:
_component_: torchtune.models.llama3.llama3_8b
checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
_component_: torchtune.training.FullModelMetaCheckpointer
checkpoint_dir: <your QAT checkpoint dir>
checkpoint_files: [meta_model_0.pt]
recipe_checkpoint: null
Expand All @@ -233,7 +233,7 @@ copy and make the following modifications to the quantization config:
...
quantizer:
_component_: torchtune.utils.quantization.Int8DynActInt4WeightQATQuantizer
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256
The following command performs the convert step in the QAT flow, which actually
Expand Down Expand Up @@ -269,7 +269,7 @@ integrated in torchtune. First, copy the evaluation config and make the followin
_component_: torchtune.models.llama3.llama3_8b
checkpointer:
_component_: torchtune.utils.FullModelTorchTuneCheckpointer
_component_: torchtune.training.FullModelTorchTuneCheckpointer
checkpoint_dir: <your quantized model checkpoint dir>
checkpoint_files: [meta_model_0-8da4w.pt]
recipe_checkpoint: null
Expand All @@ -285,7 +285,7 @@ integrated in torchtune. First, copy the evaluation config and make the followin
batch_size: 8
quantizer:
_component_: torchtune.utils.quantization.Int8DynActInt4WeightQuantizer
_component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer
groupsize: 256
.. note::
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/code_llama2/7B_full_low_memory.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ tokenizer:

# Checkpointer
checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/CodeLlama-7b-hf
checkpoint_files: [
pytorch_model-00001-of-00003.bin,
Expand Down
4 changes: 2 additions & 2 deletions recipes/configs/code_llama2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ tokenizer:

# Checkpointer
checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/CodeLlama-7b-hf
checkpoint_files: [
pytorch_model-00001-of-00003.bin,
Expand Down Expand Up @@ -86,7 +86,7 @@ log_peak_memory_stats: False
# Showcase the usage of PyTorch profiler
# Set enabled to False as it's only needed for debugging training
profiler:
_component_: torchtune.utils.setup_torch_profiler
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
Expand Down
4 changes: 2 additions & 2 deletions recipes/configs/code_llama2/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ tokenizer:

# Checkpointer
checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/CodeLlama-7b-hf
checkpoint_files: [
pytorch_model-00001-of-00003.bin,
Expand Down Expand Up @@ -86,7 +86,7 @@ log_peak_memory_stats: False
# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
profiler:
_component_: torchtune.utils.setup_torch_profiler
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/dev/8B_full_experimental.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ model:
_component_: torchtune.models.llama3.llama3_8b

checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
_component_: torchtune.training.FullModelMetaCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3-8B/original/
checkpoint_files: [
consolidated.00.pth
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/dev/llama2/13B_lora_fsdp2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ model:
lora_alpha: 16

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-2-13b-hf/
checkpoint_files: [
pytorch_model-00001-of-00003.bin,
Expand Down
Loading

0 comments on commit 2d46255

Please sign in to comment.