-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathtrain.py
69 lines (54 loc) · 2.02 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
import os
from omegaconf import OmegaConf
from muscall.utils.logger import Logger
from muscall.utils.utils import (
load_conf,
merge_conf,
get_root_dir,
update_conf_with_cli_params,
)
from muscall.models.muscall import MusCALL
from muscall.trainers.muscall_trainer import MusCALLTrainer
from muscall.datasets.audiocaption import AudioCaptionDataset
def parse_args():
parser = argparse.ArgumentParser(description="Train a MusCALL model")
parser.add_argument(
"--experiment_id",
type=str,
help="experiment id under which checkpoint was saved",
default=None,
)
parser.add_argument(
"--config_path",
type=str,
help="path to base config file",
default=os.path.join(get_root_dir(), "configs", "training.yaml"),
)
parser.add_argument(
"--dataset", type=str, help="name of the dataset", default="audiocaption"
)
parser.add_argument("--device_num", type=str, default="0")
args = parser.parse_args()
return args
if __name__ == "__main__":
params = parse_args()
if params.experiment_id is None:
# 1. Load config (base + dataset + model)
base_conf = load_conf(params.config_path)
if params.dataset == "audiocaption":
dataset_conf_path = os.path.join(base_conf.env.base_dir, AudioCaptionDataset.config_path())
else:
raise ValueError("{} dataset not supported".format(params.dataset))
model_conf_path = os.path.join(base_conf.env.base_dir, MusCALL.config_path())
config = merge_conf(params.config_path, dataset_conf_path, model_conf_path)
update_conf_with_cli_params(params, config)
else:
config = OmegaConf.load(
"./save/experiments/{}/config.yaml".format(params.experiment_id)
)
logger = Logger(config)
os.environ["CUDA_VISIBLE_DEVICES"] = params.device_num
trainer = MusCALLTrainer(config, logger)
print("# of trainable parameters:", trainer.count_parameters())
trainer.train()