-
Notifications
You must be signed in to change notification settings - Fork 23
/
train_acoustic.py
63 lines (57 loc) · 3 KB
/
train_acoustic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from models.model import TTSmodel
from trainer import fastspeech_trainer,tacotron_trainer
from dataloaders import tacotron_dataloader,fastspeech_dataloader
from utils.user_config import UserConfig
import tensorflow as tf
import logging
import os
os.environ['CUDA_VISIBLE_DEVICES']='0'
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
gpus = tf.config.experimental.list_physical_devices('GPU')
logging.info('valid gpus:%d' % len(gpus))
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
class Trainer():
def __init__(self,config):
self.config=config
self.am_model=TTSmodel(config=config)
self.am_model.load_model(True)
if self.am_model.acoustic=='Tacotron2':
self.dg=tacotron_dataloader.TacotronDataLoader(self.config)
self.trainer=tacotron_trainer.Tacotron2Trainer(self.config)
else:
self.dg=fastspeech_dataloader.FastSpeechDataLoader(self.config)
self.trainer = fastspeech_trainer.FastSpeechTrainer(self.config)
self.opt=tf.keras.optimizers.Adamax(lr=self.config['learning_rate'],beta_1=self.config['beta_1'],beta_2=self.config['beta_2'],
epsilon=self.config['epsilon'])
all_train_step = self.dg.get_per_epoch_steps() * self.config['num_epochs']
self.trainer.set_total_train_steps(all_train_step)
self.trainer.compile(self.am_model.acoustic_model, self.opt)
self.dg.batch=self.trainer.global_batch_size
def run(self,):
train_datasets = tf.data.Dataset.from_generator(self.dg.generator,
self.dg.return_data_types(),
self.dg.return_data_shape(),
args=(True,))
eval_datasets = tf.data.Dataset.from_generator(self.dg.generator,
self.dg.return_data_types(),
self.dg.return_data_shape(),
args=(False,))
self.trainer.set_datasets(train_datasets, eval_datasets)
while 1:
self.trainer.fit(epoch=self.dg.epochs)
if self.trainer._finished():
self.trainer.save_checkpoint()
logging.info('Finish training!')
break
if self.trainer.steps%self.config['save_interval_steps']==0:
self.dg.save_state(self.config['outdir'])
if __name__ == '__main__':
import argparse
parse=argparse.ArgumentParser()
parse.add_argument('--data_config', type=str, default='./configs/common.yml', help='the am data config path')
parse.add_argument('--model_config', type=str, default='./configs/fastspeech.yml', help='the am model config path')
args=parse.parse_args()
config=UserConfig(args.data_config,args.model_config)
train=Trainer(config)
train.run()