You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I found there are some difference between fairseq checkpoint_wrapper and fairscale checkpoint_wrapper
Code
fairseq checkpoint_wrapper
defcheckpoint_wrapper(m, offload_to_cpu=False):
""" A friendlier wrapper for performing activation checkpointing. Compared to the PyTorch version, this version: - wraps an nn.Module, so that all subsequent calls will use checkpointing - handles keyword arguments in the forward - handles non-Tensor outputs from the forward Usage:: checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True) a, b = checkpointed_module(x, y=3, z=torch.Tensor([1])) """# should I check whether original_forward has already been set?assertnothasattr(
m, "precheckpoint_forward"
), "checkpoint function has already been applied?"m.precheckpoint_forward=m.forwardm.forward=functools.partial(
_checkpointed_forward,
m.precheckpoint_forward, # original_forwardoffload_to_cpu,
)
returnm
fairscale checkpoint_wrapper
defcheckpoint_wrapper(
module: nn.Module,
offload_to_cpu: bool=False,
) ->nn.Module:
""" A friendlier wrapper for performing activation checkpointing. Compared to the PyTorch version, this version: - wraps an nn.Module, so that all subsequent calls will use checkpointing - handles keyword arguments in the forward - handles non-Tensor outputs from the forward - supports offloading activations to CPU Usage:: checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True) a, b = checkpointed_module(x, y=3, z=torch.Tensor([1])) To understand the benefits of checkpointing and the `offload_to_cpu` flag, let's divide activations into 2 types: inner activations and outer activations w.r.t. the checkpointed modules. The inner ones are saved by activation checkpointing, the outer ones are saved by offload_to_cpu. In terms of GPU memory savings: - When inner ones are large in size and outer ones are small, checkpointing helps a lot, offload_to_cpu may help a little. - When inner ones are small and outer ones are large, checkpointing helps little, offload_to_cpu helps a lot. - When both inner and outer are large, both help and the benefit is additive. ..Note:: The first and last layers are not likely to benefit from the `offload_to_cpu` flag because (1) there are typically other references to the first layer's input, so the GPU memory won't be freed; (2) the input to the last layer is immediately used by the backward pass and won't result in memory savings. Args: module (nn.Module): The module to be wrapped offload_to_cpu (bool): Whether to offload activations to CPU. Returns: (nn.Module): Wrapped module """# Patch the batchnorm layers in case there are any in this module.patch_batchnorm(module)
# The use of weakref here is to prevent creating a ref cycle: m -> m.forward -> m.# When such cycle exists, gc won't collect the module when the module is freed.# That causes GPU memory to be leaked. See the unit test for how we catch that.## We prefer this over a class wrapper since the class wrapper would have to# proxy a lot of fields and methods.module.forward=functools.partial( # type: ignore_checkpointed_forward, type(module).forward, weakref.ref(module), offload_to_cpu
)
returnmodule
What have you tried?
I want to know which checkpoint_wrapper is more safe? Shell we need to change the fairseq checkpoint_wrapper like fairscale checkpoint_wrapper
What's your environment?
fairseq Version (e.g., 0.12.2 or main):
PyTorch Version (e.g., 2.4.0+cu121)
OS (e.g., Linux): Ubuntu22.4
How you installed fairseq (pip, source): source
Build command you used (if compiling from source): pip install -e .
Python version: 3.9
CUDA/cuDNN version: 12.1
GPU models and configuration: A100
Any other relevant information:
The text was updated successfully, but these errors were encountered:
❓ Questions and Help
Before asking:
What is your question?
I found there are some difference between fairseq checkpoint_wrapper and fairscale checkpoint_wrapper
Code
fairseq checkpoint_wrapper
fairscale checkpoint_wrapper
What have you tried?
I want to know which checkpoint_wrapper is more safe? Shell we need to change the fairseq checkpoint_wrapper like fairscale checkpoint_wrapper
What's your environment?
pip
, source): sourceThe text was updated successfully, but these errors were encountered: