Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate full finetune into multi-gpu and single device recipes #482

Merged
merged 30 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Config for FullFinetuneRecipe in full_finetune.py
# Config for FullFinetuneRecipe in full_finetune_distributed.py
#
# To launch, run the following command from root:
# tune --nnodes 1 --nproc_per_node 1 full_finetune --config alpaca_llama2_full_finetune model_checkpoint=<your_checkpoint_dir> ...
# tune --nnodes 1 --nproc_per_node 1 full_finetune_distributed --config alpaca_llama2_full_finetune_distributed model_checkpoint=<your_checkpoint_dir> ...

# Tokenizer
tokenizer:
Expand Down Expand Up @@ -38,19 +38,24 @@ loss:
_component_: torch.nn.CrossEntropyLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1
log_every_n_steps: null
run_generation: null


# Distributed
# Training env
device: cuda
dtype: fp32

# Distributed
enable_fsdp: True
enable_activation_checkpointing: True
cpu_offload: False

# Memory management
enable_activation_checkpointing: True

# Reduced precision
dtype: bf16

# Logging
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}
output_dir: /tmp/alpaca-llama2-finetune
log_every_n_steps: null
57 changes: 57 additions & 0 deletions recipes/configs/alpaca_llama2_full_finetune_single_device.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Config for FullFinetuneRecipe in full_finetune_single_device.py
#
# To launch, run the following command from root:
# tune --nnodes 1 --nproc_per_node 1 full_finetune_single_device --config alpaca_llama2_full_finetune_single_device model_checkpoint=<your_checkpoint_dir> ...

# Tokenizer
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/llama2/tokenizer.model

# Dataset
dataset:
_component_: torchtune.datasets.AlpacaDataset
train_on_input: True
seed: null
shuffle: True

# Model Arguments
model:
_component_: torchtune.models.llama2.llama2_7b

checkpointer:
_component_: torchtune.utils.FullModelMetaCheckpointer
checkpoint_dir: /tmp/llama2
checkpoint_files: [consolidated.00.pth]
recipe_checkpoint: null
output_dir: /tmp/llama2
model_type: LLAMA2
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 2
epochs: 3
optimizer:
_component_: torch.optim.SGD
lr: 2e-5
loss:
_component_: torch.nn.CrossEntropyLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1


# Training environment
device: cuda

# Memory management
enable_activation_checkpointing: True

# Reduced precision
dtype: bf16

# Logging
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}
output_dir: /tmp/alpaca-llama2-finetune
log_every_n_steps: null
84 changes: 46 additions & 38 deletions recipes/full_finetune.py → recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from omegaconf import DictConfig

from torch import nn
from torch.cuda.amp import GradScaler
from torch.distributed import init_process_group
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.optim import Optimizer
Expand All @@ -37,7 +36,7 @@ class FullFinetuneRecipe(FTRecipeInterface):
This recipe supports:
- FSDP and activation checkpointing. This is enabled by default but can be
configured using the ``enable_fsdp`` and ``enable_activation_checkpointing`` flags.
- Mixed precision training - fp32, fp16 and bf16 are supported.
- Full bf16 training via setting the ``dtype`` flag to bf16.
Comment on lines 37 to +39
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Comment is on L38-39). We should make sure we're aligned on the right default for AC, as #514 changes the default for distributed LoRA to no AC

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default has been on for memory efficiency, and I can't tell why #514 turns it off by default (doesn't appear to be in the PR description). So sticking with leaving it on for now.

- Checkpointing of model weights, optimizer state and the recipe state (epoch and seed).
- Resuming from checkpoints saved using the ``save_checkpoint`` functionality.
- Logging to terminal. WandB and TensorBoard are currently not supported.
Expand All @@ -51,21 +50,31 @@ class FullFinetuneRecipe(FTRecipeInterface):

The following configs can be used to run this recipe:
>>> tune ls
RECIPE CONFIG
full_finetune alpaca_llama2_full_finetune
RECIPE CONFIG
full_finetune_distributed alpaca_llama2_full_finetune_distributed

Args:
cfg (DictConfig): OmegaConf object parsed from yaml file

