forked from pytorch/torchtune
-
Notifications
You must be signed in to change notification settings - Fork 0
/
full_finetune_single_device.py
391 lines (335 loc) · 15.2 KB
/
full_finetune_single_device.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import sys
from functools import partial
from typing import Any, Dict, Optional, Tuple
from warnings import warn
import torch
from omegaconf import DictConfig
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, utils
from torchtune.recipe_interfaces import FTRecipeInterface
from tqdm import tqdm
log = utils.get_logger("DEBUG")
class FullFinetuneRecipeSingleDevice(FTRecipeInterface):
"""
Full finetuning recipe for dense transformer-based LLMs such as Llama2.
This recipe supports:
- Activation checkpointing. This is enabled by default but can be
configured using the ``enable_activation_checkpointing`` flags.
- Full bf16 training via setting the ``dtype`` flag to bf16.
- Checkpointing of model weights, optimizer state and the recipe state (epoch and seed).
- Resuming from checkpoints saved using the ``save_checkpoint`` functionality.
- Logging to terminal, WandB, or TensorBoard.
Assumptions:
- Training is launched with the Tune CLI (recommended) which uses TorchRun under the
hood. Setting up the env variables is handled by TorchRun.
- Training happens on CUDA (CPU training is not supported)
- Checkpoints are ONLY saved at epoch boundaries. Mid-epoch checkpointing is NOT supported.
- Datasets are Map-style and data fits in memory (not streamed).
The following configs can be used to run this recipe:
>>> tune ls
RECIPE CONFIG
full_finetune_single_device full_finetune_single_device
Args:
cfg (DictConfig): OmegaConf object parsed from yaml file
Raises:
ValueError: If ``dtype`` is set to fp16.
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
"""
def __init__(self, cfg: DictConfig) -> None:
self._device = utils.get_device(device=cfg.device)
self._dtype = utils.get_dtype(dtype=cfg.dtype)
# Disable for fp16, as we haven't validated "full" fp16 with this recipe, nor
# enabled necessary features such as gradient scaling.
if self._dtype == torch.float16:
raise ValueError(
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)
# For CUDA devices, check if the HW supports bf16 if bf16 is specified.
if (
self._dtype == torch.bfloat16
and self._device != torch.device("cpu")
and not torch.cuda.is_bf16_supported()
):
raise RuntimeError("Full bf16 training is not supported on this hardware.")
# logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 1
self._log_peak_memory_every_n_steps = 100
# Training cfg
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
# These are public properties which are updated by the checkpoint loader
# when ``resume_from_checkpoint`` is `True` or validated in tests
self.seed = utils.set_seed(seed=cfg.seed)
self.epochs_run = 0
self.total_epochs = cfg.epochs
self.max_steps_per_epoch = cfg.max_steps_per_epoch
self.total_training_steps = 0
def load_checkpoint(self, cfg: DictConfig) -> Dict[str, Any]:
"""
Extract the checkpoint state from file and validate. If resume_from_checkpoint
is True, this also includes the recipe state.
"""
self._checkpointer = config.instantiate(
cfg,
resume_from_checkpoint=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()
if self._resume_from_checkpoint:
self._update_recipe_state(checkpoint_dict)
return checkpoint_dict
def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
"""
Updates the recipe state from checkpoint.
"""
# If seed, total_epoch or max_steps_per_epoch don't match,
# warn the user and overwrite
try:
if (
self.seed != ckpt_dict[utils.SEED_KEY]
or self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY]
or self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY]
):
warn(
message="""Configured value for seed, epochs or max_steps_per_epoch
does not match the value stored in checkpoint."""
)
self.seed = utils.set_seed(seed=ckpt_dict[utils.SEED_KEY])
self.epochs_run = ckpt_dict[utils.EPOCHS_KEY]
self.total_epochs = ckpt_dict[utils.TOTAL_EPOCHS_KEY]
self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY]
except KeyError as e:
raise KeyError from e(
"Checkpoint does not contain the required keys needed for updating recipe state."
"Are you sure you passed in the right recipe checkpoint?"
)
def setup(self, cfg: DictConfig) -> None:
"""
Sets up the recipe state correctly. This includes setting recipe attributes based
on the ``resume_from_checkpoint`` flag.
"""
self._metric_logger = config.instantiate(cfg.metric_logger)
ckpt_dict = self.load_checkpoint(cfg.checkpointer)
# ``_setup_model`` handles initialization and loading the state dict. This method
# should be called before ``_setup_optimizer`` since transforming the optimizer
# state dict requires the model
self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
model_state_dict=ckpt_dict[utils.MODEL_KEY],
)
self._tokenizer = config.instantiate(cfg.tokenizer)
log.info("Tokenizer is initialized from file.")
# _setup_optimizer should take in ckpt_dict only if training is resumed from
# checkpoint. Transforming the opt state dict is handled by this method
self._optimizer = self._setup_optimizer(
cfg_optimizer=cfg.optimizer,
opt_state_dict=(
ckpt_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None
),
)
self._loss_fn = config.instantiate(cfg.loss)
log.info("Loss is initialized.")
# sampler and dataloader depend on the tokenizer and loss_fn and should be
# setup after both of these are initialized
self._sampler, self._dataloader = self._setup_data(
cfg_dataset=cfg.dataset,
shuffle=cfg.shuffle,
batch_size=cfg.batch_size,
)
# Finally update the recipe state which can only be correctly set after all of the
# other components have been initialized and updated.
#
# Number of training steps in each epoch depends on the number of batches produced
# by the dataloader, the max_steps_per_epoch param set by the user and the
# gradient_accumulation_steps param. This value is used for logging and tracking
# training state. The computation should happen after the dataloader has been setup
self._steps_per_epoch = (
len(self._dataloader) // self._gradient_accumulation_steps
)
if (
self.max_steps_per_epoch is not None
and self.max_steps_per_epoch < self._steps_per_epoch
):
self._steps_per_epoch = self.max_steps_per_epoch
self.total_training_steps = self.epochs_run * self._steps_per_epoch
def _setup_model(
self,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
model_state_dict: Dict[str, Any],
) -> nn.Module:
"""
Set up the model including enabling activation checkpointing.
"""
with utils.set_default_dtype(self._dtype), self._device:
model = config.instantiate(cfg_model)
if enable_activation_checkpointing:
utils.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerDecoderLayer}
)
model.load_state_dict(model_state_dict)
# Validate model was loaded in with the expected dtype.
utils.validate_expected_param_dtype(model, dtype=self._dtype)
log.info(f"Model is initialized with precision {self._dtype}.")
log.info(
utils.memory_stats_log(
"Memory Stats after model init:", device=self._device
)
)
return model
def _setup_optimizer(
self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None
) -> Optimizer:
"""
Set up the optimizer. This method also handles loading the optimizer state_dict, if specified.
"""
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
optimizer.load_state_dict(opt_state_dict)
log.info("Optimizer is initialized.")
return optimizer
def _setup_data(
self,
cfg_dataset: DictConfig,
shuffle: bool,
batch_size: int,
) -> Tuple[DistributedSampler, DataLoader]:
"""
All data related setup happens here. Currently this recipe only supports the
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
iterable datasets and streaming datasets are not supported.
"""
ds = config.instantiate(
cfg_dataset,
tokenizer=self._tokenizer,
)
sampler = DistributedSampler(
ds,
num_replicas=1,
rank=0,
shuffle=shuffle,
seed=0,
)
dataloader = DataLoader(
dataset=ds,
batch_size=batch_size,
sampler=sampler,
collate_fn=partial(
utils.padded_collate,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index, # TODO support loss without ignore_index
),
)
log.info("Dataset and Sampler are initialized.")
return sampler, dataloader
def save_checkpoint(self, epoch: int) -> None:
"""
Save state dict to file. The recipe save_checkpoint method is responsible for
correctly creating the checkpoint dict and passing to the checkpointer.
"""
ckpt_dict = {utils.MODEL_KEY: self._model.state_dict()}
# if training is in-progress, checkpoint the optimizer state as well
if epoch + 1 < self.total_epochs:
ckpt_dict.update(
{
utils.OPT_KEY: self._optimizer.state_dict(),
utils.SEED_KEY: self.seed,
utils.EPOCHS_KEY: self.epochs_run,
utils.TOTAL_EPOCHS_KEY: self.total_epochs,
utils.MAX_STEPS_KEY: self.max_steps_per_epoch,
}
)
self._checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
intermediate_checkpoint=(epoch + 1 < self.total_epochs),
)
def _should_update_weights(self, current_iteration: int) -> bool:
"""
Determines whether the weights should be updated on the current iteration or not.
True is returned either if we've accumulated gradients for enough steps or if this
is the last step in the epoch.
"""
should_update_weights = (
current_iteration + 1
) % self._gradient_accumulation_steps == 0
return should_update_weights
def train(self) -> None:
"""
The core training loop. Supports training on subsets of the dataset using the
``max_steps_per_epoch``.
"""
# zero out the gradients before starting training
self._optimizer.zero_grad()
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True
self._sampler.set_epoch(curr_epoch)
for idx, batch in enumerate(pbar := tqdm(self._dataloader)):
if (
self.max_steps_per_epoch is not None
and (idx // self._gradient_accumulation_steps)
== self.max_steps_per_epoch
):
break
input_ids, labels = batch
input_ids = input_ids.to(self._device)
labels = labels.to(self._device)
logits = self._model(input_ids)
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
logits = logits.transpose(1, 2)
# Compute loss
loss = self._loss_fn(logits, labels)
# Note: We're always logging the loss before normalizing it
# Check if this is the norm or not
pbar.set_description(f"{curr_epoch+1}|{idx+1}|Loss: {loss.item()}")
if self.total_training_steps % self._log_every_n_steps == 0:
self._metric_logger.log_dict(
{
"loss": loss.item(),
"lr": self._optimizer.param_groups[0]["lr"],
"gpu_resources": torch.cuda.memory_allocated(),
},
step=self.total_training_steps,
)
loss = loss / self._gradient_accumulation_steps
loss.backward()
if self._should_update_weights(idx):
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
# Update the number of steps when the weights are updated
self.total_training_steps += 1
# Log peak memory for iteration
if self.total_training_steps % self._log_peak_memory_every_n_steps == 0:
log.info(
utils.memory_stats_log("Memory Stats:", device=self._device)
)
self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch)
def cleanup(self) -> None:
self._metric_logger.close()
@config.parse
def recipe_main(cfg: DictConfig) -> None:
"""
Entry point for the recipe.
Configurable parameters are read in the following order:
- Parameters specified in ``full_finetune_single_device.yaml``
- Overwritten by arguments from the command-line
"""
recipe = FullFinetuneRecipeSingleDevice(cfg=cfg)
recipe.setup(cfg=cfg)
recipe.train()
recipe.cleanup()
if __name__ == "__main__":
sys.exit(recipe_main())