Skip to content

Commit

Permalink
Fix activation checkpointing
Browse files Browse the repository at this point in the history
- Create new function `activation_checkpoint_wrapper` to convert module outputs to be
  compatible with fairscale activation checkpoint wrapper
- Use `static_graph` (see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html)
  in img_clf training to allow training with activation checkpointing which otherwise
  fails with an error
  • Loading branch information
cstub committed Dec 30, 2023
1 parent 4ac9b2c commit 5ced275
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 6 deletions.
3 changes: 1 addition & 2 deletions examples/training/img_clf/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,14 @@ def configure_optimizers(self):
num_latent_channels=128,
)


if __name__ == "__main__":
lit_model = LitImageClassifier.create(config)

trainer = pl.Trainer(
accelerator="gpu",
devices=2,
max_epochs=30,
strategy=DDPStrategy(find_unused_parameters=False),
strategy=DDPStrategy(find_unused_parameters=False, static_graph=True),
logger=TensorBoardLogger(save_dir="logs", name="img_clf"),
)

Expand Down
1 change: 1 addition & 0 deletions examples/training/img_clf/train.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ python -m perceiver.scripts.vision.image_classifier fit \
--trainer.accelerator=gpu \
--trainer.devices=2 \
--trainer.max_epochs=30 \
--trainer.strategy=ddp_static_graph \
--trainer.logger=TensorBoardLogger \
--trainer.logger.save_dir=logs \
--trainer.logger.name=img_clf
Binary file added perceiver/__pycache__/__init__.cpython-311.pyc
Binary file not shown.
34 changes: 30 additions & 4 deletions perceiver/model/core/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def __init__(
]

if activation_checkpointing:
layers = [checkpoint_wrapper(layer, offload_to_cpu=activation_offloading) for layer in layers]
layers = [activation_checkpoint_wrapper(layer, offload_to_cpu=activation_offloading) for layer in layers]

self.num_rotary_layers = num_rotary_layers
super().__init__(*layers)
Expand Down Expand Up @@ -543,7 +543,8 @@ def cross_attn():
residual_dropout=residual_dropout,
)
return (
checkpoint_wrapper(layer, offload_to_cpu=activation_offloading) if activation_checkpointing else layer
activation_checkpoint_wrapper(layer, offload_to_cpu=activation_offloading)
if activation_checkpointing else layer
)

def self_attn():
Expand Down Expand Up @@ -659,7 +660,7 @@ def __init__(
)

if activation_checkpointing:
cross_attn = checkpoint_wrapper(cross_attn, offload_to_cpu=activation_offloading)
cross_attn = activation_checkpoint_wrapper(cross_attn, offload_to_cpu=activation_offloading)

self.cross_attn = cross_attn
self._init_parameters(init_scale)
Expand Down Expand Up @@ -738,7 +739,8 @@ def cross_attn():
mlp_bias=False,
)
return (
checkpoint_wrapper(layer, offload_to_cpu=activation_offloading) if activation_checkpointing else layer
activation_checkpoint_wrapper(layer, offload_to_cpu=activation_offloading)
if activation_checkpointing else layer
)

def self_attn():
Expand Down Expand Up @@ -926,3 +928,27 @@ def forward(

output.logits = self.output_adapter(output.last_hidden_state, txt_embedding=self.input_adapter.txt_embedding)
return output


def activation_checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False):
abstract_attention_layer_original_forward = AbstractAttentionLayer.forward

def _abstract_attention_layer_patched_forward(self, *args, **kwargs):
output = abstract_attention_layer_original_forward(self, *args, **kwargs)
if self.training and isinstance(output, ModuleOutput):
return output.last_hidden_state
return output

AbstractAttentionLayer.forward = _abstract_attention_layer_patched_forward

module = checkpoint_wrapper(module, offload_to_cpu=offload_to_cpu)
module_original_forward = module.forward

def _module_patched_forward(*args, **kwargs):
output = module_original_forward(*args, **kwargs)
if isinstance(output, ModuleOutput):
return output
return ModuleOutput(last_hidden_state=output, kv_cache=None)

module.forward = _module_patched_forward
return module

0 comments on commit 5ced275

Please sign in to comment.