Raises:
ValueError: If ``dtype`` is set to fp16.
"""

def __init__(self, cfg: DictConfig) -> None:

self._device = utils.get_device(device=cfg.device)
self._dtype = utils.get_dtype(dtype=cfg.dtype)

self._training_precision = utils.get_dtype(dtype=cfg.dtype)
# Disable for fp16, as we haven't validated "full" fp16 with this recipe, nor
# enabled necessary features such as gradient scaling.
if self._training_precision == torch.float16:
raise ValueError(
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)
# logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 1
self._log_peak_memory_every_n_steps = 100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here, just define in the config?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's been discussion in the past about what should be configurable so as to not bloat configs. I'll defer to @kartikayk on this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In #514, we hardcoded 100 so sticking with that in a variable for now seems reasonable.


# _is_rank_zero is used primarily for logging. In the future, the logger
# should directly take care of this
Expand Down Expand Up @@ -153,9 +162,9 @@ def setup(self, cfg: DictConfig) -> None:
# checkpoint. Transforming the opt state dict is handled by this method
self._optimizer = self._setup_optimizer(
cfg_optimizer=cfg.optimizer,
opt_state_dict=ckpt_dict[utils.OPT_KEY]
if self._resume_from_checkpoint
else None,
opt_state_dict=(
ckpt_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None
),
)

self._loss_fn = config.instantiate(cfg.loss)
Expand All @@ -170,14 +179,6 @@ def setup(self, cfg: DictConfig) -> None:
batch_size=cfg.batch_size,
)

# training setup
self._autocast = utils.get_autocast(self._dtype, self._device)
self._grad_scaler = None
if self._dtype == torch.float16:
self._grad_scaler = utils.get_gradient_scaler(fsdp=cfg.enable_fsdp)
else:
self._grad_scaler = GradScaler(enabled=False)

# Finally update the recipe state which can only be correctly set after all of the
# other components have been initialized and updated.
#
Expand Down Expand Up @@ -207,7 +208,7 @@ def _setup_model(
``enable_fsdp`` should always be ``True``. This is currently a configurable flag for
running tests on CPUs.
"""
with self._device:
with utils.set_default_dtype(self._training_precision), self._device:
model = config.instantiate(cfg_model)

model = (
Expand All @@ -227,9 +228,13 @@ def _setup_model(
)

model.load_state_dict(model_state_dict)

# Validate model was loaded in with the expected dtype.
utils.validate_expected_param_dtype(model, dtype=self._training_precision)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This validates that all params in the model are of the expected type. Would be useful for catching issues where some parameters dont end up as fp32, maybe due to accidental overwrite, or state_dict hook manipulating them, etc. Can take it out if needed.

if self._is_rank_zero:
log.info("Model is initialized.")
log.info(f"Model is initialized with precision {self._training_precision}.")
utils.memory_stats_log(
"Memory Stats after model init:", device=self._device
)
return model

def _setup_optimizer(
Expand Down Expand Up @@ -339,7 +344,6 @@ def train(self) -> None:

# zero out the gradients before starting training
self._optimizer.zero_grad()

# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):

Expand All @@ -360,16 +364,13 @@ def train(self) -> None:
input_ids, labels = batch
input_ids = input_ids.to(self._device)
labels = labels.to(self._device)

with self._autocast:
logits = self._model(input_ids)
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
logits = logits.transpose(1, 2)
# Compute loss
loss = self._loss_fn(logits, labels)

logits = self._model(input_ids)
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
logits = logits.transpose(1, 2)
# Compute loss
loss = self._loss_fn(logits, labels)
# Note: We're always logging the loss before normalizing it
# Check if this is the norm or not
if self.total_training_steps % self._log_every_n_steps == 0:
Expand All @@ -383,22 +384,24 @@ def train(self) -> None:
step=self.total_training_steps,
)

# Does loss normalization need to happen within autocast context?
loss = loss / self._gradient_accumulation_steps
self._grad_scaler.scale(loss).backward()
loss.backward()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we lose grad accumulation in here somewhere?

Copy link
Member Author

@rohan-varma rohan-varma Mar 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great call (but again, CI didn't catch it, unfortunate)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think our grad accumulation test is not running for the distributed recipe. I will look into setting this up with the distributed tests

if self._should_update_weights(idx):
self._grad_scaler.step(self._optimizer)
self._grad_scaler.update()
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

# Update the number of steps when the weights are updated
self.total_training_steps += 1

# Log peak memory for iteration
if self.total_training_steps % self._log_peak_memory_every_n_steps == 0:
utils.memory_stats_log("Memory Stats:", device=self._device)

self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch)

def cleanup(self) -> None:
self._metric_logger.close()
torch.distributed.destroy_process_group()


@config.parse
Expand All @@ -407,11 +410,16 @@ def recipe_main(cfg: DictConfig) -> None:
Entry point for the recipe.

Configurable parameters are read in the following order:
- Parameters specified in ``alpaca_llama2_full_finetune.yaml``
- Parameters specified in ``alpaca_llama2_full_finetune_distributed.yaml``
- Overwritten by arguments from the command-line
"""
if utils.is_distributed():
init_process_group(backend="nccl")
if not utils.is_distributed():
raise RuntimeError(
"Distributed finetune recipe should be run via a distributed launcher."
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
)

init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is gloo moot if we don't support CPU training?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For unittest until we have GPU support.


recipe = FullFinetuneRecipe(cfg=cfg)
recipe.setup(cfg=cfg)
Expand Down
Loading
Loading