-
Notifications
You must be signed in to change notification settings - Fork 1
/
nano-gpt.sh
93 lines (81 loc) · 2.61 KB
/
nano-gpt.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
################### Running on single TPU chip ############################
gcloud compute tpus tpu-vm create flash-nano-gpt-tpu \
--zone=us-central1-f \
--accelerator-type=v2-8 \
--version=tpu-ubuntu2204-base
gcloud compute tpus tpu-vm ssh flash-nano-gpt-tpu --zone=us-central1-f -- -L 6006:127.0.0.1:6006
gcloud compute tpus tpu-vm ssh flash-nano-gpt-tpu \
--zone=us-central1-f \
--command="
git clone https://github.com/azzeddineCH/flash-nanoGPT.git
cd flash-nanoGPT
pip install tyro
pip install orbax-checkpoint
pip install tiktoken
pip install tensorflow
pip install tensorboard-plugin-profile
pip install flax
pip install optax
pip install git+https://github.com/deepmind/jmp
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install wandb
pip install datasets
sudo apt install tmux
"
gcloud compute tpus tpu-vm ssh flash-nano-gpt-tpu \
--zone=europe-west4-a \
--worker=all \
--command="
cd flash-nanoGPT
tmux new -d -s nano-gpt '
export GPT_CONFIG_PATH='yaml/nano-gpt.yaml'
export WANDB_API_KEY='...'
git pull
python3 train.py --grad_accum_steps 32
"
################### Running on TPU VM ############################
gcloud compute tpus tpu-vm create flash-nano-gpt-tpu-vm \
--zone=europe-west4-a \
--accelerator-type=v3-32 \
--version=tpu-ubuntu2204-base \
--preemptible
# gcloud compute tpus tpu-vm delete flash-nano-gpt-tpu-vm \
# --zone=europe-west4-a \
gcloud compute tpus tpu-vm ssh flash-nano-gpt-tpu-vm \
--zone=europe-west4-a \
--worker=all \
--command="
git clone https://github.com/azzeddineCH/flash-nanoGPT.git
cd flash-nanoGPT
pip install tyro
pip install orbax-checkpoint
pip install tiktoken
pip install tensorflow
pip install flax
pip install optax
pip install git+https://github.com/deepmind/jmp
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install wandb
pip install datasets
sudo apt install tmux
"
gcloud compute tpus tpu-vm scp test.py flash-nano-gpt-tpu-vm: \
--worker=all \
--zone=europe-west4-a
gcloud compute tpus tpu-vm ssh flash-nano-gpt-tpu-vm \
--zone=europe-west4-a \
--worker=all \
--command="
python3 test.py"
gcloud compute tpus tpu-vm ssh flash-nano-gpt-tpu-vm\
--zone=europe-west4-a \
--worker=all \
--command="
cd flash-nanoGPT
tmux new -d -s nano-gpt '
export GPT_CONFIG_PATH='yaml/train-nano-gpt.yaml'
export WANDB_API_KEY='...'
git pull
python3 train.py --no-save-checkpoint --no-amp
'
"