Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove duplicate/unused code #3243

Merged
merged 5 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions TTS/bin/train_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@

import torch
from torch.utils.data import DataLoader
from trainer.io import copy_model_files, save_best_model, save_checkpoint
from trainer.torch import NoamLR
from trainer.trainer_utils import get_optimizer

from TTS.encoder.dataset import EncoderDataset
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.encoder.utils.training import init_training
from TTS.encoder.utils.visual import plot_embeddings
from TTS.tts.datasets import load_tts_samples
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
from TTS.utils.io import copy_model_files
from TTS.utils.samplers import PerfectBatchSampler
from TTS.utils.training import check_update

Expand Down Expand Up @@ -222,7 +222,9 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,

if global_step % c.save_step == 0:
# save model
save_checkpoint(model, optimizer, criterion, loss.item(), OUT_PATH, global_step, epoch)
save_checkpoint(
c, model, optimizer, None, global_step, epoch, OUT_PATH, criterion=criterion.state_dict()
)

end_time = time.time()

Expand All @@ -245,7 +247,18 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
flush=True,
)
# save the best checkpoint
best_loss = save_best_model(model, optimizer, criterion, eval_loss, best_loss, OUT_PATH, global_step, epoch)
best_loss = save_best_model(
eval_loss,
best_loss,
c,
model,
optimizer,
None,
global_step,
epoch,
OUT_PATH,
criterion=criterion.state_dict(),
)
model.train()

return best_loss, global_step
Expand Down Expand Up @@ -276,7 +289,7 @@ def main(args): # pylint: disable=redefined-outer-name

if c.loss == "softmaxproto" and c.model != "speaker_encoder":
c.map_classid_to_classname = map_classid_to_classname
copy_model_files(c, OUT_PATH)
copy_model_files(c, OUT_PATH, new_fields={})

if args.restore_path:
criterion, args.restore_step = model.load_checkpoint(
Expand Down
46 changes: 0 additions & 46 deletions TTS/encoder/utils/generic_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import datetime
import glob
import os
import random
import re

import numpy as np
from scipy import signal

from TTS.encoder.models.lstm import LSTMSpeakerEncoder
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
from TTS.utils.io import save_fsspec


class AugmentWAV(object):
Expand Down Expand Up @@ -118,11 +115,6 @@ def apply_one(self, audio):
return self.additive_noise(noise_type, audio)


def to_camel(text):
text = text.capitalize()
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)


def setup_encoder_model(config: "Coqpit"):
if config.model_params["model_name"].lower() == "lstm":
model = LSTMSpeakerEncoder(
Expand All @@ -142,41 +134,3 @@ def setup_encoder_model(config: "Coqpit"):
audio_config=config.audio,
)
return model


def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch):
checkpoint_path = "checkpoint_{}.pth".format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path)
print(" | | > Checkpoint saving : {}".format(checkpoint_path))

new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict() if optimizer is not None else None,
"criterion": criterion.state_dict(),
"step": current_step,
"epoch": epoch,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
save_fsspec(state, checkpoint_path)


def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step, epoch):
if model_loss < best_loss:
new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict(),
"criterion": criterion.state_dict(),
"step": current_step,
"epoch": epoch,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
best_loss = model_loss
bestmodel_path = "best_model.pth"
bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
save_fsspec(state, bestmodel_path)
return best_loss
38 changes: 0 additions & 38 deletions TTS/encoder/utils/io.py

This file was deleted.

2 changes: 1 addition & 1 deletion TTS/encoder/utils/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

from coqpit import Coqpit
from trainer import TrainerArgs, get_last_checkpoint
from trainer.io import copy_model_files
from trainer.logging import logger_factory
from trainer.logging.console_logger import ConsoleLogger

from TTS.config import load_config, register_config
from TTS.tts.utils.text.characters import parse_symbols
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
from TTS.utils.io import copy_model_files


@dataclass
Expand Down
146 changes: 0 additions & 146 deletions TTS/utils/io.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import datetime
import json
import os
import pickle as pickle_tts
import shutil
from typing import Any, Callable, Dict, Union

import fsspec
import torch
from coqpit import Coqpit

from TTS.utils.generic_utils import get_user_data_dir

Expand All @@ -28,34 +24,6 @@ def __init__(self, *args, **kwargs):
self.__dict__ = self


def copy_model_files(config: Coqpit, out_path, new_fields=None):
"""Copy config.json and other model files to training folder and add
new fields.

Args:
config (Coqpit): Coqpit config defining the training run.
out_path (str): output path to copy the file.
new_fields (dict): new fileds to be added or edited
in the config file.
"""
copy_config_path = os.path.join(out_path, "config.json")
# add extra information fields
if new_fields:
config.update(new_fields, allow_new=True)
# TODO: Revert to config.save_json() once Coqpit supports arbitrary paths.
with fsspec.open(copy_config_path, "w", encoding="utf8") as f:
json.dump(config.to_dict(), f, indent=4)

