Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Activation offloading for fullfinetuning + fix tied embedding #1847

Merged
merged 21 commits into from
Oct 30, 2024
1 change: 1 addition & 0 deletions recipes/configs/code_llama2/7B_full_low_memory.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: True # True reduces memory
dtype: bf16

# Logging
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/code_llama2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory
dtype: bf16

# Logging
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/code_llama2/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory
dtype: bf16

# Logging
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/dev/8B_full_experimental.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: False
enable_activation_offloading: False # True reduces memory
ac_mode: 'selective' # ['selective', 'full']
ac_option: 2 # [int] = ac every positive int layer
memory_efficient_fsdp_wrap: False
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/2B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/2B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/2B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/2B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/7B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/7B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/13B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/13B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: False
enable_activation_offloading: False # True reduces memory
2 changes: 1 addition & 1 deletion recipes/configs/llama2/13B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ device: cuda
dtype: bf16

enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory
1 change: 1 addition & 0 deletions recipes/configs/llama2/70B_qlora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_full_low_memory.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: True # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: False
enable_activation_offloading: False # True reduces memory

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_qlora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory
2 changes: 1 addition & 1 deletion recipes/configs/llama2/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory
custom_sharded_layers: ['tok_embeddings', 'output']
fsdp_cpu_offload: True
compile: False # pytorch compile, set to true for perf/memory improvement
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_dora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: False
enable_activation_offloading: False # True reduces memory
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_dora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory
custom_sharded_layers: ['tok_embeddings', 'output']

# Reduced precision
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: False
enable_activation_offloading: False # True reduces memory
2 changes: 1 addition & 1 deletion recipes/configs/llama3/8B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Profiler (disabled)
profiler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_qdora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3/8B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: True
enable_activation_offloading: False # True reduces memory

# Profiler (disabled)
profiler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/405B_qlora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory
custom_sharded_layers: ['tok_embeddings', 'output']
fsdp_cpu_offload: True
compile: False # pytorch compile, set to true for perf/memory improvement
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/8B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory
custom_sharded_layers: ['tok_embeddings', 'output']
compile: False # pytorch compile, set to true for perf/memory improvement

Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/8B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/8B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: False
enable_activation_offloading: False # True reduces memory
2 changes: 1 addition & 1 deletion recipes/configs/llama3_1/8B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Profiler (disabled)
profiler:
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3_1/8B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ dtype: bf16

# Activations Offloading
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Profiler (disabled)
profiler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_2/1B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: False
enable_activation_offloading: False # True reduces memory
compile: False # pytorch compile, set to true for perf/memory improvement

# Reduced precision
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_2/1B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: False
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_2/1B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: False
enable_activation_offloading: False # True reduces memory
2 changes: 1 addition & 1 deletion recipes/configs/llama3_2/1B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ dtype: bf16

# Activations Memory
enable_activation_checkpointing: False
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Profiler (disabled)
profiler:
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3_2/1B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ dtype: bf16

# Activations Memory
enable_activation_checkpointing: False
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Profiler (disabled)
profiler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_2/3B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory
compile: False # pytorch compile, set to true for perf/memory improvement

# Reduced precision
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_2/3B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_2/3B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: False
enable_activation_offloading: False # True reduces memory
2 changes: 1 addition & 1 deletion recipes/configs/llama3_2/3B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Profiler (disabled)
profiler:
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3_2/3B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Profiler (disabled)
profiler:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ dtype: bf16

# Activations Memory
enable_activation_checkpointing: False
enable_activation_offloading: False

# Profiler (disabled)
profiler:
Expand Down
Loading
Loading