diff --git a/recipes/configs/alpaca_llama2_full_finetune.yaml b/recipes/configs/alpaca_llama2_full_finetune_distributed.yaml similarity index 78% rename from recipes/configs/alpaca_llama2_full_finetune.yaml rename to recipes/configs/alpaca_llama2_full_finetune_distributed.yaml index d3396b1eeb..561084d21d 100644 --- a/recipes/configs/alpaca_llama2_full_finetune.yaml +++ b/recipes/configs/alpaca_llama2_full_finetune_distributed.yaml @@ -1,7 +1,7 @@ -# Config for FullFinetuneRecipe in full_finetune.py +# Config for FullFinetuneRecipe in full_finetune_distributed.py # # To launch, run the following command from root: -# tune --nnodes 1 --nproc_per_node 1 full_finetune --config alpaca_llama2_full_finetune model_checkpoint= ... +# tune --nnodes 1 --nproc_per_node 1 full_finetune_distributed --config alpaca_llama2_full_finetune_distributed model_checkpoint= ... # Tokenizer tokenizer: @@ -38,19 +38,24 @@ loss: _component_: torch.nn.CrossEntropyLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 -log_every_n_steps: null -run_generation: null -# Distributed +# Training env device: cuda -dtype: fp32 + +# Distributed enable_fsdp: True -enable_activation_checkpointing: True cpu_offload: False +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + # Logging metric_logger: _component_: torchtune.utils.metric_logging.DiskLogger log_dir: ${output_dir} output_dir: /tmp/alpaca-llama2-finetune +log_every_n_steps: null diff --git a/recipes/configs/alpaca_llama2_full_finetune_single_device.yaml b/recipes/configs/alpaca_llama2_full_finetune_single_device.yaml new file mode 100644 index 0000000000..432b7c650c --- /dev/null +++ b/recipes/configs/alpaca_llama2_full_finetune_single_device.yaml @@ -0,0 +1,57 @@ +# Config for FullFinetuneRecipe in full_finetune_single_device.py +# +# To launch, run the following command from root: +# tune --nnodes 1 --nproc_per_node 1 full_finetune_single_device --config alpaca_llama2_full_finetune_single_device model_checkpoint= ... + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/llama2/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.AlpacaDataset + train_on_input: True +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.llama2.llama2_7b + +checkpointer: + _component_: torchtune.utils.FullModelMetaCheckpointer + checkpoint_dir: /tmp/llama2 + checkpoint_files: [consolidated.00.pth] + recipe_checkpoint: null + output_dir: /tmp/llama2 + model_type: LLAMA2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 +optimizer: + _component_: torch.optim.SGD + lr: 2e-5 +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + + +# Training environment +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-llama2-finetune +log_every_n_steps: null diff --git a/recipes/full_finetune.py b/recipes/full_finetune_distributed.py similarity index 86% rename from recipes/full_finetune.py rename to recipes/full_finetune_distributed.py index 517a1195c6..cf2ca1efe5 100644 --- a/recipes/full_finetune.py +++ b/recipes/full_finetune_distributed.py @@ -14,7 +14,6 @@ from omegaconf import DictConfig from torch import nn -from torch.cuda.amp import GradScaler from torch.distributed import init_process_group from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.optim import Optimizer @@ -37,7 +36,7 @@ class FullFinetuneRecipe(FTRecipeInterface): This recipe supports: - FSDP and activation checkpointing. This is enabled by default but can be configured using the ``enable_fsdp`` and ``enable_activation_checkpointing`` flags. - - Mixed precision training - fp32, fp16 and bf16 are supported. + - Full bf16 training via setting the ``dtype`` flag to bf16. - Checkpointing of model weights, optimizer state and the recipe state (epoch and seed). - Resuming from checkpoints saved using the ``save_checkpoint`` functionality. - Logging to terminal. WandB and TensorBoard are currently not supported. @@ -51,21 +50,31 @@ class FullFinetuneRecipe(FTRecipeInterface): The following configs can be used to run this recipe: >>> tune ls - RECIPE CONFIG - full_finetune alpaca_llama2_full_finetune + RECIPE CONFIG + full_finetune_distributed alpaca_llama2_full_finetune_distributed Args: cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. """ def __init__(self, cfg: DictConfig) -> None: self._device = utils.get_device(device=cfg.device) self._dtype = utils.get_dtype(dtype=cfg.dtype) - + self._training_precision = utils.get_dtype(dtype=cfg.dtype) + # Disable for fp16, as we haven't validated "full" fp16 with this recipe, nor + # enabled necessary features such as gradient scaling. + if self._training_precision == torch.float16: + raise ValueError( + "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." + ) # logging attributes self._output_dir = cfg.output_dir self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 1 + self._log_peak_memory_every_n_steps = 100 # _is_rank_zero is used primarily for logging. In the future, the logger # should directly take care of this @@ -153,9 +162,9 @@ def setup(self, cfg: DictConfig) -> None: # checkpoint. Transforming the opt state dict is handled by this method self._optimizer = self._setup_optimizer( cfg_optimizer=cfg.optimizer, - opt_state_dict=ckpt_dict[utils.OPT_KEY] - if self._resume_from_checkpoint - else None, + opt_state_dict=( + ckpt_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None + ), ) self._loss_fn = config.instantiate(cfg.loss) @@ -170,14 +179,6 @@ def setup(self, cfg: DictConfig) -> None: batch_size=cfg.batch_size, ) - # training setup - self._autocast = utils.get_autocast(self._dtype, self._device) - self._grad_scaler = None - if self._dtype == torch.float16: - self._grad_scaler = utils.get_gradient_scaler(fsdp=cfg.enable_fsdp) - else: - self._grad_scaler = GradScaler(enabled=False) - # Finally update the recipe state which can only be correctly set after all of the # other components have been initialized and updated. # @@ -207,7 +208,7 @@ def _setup_model( ``enable_fsdp`` should always be ``True``. This is currently a configurable flag for running tests on CPUs. """ - with self._device: + with utils.set_default_dtype(self._training_precision), self._device: model = config.instantiate(cfg_model) model = ( @@ -227,9 +228,13 @@ def _setup_model( ) model.load_state_dict(model_state_dict) - + # Validate model was loaded in with the expected dtype. + utils.validate_expected_param_dtype(model, dtype=self._training_precision) if self._is_rank_zero: - log.info("Model is initialized.") + log.info(f"Model is initialized with precision {self._training_precision}.") + utils.memory_stats_log( + "Memory Stats after model init:", device=self._device + ) return model def _setup_optimizer( @@ -339,7 +344,6 @@ def train(self) -> None: # zero out the gradients before starting training self._optimizer.zero_grad() - # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): @@ -360,16 +364,13 @@ def train(self) -> None: input_ids, labels = batch input_ids = input_ids.to(self._device) labels = labels.to(self._device) - - with self._autocast: - logits = self._model(input_ids) - # Shift so that tokens < n predict n - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - logits = logits.transpose(1, 2) - # Compute loss - loss = self._loss_fn(logits, labels) - + logits = self._model(input_ids) + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + logits = logits.transpose(1, 2) + # Compute loss + loss = self._loss_fn(logits, labels) # Note: We're always logging the loss before normalizing it # Check if this is the norm or not if self.total_training_steps % self._log_every_n_steps == 0: @@ -383,22 +384,24 @@ def train(self) -> None: step=self.total_training_steps, ) - # Does loss normalization need to happen within autocast context? loss = loss / self._gradient_accumulation_steps - self._grad_scaler.scale(loss).backward() + loss.backward() if self._should_update_weights(idx): - self._grad_scaler.step(self._optimizer) - self._grad_scaler.update() + self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) - # Update the number of steps when the weights are updated self.total_training_steps += 1 + # Log peak memory for iteration + if self.total_training_steps % self._log_peak_memory_every_n_steps == 0: + utils.memory_stats_log("Memory Stats:", device=self._device) + self.epochs_run += 1 self.save_checkpoint(epoch=curr_epoch) def cleanup(self) -> None: self._metric_logger.close() + torch.distributed.destroy_process_group() @config.parse @@ -407,11 +410,16 @@ def recipe_main(cfg: DictConfig) -> None: Entry point for the recipe. Configurable parameters are read in the following order: - - Parameters specified in ``alpaca_llama2_full_finetune.yaml`` + - Parameters specified in ``alpaca_llama2_full_finetune_distributed.yaml`` - Overwritten by arguments from the command-line """ - if utils.is_distributed(): - init_process_group(backend="nccl") + if not utils.is_distributed(): + raise RuntimeError( + "Distributed finetune recipe should be run via a distributed launcher." + "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" + ) + + init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") recipe = FullFinetuneRecipe(cfg=cfg) recipe.setup(cfg=cfg) diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py new file mode 100644 index 0000000000..e2c7fdc5bb --- /dev/null +++ b/recipes/full_finetune_single_device.py @@ -0,0 +1,376 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +from functools import partial +from typing import Any, Dict, Optional, Tuple +from warnings import warn + +import torch +from omegaconf import DictConfig + +from torch import nn +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler + +from torchtune import config, modules, utils + +from torchtune.recipe_interfaces import FTRecipeInterface + +from tqdm import tqdm + + +log = utils.get_logger("DEBUG") + + +class FullFinetuneRecipe(FTRecipeInterface): + """ + Full finetuning recipe for dense transformer-based LLMs such as Llama2. + + This recipe supports: + - Activation checkpointing. This is enabled by default but can be + configured using the ``enable_activation_checkpointing`` flags. + - Full bf16 training via setting the ``dtype`` flag to bf16. + - Checkpointing of model weights, optimizer state and the recipe state (epoch and seed). + - Resuming from checkpoints saved using the ``save_checkpoint`` functionality. + - Logging to terminal, WandB, or TensorBoard. + + Assumptions: + - Training is launched with the Tune CLI (recommended) which uses TorchRun under the + hood. Setting up the env variables is handled by TorchRun. + - Training happens on CUDA (CPU training is not supported) + - Checkpoints are ONLY saved at epoch boundaries. Mid-epoch checkpointing is NOT supported. + - Datasets are Map-style and data fits in memory (not streamed). + + The following configs can be used to run this recipe: + >>> tune ls + RECIPE CONFIG + full_finetune_single_device alpaca_llama2_full_finetune_single_device + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. + """ + + def __init__(self, cfg: DictConfig) -> None: + + self._device = utils.get_device(device=cfg.device) + self._training_precision = utils.get_dtype(dtype=cfg.dtype) + # Disable for fp16, as we haven't validated "full" fp16 with this recipe, nor + # enabled necessary features such as gradient scaling. + if self._training_precision == torch.float16: + raise ValueError( + "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." + ) + + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 1 + self._log_peak_memory_every_n_steps = 100 + # Training cfg + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + + # These are public properties which are updated by the checkpoint loader + # when ``resume_from_checkpoint`` is `True` or validated in tests + self.seed = utils.set_seed(seed=cfg.seed) + self.epochs_run = 0 + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch + self.total_training_steps = 0 + + def load_checkpoint(self, cfg: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. If resume_from_checkpoint + is True, this also includes the recipe state. + """ + self._checkpointer = config.instantiate( + cfg, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + if self._resume_from_checkpoint: + self._update_recipe_state(checkpoint_dict) + return checkpoint_dict + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + # If seed, total_epoch or max_steps_per_epoch don't match, + # warn the user and overwrite + try: + if ( + self.seed != ckpt_dict[utils.SEED_KEY] + or self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY] + or self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY] + ): + warn( + message="""Configured value for seed, epochs or max_steps_per_epoch + does not match the value stored in checkpoint.""" + ) + self.seed = utils.set_seed(seed=ckpt_dict[utils.SEED_KEY]) + self.epochs_run = ckpt_dict[utils.EPOCHS_KEY] + self.total_epochs = ckpt_dict[utils.TOTAL_EPOCHS_KEY] + self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY] + except KeyError as e: + raise KeyError from e( + "Checkpoint does not contain the required keys needed for updating recipe state." + "Are you sure you passed in the right recipe checkpoint?" + ) + + def setup(self, cfg: DictConfig) -> None: + """ + Sets up the recipe state correctly. This includes setting recipe attributes based + on the ``resume_from_checkpoint`` flag. + """ + self._metric_logger = config.instantiate(cfg.metric_logger) + + ckpt_dict = self.load_checkpoint(cfg.checkpointer) + + # ``_setup_model`` handles initialization and loading the state dict. This method + # should be called before ``_setup_optimizer`` since transforming the optimizer + # state dict requires the model + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=cfg.enable_activation_checkpointing, + model_state_dict=ckpt_dict[utils.MODEL_KEY], + ) + self._tokenizer = config.instantiate(cfg.tokenizer) + log.info("Tokenizer is initialized from file.") + + # _setup_optimizer should take in ckpt_dict only if training is resumed from + # checkpoint. Transforming the opt state dict is handled by this method + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + opt_state_dict=( + ckpt_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None + ), + ) + + self._loss_fn = config.instantiate(cfg.loss) + log.info("Loss is initialized.") + + # sampler and dataloader depend on the tokenizer and loss_fn and should be + # setup after both of these are initialized + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + # + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader, the max_steps_per_epoch param set by the user and the + # gradient_accumulation_steps param. This value is used for logging and tracking + # training state. The computation should happen after the dataloader has been setup + self._steps_per_epoch = ( + len(self._dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): + self._steps_per_epoch = self.max_steps_per_epoch + self.total_training_steps = self.epochs_run * self._steps_per_epoch + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + model_state_dict: Dict[str, Any], + ) -> nn.Module: + """ + Set up the model including enabling activation checkpointing. + """ + with utils.set_default_dtype(self._training_precision), self._device: + model = config.instantiate(cfg_model) + + if enable_activation_checkpointing: + utils.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerDecoderLayer} + ) + + model.load_state_dict(model_state_dict) + + # Validate model was loaded in with the expected dtype. + utils.validate_expected_param_dtype(model, dtype=self._training_precision) + log.info(f"Model is initialized with precision {self._training_precision}.") + utils.memory_stats_log("Memory Stats after model init:", device=self._device) + return model + + def _setup_optimizer( + self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None + ) -> Optimizer: + """ + Set up the optimizer. This method also handles loading the optimizer state_dict, if specified. + """ + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + + if opt_state_dict: + optimizer.load_state_dict(opt_state_dict) + + log.info("Optimizer is initialized.") + return optimizer + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. Currently this recipe only supports the + DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, + iterable datasets and streaming datasets are not supported. + """ + ds = config.instantiate( + cfg_dataset, + tokenizer=self._tokenizer, + ) + sampler = DistributedSampler( + ds, + num_replicas=1, + rank=0, + shuffle=shuffle, + seed=0, + ) + dataloader = DataLoader( + dataset=ds, + batch_size=batch_size, + sampler=sampler, + collate_fn=partial( + utils.padded_collate, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, # TODO support loss without ignore_index + ), + ) + + log.info("Dataset and Sampler are initialized.") + + return sampler, dataloader + + def save_checkpoint(self, epoch: int) -> None: + """ + Save state dict to file. The recipe save_checkpoint method is responsible for + correctly creating the checkpoint dict and passing to the checkpointer. + """ + ckpt_dict = {utils.MODEL_KEY: self._model.state_dict()} + # if training is in-progress, checkpoint the optimizer state as well + if epoch + 1 < self.total_epochs: + ckpt_dict.update( + { + utils.OPT_KEY: self._optimizer.state_dict(), + utils.SEED_KEY: self.seed, + utils.EPOCHS_KEY: self.epochs_run, + utils.TOTAL_EPOCHS_KEY: self.total_epochs, + utils.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + self._checkpointer.save_checkpoint( + ckpt_dict, + epoch=epoch, + intermediate_checkpoint=(epoch + 1 < self.total_epochs), + ) + + def _should_update_weights(self, current_iteration: int) -> bool: + """ + Determines whether the weights should be updated on the current iteration or not. + True is returned either if we've accumulated gradients for enough steps or if this + is the last step in the epoch. + """ + should_update_weights = ( + current_iteration + 1 + ) % self._gradient_accumulation_steps == 0 + return should_update_weights + + def train(self) -> None: + """ + The core training loop. Supports training on subsets of the dataset using the + ``max_steps_per_epoch``. + """ + # zero out the gradients before starting training + self._optimizer.zero_grad() + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + for idx, batch in enumerate(pbar := tqdm(self._dataloader)): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + input_ids, labels = batch + input_ids = input_ids.to(self._device) + labels = labels.to(self._device) + logits = self._model(input_ids) + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + logits = logits.transpose(1, 2) + # Compute loss + loss = self._loss_fn(logits, labels) + # Note: We're always logging the loss before normalizing it + # Check if this is the norm or not + pbar.set_description(f"{curr_epoch+1}|{idx+1}|Loss: {loss.item()}") + + if self.total_training_steps % self._log_every_n_steps == 0: + self._metric_logger.log_dict( + { + "loss": loss.item(), + "lr": self._optimizer.param_groups[0]["lr"], + "gpu_resources": torch.cuda.memory_allocated(), + }, + step=self.total_training_steps, + ) + + loss = loss / self._gradient_accumulation_steps + loss.backward() + if self._should_update_weights(idx): + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + + # Update the number of steps when the weights are updated + self.total_training_steps += 1 + + # Log peak memory for iteration + if self.total_training_steps % self._log_peak_memory_every_n_steps == 0: + utils.memory_stats_log("Memory Stats:", device=self._device) + self.epochs_run += 1 + self.save_checkpoint(epoch=curr_epoch) + + def cleanup(self) -> None: + self._metric_logger.close() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in ``alpaca_llama2_full_finetune_single_device.yaml`` + - Overwritten by arguments from the command-line + """ + recipe = FullFinetuneRecipe(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 9a07e8dcd3..3a81b2a659 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -317,7 +317,11 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerDecoderLayer} ) if self._is_rank_zero: - log.info(utils.memory_stats_log("Memory Stats after model init:")) + log.info( + utils.memory_stats_log( + "Memory Stats after model init:", device=self._device + ) + ) return model def _setup_optimizer( @@ -508,7 +512,9 @@ def train(self) -> None: self._optimizer.step() self._lr_scheduler.step() if self.total_training_steps % 100 == 0 and self._is_rank_zero: - log.info(utils.memory_stats_log("Memory Stats:")) + log.info( + utils.memory_stats_log("Memory Stats:", device=self._device) + ) self.epochs_run += 1 self.save_checkpoint(epoch=curr_epoch) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index f312c82657..f540ea2a00 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -38,7 +38,6 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface): This recipe supports: - Activation checkpointing. This is enabled by default but is configurable. - - Mixed precision training via `torch.autocast` - fp32, fp16 and bf16 are supported. - Full bf16 training for supported HW architectures. We currently check bf16 support via the `torch.cuda.is_bf16_supported` API. This is disabled by default but can be enabled via the "full_bf16" configuration flag. diff --git a/tests/recipes/test_full_finetune.py b/tests/recipes/test_full_finetune.py index 30a20aa878..b019280e35 100644 --- a/tests/recipes/test_full_finetune.py +++ b/tests/recipes/test_full_finetune.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import contextlib import json import logging import os @@ -29,7 +30,7 @@ llama2_tiny_test_ckpt, validate_loss_values, ) -from tests.test_utils import get_assets_path +from tests.test_utils import get_assets_path, single_box_init from torchtune import models @@ -94,7 +95,8 @@ def fetch_checkpointer(self, ckpt): if ckpt == "small_test_ckpt_meta": return "FullModelMetaCheckpointer" - def test_loss(self, capsys, pytestconfig, tmpdir, monkeypatch): + @pytest.mark.parametrize("single_device", [True, False]) + def test_loss(self, single_device, capsys, pytestconfig, tmpdir, monkeypatch): large_scale = pytestconfig.getoption("--large-scale") ckpts = ( ["llama2.llama2_7b"] @@ -122,8 +124,12 @@ def test_loss(self, capsys, pytestconfig, tmpdir, monkeypatch): with config_file.open("w") as f: json.dump(config, f) + if single_device: + recipe_cmd = "full_finetune_single_device" + else: + recipe_cmd = "full_finetune_distributed" cmd = f""" - tune full_finetune + tune {recipe_cmd} --config {_CONFIG_PATH} \ output_dir={tmpdir} \ model=torchtune.models.{ckpt} \ @@ -131,12 +137,18 @@ def test_loss(self, capsys, pytestconfig, tmpdir, monkeypatch): checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}]\ checkpointer.output_dir={tmpdir} \ - checkpointer.model_type=LLAMA2 + checkpointer.model_type=LLAMA2 \ + log_every_n_steps=1 """.split() monkeypatch.setattr(sys, "argv", cmd) with pytest.raises(SystemExit): - runpy.run_path(TUNE_PATH, run_name="__main__") + with ( + single_box_init(init_pg=False) + if not single_device + else contextlib.nullcontext() + ): + runpy.run_path(TUNE_PATH, run_name="__main__") loss_values = fetch_loss_values(capsys.readouterr().err) validate_loss_values(loss_values, expected_loss_values) @@ -171,7 +183,7 @@ def test_training_state_on_resume(self, capsys, tmpdir, monkeypatch): # Train cmd_1 = f""" - tune full_finetune + tune full_finetune_single_device --config {_CONFIG_PATH} \ output_dir={tmpdir} \ model=torchtune.models.{model_ckpt} \ @@ -181,6 +193,7 @@ def test_training_state_on_resume(self, capsys, tmpdir, monkeypatch): checkpointer.output_dir={tmpdir} \ checkpointer.model_type=LLAMA2 \ epochs=4 \ + log_every_n_steps=1 \ """.split() monkeypatch.setattr(sys, "argv", cmd_1) @@ -196,7 +209,7 @@ def test_training_state_on_resume(self, capsys, tmpdir, monkeypatch): # Resume training cmd_2 = f""" - tune full_finetune + tune full_finetune_single_device --config {_CONFIG_PATH} \ output_dir={tmpdir} \ model=torchtune.models.{model_ckpt} \ @@ -210,6 +223,7 @@ def test_training_state_on_resume(self, capsys, tmpdir, monkeypatch): resume_from_checkpoint=True \ max_steps_per_epoch=None \ seed=0 \ + log_every_n_steps=1 \ """.split() monkeypatch.setattr(sys, "argv", cmd_2) @@ -240,7 +254,7 @@ def test_gradient_accumulation( ckpt_dir = ckpt_path.parent cmd = f""" - tune full_finetune \ + tune full_finetune_single_device \ --config {_CONFIG_PATH} \ model=torchtune.models.{model_ckpt} \ checkpointer._component_=torchtune.utils.FullModelTorchTuneCheckpointer \ @@ -253,6 +267,7 @@ def test_gradient_accumulation( epochs=1 \ max_steps_per_epoch=1 \ output_dir={tmpdir} \ + log_every_n_steps=1 \ """.split() monkeypatch.setattr(sys, "argv", cmd) @@ -268,7 +283,7 @@ def test_gradient_accumulation( ) # Update the cmd with new values for gradient accumulation cmd_2 = f""" - tune full_finetune \ + tune full_finetune_single_device \ --config {_CONFIG_PATH} \ model=torchtune.models.{model_ckpt} \ checkpointer._component_=torchtune.utils.FullModelTorchTuneCheckpointer \ @@ -282,6 +297,7 @@ def test_gradient_accumulation( epochs=1 \ max_steps_per_epoch=1 \ output_dir={tmpdir} \ + log_every_n_steps=1 \ """.split() monkeypatch.setattr(sys, "argv", cmd_2) diff --git a/tests/torchtune/_cli/test_cp.py b/tests/torchtune/_cli/test_cp.py index a18d4615dc..7f396ad118 100644 --- a/tests/torchtune/_cli/test_cp.py +++ b/tests/torchtune/_cli/test_cp.py @@ -23,7 +23,7 @@ def test_copy_successful(self, capsys, monkeypatch, tmpdir, already_exists): if already_exists: dest.touch() - args = f"tune cp alpaca_llama2_full_finetune.yaml {dest}".split() + args = f"tune cp alpaca_llama2_full_finetune_single_device.yaml {dest}".split() monkeypatch.setattr(sys, "argv", args) runpy.run_path(TUNE_PATH, run_name="__main__") @@ -41,7 +41,7 @@ def test_copy_skips_when_dest_already_exists_and_no_clobber_is_true( existing_file = tmpdir_path / "existing_file.yaml" existing_file.touch() - args = f"tune cp alpaca_llama2_full_finetune.yaml {existing_file} -n".split() + args = f"tune cp alpaca_llama2_full_finetune_single_device.yaml {existing_file} -n".split() monkeypatch.setattr(sys, "argv", args) runpy.run_path(TUNE_PATH, run_name="__main__") @@ -67,8 +67,8 @@ def test_copy_skips_when_dest_already_exists_and_no_clobber_is_true( "error: Invalid file name: non_existent_config.yaml. Try `tune ls` to see all available files to copy.", ), ( - "tune cp full_finetune.py /home/mr_bean/full_finetune.py", - "error: Cannot create regular file: '/home/mr_bean/full_finetune.py'. No such file or directory.", + "tune cp full_finetune_single_device.py /home/mr_bean/full_finetune_single_device.py", + "error: Cannot create regular file: '/home/mr_bean/full_finetune_single_device.py'. No such file or directory.", ), ( "tune cp", diff --git a/torchtune/__init__.py b/torchtune/__init__.py index 83a7a2ecaf..8ad3809678 100644 --- a/torchtune/__init__.py +++ b/torchtune/__init__.py @@ -7,13 +7,17 @@ from torchtune import datasets, models, modules, utils _RECIPE_LIST = [ - "full_finetune.py", + "full_finetune_single_device.py", + "full_finetune_distributed.py", "alpaca_generate.py", "lora_finetune_single_device.py", "lora_finetune_distributed.py", ] _CONFIG_LISTS = { - "full_finetune.py": ["alpaca_llama2_full_finetune.yaml"], + "full_finetune_single_device.py": [ + "alpaca_llama2_full_finetune_single_device.yaml" + ], + "full_finetune_distributed.py": ["alpaca_llama2_full_finetune_distributed.yaml"], "lora_finetune_single_device.py": [ "alpaca_llama2_lora_finetune_single_device.yaml" ], diff --git a/torchtune/utils/__init__.py b/torchtune/utils/__init__.py index 63469fc89e..ff4bec8c00 100644 --- a/torchtune/utils/__init__.py +++ b/torchtune/utils/__init__.py @@ -53,6 +53,7 @@ "transform_opt_state_dict", "validate_checkpoint", "get_autocast", + "memory_stats_log", "get_device", "get_dtype", "wrap_fsdp", diff --git a/torchtune/utils/memory.py b/torchtune/utils/memory.py index f113034510..283c6b9ef5 100644 --- a/torchtune/utils/memory.py +++ b/torchtune/utils/memory.py @@ -29,10 +29,31 @@ def set_activation_checkpointing( apply_activation_checkpointing(model, auto_wrap_policy=wrap_policy, **kwargs) -def memory_stats_log(msg: str) -> str: - return f""" - Memory Stats {msg}: - Memory Allocated: {torch.cuda.memory_allocated() / 1000**3:.2f} GB - Memory Reserved: {torch.cuda.memory_reserved() / 1000**3:.2f} GB - Peak Memory: {torch.cuda.max_memory_allocated() / 1000**3:.2f} GB +def memory_stats_log( + prefix: str, device: torch.device, reset_stats: bool = True +) -> None: + """ + Print a memory summary for the passed in device. If ``reset_stats`` is ``True``, this will + also reset CUDA's peak memory tracking. This is useful to get data around relative use of peak + memory (i.e. peak memory during model init, during forward, etc) and optimize memory for + individual sections of training. + + Args: + prefix (str): Prefix to prepend to the printed summary. + device (torch.device): Device to get memory summary for. Only CUDA devices are supported. + reset_stats (bool): Whether to reset CUDA's peak memory tracking. + + Returns: + None """ + if device.type != "cuda": + return + peak_memory_active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) + print( + f"{prefix}, GPU peak memory allocation: {torch.cuda.max_memory_allocated(device) / 1e9}GB, " + f"GPU peak memory reserved: {torch.cuda.max_memory_reserved(device) / 1e9}GB, " + f"GPU peak memory active: {peak_memory_active / 1e9}GB", + flush=True, + ) + if reset_stats: + torch.cuda.reset_peak_memory_stats(device)