Skip to content

Commit

Permalink
likely hood
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxMax2016 authored Nov 13, 2023
1 parent 4649279 commit ef2cea4
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions pitch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def forward(self, phone, lengths, score, slurs):
x = self.emb_phone(phone) + self.emb_score(score) + self.emb_slurs(slurs)
x = x * math.sqrt(self.hidden_channels) # [b, t, h]
x = torch.transpose(x, 1, -1) # [b, h, t]
c = x
x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
x.dtype
)
x = self.enc(x * x_mask, x_mask)
c = x
x = self.proj(x)
return x, x_mask, c

Expand All @@ -53,19 +53,19 @@ def __init__(self):
super().__init__()
self.pit_encoder = TextEncoder(hidden_channels=192, filter_channels=768,
n_heads=2, n_layers=5, kernel_size=5, p_dropout=0.1)
self.decoder = Diffusion(2, 64, 192, beta_min=0.05, beta_max=20.0, pe_scale=1000)
self.decoder = Diffusion(2, 192, 64, beta_min=0.05, beta_max=20.0, pe_scale=1000)


@torch.no_grad()
def forward(self, phone, lengths, score, slurs, n_timesteps, temperature=1.0, stoc=False):
def forward(self, phone, lengths, score, slurs, n_timesteps, temperature=1.0):
# Encoder
mu_x, mask_x, c = self.pit_encoder(phone, lengths, score, slurs)
encoder_outputs = mu_x

# Sample latent representation from terminal distribution N(mu_y, I)
z = mu_x + torch.randn_like(mu_x, device=mu_x.device) / temperature
# Generate sample by performing reverse dynamics
decoder_outputs = self.decoder(c, z, mask_x, mu_x, n_timesteps, stoc)
decoder_outputs = self.decoder(z, mask_x, mu_x, c, n_timesteps)
return encoder_outputs, decoder_outputs

def compute_loss(self, phone, lengths, score, slurs, pitch, out_size):
Expand Down Expand Up @@ -95,6 +95,6 @@ def compute_loss(self, phone, lengths, score, slurs, pitch, out_size):
mu_x = slice_segments(mu_x, ids, out_size)
c = slice_segments(c, ids, out_size)

diff_loss, xt = self.decoder.compute_loss(c, pitch_gt, mask_x, mu_x)
diff_loss, xt = self.decoder.compute_loss(pitch_gt, mask_x, mu_x, c)
return prior_loss, diff_loss

0 comments on commit ef2cea4

Please sign in to comment.