Skip to content

Commit

Permalink
unit test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Aya-AlJafari committed Oct 13, 2023
1 parent 6eaecab commit ffddf10
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions TTS/tts/models/forward_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,8 @@ def _forward_encoder(
- x_mask: :math:`(B, 1, T_{en})`
- g: :math:`(B, C)`
"""
g = g.type(torch.LongTensor)
if hasattr(self, "emb_g"):
g = g.type(torch.LongTensor)
g = self.emb_g(g) # [B, C, 1]
if g is not None:
g = g.unsqueeze(-1)
Expand Down Expand Up @@ -684,8 +684,7 @@ def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # p
# encoder pass
o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g)
# duration predictor pass
o_en = o_en.squeeze()
o_dr_log = self.duration_predictor(o_en, x_mask)
o_dr_log = self.duration_predictor(o_en.squeeze(), x_mask)
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
y_lengths = o_dr.sum(1)

Expand Down

0 comments on commit ffddf10

Please sign in to comment.