From ffddf1045874c98e0ebede168db3803502a58e4c Mon Sep 17 00:00:00 2001 From: Aya Jafari Date: Fri, 13 Oct 2023 10:56:47 -0300 Subject: [PATCH] unit test fix --- TTS/tts/models/forward_tts.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index 8dfc6c03bf..9e1b1c4097 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -395,8 +395,8 @@ def _forward_encoder( - x_mask: :math:`(B, 1, T_{en})` - g: :math:`(B, C)` """ - g = g.type(torch.LongTensor) if hasattr(self, "emb_g"): + g = g.type(torch.LongTensor) g = self.emb_g(g) # [B, C, 1] if g is not None: g = g.unsqueeze(-1) @@ -684,8 +684,7 @@ def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # p # encoder pass o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g) # duration predictor pass - o_en = o_en.squeeze() - o_dr_log = self.duration_predictor(o_en, x_mask) + o_dr_log = self.duration_predictor(o_en.squeeze(), x_mask) o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) y_lengths = o_dr.sum(1)