Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XTTS v1.1 GPT Trainer #3086

Merged
merged 24 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
a32961b
Add XTTS base training code
Edresson Oct 11, 2023
40a4e63
Update mel spectrogram for the style encoder
Edresson Oct 11, 2023
47d613d
Add reproducible evaluation
Edresson Oct 13, 2023
bafab04
Add prompting masking
Edresson Oct 16, 2023
2f868dd
Bug fix on reproducible evaluation
Edresson Oct 16, 2023
c4ceaab
Add test sentences during the training
Edresson Oct 16, 2023
9e3598c
Bug Fix on inference using XTTS trainer checkpoint
Edresson Oct 18, 2023
469d624
Update LJspeech XTTS recipe
Edresson Oct 18, 2023
5f98dbe
Update Ljspeech XTTS recipe
Edresson Oct 18, 2023
94dcf84
Rename XTTS recipe
Edresson Oct 18, 2023
1f92741
Fix issue #2971
Edresson Oct 18, 2023
affaf11
Add XTTS training unit test
Edresson Oct 18, 2023
ec7f547
Rebase bug fix and update recipe
Edresson Oct 21, 2023
e8a1a50
Remove unused vars in Delightful TTS layers tests
Edresson Oct 23, 2023
653f2e7
Update xtts trainer recipe
Edresson Oct 23, 2023
8853e1c
Update XTTS recipe to only download checkpoint if it is needed
Edresson Oct 23, 2023
6fefc36
Update XTTS docs
Edresson Oct 23, 2023
1ee8096
Update XTTS docs
Edresson Oct 23, 2023
37b7945
Update XTTS train not implemented error to point to the XTTS docs
Edresson Oct 23, 2023
67ca70a
Fix Delightful TTS layers unit test
Edresson Oct 23, 2023
0f96abb
Add FT inference example on XTTS docs
Edresson Oct 23, 2023
de1d521
Update XTTS docs
Edresson Oct 23, 2023
8af3d2d
Add a dedicated workflow for XTTS tests
Edresson Oct 24, 2023
01839af
Bug fix on XTTS masking training
Edresson Oct 24, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions .github/workflows/xtts_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
name: xtts-tests

on:
push:
branches:
- main
pull_request:
types: [opened, synchronize, reopened]
jobs:
check_skip:
runs-on: ubuntu-latest
if: "! contains(github.event.head_commit.message, '[ci skip]')"
steps:
- run: echo "${{ github.event.head_commit.message }}"

test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: [3.9, "3.10", "3.11"]
experimental: [false]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
architecture: x64
cache: 'pip'
cache-dependency-path: 'requirements*'
- name: check OS
run: cat /etc/os-release
- name: set ENV
run: export TRAINER_TELEMETRY=0
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get install -y --no-install-recommends git make gcc
sudo apt-get install espeak
sudo apt-get install espeak-ng
make system-deps
- name: Install/upgrade Python setup deps
run: python3 -m pip install --upgrade pip setuptools wheel
- name: Replace scarf urls
run: |
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
- name: Install TTS
run: |
python3 -m pip install .[all]
python3 setup.py egg_info
- name: Unit tests
run: make test_xtts
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ test_tts: ## run tts tests.
test_tts2: ## run tts tests.
nose2 -F -v -B --with-coverage --coverage TTS tests.tts_tests2

test_xtts:
nose2 -F -v -B --with-coverage --coverage TTS tests.xtts_tests

test_aux: ## run aux tests.
nose2 -F -v -B --with-coverage --coverage TTS tests.aux_tests
./run_bash_tests.sh
Expand Down
69 changes: 40 additions & 29 deletions TTS/tts/layers/xtts/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def init_gpt_for_inference(self, kv_cache=True, use_deepspeed=False):

if use_deepspeed:
import deepspeed

