forked from lucidrains/e2-tts-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_example.py
44 lines (35 loc) · 882 Bytes
/
train_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
from e2_tts_pytorch import E2TTS, DurationPredictor
from torch.optim import Adam
from datasets import load_dataset
from e2_tts_pytorch.trainer import (
HFDataset,
E2Trainer
)
duration_predictor = DurationPredictor(
transformer = dict(
dim = 512,
depth = 6,
)
)
e2tts = E2TTS(
duration_predictor = duration_predictor,
transformer = dict(
dim = 512,
depth = 12,
skip_connect_type = 'concat'
),
)
train_dataset = HFDataset(load_dataset("MushanW/GLOBE")["train"])
optimizer = Adam(e2tts.parameters(), lr=7.5e-5)
trainer = E2Trainer(
e2tts,
optimizer,
num_warmup_steps=20000,
checkpoint_path = 'e2tts.pt',
log_file = 'e2tts.txt'
)
epochs = 10
batch_size = 32
grad_accumulation_steps = 1
trainer.train(train_dataset, epochs, batch_size, grad_accumulation_steps, save_step=1000)