Skip to content

Commit

Permalink
[Feat] Activation offloading for distributed lora recipe (pytorch#1645)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackmin801 authored and mori360 committed Oct 14, 2024
1 parent 207b1b1 commit c48da2a
Showing 1 changed file with 49 additions and 3 deletions.
52 changes: 49 additions & 3 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import sys
import time

Expand Down Expand Up @@ -34,7 +35,12 @@
validate_missing_and_unexpected_for_lora,
)
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY
from torchtune.training import (
DummyProfiler,
NoOpManager,
OffloadActivations,
PROFILER_KEY,
)

from tqdm import tqdm

Expand All @@ -53,13 +59,25 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface):
``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
DDP is currently not supported. Training on CPU is not supported.
- Activation Checkpointing. This can be controlled using the ``activation_checkpointing``
- Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing``
flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
activations in memory and instead recompute them during the backward pass. This is especially
helpful for larger batch sizes when you're memory constrained. But these savings in memory
come at the cost of training performance. In most cases training can slow-down quite a bit as
a result of this activation recomputation.
- Activation Offloading. This can be controlled using the ``enable_activation_offloading``
flag. Activation offloading is a technique similar to activations checkpointing that helps
reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations
checkpointing drops the activation in the forward to recompute it later in the backward,
activations offloading will drop the activation in the forward to the CPU and bring it
back during the backward pass. As always, there is a tradeoff--these savings in memory can
come at the cost of training performance and CPU resources. To recover some runtime cost,
we've added an option to enable offloading on a different stream to permit overlapping with
the computation. This option is currently only available on PyTorch nightly 2.5.0.dev20240907
or later and will be enabled by default if an acceptable torch version is found. Activation
offloading can be used in conjunction with activation checkpointing.
- Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
most cases this should halve the memory footprint of full precision (fp32) training, without
Expand Down Expand Up @@ -110,6 +128,7 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface):
ValueError: If world_size is 1
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
RuntimeError: If ``left_pad_sequence`` is set as the data collator.
RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA.
"""

def __init__(self, cfg: DictConfig) -> None:
Expand All @@ -134,6 +153,13 @@ def __init__(self, cfg: DictConfig) -> None:

# training attributes
self._enable_activation_checkpointing = cfg.enable_activation_checkpointing
self._enable_activation_offloading = cfg.get(
"enable_activation_offloading", False
)
if self._enable_activation_offloading and self._device.type != "cuda":
raise RuntimeError(
"enable_activation_offloading should only be enabled for training on CUDA"
)

# These attributes constitute the recipe state and are updated by ``load_checkpoint``
# when ``resume_from_checkpoint`` is ``True``
Expand Down Expand Up @@ -230,6 +256,7 @@ def setup(self, cfg: DictConfig) -> None:
self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
enable_activation_offloading=self._enable_activation_offloading,
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
Expand Down Expand Up @@ -377,6 +404,7 @@ def _setup_model(
self,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
enable_activation_offloading: bool,
fsdp_cpu_offload: bool,
reshard_after_forward: bool,
base_model_state_dict: Dict[str, Any],
Expand Down Expand Up @@ -496,6 +524,23 @@ def _is_layer_name(name: str, module: nn.Module) -> bool:
# Ensure no params and buffers are on meta device
training.validate_no_params_on_meta_device(model)

self.activations_handling_ctx = contextlib.nullcontext()
if enable_activation_offloading:
self.activations_handling_ctx = OffloadActivations()

# Below is our hack to disable offloading the last output Linear in every
# step, as the cost for offloading the activation and then soon after bringing
# it back is expensive. Moreover, due to heuristics in our streaming API,
# we actually use more memory if we offload it as it interferes with chunkedCE.
if hasattr(model, "output") and isinstance(model.output, nn.Module):
noop_ctx = NoOpManager()
model.output.register_forward_pre_hook(
lambda *args: noop_ctx.__enter__()
)
model.output.register_forward_hook(
lambda *args: noop_ctx.__exit__(), always_call=True
)

if self._is_rank_zero:
log.info(
f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs"
Expand Down Expand Up @@ -733,7 +778,8 @@ def train(self) -> None:
# Shape [b, s], needed for the loss not the model
labels = batch.pop("labels")

logits = self._model(**batch)
with self.activations_handling_ctx:
logits = self._model(**batch)

# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
Expand Down

0 comments on commit c48da2a

Please sign in to comment.