Skip to content

Commit

Permalink
[checkpoint] Open Source (#27)
Browse files Browse the repository at this point in the history
### In this PR, we open source our `vescale.checkpoint`, Yo. ~

`vescale.checkpoint` is a distributed LLM checkpointing system.  

`vescale.checkpoint` offers simple and straightforward APIs,
enabling users to load and save distributed model (`DModule`) and
optimizer (`DistributedOptimizer`) seamlessly,
abstracting away the complexities of underlying details such as process
rank and device mesh.

`vescale.checkpoint`supports load-time checkpoint resharding when
varying the degrees of data, tensor, or pipeline (TODO) parallelism
for both veScale distributed model (`DModule`) and optimizer
(`DistributedOptimizer`).

`vescale.checkpoint` incorporates [fast
checkpointing](https://arxiv.org/abs/2402.15627) and various I/O
optimization techinques,
enhancing I/O efficiency during large language model training.  

`vescale.checkpoint` will be a part of OmniStore project, a new open
source project coming soon.

### Credit to veScale Checkpoint Team
This endeavor would not have been possible without the contribution of
veScale Checkpoint team which includes but not limited to:
@shanesyy-1992 @MingjiHan99 @AHEADer @raywan-110 @michael4RD @lazychao
@leochen-ai

Also thanks to the great guidance and leadership of:
[@pengyanghua](https://github.com/pengyanghua)
[@eric-haibin-lin](https://github.com/eric-haibin-lin)
[@liwenchangbdbz](https://github.com/liwenchangbdbz)
[@Meteorix](https://github.com/Meteorix)

### Credit to veScale Team

We would like to sincerely acknowledge the assistance of and
collaboration with the veScale team which inlcudes but not limited to:
[@leonardo0lyj](https://github.com/leonardo0lyj) @JsBlueCat @MackZackA
@Vremold @jc-bytedance @lichen225

### Credit to PyTorch Distributed Checkpoint (DCP) Team
We would like to sincerely acknowledge the assistance of and
collaboration
with the [PyTorch Distributed Checkpoint (DCP)
team](https://github.com/pytorch/pytorch/tree/main/torch/distributed/checkpoint)
which includes but not limited to:
@wz337 @kumpera @fegin @LucasLLC
  • Loading branch information
MingjiHan99 authored Apr 10, 2024
1 parent 9d59f8d commit bbf2860
Show file tree
Hide file tree
Showing 43 changed files with 4,985 additions and 8 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,8 @@

# pre-commit config
./.pre-commit-config.yaml

# checkpoint
*.checkpoint_dir
*.distcp
.metadata
14 changes: 11 additions & 3 deletions python/example/nanogpt_4D_finetune/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@ cd data/shakespeare && python3 prepare.py

Then, to finetune the Shakespeare dataset in an environment of multiple GPUs, run
```
torchrun --standalone --nproc_per_node={Number of GPUs} finetune_4D.py config/finetune_shakespeare.py --compile=False --dp_size={DP Size} --tp_size={TP Size}
torchrun --standalone --nproc_per_node={Number of GPUs} finetune_4D.py config/finetune_shakespeare.py --compile=False --dp_size={DP Size} --tp_size={TP Size} --save_checkpoint_path={path to save checkpoints}
```
where `DP Size` and `TP Size` denote the the degrees of Data and Tensor Parallelism that suit your environment.
where `DP Size` and `TP Size` denote the the degrees of Data and Tensor Parallelism that suit your environment, `save_checkpoint_path` is path to save checkpoints during the training.

If you want to resume training from a checkpoint, add `--load_checkpoint_path={path to load checkpoint}` in the command.

For example:
```
torchrun --standalone --nproc_per_node={Number of GPUs} finetune_4D.py config/finetune_shakespeare.py --compile=False --dp_size={DP Size} --tp_size={TP Size} --save_checkpoint_path=./nanogpt_checkpoint_dir --load_checkpoint_path=./nanogpt_checkpoint_dir/iter_5
```


To produce the single GPU result, run
```
Expand Down Expand Up @@ -53,4 +61,4 @@ For the bf16 runs, in `base_train.py`, instead of using `torch.amp.autocast`, we

2. veScale does not focus on fp16, as fp16 is ancient in industry.

3. Checkpointing is not supported.
3. Checkpointing is supported now.
19 changes: 15 additions & 4 deletions python/example/nanogpt_4D_finetune/finetune_4D.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from vescale.optim.distributed_optimizer import DistributedOptimizer
from vescale.optim.base_optimizer import BasicOptimizer, GradOptimizerHookBase
from sharding_plan import nanoGPT_plan

import vescale

# -----------------------------------------------------------------------------
# default config values designed to train a gpt2 (124M) on OpenWebText
Expand Down Expand Up @@ -93,7 +93,8 @@
dp_size = 4
tp_size = 1
DDP_grads_in_fp32 = True

save_checkpoint_path = "./nanogpt_checkpoint_dir"
load_checkpoint_path = ""
config = {}


Expand Down Expand Up @@ -208,8 +209,7 @@ def get_batch(split, bsz=batch_size, lbsz=local_batch_size):
for k in ["n_layer", "n_head", "n_embd", "block_size", "bias", "vocab_size"]:
model_args[k] = getattr(model.config, k)
elif init_from == "resume":
print("WARNING: checkpointing is not supported")
print(f"Resuming the training process from: {out_dir}")
print(f"Resuming the training process from: {load_checkpoint_path}")
# determine the vocab size we'll use for from-scratch training
if meta_vocab_size is None:
print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
Expand Down Expand Up @@ -326,6 +326,11 @@ def get_lr(it):
global config
wandb.init(project=wandb_project, name=wandb_run_name, config=config)

# Load checkpoint
if load_checkpoint_path:
checkpoint_state = {"model": model, "optimizer": optimizer}
with mesh:
vescale.checkpoint.load(load_checkpoint_path, checkpoint_state)
# training loop
X, Y = get_batch("train") # fetch the very first batch
t0 = time.time()
Expand Down Expand Up @@ -354,6 +359,12 @@ def get_lr(it):
"mfu": running_mfu * 100, # convert to percentage
}
)
if iter_num > 0:
# When iter_num == 0, the training does not start sotoptimizer state is empty,
# Don't save checkpoint
checkpoint_state = {"model": model, "optimizer": optimizer}
with mesh:
vescale.checkpoint.save(os.path.join(save_checkpoint_path, f"iter_{iter_num}"), checkpoint_state)
if iter_num == 0 and eval_only:
break

Expand Down
2 changes: 2 additions & 0 deletions python/vescale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import torch
import torch.utils._pytree as pytree
import vescale.checkpoint as checkpoint
from vescale.dmodule.api import parallelize_module, is_dmodule, PlacementsInterface
from vescale.dtensor.api import normalize_placements, distribute_tensor, from_local, redistribute_dtensor, to_local
from vescale.dtensor.device_mesh import DeviceMesh, init_device_mesh
Expand All @@ -44,6 +45,7 @@
"set_plan_overriding_policy",
"get_plan_overriding_policy",
"auto_parallelize_module",
"checkpoint",
"DTensor",
"DeviceMesh",
"init_device_mesh",
Expand Down
48 changes: 48 additions & 0 deletions python/vescale/checkpoint/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# vescale.checkpoint

`vescale.checkpoint` is an automatic distributed checkpointing system for LLM training and inference.

## Why `vescale.checkpoint`?

1. Manually managing distributed checkpointing, such as writing model saving/loading/resharding scripts under complex distributed environments, is painful and error-prone.

2. `torch.save` and `torch.load` lacks the capability of managing checkpointing in distributed settings, let alone resharding checkpoints for different distributed settings.
Although existing systems extend `torch.save` for saving checkpoints on multiple GPUs or machines, the saved checkpoints are heavily coupled with a single distributed setting like the degrees of data, tensor and pipeline parallelism. Consequently, existing systems with `torch.load` fail to load checkpoints with varying degrees of parallelism, which is common in elastic training or switching between training and fine-tuning.

3. `PyTorch Distirbuted Checkpoint` indeed supports checkpoint resharding to some extent. Nonetheless, it currently only supports resharding for the simplest data parallelism, but not for the complex tensor nor pipeline parallelism, which are commonly used in 3D parallelism of LLM training. Furthermore, it does not support load-time resharding for Distributed Optimizer, nor provide decent performance optimizations.

## What is `vescale.checkpoint`?

`vescale.checkpoint` offers simple and straightforward APIs,
enabling users to load and save distributed model (e.g., `DModule`) and optimizer (e.g., `DistributedOptimizer`) seamlessly,
abstracting away the complexities of underlying details such as process rank and device mesh.

`vescale.checkpoint` supports load-time checkpoint resharding when varying the degrees of data, tensor, or pipeline (TODO) parallelism for both veScale model (e.g., `DModule`) and optimizer (e.g., `DistributedOptimizer`).

`vescale.checkpoint` incorporates [fast checkpointing](https://arxiv.org/abs/2402.15627) and various I/O optimization techinques, enhancing I/O efficiency during LLM training.

`vescale.checkpoint` will be a part of `OmniStore` project, a new open-source project coming soon.

`vescale.checkpoint` is built on top of `PyTorch Distributed Checkpoint` with significant differences as discussed above.

## How to use `vescale.checkpoint`?

- Saving checkpoint:

```
# prepare checkpoint state for the model and optimizer
checkpoint_state = { "model": distributed_model, "optimizer": distributed_optimizer }
# save the checkpoint
vescale.checkpoint.save("/user/vescale/gpt/", checkpoint_state)
```
- Loading checkpoint (under different world size or 3D parallelism degrees):
```
# prepare checkpoint state for the model and optimizer
checkpoint_state = { "model": distributed_model, "optimizer": distributed_optimizer }
# load the checkpoint
vescale.checkpoint.load("/user/vescale/gpt/", checkpoint_state)
```

- More examples can be found under `<repo>/test/checkpoint` and `<repo>/python/example`.

- Original examples can be found in PyTorch [Distributed Checkpoint](https://github.com/pytorch/pytorch/tree/main/torch/distributed/checkpoint)
44 changes: 44 additions & 0 deletions python/vescale/checkpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
################################################################################
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
################################################################################
# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates.
################################################################################
# The "checkpoint" folder is ONLY USED for "open source" version veScale
# If you use veScale in ByteDance, please use OmniStore

from .api.vescale_checkpointer import VeScaleCheckpointer
from .api.meta_type import CheckpointState


def save(path: str, checkpoint_state: CheckpointState):
"""
Save a checkpoint to a given path
Args:
path: Defines the storage path for checkpoint.
checkpoint_state: A dictionary contains key-value pairs for model and optimizer.
- Model: Identified by 'model' key, value should be a model instance.
- Optimizer: Identified by 'optimizer' key, value should be an optimizer instance.
Example:
>>> checkpoint_state = { "model": distributd_model, "optimizer": distributed_optimizer }
>>> vescale.checkpoint.save("/user/vescale/gpt/", checkpoint_state)
"""
VeScaleCheckpointer.save(path, checkpoint_state)


def load(path: str, checkpoint_state: CheckpointState):
"""
Load a checkpoint from a given path
Args:
path: Defines the storage path for checkpoint.
checkpoint_state: A dictionary contains key-value pairs for model and optimizer.
- Model: Identified by 'model' key, value should be a model instance.
- Optimizer: Identified by 'optimizer' key, value should be an optimizer instance.
Example:
>>> checkpoint_state = { "model": distributd_model, "optimizer": distributed_optimizer }
>>> vescale.checkpoint.load("/user/vescale/gpt/", checkpoint_state)
"""
VeScaleCheckpointer.load(path, checkpoint_state)
47 changes: 47 additions & 0 deletions python/vescale/checkpoint/api/base_checkpointer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from .meta_type import CheckpointState


class BaseCheckpointer:
"""
The Checkpointer class offers APIs that enable users to save and load state dictionarie.
It is designed for extension across various training frameworks.
"""

@classmethod
def save(cls, path: str, checkpoint_state: CheckpointState):
"""
A Method for saving checkpoint
Args:
path: Defines the storage path for checkpoint.
checkpoint_state: A dictionary contains key-value pairs for model, optimizer and dataloader(TODO).
- Model: Identified by 'model' key, value should be a model instance.
- Optimizer: Identified by 'optimizer' key, value should be an optimizer instance.
"""
raise NotImplementedError()

def load(cls, path: str, checkpoint_state: CheckpointState):
"""
A Method for loading checkpoint
Args:
path: Defines the storage path for checkpoint.
checkpoint_state: A dictionary contains key-value pairs for model, optimizer and dataloader(TODO).
- Model: Identified by 'model' key, value should be a model instance.
- Optimizer: Identified by 'optimizer' key, value should be an optimizer instance.
"""
raise NotImplementedError()
38 changes: 38 additions & 0 deletions python/vescale/checkpoint/api/meta_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
################################################################################
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
################################################################################
# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates.
################################################################################
# meta_type.py saves all constants and data types commonly used in omnistore

from enum import Enum
from typing import Dict, Any, TypeVar
from typing_extensions import Protocol, runtime_checkable


STATE_DICT_TYPE = Dict[str, Any]

MODEL_STR = "model"
OPTIMIZER_STR = "optimizer"
SHM_PATH = "/dev/shm"


class SupportedStrategy(Enum):
Megatron = 0
FSDP = 1
VeScale = 2


@runtime_checkable
class Stateful(Protocol):
def state_dict(self) -> Dict[str, Any]: ...

def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ...


T = TypeVar("T", bound=Stateful)
CheckpointState = Dict[str, T]
Loading

0 comments on commit bbf2860

Please sign in to comment.