-
Notifications
You must be signed in to change notification settings - Fork 448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Separate full finetune into multi-gpu and single device recipes #482
Changes from all commits
1700469
7ec93c6
ea4fbd6
86a5340
93b36fd
1cbac8d
a8f51ef
c3424ef
fea4477
811c1e0
6e86d52
d178ef0
61ef91f
e3440df
c2c96f4
1cf1084
e5a783d
767fdbd
c2a3609
decb867
afb4e16
1636808
b071949
3774484
262f964
3244010
e4d4477
c1e4c0a
4a3b3a4
b3dbe4e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Config for FullFinetuneRecipe in full_finetune_single_device.py | ||
# | ||
# To launch, run the following command from root: | ||
# tune --nnodes 1 --nproc_per_node 1 full_finetune_single_device --config alpaca_llama2_full_finetune_single_device model_checkpoint=<your_checkpoint_dir> ... | ||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.llama2.llama2_tokenizer | ||
path: /tmp/llama2/tokenizer.model | ||
|
||
# Dataset | ||
dataset: | ||
_component_: torchtune.datasets.AlpacaDataset | ||
train_on_input: True | ||
seed: null | ||
shuffle: True | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.llama2.llama2_7b | ||
|
||
checkpointer: | ||
_component_: torchtune.utils.FullModelMetaCheckpointer | ||
checkpoint_dir: /tmp/llama2 | ||
checkpoint_files: [consolidated.00.pth] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/llama2 | ||
model_type: LLAMA2 | ||
resume_from_checkpoint: False | ||
|
||
# Fine-tuning arguments | ||
batch_size: 2 | ||
epochs: 3 | ||
optimizer: | ||
_component_: torch.optim.SGD | ||
lr: 2e-5 | ||
loss: | ||
_component_: torch.nn.CrossEntropyLoss | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 1 | ||
|
||
|
||
# Training environment | ||
device: cuda | ||
|
||
# Memory management | ||
enable_activation_checkpointing: True | ||
|
||
# Reduced precision | ||
dtype: bf16 | ||
|
||
# Logging | ||
metric_logger: | ||
_component_: torchtune.utils.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
output_dir: /tmp/alpaca-llama2-finetune | ||
log_every_n_steps: null |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,6 @@ | |
from omegaconf import DictConfig | ||
|
||
from torch import nn | ||
from torch.cuda.amp import GradScaler | ||
from torch.distributed import init_process_group | ||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
from torch.optim import Optimizer | ||
|
@@ -37,7 +36,7 @@ class FullFinetuneRecipe(FTRecipeInterface): | |
This recipe supports: | ||
- FSDP and activation checkpointing. This is enabled by default but can be | ||
configured using the ``enable_fsdp`` and ``enable_activation_checkpointing`` flags. | ||
- Mixed precision training - fp32, fp16 and bf16 are supported. | ||
- Full bf16 training via setting the ``dtype`` flag to bf16. | ||
- Checkpointing of model weights, optimizer state and the recipe state (epoch and seed). | ||
- Resuming from checkpoints saved using the ``save_checkpoint`` functionality. | ||
- Logging to terminal. WandB and TensorBoard are currently not supported. | ||
|
@@ -51,21 +50,31 @@ class FullFinetuneRecipe(FTRecipeInterface): | |
|
||
The following configs can be used to run this recipe: | ||
>>> tune ls | ||
RECIPE CONFIG | ||
full_finetune alpaca_llama2_full_finetune | ||
RECIPE CONFIG | ||
full_finetune_distributed alpaca_llama2_full_finetune_distributed | ||
|
||
Args: | ||
cfg (DictConfig): OmegaConf object parsed from yaml file | ||
|
||
Raises: | ||
ValueError: If ``dtype`` is set to fp16. | ||
""" | ||
|
||
def __init__(self, cfg: DictConfig) -> None: | ||
|
||
self._device = utils.get_device(device=cfg.device) | ||
self._dtype = utils.get_dtype(dtype=cfg.dtype) | ||
|
||
self._training_precision = utils.get_dtype(dtype=cfg.dtype) | ||
# Disable for fp16, as we haven't validated "full" fp16 with this recipe, nor | ||
# enabled necessary features such as gradient scaling. | ||
if self._training_precision == torch.float16: | ||
raise ValueError( | ||
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." | ||
) | ||
# logging attributes | ||
self._output_dir = cfg.output_dir | ||
self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 1 | ||
self._log_peak_memory_every_n_steps = 100 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar comment here, just define in the config? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's been discussion in the past about what should be configurable so as to not bloat configs. I'll defer to @kartikayk on this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In #514, we hardcoded 100 so sticking with that in a variable for now seems reasonable. |
||
|
||
# _is_rank_zero is used primarily for logging. In the future, the logger | ||
# should directly take care of this | ||
|
@@ -153,9 +162,9 @@ def setup(self, cfg: DictConfig) -> None: | |
# checkpoint. Transforming the opt state dict is handled by this method | ||
self._optimizer = self._setup_optimizer( | ||
cfg_optimizer=cfg.optimizer, | ||
opt_state_dict=ckpt_dict[utils.OPT_KEY] | ||
if self._resume_from_checkpoint | ||
else None, | ||
opt_state_dict=( | ||
ckpt_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None | ||
), | ||
) | ||
|
||
self._loss_fn = config.instantiate(cfg.loss) | ||
|
@@ -170,14 +179,6 @@ def setup(self, cfg: DictConfig) -> None: | |
batch_size=cfg.batch_size, | ||
) | ||
|
||
# training setup | ||
self._autocast = utils.get_autocast(self._dtype, self._device) | ||
self._grad_scaler = None | ||
if self._dtype == torch.float16: | ||
self._grad_scaler = utils.get_gradient_scaler(fsdp=cfg.enable_fsdp) | ||
else: | ||
self._grad_scaler = GradScaler(enabled=False) | ||
|
||
# Finally update the recipe state which can only be correctly set after all of the | ||
# other components have been initialized and updated. | ||
# | ||
|
@@ -207,7 +208,7 @@ def _setup_model( | |
``enable_fsdp`` should always be ``True``. This is currently a configurable flag for | ||
running tests on CPUs. | ||
""" | ||
with self._device: | ||
with utils.set_default_dtype(self._training_precision), self._device: | ||
model = config.instantiate(cfg_model) | ||
|
||
model = ( | ||
|
@@ -227,9 +228,13 @@ def _setup_model( | |
) | ||
|
||
model.load_state_dict(model_state_dict) | ||
|
||
# Validate model was loaded in with the expected dtype. | ||
utils.validate_expected_param_dtype(model, dtype=self._training_precision) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This validates that all params in the model are of the expected type. Would be useful for catching issues where some parameters dont end up as fp32, maybe due to accidental overwrite, or state_dict hook manipulating them, etc. Can take it out if needed. |
||
if self._is_rank_zero: | ||
log.info("Model is initialized.") | ||
log.info(f"Model is initialized with precision {self._training_precision}.") | ||
utils.memory_stats_log( | ||
"Memory Stats after model init:", device=self._device | ||
) | ||
return model | ||
|
||
def _setup_optimizer( | ||
|
@@ -339,7 +344,6 @@ def train(self) -> None: | |
|
||
# zero out the gradients before starting training | ||
self._optimizer.zero_grad() | ||
|
||
# self.epochs_run should be non-zero when we're resuming from a checkpoint | ||
for curr_epoch in range(self.epochs_run, self.total_epochs): | ||
|
||
|
@@ -360,16 +364,13 @@ def train(self) -> None: | |
input_ids, labels = batch | ||
input_ids = input_ids.to(self._device) | ||
labels = labels.to(self._device) | ||
|
||
with self._autocast: | ||
logits = self._model(input_ids) | ||
# Shift so that tokens < n predict n | ||
logits = logits[..., :-1, :].contiguous() | ||
labels = labels[..., 1:].contiguous() | ||
logits = logits.transpose(1, 2) | ||
# Compute loss | ||
loss = self._loss_fn(logits, labels) | ||
|
||
logits = self._model(input_ids) | ||
# Shift so that tokens < n predict n | ||
logits = logits[..., :-1, :].contiguous() | ||
labels = labels[..., 1:].contiguous() | ||
logits = logits.transpose(1, 2) | ||
# Compute loss | ||
loss = self._loss_fn(logits, labels) | ||
# Note: We're always logging the loss before normalizing it | ||
# Check if this is the norm or not | ||
if self.total_training_steps % self._log_every_n_steps == 0: | ||
|
@@ -383,22 +384,24 @@ def train(self) -> None: | |
step=self.total_training_steps, | ||
) | ||
|
||
# Does loss normalization need to happen within autocast context? | ||
loss = loss / self._gradient_accumulation_steps | ||
self._grad_scaler.scale(loss).backward() | ||
loss.backward() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did we lose grad accumulation in here somewhere? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great call (but again, CI didn't catch it, unfortunate) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think our grad accumulation test is not running for the distributed recipe. I will look into setting this up with the distributed tests |
||
if self._should_update_weights(idx): | ||
self._grad_scaler.step(self._optimizer) | ||
self._grad_scaler.update() | ||
self._optimizer.step() | ||
self._optimizer.zero_grad(set_to_none=True) | ||
|
||
# Update the number of steps when the weights are updated | ||
self.total_training_steps += 1 | ||
|
||
# Log peak memory for iteration | ||
if self.total_training_steps % self._log_peak_memory_every_n_steps == 0: | ||
utils.memory_stats_log("Memory Stats:", device=self._device) | ||
|
||
self.epochs_run += 1 | ||
self.save_checkpoint(epoch=curr_epoch) | ||
|
||
def cleanup(self) -> None: | ||
self._metric_logger.close() | ||
torch.distributed.destroy_process_group() | ||
|
||
|
||
@config.parse | ||
|
@@ -407,11 +410,16 @@ def recipe_main(cfg: DictConfig) -> None: | |
Entry point for the recipe. | ||
|
||
Configurable parameters are read in the following order: | ||
- Parameters specified in ``alpaca_llama2_full_finetune.yaml`` | ||
- Parameters specified in ``alpaca_llama2_full_finetune_distributed.yaml`` | ||
- Overwritten by arguments from the command-line | ||
""" | ||
if utils.is_distributed(): | ||
init_process_group(backend="nccl") | ||
if not utils.is_distributed(): | ||
raise RuntimeError( | ||
"Distributed finetune recipe should be run via a distributed launcher." | ||
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" | ||
) | ||
|
||
init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is gloo moot if we don't support CPU training? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For unittest until we have GPU support. |
||
|
||
recipe = FullFinetuneRecipe(cfg=cfg) | ||
recipe.setup(cfg=cfg) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Comment is on L38-39). We should make sure we're aligned on the right default for AC, as #514 changes the default for distributed LoRA to no AC
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default has been on for memory efficiency, and I can't tell why #514 turns it off by default (doesn't appear to be in the PR description). So sticking with leaving it on for now.