forked from skypilot-org/skypilot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ray_train.yaml
30 lines (25 loc) · 887 Bytes
/
ray_train.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
resources:
accelerators: L4:2
num_nodes: 2
workdir: .
setup: |
conda activate ray
if [ $? -ne 0 ]; then
conda create -n ray python=3.10 -y
conda activate ray
fi
pip install "ray[train]" # Have to use ray < 2.8 as the later version will cause a conflict with the existing ray cluster
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
run: |
sudo chmod 777 -R /var/tmp
conda activate ray
head_ip=`echo "$SKYPILOT_NODE_IPS" | head -n1`
num_nodes=`echo "$SKYPILOT_NODE_IPS" | wc -l`
if [ "$SKYPILOT_NODE_RANK" == "0" ]; then
ps aux | grep ray | grep 6379 &> /dev/null || ray start --head --disable-usage-stats --port 6379
sleep 5
python train.py --num-workers $num_nodes
else
sleep 5
ps aux | grep ray | grep 6379 &> /dev/null || ray start --address $head_ip:6379 --disable-usage-stats
fi