Skip to content

Commit

Permalink
Update some utils
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue committed Sep 21, 2023
1 parent 25d19e9 commit 492b0c7
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 27 deletions.
5 changes: 3 additions & 2 deletions configs/_base_/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
log_every_n_steps=10,
val_check_interval=5000,
check_val_every_n_epoch=None,
max_steps=1_000_000,
max_steps=2_000_000,
# Warning: If you are training the model with fs2 (and see nan), you should either use bf16 or fp32
precision="16-mixed",
precision="32",
accumulate_grad_batches=4,
callbacks=[
ModelCheckpoint(
filename="{epoch}-{step}-{valid_loss:.4f}",
Expand Down
8 changes: 7 additions & 1 deletion configs/tts_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
mel_encoder=dict(
type="TransformerEncoder",
input_size=bert_dim,
output_size=mel_channels,
output_size=mel_channels * 2,
hidden_size=bert_dim,
num_layers=4,
),
Expand Down Expand Up @@ -114,6 +114,12 @@
),
)

dataloader = dict(
train=dict(
batch_size=5,
),
)

preprocessing = dict(
text_features_extractor=dict(
type="BertTokenizer",
Expand Down
68 changes: 51 additions & 17 deletions fish_diffusion/archs/diffsinger/grad_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def __init__(self, model_config):
if getattr(model_config, "speaker_encoder", None):
self.speaker_encoder = ENCODERS.build(model_config.speaker_encoder)

self.current_mas_noise_scale = nn.Parameter(
torch.tensor(1e-2, dtype=torch.float32), requires_grad=False
)

@staticmethod
def get_mask_from_lengths(lengths, max_len=None):
batch_size = lengths.shape[0]
Expand Down Expand Up @@ -116,7 +120,8 @@ def forward_features(
if speaker_embed is not None:
features += speaker_embed

mu_x = self.mel_encoder(features, src_masks).transpose(1, 2)
stats_p = self.mel_encoder(features, src_masks).transpose(1, 2)
mu_p, logs_p = torch.split(stats_p, stats_p.shape[1] // 2, dim=1)
logw = self.duration_predictor(features.detach(), src_masks).transpose(1, 2)
x_mask = src_masks.unsqueeze(1)

Expand All @@ -139,46 +144,65 @@ def forward_features(
).unsqueeze(1)

# Align encoded text and get mu_y
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
z_p = mu_p + torch.randn_like(mu_p) * torch.exp(logs_p) * 0.667
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), z_p.transpose(1, 2))
mel_features = mu_y[:, :y_max_length, :]

return dict(
features=mel_features,
mel_masks=mel_masks,
)

y = mel.transpose(1, 2)
z_p = mel.transpose(1, 2)
y_mask = (~mel_masks).float().unsqueeze(1)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)

# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
with torch.no_grad():
const = -0.5 * math.log(2 * math.pi) * mel.shape[-1]
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
y_square = torch.matmul(factor.transpose(1, 2), y**2)
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
log_prior = y_square - y_mu_double + mu_square + const
# negative cross-entropy
s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
neg_cent1 = torch.sum(
-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
) # [b, 1, t_s]
neg_cent2 = torch.matmul(
-0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
neg_cent3 = torch.matmul(
z_p.transpose(1, 2), (mu_p * s_p_sq_r)
) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
neg_cent4 = torch.sum(
-0.5 * (mu_p**2) * s_p_sq_r, [1], keepdim=True
) # [b, 1, t_s]
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4

# Noise scale
epsilon = (
torch.std(neg_cent)
* torch.randn_like(neg_cent)
* self.current_mas_noise_scale
)
neg_cent = neg_cent + epsilon

attn = maximum_path(log_prior, attn_mask.squeeze(1))
attn = attn.detach()
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()

# Compute loss between predicted log-scaled durations and those obtained from MAS
logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask

logw_ = torch.log(1e-8 + attn.sum(2)) * x_mask
dur_loss = torch.sum((logw - logw_) ** 2) / torch.sum(src_masks)

# Align encoded text with mel-spectrogram and get mu_y segment
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
mu_p = torch.matmul(attn.squeeze(1), mu_p.transpose(1, 2)).transpose(1, 2)
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
z_p_hat = mu_p + torch.randn_like(mu_p) * torch.exp(logs_p) * 0.667

# Compute loss between aligned encoder outputs and mel-spectrogram
prior_loss = torch.sum(
0.5 * ((y - mu_y.transpose(1, 2)) ** 2 + math.log(2 * math.pi)) * y_mask
0.5 * ((z_p - z_p_hat) ** 2 + math.log(2 * math.pi)) * y_mask
)
prior_loss = prior_loss / (torch.sum(y_mask) * mel.shape[-1])
prior_loss = prior_loss / (torch.sum(y_mask) * z_p.shape[-2])

return dict(
features=mu_y,
features=z_p_hat.transpose(1, 2),
mel_masks=mel_masks,
loss=dur_loss + prior_loss,
metrics={
Expand Down Expand Up @@ -232,6 +256,16 @@ def forward(
metrics.update(features["metrics"])
output_dict["metrics"] = metrics

if self.training:
output_dict["metrics"]["noise_scale"] = float(self.current_mas_noise_scale)

# Update MAS noise scale
self.current_mas_noise_scale -= 2e-6
if self.current_mas_noise_scale < 0.0:
self.current_mas_noise_scale -= (
self.current_mas_noise_scale
) # clip to 0

# For validation
output_dict["features"] = features["features"]

Expand Down
19 changes: 13 additions & 6 deletions fish_diffusion/utils/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,19 @@ def viz_synth_sample(
wav_reconstruction = vocoder.spec2wav(mel_target, pitch)
wav_prediction = vocoder.spec2wav(mel_prediction, pitch)

wav_reconstruction = loudness_norm(
wav_reconstruction.cpu().float().numpy(), 44100, block_size=0.1
)
wav_prediction = loudness_norm(
wav_prediction.cpu().float().numpy(), 44100, block_size=0.1
)
try:
wav_reconstruction = loudness_norm(
wav_reconstruction.cpu().float().numpy(), 44100, block_size=0.1
)
except:
wav_reconstruction = wav_reconstruction.cpu().float().numpy()

try:
wav_prediction = loudness_norm(
wav_prediction.cpu().float().numpy(), 44100, block_size=0.1
)
except:
wav_prediction = wav_prediction.cpu().float().numpy()

wav_reconstruction = torch.from_numpy(wav_reconstruction)
wav_prediction = torch.from_numpy(wav_prediction)
Expand Down
2 changes: 1 addition & 1 deletion tools/preprocessing/extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def safe_process(args, config, audio_path: Path):
return aug_count + 1
except Exception as e:
logger.error(f"Error processing {audio_path}")
logger.exception(e)
# logger.exception(e)


def parse_args():
Expand Down

0 comments on commit 492b0c7

Please sign in to comment.