Skip to content

Commit

Permalink
chunked cross entropy (#1390)
Browse files Browse the repository at this point in the history
Co-authored-by: Felipe Mello <felipemello@fb.com>
  • Loading branch information
felipemello1 and Felipe Mello authored Aug 29, 2024
1 parent ec21546 commit 4fba6cd
Show file tree
Hide file tree
Showing 77 changed files with 496 additions and 107 deletions.
1 change: 1 addition & 0 deletions docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,4 @@ Losses
rlhf.loss.RSOLoss
rlhf.loss.IPOLoss
rlhf.loss.SimPOLoss
loss.CEWithChunkedOutputLoss
2 changes: 1 addition & 1 deletion recipes/configs/code_llama2/7B_full_low_memory.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ optimizer:
lr: 2e-5
optimizer_in_bwd: True
loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
compile: False

# Training env
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/code_llama2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100
loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
compile: False

# Training env
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/code_llama2/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100
loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
compile: False

# Training env
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/dev/8B_full_experimental.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ optimizer:
foreach: False

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1

Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/dev/llama2/13B_lora_fsdp2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training
epochs: 1
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/dev/llama2/70B_lora_fsdp2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training
epochs: 1
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/dev/llama2/70B_qlora_fsdp2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

fsdp:
cpu_offload: False
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/dev/llama2/7B_lora_fsdp2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training
epochs: 1
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/dev/llama2/7B_qlora_fsdp2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

fsdp:
cpu_offload: False
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/2B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1

Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/2B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Fine-tuning arguments
batch_size: 4
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/2B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Fine-tuning arguments
batch_size: 4
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/2B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Fine-tuning arguments
batch_size: 4
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/7B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1

Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/7B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Fine-tuning arguments
batch_size: 4
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Fine-tuning arguments
batch_size: 8
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Fine-tuning arguments
batch_size: 4
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/13B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1

Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/13B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training
epochs: 1
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/13B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training
epochs: 1
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training
epochs: 1
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/7B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1

Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/7B_full_low_memory.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ optimizer:
lr: 2e-5
optimizer_in_bwd: True
loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1
compile: False
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/7B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training
epochs: 1
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training
epochs: 1
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/7B_qat_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1

Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training
epochs: 1
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ optimizer:
fused: True

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1

Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training
epochs: 1
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3/8B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ optimizer:
foreach: False

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1

Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3/8B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ optimizer:
_component_: bitsandbytes.optim.PagedAdamW8bit
lr: 2e-5
loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1
optimizer_in_bwd: True
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3/8B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training
epochs: 1
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3/8B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training
epochs: 1
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3/8B_qat_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ optimizer:
foreach: False

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1

Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3/8B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training
epochs: 1
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3_1/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ optimizer:
fused: True

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1

Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3_1/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training
epochs: 1
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3_1/8B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ optimizer:
foreach: False

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1

Expand Down
27 changes: 26 additions & 1 deletion recipes/configs/llama3_1/8B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ optimizer:
_component_: bitsandbytes.optim.PagedAdamW8bit
lr: 2e-5
loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1
optimizer_in_bwd: True
Expand All @@ -79,3 +79,28 @@ metric_logger:
output_dir: /tmp/full-llama3.1-finetune
log_every_n_steps: 1
log_peak_memory_stats: False

# Profiler (disabled)
profiler:
_component_: torchtune.utils.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: True
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 1
warmup_steps: 2
active_steps: 1
num_cycles: 1
2 changes: 1 addition & 1 deletion recipes/configs/llama3_1/8B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ lr_scheduler:
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training
epochs: 1
Expand Down
Loading

0 comments on commit 4fba6cd

Please sign in to comment.