-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Example] add an example of running open llama model in 4D using veSc…
…ale (#22) This PR adds an 4D parallelism example of using veScale to run a [open llama model](https://huggingface.co/openlm-research/open_llama_7b) that is directly imported from HuggingFace without any model code modifications.
- Loading branch information
1 parent
364c3b2
commit 028806f
Showing
12 changed files
with
869 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# veScale Open Llama Example | ||
## Overview | ||
In this directory, we provides an 4D parallelism example of using veScale to run | ||
a [open llama model](https://huggingface.co/openlm-research/open_llama_7b) that is directly imported | ||
from HuggingFace without any model code modifications. | ||
|
||
|
||
## Run | ||
### Single Machine 8 cards | ||
``` | ||
torchrun --standalone --nnodes=1 --nproc-per-node=8 ./run_open_llama_w_vescale.py --dp=4 --tp=2 --warmup=10 --iter=40 | ||
``` | ||
This will start a 8-cards MFU benchmark for open Llama with veScale with dp=4 and tp=2. | ||
|
||
### Distributed Environment (4 Machine 32 cards example) | ||
``` | ||
torchrun --nnodes=4 --nproc-per-node=8 --node_rank=$node_rank --master_addr=$master_addr --master_port=$master_port ./run_open_llama_w_vescale.py --dp=16 --tp=2 --warmup=10 --iter=40 | ||
``` | ||
This will start a 32 cards MFU benchmark for open Llama with veScale with dp=16 and tp=2. | ||
|
||
### Options | ||
1. `--total_bsz`: the total number of batch size for one iteration. The default is 16. | ||
2. `--dp`: the amount of data parallelism (DDP). This arg has no default value. | ||
3. `--tp`: the amount of tensor parallelism. This arg has no default value. | ||
4. `--warmup`: the number of warmup iteration performed. The default is 5. | ||
5. `--iter`: the number of iteration used for calculating the MFU. The default is 10. | ||
6. `--no-ckpt"`: This arg turn off loading check points from Huggingface. | ||
|
||
## Caveats | ||
1. The scripts are purely for demonstration propose and mfu calculation. You need to write your own training script | ||
it in order to fine-tune open llama with your data. | ||
2. This is a known issue with transformer version greater than 4.37.2. We will be fixing it later. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
{ | ||
"architectures": [ | ||
"LlamaForCausalLM" | ||
], | ||
"bos_token_id": 1, | ||
"eos_token_id": 2, | ||
"hidden_act": "silu", | ||
"hidden_size": 4096, | ||
"initializer_range": 0.02, | ||
"intermediate_size": 11008, | ||
"max_position_embeddings": 2048, | ||
"model_type": "llama", | ||
"num_attention_heads": 32, | ||
"num_hidden_layers": 32, | ||
"pad_token_id": 0, | ||
"rms_norm_eps": 1e-06, | ||
"tie_word_embeddings": false, | ||
"torch_dtype": "float16", | ||
"transformers_version": "4.30.0.dev0", | ||
"use_cache": true, | ||
"vocab_size": 32000 | ||
} |
23 changes: 23 additions & 0 deletions
23
python/example/open_llama_4D_benchmark/download_open_llama_ckpt.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
################################################################################ | ||
# | ||
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
################################################################################ | ||
|
||
|
||
from transformers import AutoModelForCausalLM | ||
|
||
model = AutoModelForCausalLM.from_pretrained("openlm-research/open_llama_7b") | ||
|
||
print(model) |
29 changes: 29 additions & 0 deletions
29
python/example/open_llama_4D_benchmark/llama_mfu_calculator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
################################################################################ | ||
# | ||
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
################################################################################ | ||
|
||
# reference: https://www.adamcasson.com/posts/transformer-flops | ||
# reference: https://arxiv.org/pdf/2001.08361.pdf | ||
|
||
|
||
def estimate_llama(config, bsz, sqence_length): | ||
embed = 4 * bsz * sqence_length * config.hidden_size | ||
ff = 3 * 2 * config.hidden_size * config.intermediate_size * bsz * sqence_length | ||
attn_qkv = 2 * bsz * sqence_length * config.hidden_size * 3 * config.hidden_size | ||
attn_mask = 2 * sqence_length * config.hidden_size | ||
attn_proj = 2 * config.hidden_size * config.intermediate_size * bsz * sqence_length | ||
attn = attn_qkv + attn_mask + attn_proj | ||
return embed + (ff + attn) * config.num_hidden_layers |
123 changes: 123 additions & 0 deletions
123
python/example/open_llama_4D_benchmark/run_open_llama_w_vescale.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
################################################################################ | ||
# | ||
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
################################################################################ | ||
import os | ||
import torch | ||
import argparse | ||
|
||
os.environ["VESCALE_DISABLE_RUN_CHECK"] = "1" | ||
|
||
from vescale.dtensor.device_mesh import init_device_mesh | ||
from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP | ||
from vescale.optim.distributed_optimizer import DistributedOptimizer | ||
from vescale.dmodule.api import parallelize_module | ||
from sharding_plan import sharding_plan | ||
|
||
from transformers import AutoModelForCausalLM, AutoConfig, LlamaModel | ||
|
||
from llama_mfu_calculator import estimate_llama | ||
|
||
local_rank = int(os.environ["LOCAL_RANK"]) | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--total_bsz", type=int, default=16) | ||
parser.add_argument("--dp", type=int) | ||
parser.add_argument("--tp", type=int) | ||
parser.add_argument("--warmup", type=int, default=5) | ||
parser.add_argument("--iter", type=int, default=10) | ||
parser.add_argument("--no-ckpt", action="store_true") | ||
|
||
args = parser.parse_args() | ||
|
||
assert args.total_bsz % args.dp == 0, f"total batch size {args.total_bsz} is not divisiable by dp size {args.dp}" | ||
bsz = args.total_bsz // args.dp | ||
s = 2048 | ||
|
||
# init model | ||
if args.no_ckpt: | ||
dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
config = AutoConfig.from_pretrained(os.path.join(dir_path, "config.json")) | ||
model = LlamaModel(config) | ||
else: | ||
model = AutoModelForCausalLM.from_pretrained("openlm-research/open_llama_7b") | ||
model = model.model | ||
config = model.config | ||
assert s <= config.max_position_embeddings | ||
|
||
# -------- training config -------- | ||
device_mesh = init_device_mesh( | ||
"cuda", | ||
( | ||
args.dp, | ||
args.tp, | ||
), | ||
mesh_dim_names=("DP", "TP"), | ||
) | ||
|
||
input = torch.randint(low=0, high=config.vocab_size, size=(bsz, s)).cuda() | ||
|
||
model = model.cuda().bfloat16() | ||
vescale_model = parallelize_module(model, device_mesh["TP"], sharding_plan) | ||
|
||
ddp_model = DDP( | ||
vescale_model, | ||
data_pg_or_device_mesh=device_mesh["DP"], | ||
use_distributed_optimizer=True, | ||
) | ||
orig_optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01) | ||
|
||
ve_optimizer = DistributedOptimizer( | ||
orig_optimizer, | ||
overlap_param_gather=True, | ||
models=[ddp_model], | ||
) | ||
|
||
start = torch.cuda.Event(enable_timing=True) | ||
end = torch.cuda.Event(enable_timing=True) | ||
|
||
# -------- warm up -------- | ||
for _ in range(args.warmup): | ||
ve_optimizer.zero_grad() | ||
vescale_output = ddp_model(input).last_hidden_state | ||
vescale_loss = vescale_output.mean() | ||
vescale_loss.backward() | ||
ve_optimizer.step() | ||
|
||
# -------- training loop -------- | ||
start.record() | ||
for _ in range(args.iter): | ||
ve_optimizer.zero_grad() | ||
vescale_output = ddp_model(input).last_hidden_state | ||
vescale_loss = vescale_output.mean() | ||
vescale_loss.backward() | ||
ve_optimizer.step() | ||
end.record() | ||
torch.cuda.synchronize() | ||
exec_t = start.elapsed_time(end) / 1000 / args.iter | ||
|
||
if local_rank == 0: | ||
flops_dict = { | ||
"A100": 312, | ||
"H100": 1000, | ||
} | ||
d_name = torch.cuda.get_device_name() | ||
total_flops = flops_dict["A100"] * (10**12) * device_mesh.ndevice | ||
for k, v in flops_dict.items(): | ||
if k in d_name: | ||
total_flops = v * (10**12) * device_mesh.ndevice | ||
break | ||
print(f"1 iter time: {exec_t}") | ||
# fwd + bwd =3 | ||
print("mfu:", estimate_llama(config, bsz, s) * 3 * args.dp * 100 / exec_t / total_flops) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
################################################################################ | ||
# | ||
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
################################################################################ | ||
|
||
from vescale.dtensor.placement_types import Replicate, Shard | ||
|
||
# forward resharding plan for a single open llama decoder | ||
_decoder_fwd_resharding_plan = { | ||
"input": {"hidden_states": [Shard(1)], "attention_mask": [Replicate()], "position_ids": [Replicate()]}, | ||
# atten | ||
"self_attn.input": {"hidden_states": [Replicate()], "attention_mask": [Replicate()], "position_ids": [Replicate()]}, | ||
"self_attn.o_proj.output": [[Shard(1)]], | ||
"self_attn.output": [[Shard(1)], None, None], | ||
# feedforward(mlp) | ||
"mlp.input": [[Replicate()]], | ||
"mlp.output": [[Shard(1)]], | ||
"output": [[Shard(1)], None], | ||
} | ||
|
||
# parameter sharding plan for a single open llama decoder | ||
_decoder_param_sharding_plan = { | ||
# atten weight, no bias | ||
"self_attn.q_proj.weight": [Shard(0)], | ||
"self_attn.k_proj.weight": [Shard(0)], | ||
"self_attn.v_proj.weight": [Shard(0)], | ||
"self_attn.o_proj.weight": [Shard(1)], | ||
# feedforward(mlp) | ||
"mlp.up_proj.weight": [Shard(0)], | ||
"mlp.gate_proj.weight": [Shard(0)], | ||
"mlp.down_proj.weight": [Shard(1)], | ||
} | ||
|
||
# forward resharding plan for the whole open llama model | ||
model_fwd_resharding_plan = { | ||
".input": [[Replicate()]], | ||
"embed_tokens.output": [[Shard(1)]], | ||
"norm.input": [[Shard(1)]], | ||
".output": { | ||
"last_hidden_state": [Replicate()], | ||
}, | ||
**{rf"layers.\d+.{k}": v for k, v in _decoder_fwd_resharding_plan.items()}, | ||
} | ||
|
||
# model parameter sharding plan for the whole open llama model | ||
model_param_sharding_plan = { | ||
"embed_tokens.weight": [Shard(1)], | ||
**{rf"layers.\d+.{k}": v for k, v in _decoder_param_sharding_plan.items()}, | ||
} | ||
|
||
sharding_plan = {"parameter": model_param_sharding_plan, "forward": model_fwd_resharding_plan} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
{ | ||
"architectures": [ | ||
"LlamaForCausalLM" | ||
], | ||
"bos_token_id": 1, | ||
"eos_token_id": 2, | ||
"hidden_act": "silu", | ||
"hidden_size": 4096, | ||
"initializer_range": 0.02, | ||
"intermediate_size": 11008, | ||
"max_position_embeddings": 2048, | ||
"model_type": "llama", | ||
"num_attention_heads": 32, | ||
"num_hidden_layers": 32, | ||
"pad_token_id": 0, | ||
"rms_norm_eps": 1e-06, | ||
"tie_word_embeddings": false, | ||
"torch_dtype": "float16", | ||
"transformers_version": "4.30.0.dev0", | ||
"use_cache": true, | ||
"vocab_size": 32000 | ||
} |
Oops, something went wrong.