self.ds_engine = deepspeed.init_inference(
model=self.gpt_inference.half(), # Transformers models
mp_size=1, # Number of GPU
Expand Down Expand Up @@ -233,6 +234,7 @@ def get_logits(
prompt=None,
get_attns=False,
return_latent=False,
attn_mask_cond=None,
attn_mask_text=None,
attn_mask_mel=None,
):
Expand All @@ -248,8 +250,11 @@ def get_logits(
if attn_mask_text is not None:
attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
if prompt is not None:
attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1)
if attn_mask_cond is not None:
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
else:
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)

gpt_out = self.gpt(
inputs_embeds=emb,
Expand Down Expand Up @@ -326,7 +331,7 @@ def get_prompts(self, prompt_codes):
prompt = F.pad(prompt, (0, 1), value=self.stop_prompt_token)
return prompt

def get_style_emb(self, cond_input, cond_lens=None, cond_seg_len=None, return_latent=False, sample=True):
def get_style_emb(self, cond_input, return_latent=False):
"""
cond_input: (b, 80, s) or (b, 1, 80, s)
conds: (b, 1024, s)
Expand All @@ -335,26 +340,7 @@ def get_style_emb(self, cond_input, cond_lens=None, cond_seg_len=None, return_la
if not return_latent:
if cond_input.ndim == 4:
cond_input = cond_input.squeeze(1)
if sample:
_len_secs = random.randint(2, 6) # in secs
cond_seg_len = int((22050 / 1024) * _len_secs) # in frames
if cond_input.shape[-1] >= cond_seg_len:
new_conds = []
for i in range(cond_input.shape[0]):
cond_len = int(cond_lens[i] / 1024)
if cond_len < cond_seg_len:
start = 0
else:
start = random.randint(0, cond_len - cond_seg_len)
cond_vec = cond_input[i, :, start : start + cond_seg_len]
new_conds.append(cond_vec)
conds = torch.stack(new_conds, dim=0)
else:
cond_seg_len = 5 if cond_seg_len is None else cond_seg_len # secs
cond_frame_len = int((22050 / 1024) * cond_seg_len)
conds = cond_input[:, :, -cond_frame_len:]

conds = self.conditioning_encoder(conds)
conds = self.conditioning_encoder(cond_input)
else:
# already computed
conds = cond_input.unsqueeze(1)
Expand All @@ -366,22 +352,22 @@ def forward(
text_lengths,
audio_codes,
wav_lengths,
cond_lens=None,
cond_mels=None,
cond_idxs=None,
cond_latents=None,
loss_weights=None,
return_attentions=False,
return_latent=False,
):
"""
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`).

cond_mels: MEL float tensor, (b, 1, 80,s)
text_inputs: long tensor, (b,t)
text_lengths: long tensor, (b,)
mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,)
cond_mels: MEL float tensor, (b, 1, 80,s)
cond_idxs: cond start and end indexs, (b, 2)

If return_attentions is specified, only logits are returned.
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
Expand All @@ -393,6 +379,11 @@ def forward(
max_text_len = text_lengths.max()
code_lengths = torch.ceil(wav_lengths / self.code_stride_len).long() + 3

if cond_idxs is not None:
# recompute cond idxs for mel lengths
for idx, l in enumerate(code_lengths):
cond_idxs[idx] = cond_idxs[idx] / self.code_stride_len

# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
max_mel_len = code_lengths.max()

Expand Down Expand Up @@ -435,9 +426,16 @@ def forward(
)

# Set attn_mask
attn_mask_cond = None
attn_mask_text = None
attn_mask_mel = None
if not return_latent:
attn_mask_cond = torch.ones(
cond_mels.shape[0],
cond_mels.shape[-1],
dtype=torch.bool,
device=text_inputs.device,
)
attn_mask_text = torch.ones(
text_inputs.shape[0],
text_inputs.shape[1],
Expand All @@ -451,6 +449,11 @@ def forward(
device=audio_codes.device,
)

if cond_idxs is not None:
for idx, r in enumerate(cond_idxs):
l = r[1] - r[0]
attn_mask_cond[idx, l:] = 0.0

for idx, l in enumerate(text_lengths):
attn_mask_text[idx, l + 1 :] = 0.0

Expand All @@ -465,7 +468,7 @@ def forward(

# Compute speech conditioning input
if cond_latents is None:
cond_latents = self.get_style_emb(cond_mels, cond_lens).transpose(1, 2)
cond_latents = self.get_style_emb(cond_mels).transpose(1, 2)

# Get logits
sub = -5 # don't ask me why 😄
Expand All @@ -480,6 +483,7 @@ def forward(
prompt=cond_latents,
get_attns=return_attentions,
return_latent=return_latent,
attn_mask_cond=attn_mask_cond,
attn_mask_text=attn_mask_text,
attn_mask_mel=attn_mask_mel,
)
Expand All @@ -501,6 +505,13 @@ def forward(
0
], f" ❗ mel_targets does not contain stop token ({self.stop_audio_token}) in every row."

# ignore the loss for the segment used for conditioning
# coin flip for the segment to be ignored
if cond_idxs is not None:
cond_start = cond_idxs[idx, 0]
cond_end = cond_idxs[idx, 1]
mel_targets[idx, cond_start:cond_end] = -1

# Compute losses
loss_text = F.cross_entropy(
text_logits, text_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing
Expand Down Expand Up @@ -548,7 +559,7 @@ def generate(
bos_token_id=self.start_audio_token,
pad_token_id=self.stop_audio_token,
eos_token_id=self.stop_audio_token,
max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
max_length=self.max_mel_tokens,
**hf_generate_kwargs,
)
if "return_dict_in_generate" in hf_generate_kwargs:
Expand All @@ -561,7 +572,7 @@ def get_generator(self, fake_inputs, **hf_generate_kwargs):
bos_token_id=self.start_audio_token,
pad_token_id=self.stop_audio_token,
eos_token_id=self.stop_audio_token,
max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
max_length=self.max_mel_tokens,
do_stream=True,
**hf_generate_kwargs,
)
29 changes: 9 additions & 20 deletions TTS/tts/layers/xtts/hifigan_decoder.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import torch
import torchaudio
from torch import nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
import torchaudio

from TTS.utils.io import load_fsspec


LRELU_SLOPE = 0.1


Expand Down Expand Up @@ -224,9 +223,7 @@ def __init__(
self.cond_in_each_up_layer = cond_in_each_up_layer

# initial upsampling layers
self.conv_pre = weight_norm(
Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
)
self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
resblock = ResBlock1 if resblock_type == "1" else ResBlock2
# upsampling layers
self.ups = nn.ModuleList()
Expand All @@ -246,14 +243,10 @@ def __init__(
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
# post convolution layer
self.conv_post = weight_norm(
Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)
)
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
if cond_channels > 0:
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)

Expand Down Expand Up @@ -318,9 +311,7 @@ def inference(self, c):
Tensor: [B, 1, T]
"""
c = c.to(self.conv_pre.weight.device)
c = torch.nn.functional.pad(
c, (self.inference_padding, self.inference_padding), "replicate"
)
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
return self.forward(c)

def remove_weight_norm(self):
Expand All @@ -342,6 +333,7 @@ def load_checkpoint(
assert not self.training
self.remove_weight_norm()


class SELayer(nn.Module):
def __init__(self, channel, reduction=8):
super(SELayer, self).__init__()
Expand Down Expand Up @@ -425,10 +417,8 @@ def forward(self, x):
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)



class ResNetSpeakerEncoder(nn.Module):
"""This is copied from 🐸TTS to remove it from the dependencies.
"""
"""This is copied from 🐸TTS to remove it from the dependencies."""

# pylint: disable=W0102
def __init__(
Expand Down Expand Up @@ -620,6 +610,7 @@ def load_checkpoint(
return criterion, state["step"]
return criterion


class HifiDecoder(torch.nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -724,9 +715,7 @@ def inference(self, c, g):
"""
return self.forward(c, g=g)

def load_checkpoint(
self, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
def load_checkpoint(self, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
# remove unused keys
state = state["model"]
Expand Down
Loading
Loading