Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfixes after merge. operational different-song training #17

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Prev Previous commit
Next Next commit
adding feature and quality loss module
  • Loading branch information
csteinmetz1 committed Jan 29, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 186841b5cef3b45ea1437fbc61230d3a526d7f70
115 changes: 90 additions & 25 deletions mst/loss.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
from mst.fx_encoder import FXencoder

from mst.modules import SpectrogramEncoder
from mst.quality_system import QualityEstimationSystem


def compute_mid_side(x: torch.Tensor):
@@ -414,7 +415,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor):


class FX_encoder_loss(torch.nn.Module):
def __init__(self, distance: Callable = torch.nn.functional.mse_loss, audiofeatures = True, weights: list[float]= [1.0],):
def __init__(
self,
distance: Callable = torch.nn.functional.mse_loss,
audiofeatures=True,
weights: list[float] = [1.0],
):
super().__init__()
self.distance = distance
config_path = "/homes/ssv02/Diff-MST/configs/models/fx_encoder_mst.yaml"
@@ -423,42 +429,42 @@ def __init__(self, distance: Callable = torch.nn.functional.mse_loss, audiofeatu
self.config = yaml.safe_load(f)
checkpoint_path = "/homes/ssv02/Diff-MST/data/FXencoder_ps.pt"
self.ddp = True
#self.embed_distance = torch.nn.CosineEmbeddingLoss(reduction = 'mean')
# self.embed_distance = torch.nn.CosineEmbeddingLoss(reduction = 'mean')
self.embed_similarity = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

# load model
self.model = FXencoder(self.config["Effects_Encoder"]['default'])

# load model
self.model = FXencoder(self.config["Effects_Encoder"]["default"])

# load checkpoint
checkpoint = torch.load(checkpoint_path)

from collections import OrderedDict

new_state_dict = OrderedDict()
for k, v in checkpoint["model"].items():
# remove `module.` if the model was trained with DDP
name = k[7:] if self.ddp else k
new_state_dict[name] = v

# load params
self.model.load_state_dict(new_state_dict)
self.model.eval()

# freeze all parameters in model
for param in self.model.parameters():
param.requires_grad = False

def compute_fx_embeds(x: torch.Tensor):
embed = self.model(x)
return embed
#weights = [0.1,0.001,1.0,1.0,0.1,100.0]

# weights = [0.1,0.001,1.0,1.0,0.1,100.0]
self.weights = weights
self.transforms = []

if audiofeatures:
self.audiofeatures = audiofeatures
self.audiofeatures = audiofeatures

self.transforms = [
compute_rms,
compute_crest_factor,
@@ -468,8 +474,7 @@ def compute_fx_embeds(x: torch.Tensor):
]

self.transforms.append(compute_fx_embeds)



assert len(self.transforms) == len(self.weights)

def forward(self, input: torch.Tensor, target: torch.Tensor):
@@ -482,28 +487,89 @@ def forward(self, input: torch.Tensor, target: torch.Tensor):
# loss = self.distance(input_embed, target_embed)

# return loss

for transform, weight in zip(self.transforms, self.weights):
transform_name = "_".join(transform.__name__.split("_")[1:])
#print(transform_name)
# print(transform_name)
input_transform = transform(input)
target_transform = transform(target)
if transform_name == "fx_embeds":
val = 1-self.embed_similarity(input_transform, target_transform).mean().clamp(min=1e-8)
#print(val)
val = 1 - self.embed_similarity(
input_transform, target_transform
).mean().clamp(min=1e-8)
# print(val)
else:
val = torch.nn.functional.mse_loss(input_transform, target_transform)
#print(val)
# print(val)
losses[transform_name] = weight * val

return losses



class QualityLoss(torch.nn.Module):
def __init__(
self,
ckpt_path: str,
) -> None:
super().__init__()
# hard-coded model configuration
encoder = SpectrogramEncoder(
embed_dim=512,
n_inputs=1,
input_batchnorm=False,
encoder_batchnorm=False,
l2_norm=True,
)

self.model = QualityEstimationSystem.load_from_checkpoint(
ckpt_path, encoder=encoder
)
self.model.eval()
self.model.freeze()

def forward(self, input: torch.Tensor, *args, **kwargs):
"""Compute loss on stereo mixes using featues from quality model.

Args:
input: (bs, 2, seq_len)
"""
logits = self.model(input) # higher is better (high quality)
return -logits


class FeatureAndQualityLoss(torch.nn.Module):
def __init__(
self,
weights: List[float],
sample_rate: int,
quality_ckpt_path: str,
quality_weight: float = 1.0,
stem_separation: bool = False,
use_clap: bool = False,
):
super().__init__()
self.feature_loss = AudioFeatureLoss(
weights=weights,
sample_rate=sample_rate,
stem_separation=stem_separation,
use_clap=use_clap,
)
self.quality_loss = QualityLoss(quality_ckpt_path)
self.quality_weight = quality_weight

def forward(self, input: torch.Tensor, target: torch.Tensor):
feature_losses = self.feature_loss(input, target)
quality_loss = self.quality_loss(input)
feature_losses["quality"] = quality_loss * self.quality_weight
return feature_losses


# if __name__ == "__main__":
# import torchaudio
# path = "/import/c4dm-datasets-ext/mtg-jamendo_wav/02/1012002.wav"

# #input1, sr = torchaudio.load(path, channels_first = True, num_frames = 44100*10)

# input1= torch.zeros(2,44100*10)
# input2 = input1
# #input2, sr = torchaudio.load(path, channels_first = True, num_frames = 44100*10, frame_offset = 44100*10)
@@ -516,4 +582,3 @@ def forward(self, input: torch.Tensor, target: torch.Tensor):
# losses = loss(input1, input2)
# print(losses)
# print(sum(losses.values()))