-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_training_pipeline.py
146 lines (127 loc) · 4.75 KB
/
run_training_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import argparse
import os
import random
import sys
import torch
from TrainingPipelines.AlignerPipeline import run as aligner
from TrainingPipelines.HiFiGAN_combined import run as HiFiGAN
from TrainingPipelines.StochasticToucanTTS_Nancy import run as nancystoch
from TrainingPipelines.ToucanTTS_IntegrationTest import run as tt_integration_test
from TrainingPipelines.ToucanTTS_MLS_English import run as mls
from TrainingPipelines.ToucanTTS_Massive_stage1 import run as stage1
from TrainingPipelines.ToucanTTS_Massive_stage2 import run as stage2
from TrainingPipelines.ToucanTTS_Massive_stage3 import run as stage3
from TrainingPipelines.ToucanTTS_MetaCheckpoint import run as meta
from TrainingPipelines.ToucanTTS_Nancy import run as nancy
from TrainingPipelines.finetuning_example_multilingual import (
run as fine_tuning_example_multilingual,
)
from TrainingPipelines.finetuning_example_simple import (
run as fine_tuning_example_simple,
)
from TrainingPipelines.finetuning_shan import run as fine_tuning_shan
pipeline_dict = {
# the finetuning example
"finetuning_example_simple": fine_tuning_example_simple,
"finetuning_example_multilingual": fine_tuning_example_multilingual,
# integration tests
"tt_it": tt_integration_test,
# regular ToucanTTS pipelines
"nancy": nancy,
"mls": mls,
"nancystoch": nancystoch,
"meta": meta,
"stage1": stage1,
"stage2": stage2,
"stage3": stage3,
# training the aligner from scratch (not recommended, best to use provided checkpoint)
"aligner": aligner,
# vocoder training (not recommended, best to use provided checkpoint)
"hifigan": HiFiGAN,
# finetuning shan
"finetuning_shan": fine_tuning_shan,
}
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Training with the IMS Toucan Speech Synthesis Toolkit"
)
parser.add_argument(
"pipeline", choices=list(pipeline_dict.keys()), help="Select pipeline to train."
)
parser.add_argument(
"--gpu_id",
type=str,
help="Which GPU(s) to run on. If not specified runs on CPU, but other than for integration tests that doesn't make much sense.",
default="cpu",
)
parser.add_argument(
"--resume_checkpoint",
type=str,
help="Path to checkpoint to resume from.",
default=None,
)
parser.add_argument(
"--resume",
action="store_true",
help="Automatically load the highest checkpoint and continue from there.",
default=False,
)
parser.add_argument(
"--finetune",
action="store_true",
help="Whether to fine-tune from the specified checkpoint.",
default=False,
)
parser.add_argument(
"--model_save_dir",
type=str,
help="Directory where the checkpoints should be saved to.",
default=None,
)
parser.add_argument(
"--wandb",
action="store_true",
help="Whether to use weights and biases to track training runs. Requires you to run wandb login and place your auth key before.",
default=False,
)
parser.add_argument(
"--wandb_resume_id",
type=str,
help="ID of a stopped wandb run to continue tracking",
default=None,
)
args = parser.parse_args()
if args.finetune and args.resume_checkpoint is None and not args.resume:
print("Need to provide path to checkpoint to fine-tune from!")
sys.exit()
if args.gpu_id == "cpu":
os.environ["CUDA_VISIBLE_DEVICES"] = ""
device = torch.device("cpu")
print(
f"No GPU specified, using CPU. Training will likely not work without GPU."
)
gpu_count = 1 # for technical reasons this is set to one, indicating it's not gpu_count training, even though there is no GPU in this case
else:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.gpu_id}"
device = torch.device("cuda")
print(
f"Making GPU {os.environ['CUDA_VISIBLE_DEVICES']} the only visible device(s)."
)
gpu_count = len(args.gpu_id.replace(",", " ").split())
# example call for gpu_count training:
# torchrun --standalone --nproc_per_node=4 --nnodes=1 run_training_pipeline.py nancy --gpu_id "1,2,3"
torch.manual_seed(9665)
random.seed(9665)
torch.random.manual_seed(9665)
torch.multiprocessing.set_sharing_strategy("file_system")
pipeline_dict[args.pipeline](
gpu_id=args.gpu_id,
resume_checkpoint=args.resume_checkpoint,
resume=args.resume,
finetune=args.finetune,
model_dir=args.model_save_dir,
use_wandb=args.wandb,
wandb_resume_id=args.wandb_resume_id,
gpu_count=gpu_count,
)