diff --git a/pitch/diffusion.py b/pitch/diffusion.py index 61528bd..c6e93ad 100644 --- a/pitch/diffusion.py +++ b/pitch/diffusion.py @@ -9,24 +9,6 @@ def forward(self, x): return x * torch.tanh(torch.nn.functional.softplus(x)) -class Upsample(BaseModule): - def __init__(self, dim): - super(Upsample, self).__init__() - self.conv = torch.nn.ConvTranspose2d(dim, dim, (1,4), (1,2), 1) - - def forward(self, x): - return self.conv(x) - - -class Downsample(BaseModule): - def __init__(self, dim): - super(Downsample, self).__init__() - self.conv = torch.nn.Conv2d(dim, dim, (1,3), (1,2), 1) - - def forward(self, x): - return self.conv(x) - - class Rezero(BaseModule): def __init__(self, fn): super(Rezero, self).__init__() @@ -117,15 +99,15 @@ def forward(self, x, scale=1000): class GradLogPEstimator2d(BaseModule): - def __init__(self, dim, c_dim, n_mels, dim_mults=(1, 2, 4), groups=8, pe_scale=1000): + def __init__(self, n_feat, n_cond, dim, dim_mults=(1, 2, 4), groups=8, pe_scale=1000): super(GradLogPEstimator2d, self).__init__() self.dim = dim self.dim_mults = dim_mults self.groups = groups self.pe_scale = pe_scale - self.cond = torch.nn.Sequential(torch.nn.Conv1d(c_dim, dim * 4, 1), Mish(), - torch.nn.Conv1d(dim * 4, n_mels, 1)) + self.cond = torch.nn.Sequential(torch.nn.Conv1d(n_cond, dim * 4, 1), Mish(), + torch.nn.Conv1d(dim * 4, n_feat, 1)) self.time_pos_emb = SinusoidalPosEmb(dim) self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim)) @@ -142,7 +124,7 @@ def __init__(self, dim, c_dim, n_mels, dim_mults=(1, 2, 4), groups=8, pe_scale=1 ResnetBlock(dim_in, dim_out, time_emb_dim=dim), ResnetBlock(dim_out, dim_out, time_emb_dim=dim), Residual(Rezero(LinearAttention(dim_out))), - Downsample(dim_out) if not is_last else torch.nn.Identity()])) + torch.nn.Identity()])) mid_dim = dims[-1] self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) @@ -151,14 +133,14 @@ def __init__(self, dim, c_dim, n_mels, dim_mults=(1, 2, 4), groups=8, pe_scale=1 for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): # 2 ups self.ups.append(torch.nn.ModuleList([ - ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim), + ResnetBlock(dim_out, dim_in, time_emb_dim=dim), ResnetBlock(dim_in, dim_in, time_emb_dim=dim), Residual(Rezero(LinearAttention(dim_in))), - Upsample(dim_in)])) + torch.nn.Identity()])) self.final_block = Block(dim, dim) self.final_conv = torch.nn.Conv2d(dim, 1, 1) - def forward(self, c, x, mask, mu, t): + def forward(self, x, mask, mu, c, t): t = self.time_pos_emb(t, scale=self.pe_scale) t = self.mlp(t) @@ -167,30 +149,21 @@ def forward(self, c, x, mask, mu, t): x = torch.stack([mu, x, c], 1) mask = mask.unsqueeze(1) - hiddens = [] - masks = [mask] for resnet1, resnet2, attn, downsample in self.downs: - mask_down = masks[-1] - x = resnet1(x, mask_down, t) - x = resnet2(x, mask_down, t) + x = resnet1(x, mask, t) + x = resnet2(x, mask, t) x = attn(x) - hiddens.append(x) - x = downsample(x * mask_down) - masks.append(mask_down[:, :, :, ::2]) + x = downsample(x * mask) - masks = masks[:-1] - mask_mid = masks[-1] - x = self.mid_block1(x, mask_mid, t) + x = self.mid_block1(x, mask, t) x = self.mid_attn(x) - x = self.mid_block2(x, mask_mid, t) + x = self.mid_block2(x, mask, t) for resnet1, resnet2, attn, upsample in self.ups: - mask_up = masks.pop() - x = torch.cat((x, hiddens.pop()), dim=1) - x = resnet1(x, mask_up, t) - x = resnet2(x, mask_up, t) + x = resnet1(x, mask, t) + x = resnet2(x, mask, t) x = attn(x) - x = upsample(x * mask_up) + x = upsample(x * mask) x = self.final_block(x, mask) output = self.final_conv(x * mask) @@ -198,69 +171,102 @@ def forward(self, c, x, mask, mu, t): return (output * mask).squeeze(1) -def get_noise(t, beta_init, beta_term, cumulative=False): - if cumulative: - noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2) - else: - noise = beta_init + (beta_term - beta_init)*t - return noise - - class Diffusion(BaseModule): - def __init__(self, n_mels, dim, c_dim, beta_min=0.05, beta_max=20, pe_scale=1000): + def __init__(self, n_feat, n_cond, dim, beta_min=0.05, beta_max=20, pe_scale=1000): super(Diffusion, self).__init__() - self.n_mels = n_mels + self.estimator = GradLogPEstimator2d(n_feat, n_cond, dim, pe_scale=pe_scale) + self.n_feat = n_feat self.beta_min = beta_min self.beta_max = beta_max - self.estimator = GradLogPEstimator2d(dim, c_dim, n_mels, pe_scale=pe_scale) - def forward_diffusion(self, mel, mask, mu, t): - time = t.unsqueeze(-1).unsqueeze(-1) - cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True) - mean = mel*torch.exp(-0.5*cum_noise) + mu*(1.0 - torch.exp(-0.5*cum_noise)) - variance = 1.0 - torch.exp(-cum_noise) - z = torch.randn(mel.shape, dtype=mel.dtype, device=mel.device, - requires_grad=False) - xt = mean + z * torch.sqrt(variance) - return xt * mask, z * mask + def get_beta(self, t): + beta = self.beta_min + (self.beta_max - self.beta_min) * t + return beta + + def get_gamma(self, s, t, p=1.0, use_torch=False): + beta_integral = self.beta_min + 0.5 * (self.beta_max - self.beta_min) * (t + s) + beta_integral *= (t - s) + if use_torch: + gamma = torch.exp(-0.5 * p * beta_integral).unsqueeze(-1).unsqueeze(-1) + else: + gamma = math.exp(-0.5 * p * beta_integral) + return gamma + + def get_mu(self, s, t): + a = self.get_gamma(s, t) + b = 1.0 - self.get_gamma(0, s, p=2.0) + c = 1.0 - self.get_gamma(0, t, p=2.0) + return a * b / c + + def get_nu(self, s, t): + a = self.get_gamma(0, s) + b = 1.0 - self.get_gamma(s, t, p=2.0) + c = 1.0 - self.get_gamma(0, t, p=2.0) + return a * b / c + + def get_sigma(self, s, t): + a = 1.0 - self.get_gamma(0, s, p=2.0) + b = 1.0 - self.get_gamma(s, t, p=2.0) + c = 1.0 - self.get_gamma(0, t, p=2.0) + return math.sqrt(a * b / c) @torch.no_grad() - def reverse_diffusion(self, c, z, mask, mu, n_timesteps, stoc=False): + def reverse_diffusion(self, z, mask, mu, mu_c, n_timesteps): h = 1.0 / n_timesteps xt = z * mask + for i in range(n_timesteps): - t = (1.0 - (i + 0.5)*h) * torch.ones(z.shape[0], dtype=z.dtype, - device=z.device) - time = t.unsqueeze(-1).unsqueeze(-1) - noise_t = get_noise(time, self.beta_min, self.beta_max, - cumulative=False) - if stoc: # adds stochastic term - dxt_det = 0.5 * (mu - xt) - self.estimator(c, xt, mask, mu, t) - dxt_det = dxt_det * noise_t * h - dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device, - requires_grad=False) - dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h) - dxt = dxt_det + dxt_stoc - else: - dxt = 0.5 * (mu - xt - self.estimator(c, xt, mask, mu, t)) - dxt = dxt * noise_t * h + t = 1.0 - i * h + time = t * torch.ones(z.shape[0], dtype=z.dtype, device=z.device) + beta_t = self.get_beta(t) + + kappa = self.get_gamma(0, t - h) * (1.0 - self.get_gamma(t - h, t, p=2.0)) + kappa /= (self.get_gamma(0, t) * beta_t * h) + kappa -= 1.0 + omega = self.get_nu(t - h, t) / self.get_gamma(0, t) + omega += self.get_mu(t - h, t) + omega -= (0.5 * beta_t * h + 1.0) + sigma = self.get_sigma(t - h, t) + + dxt = (mu - xt) * (0.5 * beta_t * h + omega) + dxt -= (self.estimator(xt, mask, mu, mu_c, time)) * (1.0 + kappa) * (beta_t * h) + dxt += torch.randn_like(z, device=z.device) * sigma xt = (xt - dxt) * mask + return xt @torch.no_grad() - def forward(self, c, z, mask, mu, n_timesteps, stoc=False): - return self.reverse_diffusion(c, z, mask, mu, n_timesteps, stoc) + def forward(self, z, mask, mu, mu_c, n_timesteps): + return self.reverse_diffusion(z, mask, mu, mu_c, n_timesteps) + + # train: mel means f0_groun_truth + def get_noise(self, t, beta_init, beta_term, cumulative=False): + if cumulative: + noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2) + else: + noise = beta_init + (beta_term - beta_init)*t + return noise + + def forward_diffusion(self, mel, mask, mu, t): + time = t.unsqueeze(-1).unsqueeze(-1) + cum_noise = self.get_noise(time, self.beta_min, self.beta_max, cumulative=True) + mean = mel*torch.exp(-0.5*cum_noise) + mu*(1.0 - torch.exp(-0.5*cum_noise)) + variance = 1.0 - torch.exp(-cum_noise) + z = torch.randn(mel.shape, dtype=mel.dtype, device=mel.device, + requires_grad=False) + xt = mean + z * torch.sqrt(variance) + return xt * mask, z * mask - def loss_t(self, c, mel, mask, mu, t): + def loss_t(self, mel, mask, mu, mu_c, t): xt, z = self.forward_diffusion(mel, mask, mu, t) time = t.unsqueeze(-1).unsqueeze(-1) - cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True) - noise_estimation = self.estimator(c, xt, mask, mu, t) + cum_noise = self.get_noise(time, self.beta_min, self.beta_max, cumulative=True) + noise_estimation = self.estimator(xt, mask, mu, mu_c, t) noise_estimation *= torch.sqrt(1.0 - torch.exp(-cum_noise)) - loss = torch.sum((noise_estimation + z)**2) / (torch.sum(mask)*self.n_mels) + loss = torch.sum((noise_estimation + z)**2) / (torch.sum(mask)*self.n_feat) return loss, xt - def compute_loss(self, c, mel, mask, mu, offset=1e-5): + def compute_loss(self, mel, mask, mu, mu_c, offset=1e-5): t = torch.rand(mel.shape[0], dtype=mel.dtype, device=mel.device, requires_grad=False) t = torch.clamp(t, offset, 1.0 - offset) - return self.loss_t(c, mel, mask, mu, t) + return self.loss_t(mel, mask, mu, mu_c, t)