From 2648ebcf44b64f1257ffead10f26defe23c76a17 Mon Sep 17 00:00:00 2001 From: Christoph Stumpf Date: Tue, 19 Dec 2023 18:38:23 +0100 Subject: [PATCH] Fix activation checkpointing - Create new function `activation_checkpoint_wrapper` to convert module outputs to be compatible with fairscale activation checkpoint wrapper - Use `static_graph` in img_clf training --- examples/training/img_clf/train.py | 3 +- examples/training/img_clf/train.sh | 1 + .../__pycache__/__init__.cpython-311.pyc | Bin 0 -> 187 bytes perceiver/model/core/modules.py | 34 +++++++++++++++--- .../__pycache__/__init__.cpython-311.pyc | Bin 0 -> 241 bytes 5 files changed, 32 insertions(+), 6 deletions(-) mode change 100644 => 100755 examples/training/img_clf/train.sh create mode 100644 perceiver/__pycache__/__init__.cpython-311.pyc create mode 100644 perceiver/scripts/__pycache__/__init__.cpython-311.pyc diff --git a/examples/training/img_clf/train.py b/examples/training/img_clf/train.py index 507cd64..5b0ab8f 100644 --- a/examples/training/img_clf/train.py +++ b/examples/training/img_clf/train.py @@ -47,7 +47,6 @@ def configure_optimizers(self): num_latent_channels=128, ) - if __name__ == "__main__": lit_model = LitImageClassifier.create(config) @@ -55,7 +54,7 @@ def configure_optimizers(self): accelerator="gpu", devices=2, max_epochs=30, - strategy=DDPStrategy(find_unused_parameters=False), + strategy=DDPStrategy(find_unused_parameters=False, static_graph=True), logger=TensorBoardLogger(save_dir="logs", name="img_clf"), ) diff --git a/examples/training/img_clf/train.sh b/examples/training/img_clf/train.sh old mode 100644 new mode 100755 index 74d2a97..ac8ffba --- a/examples/training/img_clf/train.sh +++ b/examples/training/img_clf/train.sh @@ -19,6 +19,7 @@ python -m perceiver.scripts.vision.image_classifier fit \ --trainer.accelerator=gpu \ --trainer.devices=2 \ --trainer.max_epochs=30 \ + --trainer.strategy=ddp_static_graph \ --trainer.logger=TensorBoardLogger \ --trainer.logger.save_dir=logs \ --trainer.logger.name=img_clf diff --git a/perceiver/__pycache__/__init__.cpython-311.pyc b/perceiver/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d08d74a1e4db2575dcb77f901f7598c79e102153 GIT binary patch literal 187 zcmZ3^%ge<81pDVVq=M+jAOZ#$p^VRLK*n^26oz01O-8?!3`I;p{%4TnFF*az;?$yI z{p5_I%;J*#f((6^)Uwo^{DR!nyb}F@qWrAXCenDzca%yH-YLRYcK7ykk wAD@|*SrQ+wS5Wzj!zMRBr8Fniu80+AJjemX{6OLZGb1D82L>2X#0(Sz0Jy&}c>n+a literal 0 HcmV?d00001 diff --git a/perceiver/model/core/modules.py b/perceiver/model/core/modules.py index 431546f..03fc93b 100644 --- a/perceiver/model/core/modules.py +++ b/perceiver/model/core/modules.py @@ -406,7 +406,7 @@ def __init__( ] if activation_checkpointing: - layers = [checkpoint_wrapper(layer, offload_to_cpu=activation_offloading) for layer in layers] + layers = [activation_checkpoint_wrapper(layer, offload_to_cpu=activation_offloading) for layer in layers] self.num_rotary_layers = num_rotary_layers super().__init__(*layers) @@ -543,7 +543,8 @@ def cross_attn(): residual_dropout=residual_dropout, ) return ( - checkpoint_wrapper(layer, offload_to_cpu=activation_offloading) if activation_checkpointing else layer + activation_checkpoint_wrapper(layer, offload_to_cpu=activation_offloading) + if activation_checkpointing else layer ) def self_attn(): @@ -659,7 +660,7 @@ def __init__( ) if activation_checkpointing: - cross_attn = checkpoint_wrapper(cross_attn, offload_to_cpu=activation_offloading) + cross_attn = activation_checkpoint_wrapper(cross_attn, offload_to_cpu=activation_offloading) self.cross_attn = cross_attn self._init_parameters(init_scale) @@ -738,7 +739,8 @@ def cross_attn(): mlp_bias=False, ) return ( - checkpoint_wrapper(layer, offload_to_cpu=activation_offloading) if activation_checkpointing else layer + activation_checkpoint_wrapper(layer, offload_to_cpu=activation_offloading) + if activation_checkpointing else layer ) def self_attn(): @@ -926,3 +928,27 @@ def forward( output.logits = self.output_adapter(output.last_hidden_state, txt_embedding=self.input_adapter.txt_embedding) return output + + +def activation_checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False): + abstract_attention_layer_original_forward = AbstractAttentionLayer.forward + + def _abstract_attention_layer_patched_forward(self, *args, **kwargs): + output = abstract_attention_layer_original_forward(self, *args, **kwargs) + if self.training and isinstance(output, ModuleOutput): + return output.last_hidden_state + return output + + AbstractAttentionLayer.forward = _abstract_attention_layer_patched_forward + + module = checkpoint_wrapper(module, offload_to_cpu=offload_to_cpu) + module_original_forward = module.forward + + def _module_patched_forward(*args, **kwargs): + output = module_original_forward(*args, **kwargs) + if isinstance(output, ModuleOutput): + return output + return ModuleOutput(last_hidden_state=output, kv_cache=None) + + module.forward = _module_patched_forward + return module diff --git a/perceiver/scripts/__pycache__/__init__.cpython-311.pyc b/perceiver/scripts/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f140a66171409777d7490ed9eca01a0e57ee38e8 GIT binary patch literal 241 zcmZ3^%ge<81pDVVq{;y4#~=<2FhLog#ej_I3@HpLj5!QZ5SlTHF@-UhL6hkvNRgi= z(=Cpa#FE6~)RN*`EXBrY#VZ*;gJgb%>4z4l78UC!XB1@?m*f{@=)0trrRL-ppQj3yPGs{wobTji29R1?tqRaxIqWJjCyv&mLc)fzkUmP~M p`6;D2sdhyiKx-I*xL6uUd|+l|WW2$^-vEX;7`PkYu!s#P4*(l4LYV*n literal 0 HcmV?d00001