From 7c540b9f306a6a99124effe7c04f6c7064ab5139 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Mon, 14 Oct 2024 12:27:55 -0700 Subject: [PATCH 01/19] first comit --- recipes/full_finetune_distributed.py | 60 +++++++++++++++++++++++-- recipes/full_finetune_single_device.py | 62 ++++++++++++++++++++++++-- recipes/lora_finetune_distributed.py | 28 +++++++----- recipes/lora_finetune_single_device.py | 25 +++++++---- 4 files changed, 146 insertions(+), 29 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 6e83e575f9..96f3b6f9c3 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import contextlib import sys import time @@ -24,7 +25,12 @@ from torchtune.data import padded_collate_packed from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.training import DummyProfiler, PROFILER_KEY +from torchtune.training import ( + DummyProfiler, + NoOpManager, + OffloadActivations, + PROFILER_KEY, +) from torchtune.training.activations import apply_selective_activation_checkpointing from tqdm import tqdm @@ -44,13 +50,25 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface): ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). DDP is currently not supported. Training on CPU is not supported. - - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` flag. Activation checkpointing helps reduce the memory footprint since we no longer keep activations in memory and instead recompute them during the backward pass. This is especially helpful for larger batch sizes when you're memory constrained. But these savings in memory come at the cost of training performance. In most cases training can slow-down quite a bit as a result of this activation recomputation. + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + we've added an option to enable offloading on a different stream to permit overlapping with + the computation. This option is currently only available on PyTorch nightly 2.5.0.dev20240907 + or later and will be enabled by default if an acceptable torch version is found. Activation + offloading can be used in conjunction with activation checkpointing. + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In most cases this should halve the memory footprint of full precision (fp32) training, without @@ -96,6 +114,7 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface): ValueError: If ``dtype`` is set to fp16. RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. RuntimeError: If ``left_pad_sequence`` is set as the data collator. + RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. """ def __init__(self, cfg: DictConfig) -> None: @@ -210,7 +229,10 @@ def setup(self, cfg: DictConfig) -> None: self._compile = cfg.get("compile", False) self._model = self._setup_model( cfg_model=cfg.model, - enable_activation_checkpointing=cfg.enable_activation_checkpointing, + enable_activation_checkpointing=cfg.get( + "enable_activation_checkpointing", False + ), + enable_activation_offloading=cfg.get("enable_activation_offloading", False), custom_sharded_layers=cfg.get("custom_sharded_layers", None), fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), @@ -347,6 +369,7 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, + enable_activation_offloading: bool, custom_sharded_layers: Optional[List[str]], fsdp_cpu_offload: bool, reshard_after_forward: bool, @@ -440,6 +463,34 @@ def _is_layer_fqn(s: str) -> bool: cpu_offload=fsdp_cpu_offload, ) + # activation offloading + if enable_activation_checkpointing and self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be enabled for training on CUDA" + ) + + if enable_activation_checkpointing and not enable_activation_offloading: + log.warning( + "enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further." + ) + self.activations_handling_ctx = contextlib.nullcontext() + if enable_activation_offloading: + self.activations_handling_ctx = OffloadActivations() + + # Below is our hack to disable offloading the last output Linear in every + # step, as the cost for offloading the activation and then soon after bringing + # it back is expensive. Moreover, due to heuristics in our streaming API, + # we actually use more memory if we offload it as it interferes with chunkedCE. + if hasattr(model, "output") and isinstance(model.output, nn.Module): + noop_ctx = NoOpManager() + model.output.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + model.output.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + # Ensure no params and buffers are on meta device training.validate_no_params_on_meta_device(model) @@ -632,7 +683,8 @@ def train(self) -> None: # Shape [b, s], needed for the loss not the model labels = batch.pop("labels") - logits = self._model(**batch) + with self.activations_handling_ctx: + logits = self._model(**batch) # Shift labels to compute loss # equivalent to doing labels[..., 1:] and logits[..., :-1, :] diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index 2addd92944..de9ae28818 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import contextlib import sys import time from functools import partial @@ -22,7 +23,12 @@ from torchtune.data import padded_collate_packed from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.training import DummyProfiler, PROFILER_KEY +from torchtune.training import ( + DummyProfiler, + NoOpManager, + OffloadActivations, + PROFILER_KEY, +) from tqdm import tqdm @@ -36,13 +42,25 @@ class FullFinetuneRecipeSingleDevice(FTRecipeInterface): for single GPU training. Training on CPU is not supported. Features: - - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` flag. Activation checkpointing helps reduce the memory footprint since we no longer keep activations in memory and instead recompute them during the backward pass. This is especially helpful for larger batch sizes when you're memory constrained. But these savings in memory come at the cost of training performance. In most cases training can slow-down quite a bit as a result of this activation recomputation. + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + we've added an option to enable offloading on a different stream to permit overlapping with + the computation. This option is currently only available on PyTorch nightly 2.5.0.dev20240907 + or later and will be enabled by default if an acceptable torch version is found. Activation + offloading can be used in conjunction with activation checkpointing. + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In most cases this should halve the memory footprint of full precision (fp32) training, without @@ -99,6 +117,7 @@ class FullFinetuneRecipeSingleDevice(FTRecipeInterface): RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. RuntimeError: If ``gradient_accumulation_steps > 1`` and ``optimizer_in_bwd`` is `True`. RuntimeError: If ``left_pad_sequence`` is set as the data collator. + RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. """ def __init__(self, cfg: DictConfig) -> None: @@ -208,10 +227,14 @@ def setup(self, cfg: DictConfig) -> None: # ``_setup_model`` handles initialization and loading the state dict. This method # should be called before ``_setup_optimizer`` since transforming the optimizer # state dict requires the model + self._compile = cfg.compile self._model = self._setup_model( cfg_model=cfg.model, - enable_activation_checkpointing=cfg.enable_activation_checkpointing, + enable_activation_checkpointing=cfg.get( + "enable_activation_checkpointing", False + ), + enable_activation_offloading=cfg.get("enable_activation_offloading", False), compile_model=self._compile, model_state_dict=ckpt_dict[training.MODEL_KEY], ) @@ -354,6 +377,7 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, + enable_activation_offloading: bool, compile_model: bool, model_state_dict: Dict[str, Any], ) -> nn.Module: @@ -377,6 +401,35 @@ def _setup_model( training.validate_expected_param_dtype( model.named_parameters(), dtype=self._dtype ) + + # activation checkpointing/offloading + if enable_activation_checkpointing and self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be enabled for training on CUDA" + ) + + if enable_activation_checkpointing and not enable_activation_offloading: + log.warning( + "enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further." + ) + self.activations_handling_ctx = contextlib.nullcontext() + if enable_activation_offloading: + self.activations_handling_ctx = OffloadActivations() + + # Below is our hack to disable offloading the last output Linear in every + # step, as the cost for offloading the activation and then soon after bringing + # it back is expensive. Moreover, due to heuristics in our streaming API, + # we actually use more memory if we offload it as it interferes with chunkedCE. + if hasattr(model, "output") and isinstance(model.output, nn.Module): + noop_ctx = NoOpManager() + model.output.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + model.output.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + log.info(f"Model is initialized with precision {self._dtype}.") if self._device.type == "cuda": @@ -560,7 +613,8 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: # Shape [b, s], needed for the loss not the model labels = batch.pop("labels") - logits = self._model(**batch) + with self.activations_handling_ctx: + logits = self._model(**batch) # Shift labels to compute loss # equivalent to doing labels[..., 1:] and logits[..., :-1, :] diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 1569dfee63..f1eac8c9a3 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -151,16 +151,6 @@ def __init__(self, cfg: DictConfig) -> None: self._log_every_n_steps = cfg.get("log_every_n_steps", 1) self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) - # training attributes - self._enable_activation_checkpointing = cfg.enable_activation_checkpointing - self._enable_activation_offloading = cfg.get( - "enable_activation_offloading", False - ) - if self._enable_activation_offloading and self._device.type != "cuda": - raise RuntimeError( - "enable_activation_offloading should only be enabled for training on CUDA" - ) - # These attributes constitute the recipe state and are updated by ``load_checkpoint`` # when ``resume_from_checkpoint`` is ``True`` self.seed = training.set_seed(seed=cfg.seed) @@ -255,8 +245,10 @@ def setup(self, cfg: DictConfig) -> None: self._model = self._setup_model( cfg_model=cfg.model, - enable_activation_checkpointing=cfg.enable_activation_checkpointing, - enable_activation_offloading=self._enable_activation_offloading, + enable_activation_checkpointing=cfg.get( + "enable_activation_checkpointing", False + ), + enable_activation_offloading=cfg.get("enable_activation_offloading", False), fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), base_model_state_dict=checkpoint_dict[training.MODEL_KEY], @@ -524,6 +516,18 @@ def _is_layer_name(name: str, module: nn.Module) -> bool: # Ensure no params and buffers are on meta device training.validate_no_params_on_meta_device(model) + # activation checkpointing/offloading + if enable_activation_checkpointing and self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be enabled for training on CUDA" + ) + + if enable_activation_checkpointing and not enable_activation_offloading: + log.warning( + "enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further." + ) + self.activations_handling_ctx = contextlib.nullcontext() if enable_activation_offloading: self.activations_handling_ctx = OffloadActivations() diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 5d39b72086..525a62d850 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -152,13 +152,6 @@ def __init__(self, cfg: DictConfig) -> None: self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) self._gradient_accumulation_steps = cfg.gradient_accumulation_steps self._clip_grad_norm = cfg.get("clip_grad_norm", None) - self._enable_activation_offloading = cfg.get( - "enable_activation_offloading", False - ) - if self._enable_activation_offloading and self._device.type != "cuda": - raise RuntimeError( - "enable_activation_offloading should only be enabled for training on CUDA" - ) def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ @@ -242,8 +235,10 @@ def setup(self, cfg: DictConfig) -> None: # set up model self._model = self._setup_model( cfg_model=cfg.model, - enable_activation_checkpointing=cfg.enable_activation_checkpointing, - enable_activation_offloading=self._enable_activation_offloading, + enable_activation_checkpointing=cfg.get( + "enable_activation_checkpointing", False + ), + enable_activation_offloading=cfg.get("enable_activation_offloading", False), compile_model=cfg.compile, base_model_state_dict=checkpoint_dict[training.MODEL_KEY], lora_weights_state_dict=( @@ -445,6 +440,18 @@ def _setup_model( self.adapter_params.items(), dtype=self._dtype ) + # activation checkpointing/offloading + if enable_activation_checkpointing and self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be enabled for training on CUDA" + ) + + if enable_activation_checkpointing and not enable_activation_offloading: + log.warning( + "enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further." + ) + self.activations_handling_ctx = contextlib.nullcontext() if enable_activation_offloading: self.activations_handling_ctx = OffloadActivations() From e31a10f4bfccb6c5a7e841a670e340405db2a820 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 15 Oct 2024 13:00:57 -0700 Subject: [PATCH 02/19] things work --- recipes/full_finetune_distributed.py | 58 +++++++------------- recipes/full_finetune_single_device.py | 48 +++++----------- recipes/lora_finetune_distributed.py | 50 ++++++----------- recipes/lora_finetune_single_device.py | 35 +++--------- torchtune/modules/tied_linear.py | 30 ++++++++-- torchtune/training/__init__.py | 7 ++- torchtune/training/_activation_offloading.py | 52 +++++++++++++++++- 7 files changed, 141 insertions(+), 139 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 96f3b6f9c3..578839f339 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import contextlib import sys import time @@ -25,12 +24,7 @@ from torchtune.data import padded_collate_packed from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.training import ( - DummyProfiler, - NoOpManager, - OffloadActivations, - PROFILER_KEY, -) +from torchtune.training import DummyProfiler, PROFILER_KEY from torchtune.training.activations import apply_selective_activation_checkpointing from tqdm import tqdm @@ -244,9 +238,11 @@ def setup(self, cfg: DictConfig) -> None: self._optimizer = self._setup_optimizer( cfg_optimizer=cfg.optimizer, - opt_state_dict=checkpoint_dict[training.OPT_KEY] - if self._resume_from_checkpoint - else None, + opt_state_dict=( + checkpoint_dict[training.OPT_KEY] + if self._resume_from_checkpoint + else None + ), ) # initialize loss @@ -463,33 +459,19 @@ def _is_layer_fqn(s: str) -> bool: cpu_offload=fsdp_cpu_offload, ) - # activation offloading + # activation checkpointing/offloading if enable_activation_checkpointing and self._device.type != "cuda": raise RuntimeError( "enable_activation_offloading should only be enabled for training on CUDA" ) - if enable_activation_checkpointing and not enable_activation_offloading: - log.warning( - "enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + log.info( + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " "Enabling activation offloading should reduce memory further." ) - self.activations_handling_ctx = contextlib.nullcontext() - if enable_activation_offloading: - self.activations_handling_ctx = OffloadActivations() - - # Below is our hack to disable offloading the last output Linear in every - # step, as the cost for offloading the activation and then soon after bringing - # it back is expensive. Moreover, due to heuristics in our streaming API, - # we actually use more memory if we offload it as it interferes with chunkedCE. - if hasattr(model, "output") and isinstance(model.output, nn.Module): - noop_ctx = NoOpManager() - model.output.register_forward_pre_hook( - lambda *args: noop_ctx.__enter__() - ) - model.output.register_forward_hook( - lambda *args: noop_ctx.__exit__(), always_call=True - ) + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) # Ensure no params and buffers are on meta device training.validate_no_params_on_meta_device(model) @@ -560,13 +542,15 @@ def _setup_data( sampler=sampler, # dropping last avoids shape issues with compile + flex attention drop_last=True, - collate_fn=partial( - collate_fn, - padding_idx=self._tokenizer.pad_id, - ignore_idx=self._loss_fn.ignore_index, - ) - if not packed - else padded_collate_packed, + collate_fn=( + partial( + collate_fn, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else padded_collate_packed + ), ) if self._is_rank_zero: diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index de9ae28818..c37faad036 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import contextlib import sys import time from functools import partial @@ -23,12 +22,7 @@ from torchtune.data import padded_collate_packed from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.training import ( - DummyProfiler, - NoOpManager, - OffloadActivations, - PROFILER_KEY, -) +from torchtune.training import DummyProfiler, PROFILER_KEY from tqdm import tqdm @@ -407,28 +401,14 @@ def _setup_model( raise RuntimeError( "enable_activation_offloading should only be enabled for training on CUDA" ) - if enable_activation_checkpointing and not enable_activation_offloading: - log.warning( - "enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + log.info( + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " "Enabling activation offloading should reduce memory further." ) - self.activations_handling_ctx = contextlib.nullcontext() - if enable_activation_offloading: - self.activations_handling_ctx = OffloadActivations() - - # Below is our hack to disable offloading the last output Linear in every - # step, as the cost for offloading the activation and then soon after bringing - # it back is expensive. Moreover, due to heuristics in our streaming API, - # we actually use more memory if we offload it as it interferes with chunkedCE. - if hasattr(model, "output") and isinstance(model.output, nn.Module): - noop_ctx = NoOpManager() - model.output.register_forward_pre_hook( - lambda *args: noop_ctx.__enter__() - ) - model.output.register_forward_hook( - lambda *args: noop_ctx.__exit__(), always_call=True - ) + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) log.info(f"Model is initialized with precision {self._dtype}.") @@ -570,13 +550,15 @@ def _setup_data( sampler=sampler, # dropping last avoids shape issues with compile + flex attention drop_last=True, - collate_fn=partial( - collate_fn, - padding_idx=self._tokenizer.pad_id, - ignore_idx=self._loss_fn.ignore_index, - ) - if not packed - else padded_collate_packed, + collate_fn=( + partial( + collate_fn, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else padded_collate_packed + ), ) log.info("Dataset and Sampler are initialized.") diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index f1eac8c9a3..2aff02bd21 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import contextlib import sys import time @@ -35,12 +34,7 @@ validate_missing_and_unexpected_for_lora, ) from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.training import ( - DummyProfiler, - NoOpManager, - OffloadActivations, - PROFILER_KEY, -) +from torchtune.training import DummyProfiler, PROFILER_KEY from tqdm import tqdm @@ -521,30 +515,16 @@ def _is_layer_name(name: str, module: nn.Module) -> bool: raise RuntimeError( "enable_activation_offloading should only be enabled for training on CUDA" ) - if enable_activation_checkpointing and not enable_activation_offloading: - log.warning( - "enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + log.info( + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " "Enabling activation offloading should reduce memory further." ) + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) - self.activations_handling_ctx = contextlib.nullcontext() - if enable_activation_offloading: - self.activations_handling_ctx = OffloadActivations() - - # Below is our hack to disable offloading the last output Linear in every - # step, as the cost for offloading the activation and then soon after bringing - # it back is expensive. Moreover, due to heuristics in our streaming API, - # we actually use more memory if we offload it as it interferes with chunkedCE. - if hasattr(model, "output") and isinstance(model.output, nn.Module): - noop_ctx = NoOpManager() - model.output.register_forward_pre_hook( - lambda *args: noop_ctx.__enter__() - ) - model.output.register_forward_hook( - lambda *args: noop_ctx.__exit__(), always_call=True - ) - + # log if self._is_rank_zero: log.info( f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" @@ -628,13 +608,15 @@ def _setup_data( sampler=sampler, # dropping last avoids shape issues with compile + flex attention drop_last=True, - collate_fn=partial( - collate_fn, - padding_idx=self._tokenizer.pad_id, - ignore_idx=self._loss_fn.ignore_index, - ) - if not packed - else padded_collate_packed, + collate_fn=( + partial( + collate_fn, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else padded_collate_packed + ), ) if self._is_rank_zero: diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 525a62d850..6ecd79c47e 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import contextlib import sys import time @@ -32,12 +31,7 @@ validate_missing_and_unexpected_for_lora, ) from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.training import ( - DummyProfiler, - NoOpManager, - OffloadActivations, - PROFILER_KEY, -) +from torchtune.training import DummyProfiler, PROFILER_KEY from tqdm import tqdm log = utils.get_logger("DEBUG") @@ -445,29 +439,14 @@ def _setup_model( raise RuntimeError( "enable_activation_offloading should only be enabled for training on CUDA" ) - if enable_activation_checkpointing and not enable_activation_offloading: - log.warning( - "enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + log.info( + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " "Enabling activation offloading should reduce memory further." ) - - self.activations_handling_ctx = contextlib.nullcontext() - if enable_activation_offloading: - self.activations_handling_ctx = OffloadActivations() - - # Below is our hack to disable offloading the last output Linear in every - # step, as the cost for offloading the activation and then soon after bringing - # it back is expensive. Moreover, due to heuristics in our streaming API, - # we actually use more memory if we offload it as it interferes with chunkedCE. - if hasattr(model, "output") and isinstance(model.output, nn.Module): - noop_ctx = NoOpManager() - model.output.register_forward_pre_hook( - lambda *args: noop_ctx.__enter__() - ) - model.output.register_forward_hook( - lambda *args: noop_ctx.__exit__(), always_call=True - ) + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) log.info(f"Model is initialized with precision {self._dtype}.") @@ -758,7 +737,7 @@ def train(self) -> None: self.epochs_run += 1 start_save_checkpoint = time.perf_counter() log.info("Starting checkpoint save...") - self.save_checkpoint(epoch=curr_epoch) + # self.save_checkpoint(epoch=curr_epoch) log.info( "Checkpoint saved in {:.2f} seconds.".format( time.perf_counter() - start_save_checkpoint diff --git a/torchtune/modules/tied_linear.py b/torchtune/modules/tied_linear.py index 718abd5c67..1e8368072e 100644 --- a/torchtune/modules/tied_linear.py +++ b/torchtune/modules/tied_linear.py @@ -4,18 +4,37 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import torch import torch.nn as nn import torch.nn.functional as F +class Linear(nn.Module): + """ + nn.Module used in :func:`~torchtune.modules.tied_linear.TiedLinear`, added to work with the hooks + :class:`~torchtune.training._activation_offloading.NoOpManager` that ignore activation + offloading context manager. + + Without this class, we can't add NoOp hooks, and we will offload the activation of + the tied linear layer, which is slow. + + For more information, see how NoOpManager is called in the recipes. + """ + + def forward(self, x: torch.Tensor, weight: torch.Tensor): + return F.linear(x, weight) + + class TiedLinear: """ A tied linear layer, without bias, that shares the same weight as another linear layer. This is useful for models that use tied weights, such as :func:`~torchtune.models.qwen2_0_5b`, - :func:`~torchtune.models.qwen2_1_5b` and all of the :func:`~torchtune.models.gemma` models. + :func:`~torchtune.models.qwen2_1_5b` and all of the :func:`~torchtune.models.gemma` and + :func:`~torchtune.models.llama3_2` models. + It requires as input an nn.Module, instead of the weight of the module, so it - can work with FSDP. Otherwise, the memory reference will be lost after FSDP is applied. + can work with FSDP. When FSDP is applied, the memory pointer to the weight is different, + but the nn.Module remains the same. This is why we need to pass the nn.Module instead of + the weight, if we want to keep the weights tied. Args: tied_module (nn.Module): The module whose weight is shared. Only @@ -26,12 +45,13 @@ class TiedLinear: def __init__(self, tied_module: nn.Module): self.tied_module = tied_module + self.linear = Linear() if not hasattr(tied_module, "weight"): raise AttributeError( "Provided module does not have attribute 'weight'. Please check your tied_module." ) - def __call__(self, x: torch.Tensor) -> torch.Tensor: + def __call__(self, x: torch.Tensor): """ Args: x (torch.Tensor): Input tensor. Should have shape ``(..., in_dim)``, where ``in_dim`` @@ -40,4 +60,4 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: torch.Tensor: The output tensor, having shape ``(..., out_dim)``, where ``out_dim`` is \ the output dimension of the tied module. """ - return F.linear(x, self.tied_module.weight) + return self.linear(x, self.tied_module.weight) diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index 9e33f40067..d9d9377cfb 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -3,7 +3,11 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtune.training._activation_offloading import NoOpManager, OffloadActivations +from torchtune.training._activation_offloading import ( + get_act_offloading_ctx_manager, + NoOpManager, + OffloadActivations, +) from torchtune.training._compile import compile_loss, compile_model from torchtune.training._distributed import ( contains_fsdp, @@ -70,6 +74,7 @@ from torchtune.training.seed import set_seed __all__ = [ + "get_act_offloading_ctx_manager", "apply_selective_activation_checkpointing", "get_dtype", "set_default_dtype", diff --git a/torchtune/training/_activation_offloading.py b/torchtune/training/_activation_offloading.py index 5156281aa8..ef384691a3 100644 --- a/torchtune/training/_activation_offloading.py +++ b/torchtune/training/_activation_offloading.py @@ -4,15 +4,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional +import contextlib +from typing import Optional, Union from warnings import warn import psutil import torch import torchao +from torch import nn from torch.autograd.graph import saved_tensors_hooks from torchao.dtypes.nf4tensor import NF4Tensor +from torchtune.modules import TiedLinear + class OffloadActivations(saved_tensors_hooks): """Context manager under which activation tensors created in the forward pass will be offloaded. @@ -327,3 +331,49 @@ def noop(tensor): return tensor super().__init__(noop, noop) + + +def get_act_offloading_ctx_manager( + model: nn.Module, enable_activation_offloading: bool +) -> Union[OffloadActivations, contextlib.nullcontext]: + """Returns the activation offloading context manager for the model, which will be + a null context if enable_activation_offloading is False. + + If activation offloading is enabled, we return the OffloadActivations context manager. + If activation offloading is disabled, we return a NoOpManager context manager. + + Args: + model (nn.Module): the model to wrap with the activation offloading context manager. + enable_activation_offloading (bool): whether or not to enable activation offloading + for the model. + + Returns: + contextlib.ContextDecorator: the activation offloading context manager for the model. + """ + if enable_activation_offloading: + activations_handling_ctx = OffloadActivations() + + # Below is our hack to disable offloading the last output Linear in every + # step, as the cost for offloading the activation and then soon after bringing + # it back is expensive. Moreover, due to heuristics in our streaming API, + # we actually use more memory if we offload it as it interferes with chunkedCE. + if hasattr(model, "output"): + noop_ctx = NoOpManager() + if isinstance(model.output, nn.Module): + model.output.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + model.output.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + elif isinstance(model.output, TiedLinear): + model.output.linear.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + model.output.linear.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + else: + activations_handling_ctx = contextlib.nullcontext() + + return activations_handling_ctx From 9005c0d4cc340286927192a894fe381d9670b9b7 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 15 Oct 2024 13:26:15 -0700 Subject: [PATCH 03/19] fix logging and typing --- recipes/full_finetune_distributed.py | 17 ++++++++++++----- recipes/full_finetune_single_device.py | 17 ++++++++++++----- recipes/lora_finetune_distributed.py | 17 ++++++++++++----- recipes/lora_finetune_single_device.py | 17 ++++++++++++----- torchtune/modules/tied_linear.py | 1 + 5 files changed, 49 insertions(+), 20 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 578839f339..b98747bf7b 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -109,6 +109,7 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface): RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. RuntimeError: If ``left_pad_sequence`` is set as the data collator. RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. """ def __init__(self, cfg: DictConfig) -> None: @@ -460,15 +461,21 @@ def _is_layer_fqn(s: str) -> bool: ) # activation checkpointing/offloading - if enable_activation_checkpointing and self._device.type != "cuda": - raise RuntimeError( - "enable_activation_offloading should only be enabled for training on CUDA" - ) - if enable_activation_checkpointing and not enable_activation_offloading: + if enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True for training on CUDA" + ) + if not enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif enable_activation_checkpointing: log.info( "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " "Enabling activation offloading should reduce memory further." ) + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( model, enable_activation_offloading ) diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index c37faad036..f5227e05f9 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -112,6 +112,7 @@ class FullFinetuneRecipeSingleDevice(FTRecipeInterface): RuntimeError: If ``gradient_accumulation_steps > 1`` and ``optimizer_in_bwd`` is `True`. RuntimeError: If ``left_pad_sequence`` is set as the data collator. RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. """ def __init__(self, cfg: DictConfig) -> None: @@ -397,15 +398,21 @@ def _setup_model( ) # activation checkpointing/offloading - if enable_activation_checkpointing and self._device.type != "cuda": - raise RuntimeError( - "enable_activation_offloading should only be enabled for training on CUDA" - ) - if enable_activation_checkpointing and not enable_activation_offloading: + if enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif enable_activation_checkpointing: log.info( "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " "Enabling activation offloading should reduce memory further." ) + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( model, enable_activation_offloading ) diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 2aff02bd21..bf7dcb33d8 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -123,6 +123,7 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface): RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. RuntimeError: If ``left_pad_sequence`` is set as the data collator. RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. """ def __init__(self, cfg: DictConfig) -> None: @@ -511,15 +512,21 @@ def _is_layer_name(name: str, module: nn.Module) -> bool: training.validate_no_params_on_meta_device(model) # activation checkpointing/offloading - if enable_activation_checkpointing and self._device.type != "cuda": - raise RuntimeError( - "enable_activation_offloading should only be enabled for training on CUDA" - ) - if enable_activation_checkpointing and not enable_activation_offloading: + if enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True for training on CUDA" + ) + if not enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif enable_activation_checkpointing: log.info( "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " "Enabling activation offloading should reduce memory further." ) + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( model, enable_activation_offloading ) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 6ecd79c47e..201c5aa7fb 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -114,6 +114,7 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface): ValueError: If ``dtype`` is set to fp16. RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. RuntimeError: If ``left_pad_sequence`` is set as the data collator """ @@ -435,15 +436,21 @@ def _setup_model( ) # activation checkpointing/offloading - if enable_activation_checkpointing and self._device.type != "cuda": - raise RuntimeError( - "enable_activation_offloading should only be enabled for training on CUDA" - ) - if enable_activation_checkpointing and not enable_activation_offloading: + if enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True for training on CUDA" + ) + if not enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif enable_activation_checkpointing: log.info( "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " "Enabling activation offloading should reduce memory further." ) + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( model, enable_activation_offloading ) diff --git a/torchtune/modules/tied_linear.py b/torchtune/modules/tied_linear.py index 1e8368072e..5cc1cc8d0c 100644 --- a/torchtune/modules/tied_linear.py +++ b/torchtune/modules/tied_linear.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torch import torch.nn as nn import torch.nn.functional as F From b523cdcc5a0b3ba66068d5f1eaa2c47fbd44bcca Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 15 Oct 2024 14:28:06 -0700 Subject: [PATCH 04/19] moved raise error to init --- recipes/full_finetune_distributed.py | 45 +++++++++++++----------- recipes/full_finetune_single_device.py | 46 +++++++++++++------------ recipes/lora_finetune_distributed.py | 45 +++++++++++++----------- recipes/lora_finetune_single_device.py | 47 ++++++++++++++------------ 4 files changed, 101 insertions(+), 82 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index b98747bf7b..cbe21cdae8 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -144,6 +144,28 @@ def __init__(self, cfg: DictConfig) -> None: self._resume_from_checkpoint = cfg.resume_from_checkpoint self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif self._enable_activation_checkpointing: + log.info( + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further." + ) + # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests self.seed = training.set_seed(seed=cfg.seed) @@ -224,10 +246,8 @@ def setup(self, cfg: DictConfig) -> None: self._compile = cfg.get("compile", False) self._model = self._setup_model( cfg_model=cfg.model, - enable_activation_checkpointing=cfg.get( - "enable_activation_checkpointing", False - ), - enable_activation_offloading=cfg.get("enable_activation_offloading", False), + enable_activation_checkpointing=self._enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, custom_sharded_layers=cfg.get("custom_sharded_layers", None), fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), @@ -460,22 +480,7 @@ def _is_layer_fqn(s: str) -> bool: cpu_offload=fsdp_cpu_offload, ) - # activation checkpointing/offloading - if enable_activation_offloading: - if self._device.type != "cuda": - raise RuntimeError( - "enable_activation_offloading should only be True for training on CUDA" - ) - if not enable_activation_checkpointing: - raise RuntimeError( - "enable_activation_offloading should only be True when enable_activation_checkpointing is True" - ) - elif enable_activation_checkpointing: - log.info( - "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " - "Enabling activation offloading should reduce memory further." - ) - + # activation offloading self.activations_handling_ctx = training.get_act_offloading_ctx_manager( model, enable_activation_offloading ) diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index f5227e05f9..4a9f07838c 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -135,6 +135,28 @@ def __init__(self, cfg: DictConfig) -> None: self._gradient_accumulation_steps = cfg.gradient_accumulation_steps self._optimizer_in_bwd = cfg.optimizer_in_bwd + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif self._enable_activation_checkpointing: + log.info( + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further." + ) + # TODO: find a better place / way to perform validation of args that don't yet # compose with each other. if self._gradient_accumulation_steps > 1 and self._optimizer_in_bwd: @@ -222,14 +244,11 @@ def setup(self, cfg: DictConfig) -> None: # ``_setup_model`` handles initialization and loading the state dict. This method # should be called before ``_setup_optimizer`` since transforming the optimizer # state dict requires the model - self._compile = cfg.compile self._model = self._setup_model( cfg_model=cfg.model, - enable_activation_checkpointing=cfg.get( - "enable_activation_checkpointing", False - ), - enable_activation_offloading=cfg.get("enable_activation_offloading", False), + enable_activation_checkpointing=self._enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, compile_model=self._compile, model_state_dict=ckpt_dict[training.MODEL_KEY], ) @@ -397,22 +416,7 @@ def _setup_model( model.named_parameters(), dtype=self._dtype ) - # activation checkpointing/offloading - if enable_activation_offloading: - if self._device.type != "cuda": - raise RuntimeError( - "enable_activation_offloading should only be True when training on CUDA" - ) - if not enable_activation_checkpointing: - raise RuntimeError( - "enable_activation_offloading should only be True when enable_activation_checkpointing is True" - ) - elif enable_activation_checkpointing: - log.info( - "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " - "Enabling activation offloading should reduce memory further." - ) - + # Enable activation offloading self.activations_handling_ctx = training.get_act_offloading_ctx_manager( model, enable_activation_offloading ) diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index bf7dcb33d8..78bc2e6f5f 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -159,6 +159,28 @@ def __init__(self, cfg: DictConfig) -> None: self._resume_from_checkpoint = cfg.resume_from_checkpoint self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif self._enable_activation_checkpointing: + log.info( + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further." + ) + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ Extract the checkpoint state from file and validate. This includes the @@ -240,10 +262,8 @@ def setup(self, cfg: DictConfig) -> None: self._model = self._setup_model( cfg_model=cfg.model, - enable_activation_checkpointing=cfg.get( - "enable_activation_checkpointing", False - ), - enable_activation_offloading=cfg.get("enable_activation_offloading", False), + enable_activation_checkpointing=self._enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), base_model_state_dict=checkpoint_dict[training.MODEL_KEY], @@ -511,22 +531,7 @@ def _is_layer_name(name: str, module: nn.Module) -> bool: # Ensure no params and buffers are on meta device training.validate_no_params_on_meta_device(model) - # activation checkpointing/offloading - if enable_activation_offloading: - if self._device.type != "cuda": - raise RuntimeError( - "enable_activation_offloading should only be True for training on CUDA" - ) - if not enable_activation_checkpointing: - raise RuntimeError( - "enable_activation_offloading should only be True when enable_activation_checkpointing is True" - ) - elif enable_activation_checkpointing: - log.info( - "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " - "Enabling activation offloading should reduce memory further." - ) - + # activation offloading self.activations_handling_ctx = training.get_act_offloading_ctx_manager( model, enable_activation_offloading ) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 201c5aa7fb..3e23a783bb 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -148,6 +148,28 @@ def __init__(self, cfg: DictConfig) -> None: self._gradient_accumulation_steps = cfg.gradient_accumulation_steps self._clip_grad_norm = cfg.get("clip_grad_norm", None) + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif self._enable_activation_checkpointing: + log.info( + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further." + ) + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ Extract the checkpoint state from file and validate. This includes the @@ -230,10 +252,8 @@ def setup(self, cfg: DictConfig) -> None: # set up model self._model = self._setup_model( cfg_model=cfg.model, - enable_activation_checkpointing=cfg.get( - "enable_activation_checkpointing", False - ), - enable_activation_offloading=cfg.get("enable_activation_offloading", False), + enable_activation_checkpointing=self._enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, compile_model=cfg.compile, base_model_state_dict=checkpoint_dict[training.MODEL_KEY], lora_weights_state_dict=( @@ -435,22 +455,7 @@ def _setup_model( self.adapter_params.items(), dtype=self._dtype ) - # activation checkpointing/offloading - if enable_activation_offloading: - if self._device.type != "cuda": - raise RuntimeError( - "enable_activation_offloading should only be True for training on CUDA" - ) - if not enable_activation_checkpointing: - raise RuntimeError( - "enable_activation_offloading should only be True when enable_activation_checkpointing is True" - ) - elif enable_activation_checkpointing: - log.info( - "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " - "Enabling activation offloading should reduce memory further." - ) - + # activation offloading self.activations_handling_ctx = training.get_act_offloading_ctx_manager( model, enable_activation_offloading ) @@ -744,7 +749,7 @@ def train(self) -> None: self.epochs_run += 1 start_save_checkpoint = time.perf_counter() log.info("Starting checkpoint save...") - # self.save_checkpoint(epoch=curr_epoch) + self.save_checkpoint(epoch=curr_epoch) log.info( "Checkpoint saved in {:.2f} seconds.".format( time.perf_counter() - start_save_checkpoint From 3e02a45b1b985c7be0c3e5de20652a1ebfd75b0b Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 15 Oct 2024 15:31:18 -0700 Subject: [PATCH 05/19] add qat --- recipes/qat_distributed.py | 57 ++++++++++++++++++++++++++++++-------- 1 file changed, 45 insertions(+), 12 deletions(-) diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index eb2e44fae2..721f92cedb 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -141,6 +141,28 @@ def __init__(self, cfg: DictConfig) -> None: self._fake_quant_after_n_steps = cfg.get("fake_quant_after_n_steps", None) self._quantizer_mode = None + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif self._enable_activation_checkpointing: + log.info( + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further." + ) + # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests self.seed = training.set_seed(seed=cfg.seed) @@ -220,7 +242,8 @@ def setup(self, cfg: DictConfig) -> None: self._model_compile = cfg.get("compile", False) self._model = self._setup_model( cfg_model=cfg.model, - enable_activation_checkpointing=cfg.enable_activation_checkpointing, + enable_activation_checkpointing=self._enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, custom_sharded_layers=cfg.get("custom_sharded_layers", None), fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), @@ -233,9 +256,11 @@ def setup(self, cfg: DictConfig) -> None: self._optimizer = self._setup_optimizer( cfg_optimizer=cfg.optimizer, - opt_state_dict=checkpoint_dict[training.OPT_KEY] - if self._resume_from_checkpoint - else None, + opt_state_dict=( + checkpoint_dict[training.OPT_KEY] + if self._resume_from_checkpoint + else None + ), ) # initialize loss @@ -363,6 +388,7 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, + enable_activation_offloading: bool, custom_sharded_layers: Optional[List[str]], fsdp_cpu_offload: bool, reshard_after_forward: bool, @@ -465,6 +491,11 @@ def _is_layer_fqn(s: str) -> bool: # Ensure no params and buffers are on meta device training.validate_no_params_on_meta_device(model) + # activation offloading + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) + if self._is_rank_zero: log.info( f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" @@ -525,14 +556,16 @@ def _setup_data( sampler=sampler, # dropping last avoids shape issues with compile + flex attention drop_last=True, - collate_fn=partial( - padded_collate_sft, - padding_idx=self._tokenizer.pad_id, - ignore_idx=self._loss_fn.ignore_index, - ) - if not packed - else partial( - padded_collate_packed, + collate_fn=( + partial( + padded_collate_sft, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else partial( + padded_collate_packed, + ) ), ) From 6ae18ede2a342bfd7a4d223348172d2c66747a99 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 15 Oct 2024 15:31:50 -0700 Subject: [PATCH 06/19] update configs --- recipes/configs/dev/8B_full_experimental.yaml | 1 + recipes/configs/gemma/2B_full.yaml | 1 + recipes/configs/gemma/2B_lora.yaml | 1 + recipes/configs/gemma/7B_full.yaml | 1 + recipes/configs/gemma/7B_lora.yaml | 1 + recipes/configs/llama2/13B_full.yaml | 1 + recipes/configs/llama2/13B_lora.yaml | 1 + recipes/configs/llama2/70B_lora.yaml | 1 + recipes/configs/llama2/70B_qlora.yaml | 1 + recipes/configs/llama2/7B_full.yaml | 1 + recipes/configs/llama2/7B_full_low_memory.yaml | 1 + recipes/configs/llama2/7B_lora.yaml | 1 + recipes/configs/llama2/7B_qat_full.yaml | 1 + recipes/configs/llama2/7B_qlora.yaml | 1 + recipes/configs/llama3/70B_full.yaml | 1 + recipes/configs/llama3/70B_lora.yaml | 1 + recipes/configs/llama3/8B_dora.yaml | 1 + recipes/configs/llama3/8B_dora_single_device.yaml | 1 + recipes/configs/llama3/8B_full.yaml | 1 + recipes/configs/llama3/8B_full_single_device.yaml | 1 + recipes/configs/llama3/8B_lora.yaml | 1 + recipes/configs/llama3/8B_qat_full.yaml | 1 + recipes/configs/llama3/8B_qdora_single_device.yaml | 1 + recipes/configs/llama3_1/405B_qlora.yaml | 1 + recipes/configs/llama3_1/70B_full.yaml | 1 + recipes/configs/llama3_1/70B_lora.yaml | 1 + recipes/configs/llama3_1/8B_full.yaml | 1 + recipes/configs/llama3_1/8B_full_single_device.yaml | 1 + recipes/configs/llama3_1/8B_lora.yaml | 1 + recipes/configs/llama3_2/1B_full.yaml | 1 + recipes/configs/llama3_2/1B_full_single_device.yaml | 1 + recipes/configs/llama3_2/1B_lora.yaml | 1 + recipes/configs/llama3_2/3B_full.yaml | 1 + recipes/configs/llama3_2/3B_full_single_device.yaml | 1 + recipes/configs/llama3_2/3B_lora.yaml | 1 + recipes/configs/llama3_2_vision/11B_full.yaml | 1 + recipes/configs/llama3_2_vision/11B_full_single_device.yaml | 1 + recipes/configs/mistral/7B_full.yaml | 1 + recipes/configs/mistral/7B_full_low_memory.yaml | 1 + recipes/configs/mistral/7B_lora.yaml | 1 + recipes/configs/phi3/mini_full.yaml | 1 + recipes/configs/phi3/mini_full_low_memory.yaml | 1 + recipes/configs/phi3/mini_lora.yaml | 1 + recipes/configs/qwen2/0.5B_full.yaml | 1 + recipes/configs/qwen2/0.5B_full_single_device.yaml | 1 + recipes/configs/qwen2/0.5B_lora.yaml | 1 + recipes/configs/qwen2/1.5B_full.yaml | 1 + recipes/configs/qwen2/1.5B_full_single_device.yaml | 1 + recipes/configs/qwen2/1.5B_lora.yaml | 1 + recipes/configs/qwen2/7B_full.yaml | 1 + recipes/configs/qwen2/7B_full_single_device.yaml | 1 + recipes/configs/qwen2/7B_lora.yaml | 1 + 52 files changed, 52 insertions(+) diff --git a/recipes/configs/dev/8B_full_experimental.yaml b/recipes/configs/dev/8B_full_experimental.yaml index 4ed8a80e09..03960e85b7 100644 --- a/recipes/configs/dev/8B_full_experimental.yaml +++ b/recipes/configs/dev/8B_full_experimental.yaml @@ -64,6 +64,7 @@ device: cuda # Memory management enable_activation_checkpointing: False +enable_activation_offloading: False ac_mode: 'selective' # ['selective', 'full'] ac_option: 2 # [int] = ac every positive int layer memory_efficient_fsdp_wrap: False diff --git a/recipes/configs/gemma/2B_full.yaml b/recipes/configs/gemma/2B_full.yaml index e1bd3272d0..0278b37658 100644 --- a/recipes/configs/gemma/2B_full.yaml +++ b/recipes/configs/gemma/2B_full.yaml @@ -60,6 +60,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/gemma/2B_lora.yaml b/recipes/configs/gemma/2B_lora.yaml index 5364ec2bce..28e20e76f0 100644 --- a/recipes/configs/gemma/2B_lora.yaml +++ b/recipes/configs/gemma/2B_lora.yaml @@ -72,6 +72,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/gemma/7B_full.yaml b/recipes/configs/gemma/7B_full.yaml index a8924836fe..fa39b5e529 100644 --- a/recipes/configs/gemma/7B_full.yaml +++ b/recipes/configs/gemma/7B_full.yaml @@ -62,6 +62,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/gemma/7B_lora.yaml b/recipes/configs/gemma/7B_lora.yaml index a4ee960c17..af90aeb7f8 100644 --- a/recipes/configs/gemma/7B_lora.yaml +++ b/recipes/configs/gemma/7B_lora.yaml @@ -74,6 +74,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama2/13B_full.yaml b/recipes/configs/llama2/13B_full.yaml index f5ecffc2ab..6d3c3468da 100644 --- a/recipes/configs/llama2/13B_full.yaml +++ b/recipes/configs/llama2/13B_full.yaml @@ -64,6 +64,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama2/13B_lora.yaml b/recipes/configs/llama2/13B_lora.yaml index 267725ab92..7b3b834d4a 100644 --- a/recipes/configs/llama2/13B_lora.yaml +++ b/recipes/configs/llama2/13B_lora.yaml @@ -87,3 +87,4 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: False +enable_activation_offloading: False diff --git a/recipes/configs/llama2/70B_lora.yaml b/recipes/configs/llama2/70B_lora.yaml index ff4f56493b..00ccabd98e 100644 --- a/recipes/configs/llama2/70B_lora.yaml +++ b/recipes/configs/llama2/70B_lora.yaml @@ -87,3 +87,4 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: True +enable_activation_offloading: False diff --git a/recipes/configs/llama2/70B_qlora.yaml b/recipes/configs/llama2/70B_qlora.yaml index b8ff55c01b..d7348f2dc1 100644 --- a/recipes/configs/llama2/70B_qlora.yaml +++ b/recipes/configs/llama2/70B_qlora.yaml @@ -97,3 +97,4 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: True +enable_activation_offloading: False diff --git a/recipes/configs/llama2/7B_full.yaml b/recipes/configs/llama2/7B_full.yaml index 2e80276c84..29d05242d9 100644 --- a/recipes/configs/llama2/7B_full.yaml +++ b/recipes/configs/llama2/7B_full.yaml @@ -64,6 +64,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama2/7B_full_low_memory.yaml b/recipes/configs/llama2/7B_full_low_memory.yaml index beb2248b23..a13c64d00d 100644 --- a/recipes/configs/llama2/7B_full_low_memory.yaml +++ b/recipes/configs/llama2/7B_full_low_memory.yaml @@ -69,6 +69,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: True # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama2/7B_lora.yaml b/recipes/configs/llama2/7B_lora.yaml index 68e1d302df..5115c1e0f1 100644 --- a/recipes/configs/llama2/7B_lora.yaml +++ b/recipes/configs/llama2/7B_lora.yaml @@ -84,6 +84,7 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: False +enable_activation_offloading: False # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/llama2/7B_qat_full.yaml b/recipes/configs/llama2/7B_qat_full.yaml index 6fca6c4d4a..e0a967d8c8 100644 --- a/recipes/configs/llama2/7B_qat_full.yaml +++ b/recipes/configs/llama2/7B_qat_full.yaml @@ -64,6 +64,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False memory_efficient_fsdp_wrap: False # Reduced precision diff --git a/recipes/configs/llama2/7B_qlora.yaml b/recipes/configs/llama2/7B_qlora.yaml index 630d1f6357..3fc6fb215a 100644 --- a/recipes/configs/llama2/7B_qlora.yaml +++ b/recipes/configs/llama2/7B_qlora.yaml @@ -88,3 +88,4 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: True +enable_activation_offloading: False diff --git a/recipes/configs/llama3/70B_full.yaml b/recipes/configs/llama3/70B_full.yaml index a8b7ba619c..0c8baaf8a2 100644 --- a/recipes/configs/llama3/70B_full.yaml +++ b/recipes/configs/llama3/70B_full.yaml @@ -97,6 +97,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False custom_sharded_layers: ['tok_embeddings', 'output'] fsdp_cpu_offload: True compile: False # set it to True for better memory and performance diff --git a/recipes/configs/llama3/70B_lora.yaml b/recipes/configs/llama3/70B_lora.yaml index 84bed19a02..aaeec323f3 100644 --- a/recipes/configs/llama3/70B_lora.yaml +++ b/recipes/configs/llama3/70B_lora.yaml @@ -103,3 +103,4 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: True +enable_activation_offloading: False diff --git a/recipes/configs/llama3/8B_dora.yaml b/recipes/configs/llama3/8B_dora.yaml index 3911e856c2..a8891b9fdb 100644 --- a/recipes/configs/llama3/8B_dora.yaml +++ b/recipes/configs/llama3/8B_dora.yaml @@ -77,3 +77,4 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: False +enable_activation_offloading: False diff --git a/recipes/configs/llama3/8B_dora_single_device.yaml b/recipes/configs/llama3/8B_dora_single_device.yaml index 1f91dadda8..59443921b8 100644 --- a/recipes/configs/llama3/8B_dora_single_device.yaml +++ b/recipes/configs/llama3/8B_dora_single_device.yaml @@ -80,6 +80,7 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: True +enable_activation_offloading: False # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/llama3/8B_full.yaml b/recipes/configs/llama3/8B_full.yaml index 7f24376db7..2ed4e7a351 100644 --- a/recipes/configs/llama3/8B_full.yaml +++ b/recipes/configs/llama3/8B_full.yaml @@ -64,6 +64,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False custom_sharded_layers: ['tok_embeddings', 'output'] # Reduced precision diff --git a/recipes/configs/llama3/8B_full_single_device.yaml b/recipes/configs/llama3/8B_full_single_device.yaml index 1d5479ccbc..69b2198639 100644 --- a/recipes/configs/llama3/8B_full_single_device.yaml +++ b/recipes/configs/llama3/8B_full_single_device.yaml @@ -68,6 +68,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama3/8B_lora.yaml b/recipes/configs/llama3/8B_lora.yaml index 5c3510f466..6027c7f11b 100644 --- a/recipes/configs/llama3/8B_lora.yaml +++ b/recipes/configs/llama3/8B_lora.yaml @@ -82,3 +82,4 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: False +enable_activation_offloading: False diff --git a/recipes/configs/llama3/8B_qat_full.yaml b/recipes/configs/llama3/8B_qat_full.yaml index ff4d9c3195..0f8bbf7ff4 100644 --- a/recipes/configs/llama3/8B_qat_full.yaml +++ b/recipes/configs/llama3/8B_qat_full.yaml @@ -63,6 +63,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False memory_efficient_fsdp_wrap: True # Reduced precision diff --git a/recipes/configs/llama3/8B_qdora_single_device.yaml b/recipes/configs/llama3/8B_qdora_single_device.yaml index 29a2a2d84f..0c80161200 100644 --- a/recipes/configs/llama3/8B_qdora_single_device.yaml +++ b/recipes/configs/llama3/8B_qdora_single_device.yaml @@ -81,6 +81,7 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: True +enable_activation_offloading: False # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/llama3_1/405B_qlora.yaml b/recipes/configs/llama3_1/405B_qlora.yaml index 69583dd9d4..395cc931fd 100644 --- a/recipes/configs/llama3_1/405B_qlora.yaml +++ b/recipes/configs/llama3_1/405B_qlora.yaml @@ -85,3 +85,4 @@ log_peak_memory_stats: True device: cuda dtype: bf16 enable_activation_checkpointing: True +enable_activation_offloading: False diff --git a/recipes/configs/llama3_1/70B_full.yaml b/recipes/configs/llama3_1/70B_full.yaml index fcae062999..f88ab85789 100644 --- a/recipes/configs/llama3_1/70B_full.yaml +++ b/recipes/configs/llama3_1/70B_full.yaml @@ -99,6 +99,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False custom_sharded_layers: ['tok_embeddings', 'output'] fsdp_cpu_offload: True compile: False # set it to True for better memory and performance diff --git a/recipes/configs/llama3_1/70B_lora.yaml b/recipes/configs/llama3_1/70B_lora.yaml index c4fa8d589c..689ccf56b3 100644 --- a/recipes/configs/llama3_1/70B_lora.yaml +++ b/recipes/configs/llama3_1/70B_lora.yaml @@ -102,3 +102,4 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: True +enable_activation_offloading: False diff --git a/recipes/configs/llama3_1/8B_full.yaml b/recipes/configs/llama3_1/8B_full.yaml index 4420b0cae5..cd088386f5 100644 --- a/recipes/configs/llama3_1/8B_full.yaml +++ b/recipes/configs/llama3_1/8B_full.yaml @@ -67,6 +67,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False custom_sharded_layers: ['tok_embeddings', 'output'] compile: False # set it to True for better memory and performance diff --git a/recipes/configs/llama3_1/8B_full_single_device.yaml b/recipes/configs/llama3_1/8B_full_single_device.yaml index 9f7d9472ce..9baf5cbb93 100644 --- a/recipes/configs/llama3_1/8B_full_single_device.yaml +++ b/recipes/configs/llama3_1/8B_full_single_device.yaml @@ -68,6 +68,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama3_1/8B_lora.yaml b/recipes/configs/llama3_1/8B_lora.yaml index c6e94e0aab..f622a8a27e 100644 --- a/recipes/configs/llama3_1/8B_lora.yaml +++ b/recipes/configs/llama3_1/8B_lora.yaml @@ -86,3 +86,4 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: False +enable_activation_offloading: False diff --git a/recipes/configs/llama3_2/1B_full.yaml b/recipes/configs/llama3_2/1B_full.yaml index 23b699f754..1d56a1fb95 100644 --- a/recipes/configs/llama3_2/1B_full.yaml +++ b/recipes/configs/llama3_2/1B_full.yaml @@ -64,6 +64,7 @@ device: cuda # Memory management enable_activation_checkpointing: False +enable_activation_offloading: False compile: False # set it to True for better memory and performance # Reduced precision diff --git a/recipes/configs/llama3_2/1B_full_single_device.yaml b/recipes/configs/llama3_2/1B_full_single_device.yaml index fc4b0a507c..e208cdb004 100644 --- a/recipes/configs/llama3_2/1B_full_single_device.yaml +++ b/recipes/configs/llama3_2/1B_full_single_device.yaml @@ -65,6 +65,7 @@ device: cuda # Memory management enable_activation_checkpointing: False +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama3_2/1B_lora.yaml b/recipes/configs/llama3_2/1B_lora.yaml index 1fb0f483b3..630c4110f3 100644 --- a/recipes/configs/llama3_2/1B_lora.yaml +++ b/recipes/configs/llama3_2/1B_lora.yaml @@ -83,3 +83,4 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: False +enable_activation_offloading: False diff --git a/recipes/configs/llama3_2/3B_full.yaml b/recipes/configs/llama3_2/3B_full.yaml index 6d738331ae..ac6a728b1d 100644 --- a/recipes/configs/llama3_2/3B_full.yaml +++ b/recipes/configs/llama3_2/3B_full.yaml @@ -64,6 +64,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False compile: False # set it to True for better memory and performance # Reduced precision diff --git a/recipes/configs/llama3_2/3B_full_single_device.yaml b/recipes/configs/llama3_2/3B_full_single_device.yaml index 9b21f4f865..fdea2e622f 100644 --- a/recipes/configs/llama3_2/3B_full_single_device.yaml +++ b/recipes/configs/llama3_2/3B_full_single_device.yaml @@ -66,6 +66,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama3_2/3B_lora.yaml b/recipes/configs/llama3_2/3B_lora.yaml index 9a628f2c29..ce6af86487 100644 --- a/recipes/configs/llama3_2/3B_lora.yaml +++ b/recipes/configs/llama3_2/3B_lora.yaml @@ -84,3 +84,4 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: False +enable_activation_offloading: False diff --git a/recipes/configs/llama3_2_vision/11B_full.yaml b/recipes/configs/llama3_2_vision/11B_full.yaml index 2c8f1f58fd..f03299f157 100644 --- a/recipes/configs/llama3_2_vision/11B_full.yaml +++ b/recipes/configs/llama3_2_vision/11B_full.yaml @@ -66,6 +66,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False custom_sharded_layers: ['tok_embeddings', 'output'] dtype: bf16 diff --git a/recipes/configs/llama3_2_vision/11B_full_single_device.yaml b/recipes/configs/llama3_2_vision/11B_full_single_device.yaml index d42fb971e6..9ba4a8bf74 100644 --- a/recipes/configs/llama3_2_vision/11B_full_single_device.yaml +++ b/recipes/configs/llama3_2_vision/11B_full_single_device.yaml @@ -68,6 +68,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False dtype: bf16 # Logging diff --git a/recipes/configs/mistral/7B_full.yaml b/recipes/configs/mistral/7B_full.yaml index 602b3fe082..6a34aa4e9c 100644 --- a/recipes/configs/mistral/7B_full.yaml +++ b/recipes/configs/mistral/7B_full.yaml @@ -66,6 +66,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/mistral/7B_full_low_memory.yaml b/recipes/configs/mistral/7B_full_low_memory.yaml index 7e68ee8066..0f34b1dd16 100644 --- a/recipes/configs/mistral/7B_full_low_memory.yaml +++ b/recipes/configs/mistral/7B_full_low_memory.yaml @@ -68,6 +68,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: True # Reduced precision dtype: bf16 diff --git a/recipes/configs/mistral/7B_lora.yaml b/recipes/configs/mistral/7B_lora.yaml index fd2c637df7..8165e3b692 100644 --- a/recipes/configs/mistral/7B_lora.yaml +++ b/recipes/configs/mistral/7B_lora.yaml @@ -80,6 +80,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/phi3/mini_full.yaml b/recipes/configs/phi3/mini_full.yaml index 0ee746ddd4..b7f7afe785 100644 --- a/recipes/configs/phi3/mini_full.yaml +++ b/recipes/configs/phi3/mini_full.yaml @@ -63,6 +63,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False dtype: bf16 # Logging diff --git a/recipes/configs/phi3/mini_full_low_memory.yaml b/recipes/configs/phi3/mini_full_low_memory.yaml index 182a4f6a98..f097c3ba96 100644 --- a/recipes/configs/phi3/mini_full_low_memory.yaml +++ b/recipes/configs/phi3/mini_full_low_memory.yaml @@ -66,6 +66,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: True dtype: bf16 # Logging diff --git a/recipes/configs/phi3/mini_lora.yaml b/recipes/configs/phi3/mini_lora.yaml index 721a61790b..0d5ce3b278 100644 --- a/recipes/configs/phi3/mini_lora.yaml +++ b/recipes/configs/phi3/mini_lora.yaml @@ -74,6 +74,7 @@ device: cuda # Memory management enable_activation_checkpointing: False +enable_activation_offloading: False dtype: bf16 # Logging diff --git a/recipes/configs/qwen2/0.5B_full.yaml b/recipes/configs/qwen2/0.5B_full.yaml index 5bf14591f9..4520ac932c 100644 --- a/recipes/configs/qwen2/0.5B_full.yaml +++ b/recipes/configs/qwen2/0.5B_full.yaml @@ -63,6 +63,7 @@ device: cuda # Memory management enable_activation_checkpointing: False +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/qwen2/0.5B_full_single_device.yaml b/recipes/configs/qwen2/0.5B_full_single_device.yaml index 67091a4e8a..08daf3ab0f 100644 --- a/recipes/configs/qwen2/0.5B_full_single_device.yaml +++ b/recipes/configs/qwen2/0.5B_full_single_device.yaml @@ -64,6 +64,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/qwen2/0.5B_lora.yaml b/recipes/configs/qwen2/0.5B_lora.yaml index 9ccd400897..6dd856de11 100644 --- a/recipes/configs/qwen2/0.5B_lora.yaml +++ b/recipes/configs/qwen2/0.5B_lora.yaml @@ -84,6 +84,7 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: True +enable_activation_offloading: False # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2/1.5B_full.yaml b/recipes/configs/qwen2/1.5B_full.yaml index cb7b5e2318..82819f06fe 100644 --- a/recipes/configs/qwen2/1.5B_full.yaml +++ b/recipes/configs/qwen2/1.5B_full.yaml @@ -63,6 +63,7 @@ device: cuda # Memory management enable_activation_checkpointing: False +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/qwen2/1.5B_full_single_device.yaml b/recipes/configs/qwen2/1.5B_full_single_device.yaml index 5da79ceb69..9c21bd4d89 100644 --- a/recipes/configs/qwen2/1.5B_full_single_device.yaml +++ b/recipes/configs/qwen2/1.5B_full_single_device.yaml @@ -69,6 +69,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/qwen2/1.5B_lora.yaml b/recipes/configs/qwen2/1.5B_lora.yaml index 84fd73696b..3345e364b2 100644 --- a/recipes/configs/qwen2/1.5B_lora.yaml +++ b/recipes/configs/qwen2/1.5B_lora.yaml @@ -79,6 +79,7 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: True +enable_activation_offloading: False # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2/7B_full.yaml b/recipes/configs/qwen2/7B_full.yaml index 7ffc07e457..67550203f4 100644 --- a/recipes/configs/qwen2/7B_full.yaml +++ b/recipes/configs/qwen2/7B_full.yaml @@ -66,6 +66,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/qwen2/7B_full_single_device.yaml b/recipes/configs/qwen2/7B_full_single_device.yaml index 560dd5fc9f..e29aeda677 100644 --- a/recipes/configs/qwen2/7B_full_single_device.yaml +++ b/recipes/configs/qwen2/7B_full_single_device.yaml @@ -68,6 +68,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/qwen2/7B_lora.yaml b/recipes/configs/qwen2/7B_lora.yaml index f6a4cc2ac6..5bf2c97cb0 100644 --- a/recipes/configs/qwen2/7B_lora.yaml +++ b/recipes/configs/qwen2/7B_lora.yaml @@ -85,6 +85,7 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: False +enable_activation_offloading: False # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training From 69a506e42b913035ac5387a075e8eadf3cbad8f7 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 15 Oct 2024 16:13:23 -0700 Subject: [PATCH 07/19] remove offloading from qat --- .../code_llama2/7B_full_low_memory.yaml | 1 + recipes/configs/llama2/7B_qat_full.yaml | 1 - recipes/configs/llama3/8B_qat_full.yaml | 1 - recipes/qat_distributed.py | 31 +------------------ 4 files changed, 2 insertions(+), 32 deletions(-) diff --git a/recipes/configs/code_llama2/7B_full_low_memory.yaml b/recipes/configs/code_llama2/7B_full_low_memory.yaml index 6bca6c378f..0b74792ab9 100644 --- a/recipes/configs/code_llama2/7B_full_low_memory.yaml +++ b/recipes/configs/code_llama2/7B_full_low_memory.yaml @@ -67,6 +67,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: True dtype: bf16 # Logging diff --git a/recipes/configs/llama2/7B_qat_full.yaml b/recipes/configs/llama2/7B_qat_full.yaml index e0a967d8c8..6fca6c4d4a 100644 --- a/recipes/configs/llama2/7B_qat_full.yaml +++ b/recipes/configs/llama2/7B_qat_full.yaml @@ -64,7 +64,6 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False memory_efficient_fsdp_wrap: False # Reduced precision diff --git a/recipes/configs/llama3/8B_qat_full.yaml b/recipes/configs/llama3/8B_qat_full.yaml index 0f8bbf7ff4..ff4d9c3195 100644 --- a/recipes/configs/llama3/8B_qat_full.yaml +++ b/recipes/configs/llama3/8B_qat_full.yaml @@ -63,7 +63,6 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False memory_efficient_fsdp_wrap: True # Reduced precision diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index 721f92cedb..1613257b08 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -141,28 +141,6 @@ def __init__(self, cfg: DictConfig) -> None: self._fake_quant_after_n_steps = cfg.get("fake_quant_after_n_steps", None) self._quantizer_mode = None - # activation checkpointing/offloading - self._enable_activation_checkpointing = cfg.get( - "enable_activation_checkpointing", False - ) - self._enable_activation_offloading = cfg.get( - "enable_activation_offloading", False - ) - if self._enable_activation_offloading: - if self._device.type != "cuda": - raise RuntimeError( - "enable_activation_offloading should only be True when training on CUDA" - ) - if not self._enable_activation_checkpointing: - raise RuntimeError( - "enable_activation_offloading should only be True when enable_activation_checkpointing is True" - ) - elif self._enable_activation_checkpointing: - log.info( - "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " - "Enabling activation offloading should reduce memory further." - ) - # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests self.seed = training.set_seed(seed=cfg.seed) @@ -242,8 +220,7 @@ def setup(self, cfg: DictConfig) -> None: self._model_compile = cfg.get("compile", False) self._model = self._setup_model( cfg_model=cfg.model, - enable_activation_checkpointing=self._enable_activation_checkpointing, - enable_activation_offloading=self._enable_activation_offloading, + enable_activation_checkpointing=cfg.enable_activation_checkpointing, custom_sharded_layers=cfg.get("custom_sharded_layers", None), fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), @@ -388,7 +365,6 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, - enable_activation_offloading: bool, custom_sharded_layers: Optional[List[str]], fsdp_cpu_offload: bool, reshard_after_forward: bool, @@ -491,11 +467,6 @@ def _is_layer_fqn(s: str) -> bool: # Ensure no params and buffers are on meta device training.validate_no_params_on_meta_device(model) - # activation offloading - self.activations_handling_ctx = training.get_act_offloading_ctx_manager( - model, enable_activation_offloading - ) - if self._is_rank_zero: log.info( f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" From cc67901496434b30473237b8b757fce96899ff63 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 15 Oct 2024 16:20:38 -0700 Subject: [PATCH 08/19] update unit tests --- tests/recipes/test_full_finetune_distributed.py | 1 + tests/recipes/test_full_finetune_single_device.py | 2 ++ tests/recipes/test_knowledge_distillation_single_device.py | 1 + tests/recipes/test_lora_dpo_single_device.py | 3 +++ tests/recipes/test_lora_finetune_distributed.py | 4 ++++ tests/recipes/test_lora_finetune_single_device.py | 4 ++++ tests/recipes/test_ppo_full_finetune_single_device.py | 1 + tests/recipes/test_qat_distributed.py | 1 + 8 files changed, 17 insertions(+) diff --git a/tests/recipes/test_full_finetune_distributed.py b/tests/recipes/test_full_finetune_distributed.py index 99f56a19b8..4f7e49086f 100644 --- a/tests/recipes/test_full_finetune_distributed.py +++ b/tests/recipes/test_full_finetune_distributed.py @@ -34,6 +34,7 @@ def _get_test_config_overrides(self): "batch_size=4", "dtype=fp32", "enable_activation_checkpointing=False", + "enable_activation_offloading=False", "dataset.train_on_input=False", "seed=9", "epochs=2", diff --git a/tests/recipes/test_full_finetune_single_device.py b/tests/recipes/test_full_finetune_single_device.py index 170e4008d3..36a9272903 100644 --- a/tests/recipes/test_full_finetune_single_device.py +++ b/tests/recipes/test_full_finetune_single_device.py @@ -39,6 +39,7 @@ def _get_test_config_overrides(self): "device=cpu", "dtype=fp32", "enable_activation_checkpointing=False", + "enable_activation_offloading=False", "dataset.train_on_input=False", "seed=9", "epochs=2", @@ -187,6 +188,7 @@ def _get_test_config_overrides(self): "device=cpu", "dtype=fp32", "enable_activation_checkpointing=False", + "enable_activation_offloadingg=False", "tokenizer.path=/tmp/test-artifacts/tokenizer.model", "tokenizer.prompt_template=null", "dataset=tests.recipes.utils.DummyDataset", diff --git a/tests/recipes/test_knowledge_distillation_single_device.py b/tests/recipes/test_knowledge_distillation_single_device.py index e389460b71..81f5a1e80c 100644 --- a/tests/recipes/test_knowledge_distillation_single_device.py +++ b/tests/recipes/test_knowledge_distillation_single_device.py @@ -35,6 +35,7 @@ def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2): "device=cpu", f"dtype={dtype_str}", "enable_activation_checkpointing=False", + "enable_activation_offloading=False", "dataset.train_on_input=False", "seed=9", f"epochs={epochs}", diff --git a/tests/recipes/test_lora_dpo_single_device.py b/tests/recipes/test_lora_dpo_single_device.py index d8cdca76c2..79b1e4e767 100644 --- a/tests/recipes/test_lora_dpo_single_device.py +++ b/tests/recipes/test_lora_dpo_single_device.py @@ -83,6 +83,7 @@ def test_training_state_on_resume( save_adapter_weights_only={save_adapter_weights_only} \ metric_logger.filename={log_file} \ enable_activation_checkpointing=True \ + enable_activation_offloading=False \ """.split() model_config = MODEL_TEST_CONFIGS["llama2_lora"] @@ -113,6 +114,7 @@ def test_training_state_on_resume( tokenizer.path=/tmp/test-artifacts/tokenizer.model \ tokenizer.prompt_template=null \ enable_activation_checkpointing=True \ + enable_activation_offloading=False \ """.split() cmd_2 = cmd_2 + self._get_test_config_overrides(epochs=3) + model_config monkeypatch.setattr(sys, "argv", cmd_2) @@ -144,6 +146,7 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): tokenizer.path=/tmp/test-artifacts/tokenizer.model \ tokenizer.prompt_template=null \ enable_activation_checkpointing=False \ + enable_activation_offloading=False \ """.split() model_config = MODEL_TEST_CONFIGS["llama2_lora"] diff --git a/tests/recipes/test_lora_finetune_distributed.py b/tests/recipes/test_lora_finetune_distributed.py index 7777b02862..7c3480c7cb 100644 --- a/tests/recipes/test_lora_finetune_distributed.py +++ b/tests/recipes/test_lora_finetune_distributed.py @@ -81,6 +81,7 @@ def test_loss(self, reshard_after_forward, tmpdir, monkeypatch): tokenizer.prompt_template=null \ reshard_after_forward={reshard_after_forward} \ enable_activation_checkpointing=False \ + enable_activation_offloading=False \ """.split() model_config = MODEL_TEST_CONFIGS["llama2_lora"] @@ -148,6 +149,7 @@ def test_training_state_on_resume( tokenizer.prompt_template=null \ save_adapter_weights_only={save_adapter_weights_only} \ enable_activation_checkpointing=True \ + enable_activation_offloading=False \ """.split() model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] @@ -173,6 +175,7 @@ def test_training_state_on_resume( resume_from_checkpoint=True \ metric_logger.filename={log_file} \ enable_activation_checkpointing=True \ + enable_activation_offloading=False \ """.split() cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config @@ -216,6 +219,7 @@ def test_save_and_load_merged_weights( tokenizer.path='{tokenizer_path}' \ tokenizer.prompt_template=null \ enable_activation_checkpointing=True \ + enable_activation_offloading=False \ """.split() model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] diff --git a/tests/recipes/test_lora_finetune_single_device.py b/tests/recipes/test_lora_finetune_single_device.py index f2d7409042..f2490788a9 100644 --- a/tests/recipes/test_lora_finetune_single_device.py +++ b/tests/recipes/test_lora_finetune_single_device.py @@ -133,6 +133,7 @@ def test_loss_qlora(self, compile, dtype, tmpdir, monkeypatch): tokenizer.prompt_template=null \ compile={compile} \ enable_activation_checkpointing=False \ + enable_activation_offloading=False \ """.split() model_config = MODEL_TEST_CONFIGS["llama2_qlora"] @@ -189,6 +190,7 @@ def test_training_state_on_resume( tokenizer.prompt_template=null \ save_adapter_weights_only={save_adapter_weights_only} \ enable_activation_checkpointing=True \ + enable_activation_offloading=False \ """.split() model_config = MODEL_TEST_CONFIGS["llama2_lora"] @@ -215,6 +217,7 @@ def test_training_state_on_resume( tokenizer.path=/tmp/test-artifacts/tokenizer.model \ tokenizer.prompt_template=null \ enable_activation_checkpointing=True \ + enable_activation_offloading=False \ """.split() cmd_2 = cmd_2 + self._get_test_config_overrides(epochs=3) + model_config monkeypatch.setattr(sys, "argv", cmd_2) @@ -247,6 +250,7 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): tokenizer.path=/tmp/test-artifacts/tokenizer.model \ tokenizer.prompt_template=null \ enable_activation_checkpointing=True \ + enable_activation_offloading=False \ """.split() model_config = MODEL_TEST_CONFIGS["llama2_lora"] diff --git a/tests/recipes/test_ppo_full_finetune_single_device.py b/tests/recipes/test_ppo_full_finetune_single_device.py index 63a1e68dcd..d40645acf6 100644 --- a/tests/recipes/test_ppo_full_finetune_single_device.py +++ b/tests/recipes/test_ppo_full_finetune_single_device.py @@ -41,6 +41,7 @@ def _get_test_config_overrides(self): "device=cpu", "dtype=fp32", "enable_activation_checkpointing=False", + "enable_activation_offloading=False", "tokenizer.path=/tmp/test-artifacts/tokenizer.model", "tokenizer._component_=torchtune.models.llama2.llama2_tokenizer", "tokenizer.prompt_template=null", diff --git a/tests/recipes/test_qat_distributed.py b/tests/recipes/test_qat_distributed.py index 5d4d7069f1..d614afadbe 100644 --- a/tests/recipes/test_qat_distributed.py +++ b/tests/recipes/test_qat_distributed.py @@ -35,6 +35,7 @@ def _get_test_config_overrides(self): "batch_size=4", "dtype=fp32", "enable_activation_checkpointing=False", + "enable_activation_offloading=False", "dataset.train_on_input=False", "seed=9", "epochs=2", From 387fc970d9a5a7b736d6cb4c02635a3ec93b393c Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 15 Oct 2024 16:25:34 -0700 Subject: [PATCH 09/19] typos --- tests/recipes/test_full_finetune_single_device.py | 2 +- tests/recipes/test_lora_dpo_single_device.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/recipes/test_full_finetune_single_device.py b/tests/recipes/test_full_finetune_single_device.py index 36a9272903..65f7b5bc75 100644 --- a/tests/recipes/test_full_finetune_single_device.py +++ b/tests/recipes/test_full_finetune_single_device.py @@ -188,7 +188,7 @@ def _get_test_config_overrides(self): "device=cpu", "dtype=fp32", "enable_activation_checkpointing=False", - "enable_activation_offloadingg=False", + "enable_activation_offloading=False", "tokenizer.path=/tmp/test-artifacts/tokenizer.model", "tokenizer.prompt_template=null", "dataset=tests.recipes.utils.DummyDataset", diff --git a/tests/recipes/test_lora_dpo_single_device.py b/tests/recipes/test_lora_dpo_single_device.py index 79b1e4e767..703ac2e471 100644 --- a/tests/recipes/test_lora_dpo_single_device.py +++ b/tests/recipes/test_lora_dpo_single_device.py @@ -83,7 +83,7 @@ def test_training_state_on_resume( save_adapter_weights_only={save_adapter_weights_only} \ metric_logger.filename={log_file} \ enable_activation_checkpointing=True \ - enable_activation_offloading=False \ + enable_activation_offloading=False \ """.split() model_config = MODEL_TEST_CONFIGS["llama2_lora"] From cbc58d8d234635f4396f24b6e58286306f7c443b Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 15 Oct 2024 16:30:46 -0700 Subject: [PATCH 10/19] update configs --- recipes/configs/code_llama2/7B_full_low_memory.yaml | 2 +- recipes/configs/code_llama2/7B_lora_single_device.yaml | 2 +- recipes/configs/code_llama2/7B_qlora_single_device.yaml | 2 +- recipes/configs/dev/8B_full_experimental.yaml | 2 +- recipes/configs/gemma/2B_full.yaml | 2 +- recipes/configs/gemma/2B_lora.yaml | 2 +- recipes/configs/gemma/2B_lora_single_device.yaml | 2 +- recipes/configs/gemma/2B_qlora_single_device.yaml | 2 +- recipes/configs/gemma/7B_full.yaml | 2 +- recipes/configs/gemma/7B_lora.yaml | 2 +- recipes/configs/gemma/7B_lora_single_device.yaml | 2 +- recipes/configs/gemma/7B_qlora_single_device.yaml | 2 +- recipes/configs/llama2/13B_full.yaml | 2 +- recipes/configs/llama2/13B_lora.yaml | 2 +- recipes/configs/llama2/13B_qlora_single_device.yaml | 2 +- recipes/configs/llama2/70B_lora.yaml | 2 +- recipes/configs/llama2/70B_qlora.yaml | 2 +- recipes/configs/llama2/7B_full.yaml | 2 +- recipes/configs/llama2/7B_full_low_memory.yaml | 2 +- recipes/configs/llama2/7B_lora.yaml | 2 +- recipes/configs/llama2/7B_lora_single_device.yaml | 2 +- recipes/configs/llama2/7B_qlora.yaml | 2 +- recipes/configs/llama2/7B_qlora_single_device.yaml | 2 +- recipes/configs/llama3/70B_full.yaml | 2 +- recipes/configs/llama3/70B_lora.yaml | 2 +- recipes/configs/llama3/8B_dora.yaml | 2 +- recipes/configs/llama3/8B_dora_single_device.yaml | 2 +- recipes/configs/llama3/8B_full.yaml | 2 +- recipes/configs/llama3/8B_full_single_device.yaml | 2 +- recipes/configs/llama3/8B_lora.yaml | 2 +- recipes/configs/llama3/8B_lora_single_device.yaml | 2 +- recipes/configs/llama3/8B_qdora_single_device.yaml | 2 +- recipes/configs/llama3/8B_qlora_single_device.yaml | 2 +- recipes/configs/llama3_1/405B_qlora.yaml | 2 +- recipes/configs/llama3_1/70B_full.yaml | 2 +- recipes/configs/llama3_1/70B_lora.yaml | 2 +- recipes/configs/llama3_1/8B_full.yaml | 2 +- recipes/configs/llama3_1/8B_full_single_device.yaml | 2 +- recipes/configs/llama3_1/8B_lora.yaml | 2 +- recipes/configs/llama3_1/8B_lora_single_device.yaml | 2 +- recipes/configs/llama3_1/8B_qlora_single_device.yaml | 2 +- recipes/configs/llama3_2/1B_full.yaml | 2 +- recipes/configs/llama3_2/1B_full_single_device.yaml | 2 +- recipes/configs/llama3_2/1B_lora.yaml | 2 +- recipes/configs/llama3_2/1B_lora_single_device.yaml | 2 +- recipes/configs/llama3_2/1B_qlora_single_device.yaml | 2 +- recipes/configs/llama3_2/3B_full.yaml | 2 +- recipes/configs/llama3_2/3B_full_single_device.yaml | 2 +- recipes/configs/llama3_2/3B_lora.yaml | 2 +- recipes/configs/llama3_2/3B_lora_single_device.yaml | 2 +- recipes/configs/llama3_2/3B_qlora_single_device.yaml | 2 +- .../configs/llama3_2/knowledge_distillation_single_device.yaml | 1 - recipes/configs/llama3_2_vision/11B_full.yaml | 2 +- recipes/configs/llama3_2_vision/11B_full_single_device.yaml | 2 +- recipes/configs/llama3_2_vision/11B_lora.yaml | 2 +- recipes/configs/llama3_2_vision/11B_lora_single_device.yaml | 2 +- recipes/configs/mistral/7B_full.yaml | 2 +- recipes/configs/mistral/7B_full_low_memory.yaml | 2 +- recipes/configs/mistral/7B_lora.yaml | 2 +- recipes/configs/mistral/7B_lora_single_device.yaml | 2 +- recipes/configs/mistral/7B_qlora_single_device.yaml | 2 +- recipes/configs/phi3/mini_full.yaml | 2 +- recipes/configs/phi3/mini_full_low_memory.yaml | 2 +- recipes/configs/phi3/mini_lora.yaml | 2 +- recipes/configs/phi3/mini_lora_single_device.yaml | 2 +- recipes/configs/phi3/mini_qlora_single_device.yaml | 2 +- recipes/configs/qwen2/0.5B_full.yaml | 2 +- recipes/configs/qwen2/0.5B_full_single_device.yaml | 2 +- recipes/configs/qwen2/0.5B_lora.yaml | 2 +- recipes/configs/qwen2/0.5B_lora_single_device.yaml | 2 +- recipes/configs/qwen2/1.5B_full.yaml | 2 +- recipes/configs/qwen2/1.5B_full_single_device.yaml | 2 +- recipes/configs/qwen2/1.5B_lora.yaml | 2 +- recipes/configs/qwen2/1.5B_lora_single_device.yaml | 2 +- recipes/configs/qwen2/7B_full.yaml | 2 +- recipes/configs/qwen2/7B_full_single_device.yaml | 2 +- recipes/configs/qwen2/7B_lora.yaml | 2 +- recipes/configs/qwen2/7B_lora_single_device.yaml | 2 +- 78 files changed, 77 insertions(+), 78 deletions(-) diff --git a/recipes/configs/code_llama2/7B_full_low_memory.yaml b/recipes/configs/code_llama2/7B_full_low_memory.yaml index 0b74792ab9..6118b5be01 100644 --- a/recipes/configs/code_llama2/7B_full_low_memory.yaml +++ b/recipes/configs/code_llama2/7B_full_low_memory.yaml @@ -67,7 +67,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: True +enable_activation_offloading: True # True reduces memory dtype: bf16 # Logging diff --git a/recipes/configs/code_llama2/7B_lora_single_device.yaml b/recipes/configs/code_llama2/7B_lora_single_device.yaml index 75daa2b454..ab4322728b 100644 --- a/recipes/configs/code_llama2/7B_lora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_lora_single_device.yaml @@ -75,7 +75,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory dtype: bf16 # Logging diff --git a/recipes/configs/code_llama2/7B_qlora_single_device.yaml b/recipes/configs/code_llama2/7B_qlora_single_device.yaml index ab6b4e2b55..a6e12d8c8b 100644 --- a/recipes/configs/code_llama2/7B_qlora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_qlora_single_device.yaml @@ -75,7 +75,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory dtype: bf16 # Logging diff --git a/recipes/configs/dev/8B_full_experimental.yaml b/recipes/configs/dev/8B_full_experimental.yaml index 03960e85b7..b8edc9403a 100644 --- a/recipes/configs/dev/8B_full_experimental.yaml +++ b/recipes/configs/dev/8B_full_experimental.yaml @@ -64,7 +64,7 @@ device: cuda # Memory management enable_activation_checkpointing: False -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory ac_mode: 'selective' # ['selective', 'full'] ac_option: 2 # [int] = ac every positive int layer memory_efficient_fsdp_wrap: False diff --git a/recipes/configs/gemma/2B_full.yaml b/recipes/configs/gemma/2B_full.yaml index 0278b37658..24112446ae 100644 --- a/recipes/configs/gemma/2B_full.yaml +++ b/recipes/configs/gemma/2B_full.yaml @@ -60,7 +60,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/gemma/2B_lora.yaml b/recipes/configs/gemma/2B_lora.yaml index 28e20e76f0..daa0ca2cd9 100644 --- a/recipes/configs/gemma/2B_lora.yaml +++ b/recipes/configs/gemma/2B_lora.yaml @@ -72,7 +72,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/gemma/2B_lora_single_device.yaml b/recipes/configs/gemma/2B_lora_single_device.yaml index 786b0c7f2f..140f750174 100644 --- a/recipes/configs/gemma/2B_lora_single_device.yaml +++ b/recipes/configs/gemma/2B_lora_single_device.yaml @@ -72,7 +72,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/gemma/2B_qlora_single_device.yaml b/recipes/configs/gemma/2B_qlora_single_device.yaml index 39ebc088e7..f6f3ad00c0 100644 --- a/recipes/configs/gemma/2B_qlora_single_device.yaml +++ b/recipes/configs/gemma/2B_qlora_single_device.yaml @@ -72,7 +72,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/gemma/7B_full.yaml b/recipes/configs/gemma/7B_full.yaml index fa39b5e529..da32f3d2bc 100644 --- a/recipes/configs/gemma/7B_full.yaml +++ b/recipes/configs/gemma/7B_full.yaml @@ -62,7 +62,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/gemma/7B_lora.yaml b/recipes/configs/gemma/7B_lora.yaml index af90aeb7f8..f9f1e473b3 100644 --- a/recipes/configs/gemma/7B_lora.yaml +++ b/recipes/configs/gemma/7B_lora.yaml @@ -74,7 +74,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/gemma/7B_lora_single_device.yaml b/recipes/configs/gemma/7B_lora_single_device.yaml index 2edeab2047..b707603319 100644 --- a/recipes/configs/gemma/7B_lora_single_device.yaml +++ b/recipes/configs/gemma/7B_lora_single_device.yaml @@ -74,7 +74,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/gemma/7B_qlora_single_device.yaml b/recipes/configs/gemma/7B_qlora_single_device.yaml index 23d7465770..f90931eb01 100644 --- a/recipes/configs/gemma/7B_qlora_single_device.yaml +++ b/recipes/configs/gemma/7B_qlora_single_device.yaml @@ -74,7 +74,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama2/13B_full.yaml b/recipes/configs/llama2/13B_full.yaml index 6d3c3468da..8118a39efc 100644 --- a/recipes/configs/llama2/13B_full.yaml +++ b/recipes/configs/llama2/13B_full.yaml @@ -64,7 +64,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama2/13B_lora.yaml b/recipes/configs/llama2/13B_lora.yaml index 7b3b834d4a..d5977a16e9 100644 --- a/recipes/configs/llama2/13B_lora.yaml +++ b/recipes/configs/llama2/13B_lora.yaml @@ -87,4 +87,4 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: False -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama2/13B_qlora_single_device.yaml b/recipes/configs/llama2/13B_qlora_single_device.yaml index 539d692382..993d65d392 100644 --- a/recipes/configs/llama2/13B_qlora_single_device.yaml +++ b/recipes/configs/llama2/13B_qlora_single_device.yaml @@ -84,7 +84,7 @@ device: cuda dtype: bf16 enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/llama2/70B_lora.yaml b/recipes/configs/llama2/70B_lora.yaml index 00ccabd98e..ba0ca3dfa4 100644 --- a/recipes/configs/llama2/70B_lora.yaml +++ b/recipes/configs/llama2/70B_lora.yaml @@ -87,4 +87,4 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama2/70B_qlora.yaml b/recipes/configs/llama2/70B_qlora.yaml index d7348f2dc1..b7a3f27e65 100644 --- a/recipes/configs/llama2/70B_qlora.yaml +++ b/recipes/configs/llama2/70B_qlora.yaml @@ -97,4 +97,4 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama2/7B_full.yaml b/recipes/configs/llama2/7B_full.yaml index 29d05242d9..c92e986fd1 100644 --- a/recipes/configs/llama2/7B_full.yaml +++ b/recipes/configs/llama2/7B_full.yaml @@ -64,7 +64,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama2/7B_full_low_memory.yaml b/recipes/configs/llama2/7B_full_low_memory.yaml index a13c64d00d..c598be26ef 100644 --- a/recipes/configs/llama2/7B_full_low_memory.yaml +++ b/recipes/configs/llama2/7B_full_low_memory.yaml @@ -69,7 +69,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: True +enable_activation_offloading: True # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama2/7B_lora.yaml b/recipes/configs/llama2/7B_lora.yaml index 5115c1e0f1..14866c3c87 100644 --- a/recipes/configs/llama2/7B_lora.yaml +++ b/recipes/configs/llama2/7B_lora.yaml @@ -84,7 +84,7 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: False -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/llama2/7B_lora_single_device.yaml b/recipes/configs/llama2/7B_lora_single_device.yaml index 6608bdc48d..13c4880754 100644 --- a/recipes/configs/llama2/7B_lora_single_device.yaml +++ b/recipes/configs/llama2/7B_lora_single_device.yaml @@ -85,7 +85,7 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/llama2/7B_qlora.yaml b/recipes/configs/llama2/7B_qlora.yaml index 3fc6fb215a..980efee782 100644 --- a/recipes/configs/llama2/7B_qlora.yaml +++ b/recipes/configs/llama2/7B_qlora.yaml @@ -88,4 +88,4 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama2/7B_qlora_single_device.yaml b/recipes/configs/llama2/7B_qlora_single_device.yaml index 062e66d833..1049bb2cad 100644 --- a/recipes/configs/llama2/7B_qlora_single_device.yaml +++ b/recipes/configs/llama2/7B_qlora_single_device.yaml @@ -84,7 +84,7 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/llama3/70B_full.yaml b/recipes/configs/llama3/70B_full.yaml index 0c8baaf8a2..48653abdd1 100644 --- a/recipes/configs/llama3/70B_full.yaml +++ b/recipes/configs/llama3/70B_full.yaml @@ -97,7 +97,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory custom_sharded_layers: ['tok_embeddings', 'output'] fsdp_cpu_offload: True compile: False # set it to True for better memory and performance diff --git a/recipes/configs/llama3/70B_lora.yaml b/recipes/configs/llama3/70B_lora.yaml index aaeec323f3..5fd7563a9b 100644 --- a/recipes/configs/llama3/70B_lora.yaml +++ b/recipes/configs/llama3/70B_lora.yaml @@ -103,4 +103,4 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama3/8B_dora.yaml b/recipes/configs/llama3/8B_dora.yaml index a8891b9fdb..50cbc923f1 100644 --- a/recipes/configs/llama3/8B_dora.yaml +++ b/recipes/configs/llama3/8B_dora.yaml @@ -77,4 +77,4 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: False -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama3/8B_dora_single_device.yaml b/recipes/configs/llama3/8B_dora_single_device.yaml index 59443921b8..233f7f4a9e 100644 --- a/recipes/configs/llama3/8B_dora_single_device.yaml +++ b/recipes/configs/llama3/8B_dora_single_device.yaml @@ -80,7 +80,7 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/llama3/8B_full.yaml b/recipes/configs/llama3/8B_full.yaml index 2ed4e7a351..1c3f2a4b62 100644 --- a/recipes/configs/llama3/8B_full.yaml +++ b/recipes/configs/llama3/8B_full.yaml @@ -64,7 +64,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory custom_sharded_layers: ['tok_embeddings', 'output'] # Reduced precision diff --git a/recipes/configs/llama3/8B_full_single_device.yaml b/recipes/configs/llama3/8B_full_single_device.yaml index 69b2198639..997b9617e4 100644 --- a/recipes/configs/llama3/8B_full_single_device.yaml +++ b/recipes/configs/llama3/8B_full_single_device.yaml @@ -68,7 +68,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama3/8B_lora.yaml b/recipes/configs/llama3/8B_lora.yaml index 6027c7f11b..e69dac1ea5 100644 --- a/recipes/configs/llama3/8B_lora.yaml +++ b/recipes/configs/llama3/8B_lora.yaml @@ -82,4 +82,4 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: False -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama3/8B_lora_single_device.yaml b/recipes/configs/llama3/8B_lora_single_device.yaml index 0d9cb71a16..f59f4e4635 100644 --- a/recipes/configs/llama3/8B_lora_single_device.yaml +++ b/recipes/configs/llama3/8B_lora_single_device.yaml @@ -84,7 +84,7 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3/8B_qdora_single_device.yaml b/recipes/configs/llama3/8B_qdora_single_device.yaml index 0c80161200..039ccbfb90 100644 --- a/recipes/configs/llama3/8B_qdora_single_device.yaml +++ b/recipes/configs/llama3/8B_qdora_single_device.yaml @@ -81,7 +81,7 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/llama3/8B_qlora_single_device.yaml b/recipes/configs/llama3/8B_qlora_single_device.yaml index 0d831a8b77..3171bf0d14 100644 --- a/recipes/configs/llama3/8B_qlora_single_device.yaml +++ b/recipes/configs/llama3/8B_qlora_single_device.yaml @@ -83,7 +83,7 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: True -enable_activation_offloading: True +enable_activation_offloading: True # True reduces memory # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_1/405B_qlora.yaml b/recipes/configs/llama3_1/405B_qlora.yaml index 395cc931fd..406997b589 100644 --- a/recipes/configs/llama3_1/405B_qlora.yaml +++ b/recipes/configs/llama3_1/405B_qlora.yaml @@ -85,4 +85,4 @@ log_peak_memory_stats: True device: cuda dtype: bf16 enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama3_1/70B_full.yaml b/recipes/configs/llama3_1/70B_full.yaml index f88ab85789..4f56d42392 100644 --- a/recipes/configs/llama3_1/70B_full.yaml +++ b/recipes/configs/llama3_1/70B_full.yaml @@ -99,7 +99,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory custom_sharded_layers: ['tok_embeddings', 'output'] fsdp_cpu_offload: True compile: False # set it to True for better memory and performance diff --git a/recipes/configs/llama3_1/70B_lora.yaml b/recipes/configs/llama3_1/70B_lora.yaml index 689ccf56b3..638b6e3bb6 100644 --- a/recipes/configs/llama3_1/70B_lora.yaml +++ b/recipes/configs/llama3_1/70B_lora.yaml @@ -102,4 +102,4 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama3_1/8B_full.yaml b/recipes/configs/llama3_1/8B_full.yaml index cd088386f5..801ac40547 100644 --- a/recipes/configs/llama3_1/8B_full.yaml +++ b/recipes/configs/llama3_1/8B_full.yaml @@ -67,7 +67,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory custom_sharded_layers: ['tok_embeddings', 'output'] compile: False # set it to True for better memory and performance diff --git a/recipes/configs/llama3_1/8B_full_single_device.yaml b/recipes/configs/llama3_1/8B_full_single_device.yaml index 9baf5cbb93..39ec2251b2 100644 --- a/recipes/configs/llama3_1/8B_full_single_device.yaml +++ b/recipes/configs/llama3_1/8B_full_single_device.yaml @@ -68,7 +68,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama3_1/8B_lora.yaml b/recipes/configs/llama3_1/8B_lora.yaml index f622a8a27e..17eb1a4b44 100644 --- a/recipes/configs/llama3_1/8B_lora.yaml +++ b/recipes/configs/llama3_1/8B_lora.yaml @@ -86,4 +86,4 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: False -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama3_1/8B_lora_single_device.yaml b/recipes/configs/llama3_1/8B_lora_single_device.yaml index c951abc3a5..71082b1a1f 100644 --- a/recipes/configs/llama3_1/8B_lora_single_device.yaml +++ b/recipes/configs/llama3_1/8B_lora_single_device.yaml @@ -87,7 +87,7 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_1/8B_qlora_single_device.yaml b/recipes/configs/llama3_1/8B_qlora_single_device.yaml index 0b3e615bc9..41f8f5292d 100644 --- a/recipes/configs/llama3_1/8B_qlora_single_device.yaml +++ b/recipes/configs/llama3_1/8B_qlora_single_device.yaml @@ -86,7 +86,7 @@ dtype: bf16 # Activations Offloading enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_2/1B_full.yaml b/recipes/configs/llama3_2/1B_full.yaml index 1d56a1fb95..39be42bd7b 100644 --- a/recipes/configs/llama3_2/1B_full.yaml +++ b/recipes/configs/llama3_2/1B_full.yaml @@ -64,7 +64,7 @@ device: cuda # Memory management enable_activation_checkpointing: False -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory compile: False # set it to True for better memory and performance # Reduced precision diff --git a/recipes/configs/llama3_2/1B_full_single_device.yaml b/recipes/configs/llama3_2/1B_full_single_device.yaml index e208cdb004..a94ced2c6f 100644 --- a/recipes/configs/llama3_2/1B_full_single_device.yaml +++ b/recipes/configs/llama3_2/1B_full_single_device.yaml @@ -65,7 +65,7 @@ device: cuda # Memory management enable_activation_checkpointing: False -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama3_2/1B_lora.yaml b/recipes/configs/llama3_2/1B_lora.yaml index 630c4110f3..5f1bac66e5 100644 --- a/recipes/configs/llama3_2/1B_lora.yaml +++ b/recipes/configs/llama3_2/1B_lora.yaml @@ -83,4 +83,4 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: False -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama3_2/1B_lora_single_device.yaml b/recipes/configs/llama3_2/1B_lora_single_device.yaml index c69728ac0d..ea5b525364 100644 --- a/recipes/configs/llama3_2/1B_lora_single_device.yaml +++ b/recipes/configs/llama3_2/1B_lora_single_device.yaml @@ -84,7 +84,7 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: False -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_2/1B_qlora_single_device.yaml b/recipes/configs/llama3_2/1B_qlora_single_device.yaml index ca60a687eb..c016fbb751 100644 --- a/recipes/configs/llama3_2/1B_qlora_single_device.yaml +++ b/recipes/configs/llama3_2/1B_qlora_single_device.yaml @@ -83,7 +83,7 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: False -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_2/3B_full.yaml b/recipes/configs/llama3_2/3B_full.yaml index ac6a728b1d..8544e794e1 100644 --- a/recipes/configs/llama3_2/3B_full.yaml +++ b/recipes/configs/llama3_2/3B_full.yaml @@ -64,7 +64,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory compile: False # set it to True for better memory and performance # Reduced precision diff --git a/recipes/configs/llama3_2/3B_full_single_device.yaml b/recipes/configs/llama3_2/3B_full_single_device.yaml index fdea2e622f..de63818dfc 100644 --- a/recipes/configs/llama3_2/3B_full_single_device.yaml +++ b/recipes/configs/llama3_2/3B_full_single_device.yaml @@ -66,7 +66,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama3_2/3B_lora.yaml b/recipes/configs/llama3_2/3B_lora.yaml index ce6af86487..db9407d5c0 100644 --- a/recipes/configs/llama3_2/3B_lora.yaml +++ b/recipes/configs/llama3_2/3B_lora.yaml @@ -84,4 +84,4 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: False -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama3_2/3B_lora_single_device.yaml b/recipes/configs/llama3_2/3B_lora_single_device.yaml index 8fd65dd913..35dde222cf 100644 --- a/recipes/configs/llama3_2/3B_lora_single_device.yaml +++ b/recipes/configs/llama3_2/3B_lora_single_device.yaml @@ -85,7 +85,7 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_2/3B_qlora_single_device.yaml b/recipes/configs/llama3_2/3B_qlora_single_device.yaml index 4547459282..f7ab860ecf 100644 --- a/recipes/configs/llama3_2/3B_qlora_single_device.yaml +++ b/recipes/configs/llama3_2/3B_qlora_single_device.yaml @@ -84,7 +84,7 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_2/knowledge_distillation_single_device.yaml b/recipes/configs/llama3_2/knowledge_distillation_single_device.yaml index c621467582..6f0ad1900b 100644 --- a/recipes/configs/llama3_2/knowledge_distillation_single_device.yaml +++ b/recipes/configs/llama3_2/knowledge_distillation_single_device.yaml @@ -104,7 +104,6 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: False -enable_activation_offloading: False # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_2_vision/11B_full.yaml b/recipes/configs/llama3_2_vision/11B_full.yaml index f03299f157..a07c9e6ab8 100644 --- a/recipes/configs/llama3_2_vision/11B_full.yaml +++ b/recipes/configs/llama3_2_vision/11B_full.yaml @@ -66,7 +66,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory custom_sharded_layers: ['tok_embeddings', 'output'] dtype: bf16 diff --git a/recipes/configs/llama3_2_vision/11B_full_single_device.yaml b/recipes/configs/llama3_2_vision/11B_full_single_device.yaml index 9ba4a8bf74..b5e359d72c 100644 --- a/recipes/configs/llama3_2_vision/11B_full_single_device.yaml +++ b/recipes/configs/llama3_2_vision/11B_full_single_device.yaml @@ -68,7 +68,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory dtype: bf16 # Logging diff --git a/recipes/configs/llama3_2_vision/11B_lora.yaml b/recipes/configs/llama3_2_vision/11B_lora.yaml index e39ff367ba..1eeef47c57 100644 --- a/recipes/configs/llama3_2_vision/11B_lora.yaml +++ b/recipes/configs/llama3_2_vision/11B_lora.yaml @@ -76,7 +76,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory dtype: bf16 # Logging diff --git a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml index 827e04a815..ac6c47a075 100644 --- a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml +++ b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml @@ -75,7 +75,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory dtype: bf16 # Logging diff --git a/recipes/configs/mistral/7B_full.yaml b/recipes/configs/mistral/7B_full.yaml index 6a34aa4e9c..9d7dc881b0 100644 --- a/recipes/configs/mistral/7B_full.yaml +++ b/recipes/configs/mistral/7B_full.yaml @@ -66,7 +66,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/mistral/7B_full_low_memory.yaml b/recipes/configs/mistral/7B_full_low_memory.yaml index 0f34b1dd16..fa5d85d00e 100644 --- a/recipes/configs/mistral/7B_full_low_memory.yaml +++ b/recipes/configs/mistral/7B_full_low_memory.yaml @@ -68,7 +68,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: True +enable_activation_offloading: True # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/mistral/7B_lora.yaml b/recipes/configs/mistral/7B_lora.yaml index 8165e3b692..9425c2f839 100644 --- a/recipes/configs/mistral/7B_lora.yaml +++ b/recipes/configs/mistral/7B_lora.yaml @@ -80,7 +80,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/mistral/7B_lora_single_device.yaml b/recipes/configs/mistral/7B_lora_single_device.yaml index ccfb0c2cd4..ba3c72caaf 100644 --- a/recipes/configs/mistral/7B_lora_single_device.yaml +++ b/recipes/configs/mistral/7B_lora_single_device.yaml @@ -78,7 +78,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/mistral/7B_qlora_single_device.yaml b/recipes/configs/mistral/7B_qlora_single_device.yaml index 0e2fa20d94..aa8843315c 100644 --- a/recipes/configs/mistral/7B_qlora_single_device.yaml +++ b/recipes/configs/mistral/7B_qlora_single_device.yaml @@ -79,7 +79,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/phi3/mini_full.yaml b/recipes/configs/phi3/mini_full.yaml index b7f7afe785..d1058a97fa 100644 --- a/recipes/configs/phi3/mini_full.yaml +++ b/recipes/configs/phi3/mini_full.yaml @@ -63,7 +63,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory dtype: bf16 # Logging diff --git a/recipes/configs/phi3/mini_full_low_memory.yaml b/recipes/configs/phi3/mini_full_low_memory.yaml index f097c3ba96..16e1171e01 100644 --- a/recipes/configs/phi3/mini_full_low_memory.yaml +++ b/recipes/configs/phi3/mini_full_low_memory.yaml @@ -66,7 +66,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: True +enable_activation_offloading: True # True reduces memory dtype: bf16 # Logging diff --git a/recipes/configs/phi3/mini_lora.yaml b/recipes/configs/phi3/mini_lora.yaml index 0d5ce3b278..c844c2de68 100644 --- a/recipes/configs/phi3/mini_lora.yaml +++ b/recipes/configs/phi3/mini_lora.yaml @@ -74,7 +74,7 @@ device: cuda # Memory management enable_activation_checkpointing: False -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory dtype: bf16 # Logging diff --git a/recipes/configs/phi3/mini_lora_single_device.yaml b/recipes/configs/phi3/mini_lora_single_device.yaml index 7de8a30c94..e113637d3f 100644 --- a/recipes/configs/phi3/mini_lora_single_device.yaml +++ b/recipes/configs/phi3/mini_lora_single_device.yaml @@ -73,7 +73,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/phi3/mini_qlora_single_device.yaml b/recipes/configs/phi3/mini_qlora_single_device.yaml index 1d2d5c5cbc..1635f8f3fc 100644 --- a/recipes/configs/phi3/mini_qlora_single_device.yaml +++ b/recipes/configs/phi3/mini_qlora_single_device.yaml @@ -73,7 +73,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/qwen2/0.5B_full.yaml b/recipes/configs/qwen2/0.5B_full.yaml index 4520ac932c..cfae11cec8 100644 --- a/recipes/configs/qwen2/0.5B_full.yaml +++ b/recipes/configs/qwen2/0.5B_full.yaml @@ -63,7 +63,7 @@ device: cuda # Memory management enable_activation_checkpointing: False -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/qwen2/0.5B_full_single_device.yaml b/recipes/configs/qwen2/0.5B_full_single_device.yaml index 08daf3ab0f..ec3bf00095 100644 --- a/recipes/configs/qwen2/0.5B_full_single_device.yaml +++ b/recipes/configs/qwen2/0.5B_full_single_device.yaml @@ -64,7 +64,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/qwen2/0.5B_lora.yaml b/recipes/configs/qwen2/0.5B_lora.yaml index 6dd856de11..ce33c5bf65 100644 --- a/recipes/configs/qwen2/0.5B_lora.yaml +++ b/recipes/configs/qwen2/0.5B_lora.yaml @@ -84,7 +84,7 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2/0.5B_lora_single_device.yaml b/recipes/configs/qwen2/0.5B_lora_single_device.yaml index 343eb8ea14..1590d92371 100644 --- a/recipes/configs/qwen2/0.5B_lora_single_device.yaml +++ b/recipes/configs/qwen2/0.5B_lora_single_device.yaml @@ -84,7 +84,7 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2/1.5B_full.yaml b/recipes/configs/qwen2/1.5B_full.yaml index 82819f06fe..6a0cde0c79 100644 --- a/recipes/configs/qwen2/1.5B_full.yaml +++ b/recipes/configs/qwen2/1.5B_full.yaml @@ -63,7 +63,7 @@ device: cuda # Memory management enable_activation_checkpointing: False -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/qwen2/1.5B_full_single_device.yaml b/recipes/configs/qwen2/1.5B_full_single_device.yaml index 9c21bd4d89..c899405219 100644 --- a/recipes/configs/qwen2/1.5B_full_single_device.yaml +++ b/recipes/configs/qwen2/1.5B_full_single_device.yaml @@ -69,7 +69,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/qwen2/1.5B_lora.yaml b/recipes/configs/qwen2/1.5B_lora.yaml index 3345e364b2..60e397a510 100644 --- a/recipes/configs/qwen2/1.5B_lora.yaml +++ b/recipes/configs/qwen2/1.5B_lora.yaml @@ -79,7 +79,7 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2/1.5B_lora_single_device.yaml b/recipes/configs/qwen2/1.5B_lora_single_device.yaml index 3e8377b6a1..340a86263a 100644 --- a/recipes/configs/qwen2/1.5B_lora_single_device.yaml +++ b/recipes/configs/qwen2/1.5B_lora_single_device.yaml @@ -82,7 +82,7 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2/7B_full.yaml b/recipes/configs/qwen2/7B_full.yaml index 67550203f4..61f3c80875 100644 --- a/recipes/configs/qwen2/7B_full.yaml +++ b/recipes/configs/qwen2/7B_full.yaml @@ -66,7 +66,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/qwen2/7B_full_single_device.yaml b/recipes/configs/qwen2/7B_full_single_device.yaml index e29aeda677..92c5977619 100644 --- a/recipes/configs/qwen2/7B_full_single_device.yaml +++ b/recipes/configs/qwen2/7B_full_single_device.yaml @@ -68,7 +68,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/qwen2/7B_lora.yaml b/recipes/configs/qwen2/7B_lora.yaml index 5bf2c97cb0..34d4ec6054 100644 --- a/recipes/configs/qwen2/7B_lora.yaml +++ b/recipes/configs/qwen2/7B_lora.yaml @@ -85,7 +85,7 @@ log_peak_memory_stats: False device: cuda dtype: bf16 enable_activation_checkpointing: False -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2/7B_lora_single_device.yaml b/recipes/configs/qwen2/7B_lora_single_device.yaml index 8b8d470f6d..f01eb2ba08 100644 --- a/recipes/configs/qwen2/7B_lora_single_device.yaml +++ b/recipes/configs/qwen2/7B_lora_single_device.yaml @@ -86,7 +86,7 @@ dtype: bf16 # Activations Offloading enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training From 438738e8146ea8113336950094bd5d62e45aa7e7 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 16 Oct 2024 07:49:28 -0700 Subject: [PATCH 11/19] update config --- recipes/configs/llama3/8B_qlora_single_device.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes/configs/llama3/8B_qlora_single_device.yaml b/recipes/configs/llama3/8B_qlora_single_device.yaml index 3171bf0d14..aa3c26407c 100644 --- a/recipes/configs/llama3/8B_qlora_single_device.yaml +++ b/recipes/configs/llama3/8B_qlora_single_device.yaml @@ -83,7 +83,7 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: True -enable_activation_offloading: True # True reduces memory +enable_activation_offloading: False # True reduces memory # Profiler (disabled) profiler: From 047456eb3d2bb6dfeaa36583a6eac450e7d25902 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 16 Oct 2024 07:49:40 -0700 Subject: [PATCH 12/19] add offloading to test --- tests/recipes/test_lora_finetune_distributed.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/recipes/test_lora_finetune_distributed.py b/tests/recipes/test_lora_finetune_distributed.py index 7c3480c7cb..ef4c9c9bc9 100644 --- a/tests/recipes/test_lora_finetune_distributed.py +++ b/tests/recipes/test_lora_finetune_distributed.py @@ -149,7 +149,7 @@ def test_training_state_on_resume( tokenizer.prompt_template=null \ save_adapter_weights_only={save_adapter_weights_only} \ enable_activation_checkpointing=True \ - enable_activation_offloading=False \ + enable_activation_offloading=True \ """.split() model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] @@ -175,7 +175,7 @@ def test_training_state_on_resume( resume_from_checkpoint=True \ metric_logger.filename={log_file} \ enable_activation_checkpointing=True \ - enable_activation_offloading=False \ + enable_activation_offloading=True \ """.split() cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config @@ -219,7 +219,7 @@ def test_save_and_load_merged_weights( tokenizer.path='{tokenizer_path}' \ tokenizer.prompt_template=null \ enable_activation_checkpointing=True \ - enable_activation_offloading=False \ + enable_activation_offloading=True \ """.split() model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] From 4913fc13b6e429b853d7f0ecc7c107e165ca2eba Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 16 Oct 2024 08:41:27 -0700 Subject: [PATCH 13/19] update docstring --- recipes/full_finetune_distributed.py | 6 +++--- recipes/full_finetune_single_device.py | 6 +++--- recipes/lora_finetune_distributed.py | 6 +++--- recipes/lora_finetune_single_device.py | 6 +++--- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index cbe21cdae8..e7d7c5f7c2 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -59,9 +59,9 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface): back during the backward pass. As always, there is a tradeoff--these savings in memory can come at the cost of training performance and CPU resources. To recover some runtime cost, we've added an option to enable offloading on a different stream to permit overlapping with - the computation. This option is currently only available on PyTorch nightly 2.5.0.dev20240907 - or later and will be enabled by default if an acceptable torch version is found. Activation - offloading can be used in conjunction with activation checkpointing. + the computation. This option is currently only available on PyTorch 2.5 or later and will + be enabled by default if an acceptable torch version is found. Activation offloading can be + used in conjunction with activation checkpointing. - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index 4a9f07838c..580bd77007 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -51,9 +51,9 @@ class FullFinetuneRecipeSingleDevice(FTRecipeInterface): back during the backward pass. As always, there is a tradeoff--these savings in memory can come at the cost of training performance and CPU resources. To recover some runtime cost, we've added an option to enable offloading on a different stream to permit overlapping with - the computation. This option is currently only available on PyTorch nightly 2.5.0.dev20240907 - or later and will be enabled by default if an acceptable torch version is found. Activation - offloading can be used in conjunction with activation checkpointing. + the computation. This option is currently only available on PyTorch 2.5 or later and will + be enabled by default if an acceptable torch version is found. Activation offloading can be + used in conjunction with activation checkpointing. - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 78bc2e6f5f..6c175daab3 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -68,9 +68,9 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface): back during the backward pass. As always, there is a tradeoff--these savings in memory can come at the cost of training performance and CPU resources. To recover some runtime cost, we've added an option to enable offloading on a different stream to permit overlapping with - the computation. This option is currently only available on PyTorch nightly 2.5.0.dev20240907 - or later and will be enabled by default if an acceptable torch version is found. Activation - offloading can be used in conjunction with activation checkpointing. + the computation. This option is currently only available on PyTorch 2.5 or later and will + be enabled by default if an acceptable torch version is found. Activation offloading can be + used in conjunction with activation checkpointing. - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 3e23a783bb..c22f20f8ff 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -58,9 +58,9 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface): back during the backward pass. As always, there is a tradeoff--these savings in memory can come at the cost of training performance and CPU resources. To recover some runtime cost, we've added an option to enable offloading on a different stream to permit overlapping with - the computation. This option is currently only available on PyTorch nightly 2.5.0.dev20240907 - or later and will be enabled by default if an acceptable torch version is found. Activation - offloading can be used in conjunction with activation checkpointing. + the computation. This option is currently only available on PyTorch 2.5 or later and will + be enabled by default if an acceptable torch version is found. Activation offloading can be + used in conjunction with activation checkpointing. - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In From ef56d57c07add55ff7b90cf15a67835f30e64446 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 16 Oct 2024 08:43:24 -0700 Subject: [PATCH 14/19] update vision --- recipes/configs/llama3_2_vision/11B_full.yaml | 1 - .../11B_full_single_device.yaml | 1 - recipes/configs/llama3_2_vision/11B_lora.yaml | 1 - .../11B_lora_single_device.yaml | 1 - torchtune/training/_activation_offloading.py | 24 +++++++++++++++++++ 5 files changed, 24 insertions(+), 4 deletions(-) diff --git a/recipes/configs/llama3_2_vision/11B_full.yaml b/recipes/configs/llama3_2_vision/11B_full.yaml index a07c9e6ab8..2c8f1f58fd 100644 --- a/recipes/configs/llama3_2_vision/11B_full.yaml +++ b/recipes/configs/llama3_2_vision/11B_full.yaml @@ -66,7 +66,6 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False # True reduces memory custom_sharded_layers: ['tok_embeddings', 'output'] dtype: bf16 diff --git a/recipes/configs/llama3_2_vision/11B_full_single_device.yaml b/recipes/configs/llama3_2_vision/11B_full_single_device.yaml index b5e359d72c..d42fb971e6 100644 --- a/recipes/configs/llama3_2_vision/11B_full_single_device.yaml +++ b/recipes/configs/llama3_2_vision/11B_full_single_device.yaml @@ -68,7 +68,6 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False # True reduces memory dtype: bf16 # Logging diff --git a/recipes/configs/llama3_2_vision/11B_lora.yaml b/recipes/configs/llama3_2_vision/11B_lora.yaml index 712471c6dc..b67ff6601b 100644 --- a/recipes/configs/llama3_2_vision/11B_lora.yaml +++ b/recipes/configs/llama3_2_vision/11B_lora.yaml @@ -76,7 +76,6 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False # True reduces memory dtype: bf16 # Logging diff --git a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml index 3d3a6c8170..5ada32dcd1 100644 --- a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml +++ b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml @@ -75,7 +75,6 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False # True reduces memory dtype: bf16 # Logging diff --git a/torchtune/training/_activation_offloading.py b/torchtune/training/_activation_offloading.py index ef384691a3..bfddc19be9 100644 --- a/torchtune/training/_activation_offloading.py +++ b/torchtune/training/_activation_offloading.py @@ -357,6 +357,7 @@ def get_act_offloading_ctx_manager( # step, as the cost for offloading the activation and then soon after bringing # it back is expensive. Moreover, due to heuristics in our streaming API, # we actually use more memory if we offload it as it interferes with chunkedCE. + output_head_detected = False if hasattr(model, "output"): noop_ctx = NoOpManager() if isinstance(model.output, nn.Module): @@ -366,6 +367,7 @@ def get_act_offloading_ctx_manager( model.output.register_forward_hook( lambda *args: noop_ctx.__exit__(), always_call=True ) + output_head_detected = True elif isinstance(model.output, TiedLinear): model.output.linear.register_forward_pre_hook( lambda *args: noop_ctx.__enter__() @@ -373,6 +375,28 @@ def get_act_offloading_ctx_manager( model.output.linear.register_forward_hook( lambda *args: noop_ctx.__exit__(), always_call=True ) + output_head_detected = True + + elif hasattr(model, "decoder"): + noop_ctx = NoOpManager() + if isinstance(model.decoder, nn.Module): + model.decoder.output.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + model.decoder.output.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + output_head_detected = True + + if not output_head_detected: + log.warning( + "During activation offloading, no output head was detected. " + "If your model has an output head, it will be offloaded. " + "This usually greatly slows training, given the large vocabulary size. " + "To change this behavior, set your output head as model.output and make it " + "an nn.Module." + ) + else: activations_handling_ctx = contextlib.nullcontext() From 152313855b804c5870584189ced888237312b655 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Mon, 28 Oct 2024 12:52:30 -0700 Subject: [PATCH 15/19] add back type hint --- torchtune/modules/tied_linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/modules/tied_linear.py b/torchtune/modules/tied_linear.py index 5cc1cc8d0c..67c6fea3f5 100644 --- a/torchtune/modules/tied_linear.py +++ b/torchtune/modules/tied_linear.py @@ -52,7 +52,7 @@ def __init__(self, tied_module: nn.Module): "Provided module does not have attribute 'weight'. Please check your tied_module." ) - def __call__(self, x: torch.Tensor): + def __call__(self, x: torch.Tensor) -> torch.Tensor: """ Args: x (torch.Tensor): Input tensor. Should have shape ``(..., in_dim)``, where ``in_dim`` From 2c24d487d2e36db24d16f30e83ebdeb4f333b0fa Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Mon, 28 Oct 2024 13:00:12 -0700 Subject: [PATCH 16/19] merge conflict --- recipes/lora_finetune_distributed.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 7deef38637..769a58379a 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -152,16 +152,6 @@ def __init__(self, cfg: DictConfig) -> None: ) self._log_peak_memory_stats = False - # training attributes - self._enable_activation_checkpointing = cfg.enable_activation_checkpointing - self._enable_activation_offloading = cfg.get( - "enable_activation_offloading", False - ) - if self._enable_activation_offloading and self._device.type != "cuda": - raise RuntimeError( - "enable_activation_offloading should only be enabled for training on CUDA" - ) - # These attributes constitute the recipe state and are updated by ``load_checkpoint`` # when ``resume_from_checkpoint`` is ``True`` self.seed = training.set_seed(seed=cfg.seed) From 82238f81ea6952df32c6c8c6732b90488ffe7dbb Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Mon, 28 Oct 2024 13:36:09 -0700 Subject: [PATCH 17/19] added missing logger --- torchtune/training/_activation_offloading.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchtune/training/_activation_offloading.py b/torchtune/training/_activation_offloading.py index 349ec98423..e28ddc5f2b 100644 --- a/torchtune/training/_activation_offloading.py +++ b/torchtune/training/_activation_offloading.py @@ -16,6 +16,9 @@ from torchao.dtypes.nf4tensor import NF4Tensor from torchtune.modules import TiedLinear +from torchtune.utils import get_logger + +log = get_logger("DEBUG") class OffloadActivations(saved_tensors_hooks): @@ -376,8 +379,8 @@ def get_act_offloading_ctx_manager( # it back is expensive. Moreover, due to heuristics in our streaming API, # we actually use more memory if we offload it as it interferes with chunkedCE. output_head_detected = False + noop_ctx = NoOpManager() if hasattr(model, "output"): - noop_ctx = NoOpManager() if isinstance(model.output, nn.Module): model.output.register_forward_pre_hook( lambda *args: noop_ctx.__enter__() @@ -396,7 +399,6 @@ def get_act_offloading_ctx_manager( output_head_detected = True elif hasattr(model, "decoder"): - noop_ctx = NoOpManager() if isinstance(model.decoder, nn.Module): model.decoder.output.register_forward_pre_hook( lambda *args: noop_ctx.__enter__() From f43297d7497dd52d370a51e6154c066071780c5d Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Mon, 28 Oct 2024 13:51:48 -0700 Subject: [PATCH 18/19] raise not implemented error for multimodal --- torchtune/training/_activation_offloading.py | 26 ++++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/torchtune/training/_activation_offloading.py b/torchtune/training/_activation_offloading.py index e28ddc5f2b..5a18ab9bd0 100644 --- a/torchtune/training/_activation_offloading.py +++ b/torchtune/training/_activation_offloading.py @@ -370,6 +370,9 @@ def get_act_offloading_ctx_manager( Returns: contextlib.ContextDecorator: the activation offloading context manager for the model. + + Raises: + NotImplementedError: If the model is a multimodal model and activation offloading is enabled. """ if enable_activation_offloading: activations_handling_ctx = OffloadActivations() @@ -399,14 +402,21 @@ def get_act_offloading_ctx_manager( output_head_detected = True elif hasattr(model, "decoder"): - if isinstance(model.decoder, nn.Module): - model.decoder.output.register_forward_pre_hook( - lambda *args: noop_ctx.__enter__() - ) - model.decoder.output.register_forward_hook( - lambda *args: noop_ctx.__exit__(), always_call=True - ) - output_head_detected = True + # TODO: it errors out. Needs debugging. + # assert_size_stride(rsqrt_2, (4, 32, 1601, 1), (52224, 1632, 1, 1)) + # AssertionError: expected size 4==4, stride 51232==52224 at dim=0; + # # expected size 32==32, stride 1601==1632 at dim=1 + raise NotImplementedError( + "Multimodal model does not support activation offloading yet. Please set it to False" + ) + # if isinstance(model.decoder, nn.Module): + # model.decoder.output.register_forward_pre_hook( + # lambda *args: noop_ctx.__enter__() + # ) + # model.decoder.output.register_forward_hook( + # lambda *args: noop_ctx.__exit__(), always_call=True + # ) + # output_head_detected = True if not output_head_detected: log.warning( From 785393875d45f0e5ce4727334f7d878b934f5b59 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Mon, 28 Oct 2024 13:53:01 -0700 Subject: [PATCH 19/19] update error msg --- torchtune/training/_activation_offloading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/training/_activation_offloading.py b/torchtune/training/_activation_offloading.py index 5a18ab9bd0..bee9adce6d 100644 --- a/torchtune/training/_activation_offloading.py +++ b/torchtune/training/_activation_offloading.py @@ -407,7 +407,7 @@ def get_act_offloading_ctx_manager( # AssertionError: expected size 4==4, stride 51232==52224 at dim=0; # # expected size 32==32, stride 1601==1632 at dim=1 raise NotImplementedError( - "Multimodal model does not support activation offloading yet. Please set it to False" + "Multimodal model does not support activation offloading yet. Please set enable_activation_offloading=False" ) # if isinstance(model.decoder, nn.Module): # model.decoder.output.register_forward_pre_hook(