Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposal to integrate into 🤗 Hub #555

Merged
merged 7 commits into from
May 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"tensorflow-gpu==2.3.1",
"tensorflow-addons>=0.10.0",
"setuptools>=38.5.1",
"huggingface_hub==0.0.8",
"librosa>=0.7.0",
"soundfile>=0.10.2",
"matplotlib>=3.1.0",
Expand Down
19 changes: 19 additions & 0 deletions tensorflow_tts/inference/auto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import logging
import yaml
import os
from collections import OrderedDict

from tensorflow_tts.configs import (
Expand All @@ -28,6 +29,10 @@
ParallelWaveGANGeneratorConfig,
)

from tensorflow_tts.utils import CACHE_DIRECTORY, CONFIG_FILE_NAME, LIBRARY_NAME
from tensorflow_tts import __version__ as VERSION
from huggingface_hub import hf_hub_url, cached_download

CONFIG_MAPPING = OrderedDict(
[
("fastspeech", FastSpeechConfig),
Expand All @@ -50,6 +55,20 @@ def __init__(self):

@classmethod
def from_pretrained(cls, pretrained_path, **kwargs):
# load weights from hf hub
if not os.path.isfile(pretrained_path):
# retrieve correct hub url
download_url = hf_hub_url(repo_id=pretrained_path, filename=CONFIG_FILE_NAME)

pretrained_path = str(
cached_download(
url=download_url,
library_name=LIBRARY_NAME,
library_version=VERSION,
cache_dir=CACHE_DIRECTORY,
)
)

with open(pretrained_path) as f:
config = yaml.load(f, Loader=yaml.SafeLoader)

Expand Down
35 changes: 34 additions & 1 deletion tensorflow_tts/inference/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import logging
import warnings
import os

from collections import OrderedDict

from tensorflow_tts.configs import (
Expand All @@ -40,6 +42,9 @@
SavableTFFastSpeech2,
SavableTFTacotron2
)
from tensorflow_tts.utils import CACHE_DIRECTORY, MODEL_FILE_NAME, LIBRARY_NAME
from tensorflow_tts import __version__ as VERSION
from huggingface_hub import hf_hub_url, cached_download


TF_MODEL_MAPPING = OrderedDict(
Expand All @@ -62,8 +67,35 @@ def __init__(self):
raise EnvironmentError("Cannot be instantiated using `__init__()`")

@classmethod
def from_pretrained(cls, config, pretrained_path=None, **kwargs):
def from_pretrained(cls, config=None, pretrained_path=None, **kwargs):
is_build = kwargs.pop("is_build", True)

# load weights from hf hub
if pretrained_path is not None:
if not os.path.isfile(pretrained_path):
# retrieve correct hub url
download_url = hf_hub_url(repo_id=pretrained_path, filename=MODEL_FILE_NAME)

downloaded_file = str(
cached_download(
url=download_url,
library_name=LIBRARY_NAME,
library_version=VERSION,
cache_dir=CACHE_DIRECTORY,
)
)

# load config from repo as well
if config is None:
from tensorflow_tts.inference import AutoConfig

config = AutoConfig.from_pretrained(pretrained_path)

pretraine_path = downloaded_file


assert config is not None, "Please make sure to pass a config along to load a model from a local file"

for config_class, model_class in TF_MODEL_MAPPING.items():
if isinstance(config, config_class) and str(config_class.__name__) in str(
config
Expand All @@ -79,6 +111,7 @@ def from_pretrained(cls, config, pretrained_path=None, **kwargs):
pretrained_path, by_name=True, skip_mismatch=True
)
return model

raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format(
Expand Down
18 changes: 18 additions & 0 deletions tensorflow_tts/inference/auto_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import logging
import json
import os
from collections import OrderedDict

from tensorflow_tts.processor import (
Expand All @@ -26,6 +27,10 @@
ThorstenProcessor,
)

from tensorflow_tts.utils import CACHE_DIRECTORY, PROCESSOR_FILE_NAME, LIBRARY_NAME
from tensorflow_tts import __version__ as VERSION
from huggingface_hub import hf_hub_url, cached_download

CONFIG_MAPPING = OrderedDict(
[
("LJSpeechProcessor", LJSpeechProcessor),
Expand All @@ -46,6 +51,19 @@ def __init__(self):

@classmethod
def from_pretrained(cls, pretrained_path, **kwargs):
# load weights from hf hub
if not os.path.isfile(pretrained_path):
# retrieve correct hub url
download_url = hf_hub_url(repo_id=pretrained_path, filename=PROCESSOR_FILE_NAME)

pretrained_path = str(
cached_download(
url=download_url,
library_name=LIBRARY_NAME,
library_version=VERSION,
cache_dir=CACHE_DIRECTORY,
)
)
with open(pretrained_path, "r") as f:
config = json.load(f)

Expand Down
8 changes: 8 additions & 0 deletions tensorflow_tts/processor/baker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pypinyin.converter import DefaultConverter
from pypinyin.core import Pinyin
from tensorflow_tts.processor import BaseProcessor
from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME

_pad = ["pad"]
_eos = ["eos"]
Expand Down Expand Up @@ -552,6 +553,13 @@ def __post_init__(self):
def setup_eos_token(self):
return _eos[0]

def save_pretrained(self, saved_path):
os.makedirs(saved_path, exist_ok=True)
self._save_mapper(
os.path.join(saved_path, PROCESSOR_FILE_NAME),
{"pinyin_dict": self.pinyin_dict},
)

def create_items(self):
items = []
if self.data_dir:
Expand Down
5 changes: 5 additions & 0 deletions tensorflow_tts/processor/base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,8 @@ def _save_mapper(self, saved_path: str = None, extra_attrs_to_save: dict = None)
if extra_attrs_to_save:
full_mapper = {**full_mapper, **extra_attrs_to_save}
json.dump(full_mapper, f)

@abc.abstractmethod
def save_pretrained(self, saved_path):
"""Save mappers to file"""
pass
5 changes: 5 additions & 0 deletions tensorflow_tts/processor/kss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tensorflow_tts.processor import BaseProcessor
from tensorflow_tts.utils import cleaners
from tensorflow_tts.utils.korean import symbols as KSS_SYMBOLS
from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME

# Regular expression matching text enclosed in curly braces:
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
Expand Down Expand Up @@ -57,6 +58,10 @@ def split_line(self, data_dir, line, split):
def setup_eos_token(self):
return "eos"

def save_pretrained(self, saved_path):
os.makedirs(saved_path, exist_ok=True)
self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {})

def get_one_sample(self, item):
text, wav_path, speaker_name = item

Expand Down
7 changes: 6 additions & 1 deletion tensorflow_tts/processor/libritts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from g2p_en import g2p as grapheme_to_phonem

from tensorflow_tts.processor.base_processor import BaseProcessor
from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME

g2p = grapheme_to_phonem.G2p()

Expand Down Expand Up @@ -84,7 +85,11 @@ def get_one_sample(self, item):
return sample

def setup_eos_token(self):
return None # because we do not use this
return None # because we do not use this

def save_pretrained(self, saved_path):
os.makedirs(saved_path, exist_ok=True)
self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {})

def text_to_sequence(self, text):
if (
Expand Down
5 changes: 5 additions & 0 deletions tensorflow_tts/processor/ljspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dataclasses import dataclass
from tensorflow_tts.processor import BaseProcessor
from tensorflow_tts.utils import cleaners
from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME

valid_symbols = [
"AA",
Expand Down Expand Up @@ -158,6 +159,10 @@ def split_line(self, data_dir, line, split):
def setup_eos_token(self):
return _eos

def save_pretrained(self, saved_path):
os.makedirs(saved_path, exist_ok=True)
self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {})

def get_one_sample(self, item):
text, wav_path, speaker_name = item

Expand Down
5 changes: 5 additions & 0 deletions tensorflow_tts/processor/thorsten.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dataclasses import dataclass
from tensorflow_tts.processor import BaseProcessor
from tensorflow_tts.utils import cleaners
from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME

_pad = "pad"
_eos = "eos"
Expand Down Expand Up @@ -67,6 +68,10 @@ def split_line(self, data_dir, line, split):
def setup_eos_token(self):
return _eos

def save_pretrained(self, saved_path):
os.makedirs(saved_path, exist_ok=True)
self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {})

def get_one_sample(self, item):
text, wav_path, speaker_name = item

Expand Down
2 changes: 1 addition & 1 deletion tensorflow_tts/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
calculate_3d_loss,
return_strategy,
)
from tensorflow_tts.utils.utils import find_files
from tensorflow_tts.utils.utils import find_files, MODEL_FILE_NAME, CONFIG_FILE_NAME, PROCESSOR_FILE_NAME, CACHE_DIRECTORY, LIBRARY_NAME
from tensorflow_tts.utils.weight_norm import WeightNormalization
7 changes: 7 additions & 0 deletions tensorflow_tts/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,16 @@
import os
import re
import tempfile
from pathlib import Path

import tensorflow as tf

MODEL_FILE_NAME = "model.h5"
CONFIG_FILE_NAME = "config.yml"
PROCESSOR_FILE_NAME = "processor.json"
LIBRARY_NAME = "tensorflow_tts"
CACHE_DIRECTORY = os.path.join(Path.home(), ".cache", LIBRARY_NAME)


def find_files(root_dir, query="*.wav", include_root_dir=True):
"""Find files recursively.
Expand Down
3 changes: 3 additions & 0 deletions test/test_base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def text_to_sequence(self, text):
def setup_eos_token(self):
return None

def save_pretrained(self, saved_path):
return super().save_pretrained(saved_path)


@pytest.fixture
def processor(tmpdir):
Expand Down