Skip to content

Commit

Permalink
Add reinit text encoder and duration predictor parameter (#1562)
Browse files Browse the repository at this point in the history
* Add reinit encoder and duration predictor option

* Add .data to prevent any overlooked autograd hook
  • Loading branch information
Edresson committed May 12, 2022
1 parent 1827110 commit 175ca06
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions TTS/tts/models/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,23 @@
mel_basis = {}


@torch.no_grad()
def weights_reset(m: nn.Module):
# check if the current module has reset_parameters and if it is reset the weight
reset_parameters = getattr(m, "reset_parameters", None)
if callable(reset_parameters):
m.reset_parameters()


def get_module_weights_sum(mdl: nn.Module):
dict_sums = {}
for name, w in mdl.named_parameters():
if "weight" in name:
value = w.data.sum().item()
dict_sums[name] = value
return dict_sums


def load_audio(file_path):
"""Load the audio file normalized in [-1, 1]
Expand Down Expand Up @@ -528,6 +545,8 @@ class VitsArgs(Coqpit):
freeze_waveform_decoder: bool = False
encoder_sample_rate: int = None
interpolate_z: bool = True
reinit_DP: bool = False
reinit_text_encoder: bool = False


class Vits(BaseTTS):
Expand Down Expand Up @@ -744,6 +763,28 @@ def init_upsampling(self):
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
) # pylint: disable=W0201

def on_init_end(self, trainer): # pylint: disable=W0613
"""Reinit layes if needed"""
if self.args.reinit_DP:
before_dict = get_module_weights_sum(self.duration_predictor)
# Applies weights_reset recursively to every submodule of the duration predictor
self.duration_predictor.apply(fn=weights_reset)
after_dict = get_module_weights_sum(self.duration_predictor)
for key, value in after_dict.items():
if value == before_dict[key]:
raise RuntimeError(" [!] The weights of Duration Predictor was not reinit check it !")
print(" > Duration Predictor was reinit.")

if self.args.reinit_text_encoder:
before_dict = get_module_weights_sum(self.text_encoder)
# Applies weights_reset recursively to every submodule of the duration predictor
self.text_encoder.apply(fn=weights_reset)
after_dict = get_module_weights_sum(self.text_encoder)
for key, value in after_dict.items():
if value == before_dict[key]:
raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !")
print(" > Text Encoder was reinit.")

def get_aux_input(self, aux_input: Dict):
sid, g, lid = self._set_cond_input(aux_input)
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
Expand Down

0 comments on commit 175ca06

Please sign in to comment.