# copy model stats file if available
if config.audio.stats_path is not None:
copy_stats_path = os.path.join(out_path, "scale_stats.npy")
filesystem = fsspec.get_mapper(copy_stats_path).fs
if not filesystem.exists(copy_stats_path):
with fsspec.open(config.audio.stats_path, "rb") as source_file:
with fsspec.open(copy_stats_path, "wb") as target_file:
shutil.copyfileobj(source_file, target_file)


def load_fsspec(
path: str,
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
Expand Down Expand Up @@ -100,117 +68,3 @@ def load_checkpoint(
if eval:
model.eval()
return model, state


def save_fsspec(state: Any, path: str, **kwargs):
"""Like torch.save but can save to other locations (e.g. s3:// , gs://).

Args:
state: State object to save
path: Any path or url supported by fsspec.
**kwargs: Keyword arguments forwarded to torch.save.
"""
with fsspec.open(path, "wb") as f:
torch.save(state, f, **kwargs)


def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
if hasattr(model, "module"):
model_state = model.module.state_dict()
else:
model_state = model.state_dict()
if isinstance(optimizer, list):
optimizer_state = [optim.state_dict() for optim in optimizer]
elif optimizer.__class__.__name__ == "CapacitronOptimizer":
optimizer_state = [optimizer.primary_optimizer.state_dict(), optimizer.secondary_optimizer.state_dict()]
else:
optimizer_state = optimizer.state_dict() if optimizer is not None else None

if isinstance(scaler, list):
scaler_state = [s.state_dict() for s in scaler]
else:
scaler_state = scaler.state_dict() if scaler is not None else None

if isinstance(config, Coqpit):
config = config.to_dict()

state = {
"config": config,
"model": model_state,
"optimizer": optimizer_state,
"scaler": scaler_state,
"step": current_step,
"epoch": epoch,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
state.update(kwargs)
save_fsspec(state, output_path)


def save_checkpoint(
config,
model,
optimizer,
scaler,
current_step,
epoch,
output_folder,
**kwargs,
):
file_name = "checkpoint_{}.pth".format(current_step)
checkpoint_path = os.path.join(output_folder, file_name)
print("\n > CHECKPOINT : {}".format(checkpoint_path))
save_model(
config,
model,
optimizer,
scaler,
current_step,
epoch,
checkpoint_path,
**kwargs,
)


def save_best_model(
current_loss,
best_loss,
config,
model,
optimizer,
scaler,
current_step,
epoch,
out_path,
keep_all_best=False,
keep_after=10000,
**kwargs,
):
if current_loss < best_loss:
best_model_name = f"best_model_{current_step}.pth"
checkpoint_path = os.path.join(out_path, best_model_name)
print(" > BEST MODEL : {}".format(checkpoint_path))
save_model(
config,
model,
optimizer,
scaler,
current_step,
epoch,
checkpoint_path,
model_loss=current_loss,
**kwargs,
)
fs = fsspec.get_mapper(out_path).fs
# only delete previous if current is saved successfully
if not keep_all_best or (current_step < keep_after):
model_names = fs.glob(os.path.join(out_path, "best_model*.pth"))
for model_name in model_names:
if os.path.basename(model_name) != best_model_name:
fs.rm(model_name)
# create a shortcut which always points to the currently best model
shortcut_name = "best_model.pth"
shortcut_path = os.path.join(out_path, shortcut_name)
fs.copy(checkpoint_path, shortcut_path)
best_loss = current_loss
return best_loss
4 changes: 2 additions & 2 deletions tests/aux_tests/test_embedding_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import numpy as np
import torch
from trainer.io import save_checkpoint

from tests import get_tests_input_path
from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.encoder.utils.io import save_checkpoint
from TTS.tts.utils.managers import EmbeddingManager
from TTS.utils.audio import AudioProcessor

Expand All @@ -31,7 +31,7 @@ def test_speaker_embedding():

# create a dummy speaker encoder
model = setup_encoder_model(config)
save_checkpoint(model, None, None, get_tests_input_path(), 0)
save_checkpoint(config, model, None, None, 0, 0, get_tests_input_path())

# load audio processor and speaker encoder
manager = EmbeddingManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path)
Expand Down
4 changes: 2 additions & 2 deletions tests/aux_tests/test_speaker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import numpy as np
import torch
from trainer.io import save_checkpoint

from tests import get_tests_input_path
from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.encoder.utils.io import save_checkpoint
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor

Expand All @@ -30,7 +30,7 @@ def test_speaker_embedding():

# create a dummy speaker encoder
model = setup_encoder_model(config)
save_checkpoint(model, None, None, get_tests_input_path(), 0)
save_checkpoint(config, model, None, None, 0, 0, get_tests_input_path())

# load audio processor and speaker encoder
ap = AudioProcessor(**config.audio)
Expand Down
Loading
Loading