Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Support for Wavenet vocoder #21

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ notebooks
foobar*
run.sh
README.rst
legacy
notebooks
run.sh
pretrained_models
deepvoice3_pytorch/version.py
checkpoints*
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,15 @@ python preprocess.py nikl_s ${your_nikl_root_path} data/nikl_s --preset=presets/
python train.py --data-root=./data/nikl_s --checkpoint-dir checkpoint_nikl_s --preset=presets/deepvoice3_nikls.json
```

### 4. Monitor with Tensorboard
### 3. Monitor with Tensorboard

Logs are dumped in `./log` directory by default. You can monitor logs by tensorboard:

```
tensorboard --logdir=log
```

### 5. Synthesize from a checkpoint
### 4. Synthesize from a checkpoint

Given a list of text, `synthesis.py` synthesize audio signals from trained model. Usage is:

Expand Down
6 changes: 1 addition & 5 deletions dump_hparams_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,9 @@
import sys
import os
from os.path import dirname, join, basename, splitext
import json

import audio

# The deepvoice3 model
from deepvoice3_pytorch import frontend
from hparams import hparams
import json

if __name__ == "__main__":
args = docopt(__doc__)
Expand Down
177 changes: 177 additions & 0 deletions generate_aligned_predictions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# coding: utf-8
"""
Generate ground trouth-aligned predictions

usage: generate_aligned_predictions.py [options] <checkpoint> <in_dir> <out_dir>

options:
--hparams=<parmas> Hyper parameters [default: ].
--preset=<json> Path of preset parameters (json).
--overwrite Overwrite audio and mel outputs.
-h, --help Show help message.
"""
from docopt import docopt
import os
from tqdm import tqdm
import importlib
from os.path import join
from warnings import warn
import sys

import numpy as np
import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F

# The deepvoice3 model
from deepvoice3_pytorch import frontend
from hparams import hparams

use_cuda = torch.cuda.is_available()
_frontend = None # to be set later


def preprocess(model, in_dir, out_dir, text, audio_filename, mel_filename,
p=0, speaker_id=None,
fast=False):
"""Generate ground truth-aligned prediction

The output of the network and corresponding audio are saved after time
resolution adjustment.
"""
r = hparams.outputs_per_step
downsample_step = hparams.downsample_step

if use_cuda:
model = model.cuda()
model.eval()
if fast:
model.make_generation_fast_()

mel_org = np.load(join(in_dir, mel_filename))
# zero padd
b_pad = r # imitates initial state
e_pad = r - len(mel_org) % r if len(mel_org) % r > 0 else 0
mel = np.pad(mel_org, [(b_pad, e_pad), (0, 0)],
mode="constant", constant_values=0)

mel = Variable(torch.from_numpy(mel)).unsqueeze(0).contiguous()

# Downsample mel spectrogram
if downsample_step > 1:
mel = mel[:, 0::downsample_step, :].contiguous()

decoder_target_len = mel.shape[1] // r
s, e = 1, decoder_target_len + 1
frame_positions = torch.arange(s, e).long().unsqueeze(0)
frame_positions = Variable(frame_positions)

sequence = np.array(_frontend.text_to_sequence(text, p=p))
sequence = Variable(torch.from_numpy(sequence)).unsqueeze(0)
text_positions = torch.arange(1, sequence.size(-1) + 1).unsqueeze(0).long()
text_positions = Variable(text_positions)
speaker_ids = None if speaker_id is None else Variable(torch.LongTensor([speaker_id]))
if use_cuda:
sequence = sequence.cuda()
text_positions = text_positions.cuda()
speaker_ids = None if speaker_ids is None else speaker_ids.cuda()
mel = mel.cuda()
frame_positions = frame_positions.cuda()

# **Teacher forcing** decoding
mel_outputs, _, _, _ = model(
sequence, mel, text_positions=text_positions,
frame_positions=frame_positions, speaker_ids=speaker_ids)

mel_output = mel_outputs[0].data.cpu().numpy()
# **Time resolution adjustment**
mel_output = mel_output[:-(b_pad + e_pad)]

wav = np.load(join(in_dir, audio_filename))
assert len(wav) % hparams.hop_size == 0

# Coarse upsample just for convenience
# so that we can upsample conditional features by hop_size in wavenet
if downsample_step > 0:
mel_output = np.repeat(mel_output, downsample_step, axis=0)
# downsampling -> upsampling, then we should have length equal to or larger than
# the original mel length
assert mel_output.shape[0] >= mel_org.shape[0]

# Make sure we have correct lengths
assert mel_output.shape[0] * hparams.hop_size == len(wav)

timesteps = len(wav)

# save
np.save(join(out_dir, audio_filename), wav, allow_pickle=False)
np.save(join(out_dir, mel_filename), mel_output.astype(np.float32),
allow_pickle=False)

if speaker_id is None:
return (audio_filename, mel_filename, timesteps, text)
else:
return (audio_filename, mel_filename, timesteps, text, speaker_id)


def write_metadata(metadata, out_dir):
with open(os.path.join(out_dir, 'train.txt'), 'w', encoding='utf-8') as f:
for m in metadata:
f.write('|'.join([str(x) for x in m]) + '\n')
frames = sum([m[2] for m in metadata])
sr = hparams.sample_rate
hours = frames / sr / 3600
print('Wrote %d utterances, %d time steps (%.2f hours)' % (len(metadata), frames, hours))
print('Max input length: %d' % max(len(m[3]) for m in metadata))
print('Max output length: %d' % max(m[2] for m in metadata))


if __name__ == "__main__":
args = docopt(__doc__)
checkpoint_path = args["<checkpoint>"]
in_dir = args["<in_dir>"]
out_dir = args["<out_dir>"]
preset = args["--preset"]

# Load preset if specified
if preset is not None:
with open(preset) as f:
hparams.parse_json(f.read())
# Override hyper parameters
hparams.parse(args["--hparams"])
assert hparams.name == "deepvoice3"

_frontend = getattr(frontend, hparams.frontend)
import train
train._frontend = _frontend
from train import build_model

model = build_model()

# Load checkpoint
print("Load checkpoint from {}".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["state_dict"])

os.makedirs(out_dir, exist_ok=True)
results = []
with open(os.path.join(in_dir, "train.txt")) as f:
lines = f.readlines()

for idx in tqdm(range(len(lines))):
l = lines[idx]
l = l[:-1].split("|")
audio_filename, mel_filename, _, text = l[:4]
speaker_id = int(l[4]) if len(l) > 4 else None
if text == "N/A":
raise RuntimeError("No transcription available")

result = preprocess(model, in_dir, out_dir, text, audio_filename,
mel_filename, p=0,
speaker_id=speaker_id, fast=True)
results.append(result)

write_metadata(results, out_dir)

sys.exit(0)
22 changes: 11 additions & 11 deletions hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,33 +43,33 @@
# whether to rescale waveform or not.
# Let x is an input waveform, rescaled waveform y is given by:
# y = x / np.abs(x).max() * rescaling_max
rescaling=False,
rescaling=True,
rescaling_max=0.999,
# mel-spectrogram is normalized to [0, 1] for each utterance and clipping may
# happen depends on min_level_db and ref_level_db, causing clipping noise.
# If False, assertion is added to ensure no clipping happens.
allow_clipping_in_normalization=True,

# Model:
downsample_step=4, # must be 4 when builder="nyanko"
outputs_per_step=1, # must be 1 when builder="nyanko"
downsample_step=1, # must be 4 when builder="nyanko"
outputs_per_step=4, # must be 1 when builder="nyanko"
embedding_weight_std=0.1,
speaker_embedding_weight_std=0.01,
padding_idx=0,
# Maximum number of input text length
# try setting larger value if you want to give very long text input
max_positions=512,
dropout=1 - 0.95,
kernel_size=3,
text_embed_dim=128,
encoder_channels=256,
decoder_channels=256,
max_positions=2048,
dropout=1 - 0.90,
kernel_size=5,
text_embed_dim=256,
encoder_channels=512,
decoder_channels=512,
# Note: large converter channels requires significant computational cost
converter_channels=256,
query_position_rate=1.0,
# can be computed by `compute_timestamp_ratio.py`.
key_position_rate=1.385, # 2.37 for jsut
key_projection=False,
key_projection=True,
value_projection=False,
use_memory_mask=True,
trainable_positional_encodings=False,
Expand Down Expand Up @@ -99,7 +99,7 @@
adam_beta1=0.5,
adam_beta2=0.9,
adam_eps=1e-6,
initial_learning_rate=5e-4, # 0.001,
initial_learning_rate=1e-3, # 0.001,
lr_schedule="noam_learning_rate_decay",
lr_schedule_kwargs={},
nepochs=2000,
Expand Down
1 change: 0 additions & 1 deletion preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def write_metadata(metadata, out_dir):
# Override hyper parameters
hparams.parse(args["--hparams"])
assert hparams.name == "deepvoice3"
print(hparams_debug_string())

assert name in ["jsut", "ljspeech", "vctk", "nikl_m", "nikl_s", "json_meta"]
mod = importlib.import_module(name)
Expand Down
65 changes: 65 additions & 0 deletions presets/deepvoice3_ljspeech_wavenet.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
{
"name": "deepvoice3",
"frontend": "en",
"replace_pronunciation_prob": 0.5,
"builder": "deepvoice3",
"n_speakers": 1,
"speaker_embed_dim": 16,
"num_mels": 80,
"fmin": 125,
"fmax": 7600,
"fft_size": 1024,
"hop_size": 256,
"sample_rate": 22050,
"preemphasis": 0.97,
"min_level_db": -100,
"ref_level_db": 20,
"rescaling": true,
"rescaling_max": 0.999,
"allow_clipping_in_normalization": true,
"downsample_step": 1,
"outputs_per_step": 4,
"embedding_weight_std": 0.1,
"speaker_embedding_weight_std": 0.01,
"padding_idx": 0,
"max_positions": 2048,
"dropout": 0.09999999999999998,
"kernel_size": 5,
"text_embed_dim": 256,
"encoder_channels": 512,
"decoder_channels": 512,
"converter_channels": 256,
"query_position_rate": 1.0,
"key_position_rate": 1.385,
"key_projection": true,
"value_projection": false,
"use_memory_mask": true,
"trainable_positional_encodings": false,
"freeze_embedding": false,
"use_decoder_state_for_postnet_input": true,
"pin_memory": true,
"num_workers": 2,
"masked_loss_weight": 0.5,
"priority_freq": 3000,
"priority_freq_weight": 0.0,
"binary_divergence_weight": 0.1,
"use_guided_attention": true,
"guided_attention_sigma": 0.2,
"batch_size": 16,
"adam_beta1": 0.5,
"adam_beta2": 0.9,
"adam_eps": 1e-06,
"initial_learning_rate": 0.001,
"lr_schedule": "noam_learning_rate_decay",
"lr_schedule_kwargs": {},
"nepochs": 2000,
"weight_decay": 0.0,
"clip_thresh": 0.1,
"checkpoint_interval": 10000,
"eval_interval": 10000,
"save_optimizer_state": true,
"force_monotonic_attention": true,
"window_ahead": 3,
"window_backward": 1,
"power": 1.4
}
Loading