diff --git a/recipes/configs/code_llama2/7B_full_low_memory.yaml b/recipes/configs/code_llama2/7B_full_low_memory.yaml index bae760c67e..ffe48249a7 100644 --- a/recipes/configs/code_llama2/7B_full_low_memory.yaml +++ b/recipes/configs/code_llama2/7B_full_low_memory.yaml @@ -69,6 +69,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: True # True reduces memory dtype: bf16 # Logging diff --git a/recipes/configs/code_llama2/7B_lora_single_device.yaml b/recipes/configs/code_llama2/7B_lora_single_device.yaml index 1ada63446b..6533420441 100644 --- a/recipes/configs/code_llama2/7B_lora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_lora_single_device.yaml @@ -77,7 +77,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory dtype: bf16 # Logging diff --git a/recipes/configs/code_llama2/7B_qlora_single_device.yaml b/recipes/configs/code_llama2/7B_qlora_single_device.yaml index e7910d73cc..afda975b9f 100644 --- a/recipes/configs/code_llama2/7B_qlora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_qlora_single_device.yaml @@ -76,7 +76,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory dtype: bf16 # Logging diff --git a/recipes/configs/dev/8B_full_experimental.yaml b/recipes/configs/dev/8B_full_experimental.yaml index ee1e0f650c..f70ec01004 100644 --- a/recipes/configs/dev/8B_full_experimental.yaml +++ b/recipes/configs/dev/8B_full_experimental.yaml @@ -65,6 +65,7 @@ device: cuda # Memory management enable_activation_checkpointing: False +enable_activation_offloading: False # True reduces memory ac_mode: 'selective' # ['selective', 'full'] ac_option: 2 # [int] = ac every positive int layer memory_efficient_fsdp_wrap: False diff --git a/recipes/configs/gemma/2B_full.yaml b/recipes/configs/gemma/2B_full.yaml index a3b8ed59f7..2bfe5995be 100644 --- a/recipes/configs/gemma/2B_full.yaml +++ b/recipes/configs/gemma/2B_full.yaml @@ -62,6 +62,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/gemma/2B_lora.yaml b/recipes/configs/gemma/2B_lora.yaml index 8ed92dd115..7169236759 100644 --- a/recipes/configs/gemma/2B_lora.yaml +++ b/recipes/configs/gemma/2B_lora.yaml @@ -74,6 +74,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/gemma/2B_lora_single_device.yaml b/recipes/configs/gemma/2B_lora_single_device.yaml index b661710caf..9bf463181e 100644 --- a/recipes/configs/gemma/2B_lora_single_device.yaml +++ b/recipes/configs/gemma/2B_lora_single_device.yaml @@ -73,7 +73,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/gemma/2B_qlora_single_device.yaml b/recipes/configs/gemma/2B_qlora_single_device.yaml index 2b5cbf96bb..250d6ef178 100644 --- a/recipes/configs/gemma/2B_qlora_single_device.yaml +++ b/recipes/configs/gemma/2B_qlora_single_device.yaml @@ -73,7 +73,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/gemma/7B_full.yaml b/recipes/configs/gemma/7B_full.yaml index eb6b8c9426..8c7ff001fd 100644 --- a/recipes/configs/gemma/7B_full.yaml +++ b/recipes/configs/gemma/7B_full.yaml @@ -64,6 +64,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/gemma/7B_lora.yaml b/recipes/configs/gemma/7B_lora.yaml index 4d74f93671..209277c9d5 100644 --- a/recipes/configs/gemma/7B_lora.yaml +++ b/recipes/configs/gemma/7B_lora.yaml @@ -76,6 +76,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/gemma/7B_lora_single_device.yaml b/recipes/configs/gemma/7B_lora_single_device.yaml index 369ba715e5..57be9a3be0 100644 --- a/recipes/configs/gemma/7B_lora_single_device.yaml +++ b/recipes/configs/gemma/7B_lora_single_device.yaml @@ -75,7 +75,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/gemma/7B_qlora_single_device.yaml b/recipes/configs/gemma/7B_qlora_single_device.yaml index 301a7b4a5d..0b52716d60 100644 --- a/recipes/configs/gemma/7B_qlora_single_device.yaml +++ b/recipes/configs/gemma/7B_qlora_single_device.yaml @@ -75,7 +75,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama2/13B_full.yaml b/recipes/configs/llama2/13B_full.yaml index be5a4e8b1d..fef60b7c21 100644 --- a/recipes/configs/llama2/13B_full.yaml +++ b/recipes/configs/llama2/13B_full.yaml @@ -66,6 +66,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama2/13B_lora.yaml b/recipes/configs/llama2/13B_lora.yaml index 797abc2a63..6dd3017c06 100644 --- a/recipes/configs/llama2/13B_lora.yaml +++ b/recipes/configs/llama2/13B_lora.yaml @@ -89,3 +89,4 @@ log_peak_memory_stats: True device: cuda dtype: bf16 enable_activation_checkpointing: False +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama2/13B_qlora_single_device.yaml b/recipes/configs/llama2/13B_qlora_single_device.yaml index 9e8faaa800..5e37ee820a 100644 --- a/recipes/configs/llama2/13B_qlora_single_device.yaml +++ b/recipes/configs/llama2/13B_qlora_single_device.yaml @@ -85,7 +85,7 @@ device: cuda dtype: bf16 enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/llama2/70B_lora.yaml b/recipes/configs/llama2/70B_lora.yaml index 9502690be2..7b936696ad 100644 --- a/recipes/configs/llama2/70B_lora.yaml +++ b/recipes/configs/llama2/70B_lora.yaml @@ -88,3 +88,4 @@ log_peak_memory_stats: True device: cuda dtype: bf16 enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama2/70B_qlora.yaml b/recipes/configs/llama2/70B_qlora.yaml index c0e2e320f3..5d778e13e3 100644 --- a/recipes/configs/llama2/70B_qlora.yaml +++ b/recipes/configs/llama2/70B_qlora.yaml @@ -98,3 +98,4 @@ log_peak_memory_stats: True device: cuda dtype: bf16 enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama2/7B_full.yaml b/recipes/configs/llama2/7B_full.yaml index 3a6e3c35f2..eea691ea86 100644 --- a/recipes/configs/llama2/7B_full.yaml +++ b/recipes/configs/llama2/7B_full.yaml @@ -65,6 +65,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama2/7B_full_low_memory.yaml b/recipes/configs/llama2/7B_full_low_memory.yaml index b9b933c2df..7380bd0756 100644 --- a/recipes/configs/llama2/7B_full_low_memory.yaml +++ b/recipes/configs/llama2/7B_full_low_memory.yaml @@ -70,6 +70,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: True # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama2/7B_lora.yaml b/recipes/configs/llama2/7B_lora.yaml index 82276fa317..7841eea584 100644 --- a/recipes/configs/llama2/7B_lora.yaml +++ b/recipes/configs/llama2/7B_lora.yaml @@ -85,6 +85,7 @@ log_peak_memory_stats: True device: cuda dtype: bf16 enable_activation_checkpointing: False +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/llama2/7B_lora_single_device.yaml b/recipes/configs/llama2/7B_lora_single_device.yaml index a1c001b868..b96d139174 100644 --- a/recipes/configs/llama2/7B_lora_single_device.yaml +++ b/recipes/configs/llama2/7B_lora_single_device.yaml @@ -86,7 +86,7 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/llama2/7B_qlora.yaml b/recipes/configs/llama2/7B_qlora.yaml index 26fc4faf11..97cdae7dac 100644 --- a/recipes/configs/llama2/7B_qlora.yaml +++ b/recipes/configs/llama2/7B_qlora.yaml @@ -89,3 +89,4 @@ log_peak_memory_stats: True device: cuda dtype: bf16 enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama2/7B_qlora_single_device.yaml b/recipes/configs/llama2/7B_qlora_single_device.yaml index 611c5b155b..ad6667b2fb 100644 --- a/recipes/configs/llama2/7B_qlora_single_device.yaml +++ b/recipes/configs/llama2/7B_qlora_single_device.yaml @@ -85,7 +85,7 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/llama3/70B_full.yaml b/recipes/configs/llama3/70B_full.yaml index 608992f737..e950b91dab 100644 --- a/recipes/configs/llama3/70B_full.yaml +++ b/recipes/configs/llama3/70B_full.yaml @@ -93,6 +93,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory custom_sharded_layers: ['tok_embeddings', 'output'] fsdp_cpu_offload: True compile: False # pytorch compile, set to true for perf/memory improvement diff --git a/recipes/configs/llama3/70B_lora.yaml b/recipes/configs/llama3/70B_lora.yaml index 247daba5cc..4ab6c13793 100644 --- a/recipes/configs/llama3/70B_lora.yaml +++ b/recipes/configs/llama3/70B_lora.yaml @@ -104,3 +104,4 @@ log_peak_memory_stats: True device: cuda dtype: bf16 enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama3/8B_dora.yaml b/recipes/configs/llama3/8B_dora.yaml index a9ea97986e..43b0fa6066 100644 --- a/recipes/configs/llama3/8B_dora.yaml +++ b/recipes/configs/llama3/8B_dora.yaml @@ -79,3 +79,4 @@ log_peak_memory_stats: True device: cuda dtype: bf16 enable_activation_checkpointing: False +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama3/8B_dora_single_device.yaml b/recipes/configs/llama3/8B_dora_single_device.yaml index 188b54f757..20f5804082 100644 --- a/recipes/configs/llama3/8B_dora_single_device.yaml +++ b/recipes/configs/llama3/8B_dora_single_device.yaml @@ -81,6 +81,7 @@ log_peak_memory_stats: True device: cuda dtype: bf16 enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/llama3/8B_full.yaml b/recipes/configs/llama3/8B_full.yaml index baa4a79417..27f569aa16 100644 --- a/recipes/configs/llama3/8B_full.yaml +++ b/recipes/configs/llama3/8B_full.yaml @@ -65,6 +65,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory custom_sharded_layers: ['tok_embeddings', 'output'] # Reduced precision diff --git a/recipes/configs/llama3/8B_full_single_device.yaml b/recipes/configs/llama3/8B_full_single_device.yaml index 6b8e1ad4b8..b86272842e 100644 --- a/recipes/configs/llama3/8B_full_single_device.yaml +++ b/recipes/configs/llama3/8B_full_single_device.yaml @@ -69,6 +69,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama3/8B_lora.yaml b/recipes/configs/llama3/8B_lora.yaml index 69a2349035..41537ccdbb 100644 --- a/recipes/configs/llama3/8B_lora.yaml +++ b/recipes/configs/llama3/8B_lora.yaml @@ -84,3 +84,4 @@ log_peak_memory_stats: True device: cuda dtype: bf16 enable_activation_checkpointing: False +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama3/8B_lora_single_device.yaml b/recipes/configs/llama3/8B_lora_single_device.yaml index 661bbe86db..6c6aefa525 100644 --- a/recipes/configs/llama3/8B_lora_single_device.yaml +++ b/recipes/configs/llama3/8B_lora_single_device.yaml @@ -85,7 +85,7 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3/8B_qdora_single_device.yaml b/recipes/configs/llama3/8B_qdora_single_device.yaml index fafda9a123..18c625a956 100644 --- a/recipes/configs/llama3/8B_qdora_single_device.yaml +++ b/recipes/configs/llama3/8B_qdora_single_device.yaml @@ -82,6 +82,7 @@ log_peak_memory_stats: True device: cuda dtype: bf16 enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/llama3/8B_qlora_single_device.yaml b/recipes/configs/llama3/8B_qlora_single_device.yaml index 83c0dcb9d1..5486ae9f1a 100644 --- a/recipes/configs/llama3/8B_qlora_single_device.yaml +++ b/recipes/configs/llama3/8B_qlora_single_device.yaml @@ -84,7 +84,7 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: True -enable_activation_offloading: True +enable_activation_offloading: False # True reduces memory # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_1/405B_qlora.yaml b/recipes/configs/llama3_1/405B_qlora.yaml index f640581ba1..67b2a0cb39 100644 --- a/recipes/configs/llama3_1/405B_qlora.yaml +++ b/recipes/configs/llama3_1/405B_qlora.yaml @@ -82,3 +82,4 @@ log_peak_memory_stats: True device: cuda dtype: bf16 enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama3_1/70B_full.yaml b/recipes/configs/llama3_1/70B_full.yaml index 97ca0a7052..34fabe663f 100644 --- a/recipes/configs/llama3_1/70B_full.yaml +++ b/recipes/configs/llama3_1/70B_full.yaml @@ -95,6 +95,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory custom_sharded_layers: ['tok_embeddings', 'output'] fsdp_cpu_offload: True compile: False # pytorch compile, set to true for perf/memory improvement diff --git a/recipes/configs/llama3_1/70B_lora.yaml b/recipes/configs/llama3_1/70B_lora.yaml index ad1bc64110..f988c97b2a 100644 --- a/recipes/configs/llama3_1/70B_lora.yaml +++ b/recipes/configs/llama3_1/70B_lora.yaml @@ -103,3 +103,4 @@ log_peak_memory_stats: True device: cuda dtype: bf16 enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama3_1/8B_full.yaml b/recipes/configs/llama3_1/8B_full.yaml index da27c91852..71ab8eedeb 100644 --- a/recipes/configs/llama3_1/8B_full.yaml +++ b/recipes/configs/llama3_1/8B_full.yaml @@ -68,6 +68,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory custom_sharded_layers: ['tok_embeddings', 'output'] compile: False # pytorch compile, set to true for perf/memory improvement diff --git a/recipes/configs/llama3_1/8B_full_single_device.yaml b/recipes/configs/llama3_1/8B_full_single_device.yaml index 04ba339b23..b26df8cb67 100644 --- a/recipes/configs/llama3_1/8B_full_single_device.yaml +++ b/recipes/configs/llama3_1/8B_full_single_device.yaml @@ -69,6 +69,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama3_1/8B_lora.yaml b/recipes/configs/llama3_1/8B_lora.yaml index d0a5202847..0793b8a57c 100644 --- a/recipes/configs/llama3_1/8B_lora.yaml +++ b/recipes/configs/llama3_1/8B_lora.yaml @@ -87,3 +87,4 @@ log_peak_memory_stats: True device: cuda dtype: bf16 enable_activation_checkpointing: False +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama3_1/8B_lora_single_device.yaml b/recipes/configs/llama3_1/8B_lora_single_device.yaml index bc9a3956f3..12ef984db9 100644 --- a/recipes/configs/llama3_1/8B_lora_single_device.yaml +++ b/recipes/configs/llama3_1/8B_lora_single_device.yaml @@ -88,7 +88,7 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_1/8B_qlora_single_device.yaml b/recipes/configs/llama3_1/8B_qlora_single_device.yaml index b194acb181..0b44eaf383 100644 --- a/recipes/configs/llama3_1/8B_qlora_single_device.yaml +++ b/recipes/configs/llama3_1/8B_qlora_single_device.yaml @@ -87,7 +87,7 @@ dtype: bf16 # Activations Offloading enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_2/1B_full.yaml b/recipes/configs/llama3_2/1B_full.yaml index c90fea966f..694a14b573 100644 --- a/recipes/configs/llama3_2/1B_full.yaml +++ b/recipes/configs/llama3_2/1B_full.yaml @@ -65,6 +65,7 @@ device: cuda # Memory management enable_activation_checkpointing: False +enable_activation_offloading: False # True reduces memory compile: False # pytorch compile, set to true for perf/memory improvement # Reduced precision diff --git a/recipes/configs/llama3_2/1B_full_single_device.yaml b/recipes/configs/llama3_2/1B_full_single_device.yaml index e4d1f87fac..fe641f3479 100644 --- a/recipes/configs/llama3_2/1B_full_single_device.yaml +++ b/recipes/configs/llama3_2/1B_full_single_device.yaml @@ -66,6 +66,7 @@ device: cuda # Memory management enable_activation_checkpointing: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama3_2/1B_lora.yaml b/recipes/configs/llama3_2/1B_lora.yaml index b5e53900ef..17ee6a8625 100644 --- a/recipes/configs/llama3_2/1B_lora.yaml +++ b/recipes/configs/llama3_2/1B_lora.yaml @@ -84,3 +84,4 @@ log_peak_memory_stats: True device: cuda dtype: bf16 enable_activation_checkpointing: False +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama3_2/1B_lora_single_device.yaml b/recipes/configs/llama3_2/1B_lora_single_device.yaml index 8c94bb0582..3e23a6e56a 100644 --- a/recipes/configs/llama3_2/1B_lora_single_device.yaml +++ b/recipes/configs/llama3_2/1B_lora_single_device.yaml @@ -85,7 +85,7 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: False -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_2/1B_qlora_single_device.yaml b/recipes/configs/llama3_2/1B_qlora_single_device.yaml index 282d0d9e89..d4530df081 100644 --- a/recipes/configs/llama3_2/1B_qlora_single_device.yaml +++ b/recipes/configs/llama3_2/1B_qlora_single_device.yaml @@ -84,7 +84,7 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: False -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_2/3B_full.yaml b/recipes/configs/llama3_2/3B_full.yaml index bfe9ef6420..2d9e9d2f3a 100644 --- a/recipes/configs/llama3_2/3B_full.yaml +++ b/recipes/configs/llama3_2/3B_full.yaml @@ -65,6 +65,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory compile: False # pytorch compile, set to true for perf/memory improvement # Reduced precision diff --git a/recipes/configs/llama3_2/3B_full_single_device.yaml b/recipes/configs/llama3_2/3B_full_single_device.yaml index 14a5369e71..16f5840edf 100644 --- a/recipes/configs/llama3_2/3B_full_single_device.yaml +++ b/recipes/configs/llama3_2/3B_full_single_device.yaml @@ -67,6 +67,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama3_2/3B_lora.yaml b/recipes/configs/llama3_2/3B_lora.yaml index 076f9d9171..a2f00ad19e 100644 --- a/recipes/configs/llama3_2/3B_lora.yaml +++ b/recipes/configs/llama3_2/3B_lora.yaml @@ -85,3 +85,4 @@ log_peak_memory_stats: True device: cuda dtype: bf16 enable_activation_checkpointing: False +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama3_2/3B_lora_single_device.yaml b/recipes/configs/llama3_2/3B_lora_single_device.yaml index b36d18f872..4add5d63aa 100644 --- a/recipes/configs/llama3_2/3B_lora_single_device.yaml +++ b/recipes/configs/llama3_2/3B_lora_single_device.yaml @@ -86,7 +86,7 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_2/3B_qlora_single_device.yaml b/recipes/configs/llama3_2/3B_qlora_single_device.yaml index 3efbd6c43c..520f616a79 100644 --- a/recipes/configs/llama3_2/3B_qlora_single_device.yaml +++ b/recipes/configs/llama3_2/3B_qlora_single_device.yaml @@ -85,7 +85,7 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_2/knowledge_distillation_single_device.yaml b/recipes/configs/llama3_2/knowledge_distillation_single_device.yaml index ba39474639..6a3f85f257 100644 --- a/recipes/configs/llama3_2/knowledge_distillation_single_device.yaml +++ b/recipes/configs/llama3_2/knowledge_distillation_single_device.yaml @@ -106,7 +106,6 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: False -enable_activation_offloading: False # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_2_vision/11B_lora.yaml b/recipes/configs/llama3_2_vision/11B_lora.yaml index a27f5f3510..803e4c2420 100644 --- a/recipes/configs/llama3_2_vision/11B_lora.yaml +++ b/recipes/configs/llama3_2_vision/11B_lora.yaml @@ -78,7 +78,6 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False dtype: bf16 # Logging diff --git a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml index 45288521a1..6f45e6a7c9 100644 --- a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml +++ b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml @@ -77,7 +77,6 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False dtype: bf16 # Logging diff --git a/recipes/configs/mistral/7B_full.yaml b/recipes/configs/mistral/7B_full.yaml index 25cf783846..db242d2b6f 100644 --- a/recipes/configs/mistral/7B_full.yaml +++ b/recipes/configs/mistral/7B_full.yaml @@ -68,6 +68,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/mistral/7B_full_low_memory.yaml b/recipes/configs/mistral/7B_full_low_memory.yaml index a6cf37fa8c..f25c150325 100644 --- a/recipes/configs/mistral/7B_full_low_memory.yaml +++ b/recipes/configs/mistral/7B_full_low_memory.yaml @@ -69,6 +69,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: True # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/mistral/7B_lora.yaml b/recipes/configs/mistral/7B_lora.yaml index a2dc801925..9ba9976f2a 100644 --- a/recipes/configs/mistral/7B_lora.yaml +++ b/recipes/configs/mistral/7B_lora.yaml @@ -82,6 +82,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/mistral/7B_lora_single_device.yaml b/recipes/configs/mistral/7B_lora_single_device.yaml index 21212f4983..6380448331 100644 --- a/recipes/configs/mistral/7B_lora_single_device.yaml +++ b/recipes/configs/mistral/7B_lora_single_device.yaml @@ -79,7 +79,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/mistral/7B_qlora_single_device.yaml b/recipes/configs/mistral/7B_qlora_single_device.yaml index e2f6884a9f..42c88af742 100644 --- a/recipes/configs/mistral/7B_qlora_single_device.yaml +++ b/recipes/configs/mistral/7B_qlora_single_device.yaml @@ -80,7 +80,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/phi3/mini_full.yaml b/recipes/configs/phi3/mini_full.yaml index 0be89337a7..bd5b00702c 100644 --- a/recipes/configs/phi3/mini_full.yaml +++ b/recipes/configs/phi3/mini_full.yaml @@ -65,6 +65,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory dtype: bf16 # Logging diff --git a/recipes/configs/phi3/mini_full_low_memory.yaml b/recipes/configs/phi3/mini_full_low_memory.yaml index 470f4a1afe..1fbb10d10f 100644 --- a/recipes/configs/phi3/mini_full_low_memory.yaml +++ b/recipes/configs/phi3/mini_full_low_memory.yaml @@ -67,6 +67,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: True # True reduces memory dtype: bf16 # Logging diff --git a/recipes/configs/phi3/mini_lora.yaml b/recipes/configs/phi3/mini_lora.yaml index 1af4929985..2391f9f383 100644 --- a/recipes/configs/phi3/mini_lora.yaml +++ b/recipes/configs/phi3/mini_lora.yaml @@ -76,6 +76,7 @@ device: cuda # Memory management enable_activation_checkpointing: False +enable_activation_offloading: False # True reduces memory dtype: bf16 # Logging diff --git a/recipes/configs/phi3/mini_lora_single_device.yaml b/recipes/configs/phi3/mini_lora_single_device.yaml index 21a12a3cc1..cec51773dc 100644 --- a/recipes/configs/phi3/mini_lora_single_device.yaml +++ b/recipes/configs/phi3/mini_lora_single_device.yaml @@ -74,7 +74,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/phi3/mini_qlora_single_device.yaml b/recipes/configs/phi3/mini_qlora_single_device.yaml index 21c9403bef..ceaa5b3530 100644 --- a/recipes/configs/phi3/mini_qlora_single_device.yaml +++ b/recipes/configs/phi3/mini_qlora_single_device.yaml @@ -74,7 +74,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/qwen2/0.5B_full.yaml b/recipes/configs/qwen2/0.5B_full.yaml index 39748ee052..133e24b1cc 100644 --- a/recipes/configs/qwen2/0.5B_full.yaml +++ b/recipes/configs/qwen2/0.5B_full.yaml @@ -64,6 +64,7 @@ device: cuda # Memory management enable_activation_checkpointing: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/qwen2/0.5B_full_single_device.yaml b/recipes/configs/qwen2/0.5B_full_single_device.yaml index 2d2afe883e..14ed13e213 100644 --- a/recipes/configs/qwen2/0.5B_full_single_device.yaml +++ b/recipes/configs/qwen2/0.5B_full_single_device.yaml @@ -65,6 +65,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/qwen2/0.5B_lora.yaml b/recipes/configs/qwen2/0.5B_lora.yaml index 33b5e968d0..a605229d2b 100644 --- a/recipes/configs/qwen2/0.5B_lora.yaml +++ b/recipes/configs/qwen2/0.5B_lora.yaml @@ -86,6 +86,7 @@ log_peak_memory_stats: True device: cuda dtype: bf16 enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2/0.5B_lora_single_device.yaml b/recipes/configs/qwen2/0.5B_lora_single_device.yaml index beeb21b072..0052086a03 100644 --- a/recipes/configs/qwen2/0.5B_lora_single_device.yaml +++ b/recipes/configs/qwen2/0.5B_lora_single_device.yaml @@ -85,7 +85,7 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2/1.5B_full.yaml b/recipes/configs/qwen2/1.5B_full.yaml index 8e850bae50..725d7fa65f 100644 --- a/recipes/configs/qwen2/1.5B_full.yaml +++ b/recipes/configs/qwen2/1.5B_full.yaml @@ -64,6 +64,7 @@ device: cuda # Memory management enable_activation_checkpointing: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/qwen2/1.5B_full_single_device.yaml b/recipes/configs/qwen2/1.5B_full_single_device.yaml index cc7fd5f566..6e140085c4 100644 --- a/recipes/configs/qwen2/1.5B_full_single_device.yaml +++ b/recipes/configs/qwen2/1.5B_full_single_device.yaml @@ -70,6 +70,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/qwen2/1.5B_lora.yaml b/recipes/configs/qwen2/1.5B_lora.yaml index 845cb71184..d5a23b571e 100644 --- a/recipes/configs/qwen2/1.5B_lora.yaml +++ b/recipes/configs/qwen2/1.5B_lora.yaml @@ -81,6 +81,7 @@ log_peak_memory_stats: True device: cuda dtype: bf16 enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2/1.5B_lora_single_device.yaml b/recipes/configs/qwen2/1.5B_lora_single_device.yaml index f2e8d2beb4..88e18352b8 100644 --- a/recipes/configs/qwen2/1.5B_lora_single_device.yaml +++ b/recipes/configs/qwen2/1.5B_lora_single_device.yaml @@ -83,7 +83,7 @@ dtype: bf16 # Activations Memory enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2/7B_full.yaml b/recipes/configs/qwen2/7B_full.yaml index 06083d908f..3c159f90fc 100644 --- a/recipes/configs/qwen2/7B_full.yaml +++ b/recipes/configs/qwen2/7B_full.yaml @@ -67,6 +67,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/qwen2/7B_full_single_device.yaml b/recipes/configs/qwen2/7B_full_single_device.yaml index 13290d82a0..5cc2c8b4b5 100644 --- a/recipes/configs/qwen2/7B_full_single_device.yaml +++ b/recipes/configs/qwen2/7B_full_single_device.yaml @@ -69,6 +69,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/qwen2/7B_lora.yaml b/recipes/configs/qwen2/7B_lora.yaml index 6e778ecd7d..612b48d156 100644 --- a/recipes/configs/qwen2/7B_lora.yaml +++ b/recipes/configs/qwen2/7B_lora.yaml @@ -87,6 +87,7 @@ log_peak_memory_stats: True device: cuda dtype: bf16 enable_activation_checkpointing: False +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2/7B_lora_single_device.yaml b/recipes/configs/qwen2/7B_lora_single_device.yaml index e0b19d03a3..1297d1bbe1 100644 --- a/recipes/configs/qwen2/7B_lora_single_device.yaml +++ b/recipes/configs/qwen2/7B_lora_single_device.yaml @@ -87,7 +87,7 @@ dtype: bf16 # Activations Offloading enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 4e1e3f24c5..8831c4f442 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -45,13 +45,25 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface): ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). DDP is currently not supported. Training on CPU is not supported. - - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` flag. Activation checkpointing helps reduce the memory footprint since we no longer keep activations in memory and instead recompute them during the backward pass. This is especially helpful for larger batch sizes when you're memory constrained. But these savings in memory come at the cost of training performance. In most cases training can slow-down quite a bit as a result of this activation recomputation. + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + we've added an option to enable offloading on a different stream to permit overlapping with + the computation. This option is currently only available on PyTorch 2.5 or later and will + be enabled by default if an acceptable torch version is found. Activation offloading can be + used in conjunction with activation checkpointing. + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In most cases this should halve the memory footprint of full precision (fp32) training, without @@ -97,6 +109,8 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface): ValueError: If ``dtype`` is set to fp16. RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. RuntimeError: If ``left_pad_sequence`` is set as the data collator. + RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. """ def __init__(self, cfg: DictConfig) -> None: @@ -138,6 +152,50 @@ def __init__(self, cfg: DictConfig) -> None: self._gradient_accumulation_steps = cfg.gradient_accumulation_steps self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False) + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif self._enable_activation_checkpointing: + log.info( + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further." + ) + + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif self._enable_activation_checkpointing: + log.info( + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further." + ) + # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests self.seed = training.set_seed(seed=cfg.seed) @@ -218,7 +276,8 @@ def setup(self, cfg: DictConfig) -> None: self._compile = cfg.get("compile", False) self._model = self._setup_model( cfg_model=cfg.model, - enable_activation_checkpointing=cfg.enable_activation_checkpointing, + enable_activation_checkpointing=self._enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, custom_sharded_layers=cfg.get("custom_sharded_layers", None), fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), @@ -358,6 +417,7 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, + enable_activation_offloading: bool, fsdp_cpu_offload: bool, reshard_after_forward: bool, model_state_dict: Dict[str, Any], @@ -435,6 +495,11 @@ def _setup_model( cpu_offload=fsdp_cpu_offload, ) + # activation offloading + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) + # Ensure no params and buffers are on meta device training.validate_no_params_on_meta_device(model) @@ -682,7 +747,8 @@ def train(self) -> None: # Shape [b, s], needed for the loss not the model labels = batch.pop("labels") - logits = self._model(**batch) + with self.activations_handling_ctx: + logits = self._model(**batch) # Shift labels to compute loss # equivalent to doing labels[..., 1:] and logits[..., :-1, :] diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index fd01aabf15..6819b6c210 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -37,13 +37,25 @@ class FullFinetuneRecipeSingleDevice(FTRecipeInterface): for single GPU training. Training on CPU is not supported. Features: - - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` flag. Activation checkpointing helps reduce the memory footprint since we no longer keep activations in memory and instead recompute them during the backward pass. This is especially helpful for larger batch sizes when you're memory constrained. But these savings in memory come at the cost of training performance. In most cases training can slow-down quite a bit as a result of this activation recomputation. + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + we've added an option to enable offloading on a different stream to permit overlapping with + the computation. This option is currently only available on PyTorch 2.5 or later and will + be enabled by default if an acceptable torch version is found. Activation offloading can be + used in conjunction with activation checkpointing. + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In most cases this should halve the memory footprint of full precision (fp32) training, without @@ -100,6 +112,8 @@ class FullFinetuneRecipeSingleDevice(FTRecipeInterface): RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. RuntimeError: If ``gradient_accumulation_steps > 1`` and ``optimizer_in_bwd`` is `True`. RuntimeError: If ``left_pad_sequence`` is set as the data collator. + RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. """ def __init__(self, cfg: DictConfig) -> None: @@ -128,6 +142,28 @@ def __init__(self, cfg: DictConfig) -> None: self._gradient_accumulation_steps = cfg.gradient_accumulation_steps self._optimizer_in_bwd = cfg.optimizer_in_bwd + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif self._enable_activation_checkpointing: + log.info( + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further." + ) + # TODO: find a better place / way to perform validation of args that don't yet # compose with each other. if self._gradient_accumulation_steps > 1 and self._optimizer_in_bwd: @@ -218,7 +254,8 @@ def setup(self, cfg: DictConfig) -> None: self._compile = cfg.compile self._model = self._setup_model( cfg_model=cfg.model, - enable_activation_checkpointing=cfg.enable_activation_checkpointing, + enable_activation_checkpointing=self._enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, compile_model=self._compile, model_state_dict=ckpt_dict[training.MODEL_KEY], ) @@ -361,6 +398,7 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, + enable_activation_offloading: bool, compile_model: bool, model_state_dict: Dict[str, Any], ) -> nn.Module: @@ -384,6 +422,12 @@ def _setup_model( training.validate_expected_param_dtype( model.named_parameters(), dtype=self._dtype ) + + # Enable activation offloading + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) + log.info(f"Model is initialized with precision {self._dtype}.") if self._device.type == "cuda": @@ -569,7 +613,8 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: # Shape [b, s], needed for the loss not the model labels = batch.pop("labels") - logits = self._model(**batch) + with self.activations_handling_ctx: + logits = self._model(**batch) # Shift labels to compute loss # equivalent to doing labels[..., 1:] and logits[..., :-1, :] diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 86147e08ca..769a58379a 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import contextlib import sys import time @@ -35,12 +34,7 @@ validate_missing_and_unexpected_for_lora, ) from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.training import ( - DummyProfiler, - NoOpManager, - OffloadActivations, - PROFILER_KEY, -) +from torchtune.training import DummyProfiler, PROFILER_KEY from tqdm import tqdm @@ -74,9 +68,9 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface): back during the backward pass. As always, there is a tradeoff--these savings in memory can come at the cost of training performance and CPU resources. To recover some runtime cost, we've added an option to enable offloading on a different stream to permit overlapping with - the computation. This option is currently only available on PyTorch nightly 2.5.0.dev20240907 - or later and will be enabled by default if an acceptable torch version is found. Activation - offloading can be used in conjunction with activation checkpointing. + the computation. This option is currently only available on PyTorch 2.5 or later and will + be enabled by default if an acceptable torch version is found. Activation offloading can be + used in conjunction with activation checkpointing. - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In @@ -129,6 +123,7 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface): RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. RuntimeError: If ``left_pad_sequence`` is set as the data collator. RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. """ def __init__(self, cfg: DictConfig) -> None: @@ -157,16 +152,6 @@ def __init__(self, cfg: DictConfig) -> None: ) self._log_peak_memory_stats = False - # training attributes - self._enable_activation_checkpointing = cfg.enable_activation_checkpointing - self._enable_activation_offloading = cfg.get( - "enable_activation_offloading", False - ) - if self._enable_activation_offloading and self._device.type != "cuda": - raise RuntimeError( - "enable_activation_offloading should only be enabled for training on CUDA" - ) - # These attributes constitute the recipe state and are updated by ``load_checkpoint`` # when ``resume_from_checkpoint`` is ``True`` self.seed = training.set_seed(seed=cfg.seed) @@ -180,6 +165,28 @@ def __init__(self, cfg: DictConfig) -> None: self._resume_from_checkpoint = cfg.resume_from_checkpoint self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif self._enable_activation_checkpointing: + log.info( + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further." + ) + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ Extract the checkpoint state from file and validate. This includes the @@ -261,7 +268,7 @@ def setup(self, cfg: DictConfig) -> None: self._model = self._setup_model( cfg_model=cfg.model, - enable_activation_checkpointing=cfg.enable_activation_checkpointing, + enable_activation_checkpointing=self._enable_activation_checkpointing, enable_activation_offloading=self._enable_activation_offloading, fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), @@ -519,23 +526,12 @@ def _setup_model( # Ensure no params and buffers are on meta device training.validate_no_params_on_meta_device(model) - self.activations_handling_ctx = contextlib.nullcontext() - if enable_activation_offloading: - self.activations_handling_ctx = OffloadActivations() - - # Below is our hack to disable offloading the last output Linear in every - # step, as the cost for offloading the activation and then soon after bringing - # it back is expensive. Moreover, due to heuristics in our streaming API, - # we actually use more memory if we offload it as it interferes with chunkedCE. - if hasattr(model, "output") and isinstance(model.output, nn.Module): - noop_ctx = NoOpManager() - model.output.register_forward_pre_hook( - lambda *args: noop_ctx.__enter__() - ) - model.output.register_forward_hook( - lambda *args: noop_ctx.__exit__(), always_call=True - ) + # activation offloading + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) + # log if self._is_rank_zero: log.info( f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index cbde0305f0..bc4018b810 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import contextlib import sys import time @@ -32,12 +31,7 @@ validate_missing_and_unexpected_for_lora, ) from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.training import ( - DummyProfiler, - NoOpManager, - OffloadActivations, - PROFILER_KEY, -) +from torchtune.training import DummyProfiler, PROFILER_KEY from tqdm import tqdm log = utils.get_logger("DEBUG") @@ -64,9 +58,9 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface): back during the backward pass. As always, there is a tradeoff--these savings in memory can come at the cost of training performance and CPU resources. To recover some runtime cost, we've added an option to enable offloading on a different stream to permit overlapping with - the computation. This option is currently only available on PyTorch nightly 2.5.0.dev20240907 - or later and will be enabled by default if an acceptable torch version is found. Activation - offloading can be used in conjunction with activation checkpointing. + the computation. This option is currently only available on PyTorch 2.5 or later and will + be enabled by default if an acceptable torch version is found. Activation offloading can be + used in conjunction with activation checkpointing. - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In @@ -120,6 +114,7 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface): ValueError: If ``dtype`` is set to fp16. RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. RuntimeError: If ``left_pad_sequence`` is set as the data collator """ @@ -158,12 +153,27 @@ def __init__(self, cfg: DictConfig) -> None: self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) self._gradient_accumulation_steps = cfg.gradient_accumulation_steps self._clip_grad_norm = cfg.get("clip_grad_norm", None) + + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) self._enable_activation_offloading = cfg.get( "enable_activation_offloading", False ) - if self._enable_activation_offloading and self._device.type != "cuda": - raise RuntimeError( - "enable_activation_offloading should only be enabled for training on CUDA" + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif self._enable_activation_checkpointing: + log.info( + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further." ) def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: @@ -248,7 +258,7 @@ def setup(self, cfg: DictConfig) -> None: # set up model self._model = self._setup_model( cfg_model=cfg.model, - enable_activation_checkpointing=cfg.enable_activation_checkpointing, + enable_activation_checkpointing=self._enable_activation_checkpointing, enable_activation_offloading=self._enable_activation_offloading, compile_model=cfg.compile, base_model_state_dict=checkpoint_dict[training.MODEL_KEY], @@ -451,22 +461,10 @@ def _setup_model( self.adapter_params.items(), dtype=self._dtype ) - self.activations_handling_ctx = contextlib.nullcontext() - if enable_activation_offloading: - self.activations_handling_ctx = OffloadActivations() - - # Below is our hack to disable offloading the last output Linear in every - # step, as the cost for offloading the activation and then soon after bringing - # it back is expensive. Moreover, due to heuristics in our streaming API, - # we actually use more memory if we offload it as it interferes with chunkedCE. - if hasattr(model, "output") and isinstance(model.output, nn.Module): - noop_ctx = NoOpManager() - model.output.register_forward_pre_hook( - lambda *args: noop_ctx.__enter__() - ) - model.output.register_forward_hook( - lambda *args: noop_ctx.__exit__(), always_call=True - ) + # activation offloading + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) log.info(f"Model is initialized with precision {self._dtype}.") diff --git a/tests/recipes/test_full_finetune_distributed.py b/tests/recipes/test_full_finetune_distributed.py index 8e5a5fca2b..a381b6ce58 100644 --- a/tests/recipes/test_full_finetune_distributed.py +++ b/tests/recipes/test_full_finetune_distributed.py @@ -33,6 +33,7 @@ def _get_test_config_overrides(self): return [ "dtype=fp32", "enable_activation_checkpointing=False", + "enable_activation_offloading=False", "dataset.train_on_input=False", "seed=9", "epochs=2", diff --git a/tests/recipes/test_full_finetune_single_device.py b/tests/recipes/test_full_finetune_single_device.py index bd90fbbfad..6d3bea10c6 100644 --- a/tests/recipes/test_full_finetune_single_device.py +++ b/tests/recipes/test_full_finetune_single_device.py @@ -36,6 +36,7 @@ def _get_test_config_overrides(self): "device=cpu", "dtype=fp32", "enable_activation_checkpointing=False", + "enable_activation_offloading=False", "dataset.train_on_input=False", "seed=9", "epochs=2", diff --git a/tests/recipes/test_knowledge_distillation_single_device.py b/tests/recipes/test_knowledge_distillation_single_device.py index 81b1c8aba2..713e05c98f 100644 --- a/tests/recipes/test_knowledge_distillation_single_device.py +++ b/tests/recipes/test_knowledge_distillation_single_device.py @@ -35,6 +35,7 @@ def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2): "device=cpu", f"dtype={dtype_str}", "enable_activation_checkpointing=False", + "enable_activation_offloading=False", "dataset.train_on_input=False", "seed=9", f"epochs={epochs}", diff --git a/tests/recipes/test_lora_dpo_single_device.py b/tests/recipes/test_lora_dpo_single_device.py index d8cdca76c2..703ac2e471 100644 --- a/tests/recipes/test_lora_dpo_single_device.py +++ b/tests/recipes/test_lora_dpo_single_device.py @@ -83,6 +83,7 @@ def test_training_state_on_resume( save_adapter_weights_only={save_adapter_weights_only} \ metric_logger.filename={log_file} \ enable_activation_checkpointing=True \ + enable_activation_offloading=False \ """.split() model_config = MODEL_TEST_CONFIGS["llama2_lora"] @@ -113,6 +114,7 @@ def test_training_state_on_resume( tokenizer.path=/tmp/test-artifacts/tokenizer.model \ tokenizer.prompt_template=null \ enable_activation_checkpointing=True \ + enable_activation_offloading=False \ """.split() cmd_2 = cmd_2 + self._get_test_config_overrides(epochs=3) + model_config monkeypatch.setattr(sys, "argv", cmd_2) @@ -144,6 +146,7 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): tokenizer.path=/tmp/test-artifacts/tokenizer.model \ tokenizer.prompt_template=null \ enable_activation_checkpointing=False \ + enable_activation_offloading=False \ """.split() model_config = MODEL_TEST_CONFIGS["llama2_lora"] diff --git a/tests/recipes/test_lora_finetune_distributed.py b/tests/recipes/test_lora_finetune_distributed.py index 7be6a13f03..c8515b43c4 100644 --- a/tests/recipes/test_lora_finetune_distributed.py +++ b/tests/recipes/test_lora_finetune_distributed.py @@ -85,6 +85,7 @@ def test_loss( tokenizer.prompt_template=null \ reshard_after_forward={reshard_after_forward} \ enable_activation_checkpointing=False \ + enable_activation_offloading=False \ """.split() model_config = MODEL_TEST_CONFIGS["llama2_lora"] @@ -154,6 +155,7 @@ def test_training_state_on_resume( tokenizer.prompt_template=null \ save_adapter_weights_only={save_adapter_weights_only} \ enable_activation_checkpointing=True \ + enable_activation_offloading=True \ """.split() model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] @@ -181,6 +183,7 @@ def test_training_state_on_resume( resume_from_checkpoint=True \ metric_logger.filename={log_file} \ enable_activation_checkpointing=True \ + enable_activation_offloading=True \ """.split() cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config @@ -226,6 +229,7 @@ def test_save_and_load_merged_weights( tokenizer.path='{tokenizer_path}' \ tokenizer.prompt_template=null \ enable_activation_checkpointing=True \ + enable_activation_offloading=True \ """.split() model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] diff --git a/tests/recipes/test_lora_finetune_single_device.py b/tests/recipes/test_lora_finetune_single_device.py index 80bc5dc072..ca10076f5f 100644 --- a/tests/recipes/test_lora_finetune_single_device.py +++ b/tests/recipes/test_lora_finetune_single_device.py @@ -156,6 +156,7 @@ def test_loss_qlora( tokenizer.prompt_template=null \ compile={compile} \ enable_activation_checkpointing=False \ + enable_activation_offloading=False \ """.split() model_config = MODEL_TEST_CONFIGS["llama2_qlora"] @@ -214,6 +215,7 @@ def test_training_state_on_resume( tokenizer.prompt_template=null \ save_adapter_weights_only={save_adapter_weights_only} \ enable_activation_checkpointing=True \ + enable_activation_offloading=False \ """.split() model_config = MODEL_TEST_CONFIGS["llama2_lora"] @@ -242,6 +244,7 @@ def test_training_state_on_resume( tokenizer.path=/tmp/test-artifacts/tokenizer.model \ tokenizer.prompt_template=null \ enable_activation_checkpointing=True \ + enable_activation_offloading=False \ """.split() cmd_2 = cmd_2 + self._get_test_config_overrides(epochs=3) + model_config monkeypatch.setattr(sys, "argv", cmd_2) @@ -274,6 +277,7 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): tokenizer.path=/tmp/test-artifacts/tokenizer.model \ tokenizer.prompt_template=null \ enable_activation_checkpointing=True \ + enable_activation_offloading=False \ """.split() model_config = MODEL_TEST_CONFIGS["llama2_lora"] diff --git a/tests/recipes/test_ppo_full_finetune_single_device.py b/tests/recipes/test_ppo_full_finetune_single_device.py index 63a1e68dcd..d40645acf6 100644 --- a/tests/recipes/test_ppo_full_finetune_single_device.py +++ b/tests/recipes/test_ppo_full_finetune_single_device.py @@ -41,6 +41,7 @@ def _get_test_config_overrides(self): "device=cpu", "dtype=fp32", "enable_activation_checkpointing=False", + "enable_activation_offloading=False", "tokenizer.path=/tmp/test-artifacts/tokenizer.model", "tokenizer._component_=torchtune.models.llama2.llama2_tokenizer", "tokenizer.prompt_template=null", diff --git a/tests/recipes/test_qat_distributed.py b/tests/recipes/test_qat_distributed.py index 18e87a71d1..f5174fb46a 100644 --- a/tests/recipes/test_qat_distributed.py +++ b/tests/recipes/test_qat_distributed.py @@ -33,6 +33,7 @@ def _get_test_config_overrides(self): return [ "dtype=fp32", "enable_activation_checkpointing=False", + "enable_activation_offloading=False", "dataset.train_on_input=False", "seed=9", "epochs=2", diff --git a/torchtune/modules/tied_linear.py b/torchtune/modules/tied_linear.py index 718abd5c67..67c6fea3f5 100644 --- a/torchtune/modules/tied_linear.py +++ b/torchtune/modules/tied_linear.py @@ -9,13 +9,33 @@ import torch.nn.functional as F +class Linear(nn.Module): + """ + nn.Module used in :func:`~torchtune.modules.tied_linear.TiedLinear`, added to work with the hooks + :class:`~torchtune.training._activation_offloading.NoOpManager` that ignore activation + offloading context manager. + + Without this class, we can't add NoOp hooks, and we will offload the activation of + the tied linear layer, which is slow. + + For more information, see how NoOpManager is called in the recipes. + """ + + def forward(self, x: torch.Tensor, weight: torch.Tensor): + return F.linear(x, weight) + + class TiedLinear: """ A tied linear layer, without bias, that shares the same weight as another linear layer. This is useful for models that use tied weights, such as :func:`~torchtune.models.qwen2_0_5b`, - :func:`~torchtune.models.qwen2_1_5b` and all of the :func:`~torchtune.models.gemma` models. + :func:`~torchtune.models.qwen2_1_5b` and all of the :func:`~torchtune.models.gemma` and + :func:`~torchtune.models.llama3_2` models. + It requires as input an nn.Module, instead of the weight of the module, so it - can work with FSDP. Otherwise, the memory reference will be lost after FSDP is applied. + can work with FSDP. When FSDP is applied, the memory pointer to the weight is different, + but the nn.Module remains the same. This is why we need to pass the nn.Module instead of + the weight, if we want to keep the weights tied. Args: tied_module (nn.Module): The module whose weight is shared. Only @@ -26,6 +46,7 @@ class TiedLinear: def __init__(self, tied_module: nn.Module): self.tied_module = tied_module + self.linear = Linear() if not hasattr(tied_module, "weight"): raise AttributeError( "Provided module does not have attribute 'weight'. Please check your tied_module." @@ -40,4 +61,4 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor: torch.Tensor: The output tensor, having shape ``(..., out_dim)``, where ``out_dim`` is \ the output dimension of the tied module. """ - return F.linear(x, self.tied_module.weight) + return self.linear(x, self.tied_module.weight) diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index f4ce81b449..db52e44cbd 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -3,7 +3,11 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtune.training._activation_offloading import NoOpManager, OffloadActivations +from torchtune.training._activation_offloading import ( + get_act_offloading_ctx_manager, + NoOpManager, + OffloadActivations, +) from torchtune.training._compile import compile_loss, compile_model from torchtune.training._distributed import ( contains_fsdp, @@ -72,6 +76,7 @@ from torchtune.training.seed import set_seed __all__ = [ + "get_act_offloading_ctx_manager", "apply_selective_activation_checkpointing", "get_dtype", "set_default_dtype", diff --git a/torchtune/training/_activation_offloading.py b/torchtune/training/_activation_offloading.py index c536e7f5ee..bee9adce6d 100644 --- a/torchtune/training/_activation_offloading.py +++ b/torchtune/training/_activation_offloading.py @@ -4,15 +4,22 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional +import contextlib +from typing import Optional, Union from warnings import warn import psutil import torch import torchao +from torch import nn from torch.autograd.graph import saved_tensors_hooks from torchao.dtypes.nf4tensor import NF4Tensor +from torchtune.modules import TiedLinear +from torchtune.utils import get_logger + +log = get_logger("DEBUG") + class OffloadActivations(saved_tensors_hooks): """Context manager under which activation tensors created in the forward pass will be offloaded. @@ -345,3 +352,82 @@ def noop(tensor): return tensor super().__init__(noop, noop) + + +def get_act_offloading_ctx_manager( + model: nn.Module, enable_activation_offloading: bool +) -> Union[OffloadActivations, contextlib.nullcontext]: + """Returns the activation offloading context manager for the model, which will be + a null context if enable_activation_offloading is False. + + If activation offloading is enabled, we return the OffloadActivations context manager. + If activation offloading is disabled, we return a NoOpManager context manager. + + Args: + model (nn.Module): the model to wrap with the activation offloading context manager. + enable_activation_offloading (bool): whether or not to enable activation offloading + for the model. + + Returns: + contextlib.ContextDecorator: the activation offloading context manager for the model. + + Raises: + NotImplementedError: If the model is a multimodal model and activation offloading is enabled. + """ + if enable_activation_offloading: + activations_handling_ctx = OffloadActivations() + + # Below is our hack to disable offloading the last output Linear in every + # step, as the cost for offloading the activation and then soon after bringing + # it back is expensive. Moreover, due to heuristics in our streaming API, + # we actually use more memory if we offload it as it interferes with chunkedCE. + output_head_detected = False + noop_ctx = NoOpManager() + if hasattr(model, "output"): + if isinstance(model.output, nn.Module): + model.output.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + model.output.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + output_head_detected = True + elif isinstance(model.output, TiedLinear): + model.output.linear.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + model.output.linear.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + output_head_detected = True + + elif hasattr(model, "decoder"): + # TODO: it errors out. Needs debugging. + # assert_size_stride(rsqrt_2, (4, 32, 1601, 1), (52224, 1632, 1, 1)) + # AssertionError: expected size 4==4, stride 51232==52224 at dim=0; + # # expected size 32==32, stride 1601==1632 at dim=1 + raise NotImplementedError( + "Multimodal model does not support activation offloading yet. Please set enable_activation_offloading=False" + ) + # if isinstance(model.decoder, nn.Module): + # model.decoder.output.register_forward_pre_hook( + # lambda *args: noop_ctx.__enter__() + # ) + # model.decoder.output.register_forward_hook( + # lambda *args: noop_ctx.__exit__(), always_call=True + # ) + # output_head_detected = True + + if not output_head_detected: + log.warning( + "During activation offloading, no output head was detected. " + "If your model has an output head, it will be offloaded. " + "This usually greatly slows training, given the large vocabulary size. " + "To change this behavior, set your output head as model.output and make it " + "an nn.Module." + ) + + else: + activations_handling_ctx = contextlib.nullcontext() + + return activations_handling_ctx