-
Notifications
You must be signed in to change notification settings - Fork 1
/
yaml_easy_train.py
65 lines (59 loc) · 2.29 KB
/
yaml_easy_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from glob import glob
import random
import re
import subprocess
import sys
import yaml
best_engine_arch = subprocess.check_output(
"sh /root/misc/get_native_properties.sh | awk '{print $1}'",
shell=True
).decode().strip()
default_args = {
'seed': random.randint(9999, 9_999_999),
'tui': False,
'workspace-path': '/root/easy-train-data',
'build-engine-arch': best_engine_arch,
'network-testing-book': 'https://github.com/official-stockfish/books/raw/master/UHO_Lichess_4852_v1.epd.zip',
}
yaml_config_file = sys.argv[1]
with open(yaml_config_file, "r") as f:
args = yaml.safe_load(f)
args = {**default_args, **args}
# if config filename contains gpu ids, automatically use them
gpus_from_filename = re.search(r"gpu(\d+)", yaml_config_file)
if gpus_from_filename:
gpus_str = ",".join(list(gpus_from_filename.group(1)))
args["gpus"] = gpus_str
# prepare an easy_train.py command for training
command = ["python3 easy_train.py"]
for key,value in sorted(args.items()):
if key == "training-dataset":
if isinstance(value, list):
# list of files with support for wildcard *
filenames = value
for dataset_component in filenames:
if "*" in dataset_component:
for glob_match in glob(dataset_component):
command.append(f" --{key} {glob_match}")
else:
command.append(f" --{key} {dataset_component}")
elif isinstance(value, dict):
# /data/hse-v1/:
# leela96-filt-v2.min.high-simple-eval-1k.min-v2.binpack
for basepath,filenames in value.items():
for filename in filenames:
dataset_component = f"{basepath}{filename}"
if "*" in dataset_component:
for glob_match in glob(dataset_component):
command.append(f" --{key} {glob_match}")
else:
command.append(f" --{key} {dataset_component}")
else:
command.append(f" --{key} {value}")
else:
command.append(f" --{key} {value}")
# pretty print for logging and debugging
if len(sys.argv) > 2 and sys.argv[2] == "print":
print(" \\\n".join(command))
else:
print(" ".join(command))