Skip to content

Commit

Permalink
Full finetune < 16GB (pytorch#527)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohan-varma authored and Thomas Capelle committed Apr 5, 2024
1 parent 97b994d commit 08fae5e
Show file tree
Hide file tree
Showing 8 changed files with 374 additions and 21 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,15 @@ experience different peak memory utilization based on changes made in configurat
| 1 x RTX 4090 | QLoRA | [qlora_finetune_single_device](https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_qlora_single_device.yaml) | Llama-7B | 9.29 GB * |
| 2 x RTX 4090 | LoRA | [lora_finetune_distributed](https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_lora.yaml) | Llama-7B | 14.17 GB * |
| 1 x RTX 4090 | LoRA | [lora_finetune_single_device](https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_lora_single_device.yaml) | Llama-7B | 17.18 GB * |
| 1 x A6000 | Full finetune | [full_finetune_single_device](https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_full_single_device.yaml) | Llama-7B | 27.15 GB * |
| 1 x A6000 | Full finetune | [full_finetune_single_device](https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_full_single_device_low_memory.yaml) | Llama-7B | 15.97 GB * ^ |
| 4 x RTX 4090 | Full finetune | [full_finetune_distributed](https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_full.yaml) | Llama-7B | 12.01 GB * |


NOTE: * indicates an estimated metric based on experiments conducted on A100 GPUs with GPU memory artificially limited using [torch.cuda.set_per_process_memory_fraction API](https://pytorch.org/docs/stable/generated/torch.cuda.set_per_process_memory_fraction.html). Peak memory per GPU is as reported by `torch.cuda.max_memory_reserved()`. Please file an issue if you are not able to reproduce these results when running TorchTune on certain hardware.

NOTE: ^ indicates the required use of third-party dependencies that are not installed with torchtune by default. In particular, for the most memory efficient full finetuning [configuration](https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_full_single_device_low_memory.yaml), [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) is required and can be installed via `pip install bitsandbytes`, after which the configuration
can be run successfully.

&nbsp;

---
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ loss:
_component_: torch.nn.CrossEntropyLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1
optimizer_in_bwd: False


# Training environment
Expand Down
76 changes: 76 additions & 0 deletions recipes/configs/llama2/7B_full_single_device_low_memory.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Config for single device full finetuning in full_finetune_single_device.py
# using a Llama2 7B model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download --repo-id meta-llama/Llama-2-7b \
# --hf-token <HF_TOKEN> \
# --output-dir /tmp/llama2
#
# To launch on a single device, run the following command from root:
# tune --nnodes 1 --nproc_per_node 1 full_finetune_single_device \
# --config llama2/7B_full_single_device_low_memory \
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune --nnodes 1 --nproc_per_node 1 full_finetune_single_device \
# --config llama2/7B_full_single_device_low_memory \
# checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.


# Tokenizer
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/llama2/tokenizer.model

# Dataset
dataset:
_component_: torchtune.datasets.alpaca_dataset
train_on_input: True
seed: null
shuffle: True

# Model Arguments
model:
_component_: torchtune.models.llama2.llama2_7b

checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
checkpoint_dir: /tmp/llama2
checkpoint_files: [consolidated.00.pth]
recipe_checkpoint: null
output_dir: /tmp/llama2
model_type: LLAMA2
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 2
epochs: 1
optimizer:
_component_: bitsandbytes.optim.PagedAdamW
lr: 2e-5
optimizer_in_bwd: True
loss:
_component_: torch.nn.CrossEntropyLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1


# Training environment
device: cuda

# Memory management
enable_activation_checkpointing: True

# Reduced precision
dtype: bf16

# Logging
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}
output_dir: /tmp/alpaca-llama2-finetune
log_every_n_steps: null
84 changes: 67 additions & 17 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class FullFinetuneRecipeSingleDevice(FTRecipeInterface):
hood. Setting up the env variables is handled by TorchRun.
- Training happens on CUDA (CPU training is not supported)
- Checkpoints are ONLY saved at epoch boundaries. Mid-epoch checkpointing is NOT supported.
- User can only use ONE of gradient accumulation or optimizer in backward. These features
currently do not work together.
- Datasets are Map-style and data fits in memory (not streamed).
The following configs can be used to run this recipe:
Expand All @@ -55,8 +57,9 @@ class FullFinetuneRecipeSingleDevice(FTRecipeInterface):
cfg (DictConfig): OmegaConf object parsed from yaml file
Raises:
ValueError: If ``dtype`` is set to fp16.
RuntimeError: If ``dtype`` is set to fp16.
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
RuntimeError: If ``gradient_accumulation_steps > 1`` and ``optimizer_in_bwd`` is `True`.
"""

def __init__(self, cfg: DictConfig) -> None:
Expand All @@ -65,7 +68,7 @@ def __init__(self, cfg: DictConfig) -> None:
# Disable for fp16, as we haven't validated "full" fp16 with this recipe, nor
# enabled necessary features such as gradient scaling.
if self._dtype == torch.float16:
raise ValueError(
raise RuntimeError(
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)

Expand All @@ -84,7 +87,14 @@ def __init__(self, cfg: DictConfig) -> None:
# Training cfg
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps

self._optimizer_in_bwd = cfg.optimizer_in_bwd
# TODO: find a better place / way to perform validation of args that don't yet
# compose with each other.
if self._gradient_accumulation_steps > 1 and self._optimizer_in_bwd:
raise RuntimeError(
"Gradient accumulation is not supported with optimizer in bwd."
"Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False."
)
# These are public properties which are updated by the checkpoint loader
# when ``resume_from_checkpoint`` is `True` or validated in tests
self.seed = utils.set_seed(seed=cfg.seed)
Expand Down Expand Up @@ -158,6 +168,7 @@ def setup(self, cfg: DictConfig) -> None:
# checkpoint. Transforming the opt state dict is handled by this method
self._optimizer = self._setup_optimizer(
cfg_optimizer=cfg.optimizer,
optimizer_in_bwd=cfg.optimizer_in_bwd,
opt_state_dict=(
ckpt_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None
),
Expand Down Expand Up @@ -221,18 +232,46 @@ def _setup_model(
return model

def _setup_optimizer(
self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None
) -> Optimizer:
self,
cfg_optimizer: DictConfig,
optimizer_in_bwd: bool = False,
opt_state_dict: Optional[Dict[str, Any]] = None,
) -> Optional[Optimizer]:
"""
Set up the optimizer. This method also handles loading the optimizer state_dict, if specified.
"""
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())

if opt_state_dict:
optimizer.load_state_dict(opt_state_dict)

log.info("Optimizer is initialized.")
return optimizer
if optimizer_in_bwd:
# Maintain a dict of optims for every parameter.
optim_dict = {
p: config.instantiate(cfg_optimizer, [p])
for p in self._model.parameters()
}
# Register optimizer step hooks on the model to run optimizer in backward.
utils.register_optim_in_bwd_hooks(model=self._model, optim_dict=optim_dict)
# Create a wrapper for checkpoint save/load of optimizer states when running in backward.
self._optim_ckpt_wrapper = utils.create_optim_in_bwd_wrapper(
model=self._model, optim_dict=optim_dict
)
# Load optimizer states. If optimizer states are being restored in an optimizer in backward
# run, these need to have been saved with the same setting. Cannot restore from runs that did not
# use optimizer in backward.
if opt_state_dict is not None:
try:
self._optim_ckpt_wrapper.load_state_dict(opt_state_dict)
except BaseException as e:
raise RuntimeError(
"Failed loading in-backward optimizer checkpoints."
"Please make sure run being restored from was using in-backward optimizer."
) from e
log.info("In-backward optimizers are set up.")
return None
else:
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())

if opt_state_dict:
optimizer.load_state_dict(opt_state_dict)
log.info("Optimizer is initialized.")
return optimizer

def _setup_data(
self,
Expand Down Expand Up @@ -281,13 +320,16 @@ def save_checkpoint(self, epoch: int) -> None:
if epoch + 1 < self.total_epochs:
ckpt_dict.update(
{
utils.OPT_KEY: self._optimizer.state_dict(),
utils.SEED_KEY: self.seed,
utils.EPOCHS_KEY: self.epochs_run,
utils.TOTAL_EPOCHS_KEY: self.total_epochs,
utils.MAX_STEPS_KEY: self.max_steps_per_epoch,
}
)
if not self._optimizer_in_bwd:
ckpt_dict[utils.OPT_KEY] = self._optimizer.state_dict()
else:
ckpt_dict[utils.OPT_KEY] = self._optim_ckpt_wrapper.state_dict()
self._checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
Expand All @@ -311,8 +353,8 @@ def train(self) -> None:
``max_steps_per_epoch``.
"""
# zero out the gradients before starting training
self._optimizer.zero_grad()

if not self._optimizer_in_bwd:
self._optimizer.zero_grad()
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs
Expand Down Expand Up @@ -344,20 +386,28 @@ def train(self) -> None:
self._metric_logger.log_dict(
{
"loss": loss.item(),
"lr": self._optimizer.param_groups[0]["lr"],
# NOTE: for optim in backward, this assumes all optimizers have the same LR. This is currently
# true since we don't expose the ability to configure this yet.
"lr": (
self._optim_ckpt_wrapper.get_optim_key("lr")
if self._optimizer_in_bwd
else self._optimizer.param_groups[0]["lr"]
),
"gpu_resources": torch.cuda.memory_allocated(),
},
step=self.total_training_steps,
)

loss = loss / self._gradient_accumulation_steps
loss.backward()
if self._should_update_weights(idx):
if not self._optimizer_in_bwd and self._should_update_weights(idx):
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

# Update the number of steps when the weights are updated
self.total_training_steps += 1
elif self._optimizer_in_bwd:
self.total_training_steps += 1

# Log peak memory for iteration
if self.total_training_steps % self._log_peak_memory_every_n_steps == 0:
Expand Down
7 changes: 5 additions & 2 deletions tests/recipes/test_full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,18 @@ def _fetch_expected_loss_values(self):
return [10.5074, 10.5563, 10.5152, 10.4851]

@pytest.mark.integration_test
def test_loss(self, tmpdir, monkeypatch):
@pytest.mark.parametrize(
"config", ["full_single_device_low_memory", "full_single_device"]
)
def test_loss(self, config, tmpdir, monkeypatch):
ckpt = "small_test_ckpt_meta"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent
log_file = gen_log_file_name(tmpdir)

cmd = f"""
tune full_finetune_single_device
--config llama2/7B_full_single_device \
--config llama2/7B_{config} \
output_dir={tmpdir} \
checkpointer._component_=torchtune.utils.FullModelMetaCheckpointer
checkpointer.checkpoint_dir='{ckpt_dir}' \
Expand Down
91 changes: 91 additions & 0 deletions tests/torchtune/utils/test_optim_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch
from torchtune.utils import create_optim_in_bwd_wrapper, register_optim_in_bwd_hooks


def _run_dummy_step(model, wrapper):
with torch.no_grad():
for p in model.parameters():
p.grad = torch.rand_like(p)
for v in wrapper.optim_map.values():
v.step()
v.zero_grad()


def _validate_dicts(d1, d2):
if len(d1) != len(d2):
return False
for k, v in d1.items():
if k not in d2:
return False
if isinstance(v, dict):
return _validate_dicts(v, d2[k])
else:
if isinstance(v, torch.Tensor):
if not torch.allclose(v, d2[k]):
return False
elif v != d2[k]:
return False
return True


@pytest.fixture
def model():
return torch.nn.Linear(10, 1)


@pytest.fixture
def optim_dict(model):
return {p: torch.optim.AdamW([p], lr=0.01) for p in model.parameters()}


@pytest.fixture
def wrapper(model, optim_dict):
return create_optim_in_bwd_wrapper(model, optim_dict)


class TestOptimInBackward:
def test_state_dict_save_load(self, model, wrapper):
# Run a dummy step to create optimizer states
_run_dummy_step(model, wrapper)

sd = wrapper.state_dict()
new_optim_dict = create_optim_in_bwd_wrapper(
model, {p: torch.optim.AdamW([p], lr=0.01) for p in model.parameters()}
)
assert not _validate_dicts(sd, new_optim_dict.state_dict())
new_optim_dict.load_state_dict(sd)
assert _validate_dicts(sd, new_optim_dict.state_dict())

def test_missing_unexpected_param_load_raises(self, model, wrapper):
# Run a dummy step to create optimizer states
_run_dummy_step(model, wrapper)
sd = wrapper.state_dict()
new_optim_dict = create_optim_in_bwd_wrapper(
model, {p: torch.optim.AdamW([p], lr=0.01) for p in model.parameters()}
)
with pytest.raises(RuntimeError, match="Expected to load optimizer state"):
sd.pop(next(iter(sd.keys())))
new_optim_dict.load_state_dict(sd)

sd = wrapper.state_dict()
sd["new_key"] = 1234
with pytest.raises(RuntimeError, match="unexpected param"):
new_optim_dict.load_state_dict(sd)


class TestRegisterOptimHooks:
def test_register_optim_in_bwd_hooks(self, model, optim_dict):
register_optim_in_bwd_hooks(model, optim_dict)
# Ensure backward() updates the parameters and sets grads to None
orig_params = [p.clone().detach() for p in model.parameters()]
model(torch.rand(2, 10)).sum().backward()
for p, orig_p in zip(model.parameters(), orig_params):
assert not p.grad
assert not torch.allclose(p, orig_p)
Loading

0 comments on commit 08fae5e

Please sign in to comment.