From c32e79dc1397b58720b9b82e6ef9027b903ba7f9 Mon Sep 17 00:00:00 2001 From: danielle-hausler Date: Wed, 27 Nov 2024 10:14:12 +0200 Subject: [PATCH 1/4] added wav2vec2 model, and related conf/utils files. --- soundbay/conf/augmentations/wav2vec2.yaml | 20 +++++++++ soundbay/conf/model/wav2vec2.yaml | 9 ++++ soundbay/conf/optim/wav2vec2.yaml | 10 +++++ soundbay/conf/preprocessors/wav2vec2.yaml | 2 + soundbay/conf/runs/wav2vec2.yaml | 8 ++++ soundbay/conf_dict.py | 5 ++- soundbay/models.py | 54 +++++++++++++++++++++++ soundbay/utils/files_handler.py | 40 +++++++++++++++++ soundbay/utils/logging.py | 29 ++++++------ 9 files changed, 162 insertions(+), 15 deletions(-) create mode 100644 soundbay/conf/augmentations/wav2vec2.yaml create mode 100644 soundbay/conf/model/wav2vec2.yaml create mode 100644 soundbay/conf/optim/wav2vec2.yaml create mode 100644 soundbay/conf/preprocessors/wav2vec2.yaml create mode 100644 soundbay/conf/runs/wav2vec2.yaml create mode 100644 soundbay/utils/files_handler.py diff --git a/soundbay/conf/augmentations/wav2vec2.yaml b/soundbay/conf/augmentations/wav2vec2.yaml new file mode 100644 index 00000000..b2221317 --- /dev/null +++ b/soundbay/conf/augmentations/wav2vec2.yaml @@ -0,0 +1,20 @@ +# @package _global_ +_augmentations: + time_stretch: + _target_: audiomentations.TimeStretch + min_rate: 0.9 + max_rate: 1.1 + p: 0 + time_masking: + _target_: audiomentations.TimeMask + min_band_part: 0.05 + max_band_part: 0.2 + p: 0 + frequency_masking: + _target_: audiomentations.BandStopFilter + + min_center_freq: ${data.min_freq} + max_center_freq: ${data.max_freq} + min_bandwidth_fraction: 0.05 + max_bandwidth_fraction: 0.2 + p: 0 \ No newline at end of file diff --git a/soundbay/conf/model/wav2vec2.yaml b/soundbay/conf/model/wav2vec2.yaml new file mode 100644 index 00000000..0f749dad --- /dev/null +++ b/soundbay/conf/model/wav2vec2.yaml @@ -0,0 +1,9 @@ +# @package _global_ +model: + criterion: + _target_: torch.nn.CrossEntropyLoss + model: + _target_: models.WAV2VEC2 + num_classes: 2 + pretrained: True + freeze_encoder: False \ No newline at end of file diff --git a/soundbay/conf/optim/wav2vec2.yaml b/soundbay/conf/optim/wav2vec2.yaml new file mode 100644 index 00000000..cbc8068e --- /dev/null +++ b/soundbay/conf/optim/wav2vec2.yaml @@ -0,0 +1,10 @@ +# @package _global_ +optim: + epochs: 100 + optimizer: + _target_: torch.optim.Adam + lr: 0.001 + scheduler: + _target_: torch.optim.lr_scheduler.ExponentialLR + gamma: 0.995 + freeze_layers_for_finetune: True \ No newline at end of file diff --git a/soundbay/conf/preprocessors/wav2vec2.yaml b/soundbay/conf/preprocessors/wav2vec2.yaml new file mode 100644 index 00000000..ed294921 --- /dev/null +++ b/soundbay/conf/preprocessors/wav2vec2.yaml @@ -0,0 +1,2 @@ +# @package _global_ +_preprocessors: [] \ No newline at end of file diff --git a/soundbay/conf/runs/wav2vec2.yaml b/soundbay/conf/runs/wav2vec2.yaml new file mode 100644 index 00000000..996d4f28 --- /dev/null +++ b/soundbay/conf/runs/wav2vec2.yaml @@ -0,0 +1,8 @@ +# @package _global_ +defaults: + - ../data: defaults + - ../augmentations: wav2vec2 + - ../preprocessors: wav2vec2 + - ../model: wav2vec2 + - ../optim: wav2vec2 + - ../experiment: defaults \ No newline at end of file diff --git a/soundbay/conf_dict.py b/soundbay/conf_dict.py index ee3d3307..07758235 100644 --- a/soundbay/conf_dict.py +++ b/soundbay/conf_dict.py @@ -3,7 +3,7 @@ ------- These dicts describe the allowed values of the soundbay framework ''' -from soundbay.models import ResNet1Channel, GoogleResNet50withPCEN, ChristophCNN, ResNet182D, Squeezenet2D +from soundbay.models import ResNet1Channel, GoogleResNet50withPCEN, ChristophCNN, ResNet182D, Squeezenet2D, WAV2VEC2 from soundbay.data import ClassifierDataset, InferenceDataset, NoBackGroundDataset import torch from audiomentations import PitchShift, BandStopFilter, TimeMask, TimeStretch @@ -12,7 +12,8 @@ 'models.GoogleResNet50withPCEN': GoogleResNet50withPCEN, 'models.ResNet182D': ResNet182D, 'models.Squeezenet2D': Squeezenet2D, - 'models.ChristophCNN': ChristophCNN} + 'models.ChristophCNN': ChristophCNN, + 'models.WAV2VEC2': WAV2VEC2} datasets_dict = {'soundbay.data.ClassifierDataset': ClassifierDataset, 'soundbay.data.NoBackGroundDataset': NoBackGroundDataset, diff --git a/soundbay/models.py b/soundbay/models.py index fefb1f9e..a173cfdf 100644 --- a/soundbay/models.py +++ b/soundbay/models.py @@ -1,11 +1,17 @@ import importlib +from typing import Union + import torch +import torchaudio import torch.nn as nn from torch import Tensor from torchvision.models.resnet import ResNet, BasicBlock, conv3x3, Bottleneck from torchvision.models.vgg import VGG from torchvision.models import squeezenet, ResNet18_Weights import torchvision.models as models +from torchaudio.models import wav2vec2_model + +from soundbay.utils.files_handler import load_config class ResNet1Channel(ResNet): @@ -364,3 +370,51 @@ def __init__(self, num_classes=2, pretrained=True): def forward(self, x): x = x.repeat(1, 3, 1, 1) return self.resnet(x) + + +class WAV2VEC2(nn.Module): + def __init__( + self, + num_classes: int = 2, + config: Union[str, dict] = torchaudio.pipelines.WAV2VEC2_BASE._params, + path: str = f'https://download.pytorch.org/torchaudio/models/{torchaudio.pipelines.WAV2VEC2_BASE._path}', + pretrained: bool = True, + freeze_encoder: bool = False + ): + super(WAV2VEC2, self).__init__() + if isinstance(config, str): + config = load_config(config) + config['aux_num_out'] = config.get('aux_num_out', None) + embedding_dim = config['encoder_embed_dim'] + + self.freeze_encoder = freeze_encoder + self.wav2vec = wav2vec2_model(**config) + if pretrained: + # Load a pre-trained WAV2VEC2 + self.wav2vec.load_state_dict(torch.hub.load_state_dict_from_url(path)) + self.fc = nn.Linear(in_features=embedding_dim, out_features=num_classes) + + def forward(self, x): + x = self.extract_features(x) + return self.fc(x) + + def extract_features(self, x): + # this is separated from forward to allow feature extraction. + # sometimes for a batch of samples our raw input is [batch, 1, time] + if len(x.shape) > 2: + x = torch.squeeze(x, dim=1) + x = self.wav2vec.extract_features(x)[0] + # mean pooling over the layers: [layers, batch, time, features] -> [batch, time, features] + x = torch.stack(x, dim=0).mean(dim=0) + # mean pooling over the time: [batch, time, features] -> [batch, features] + return x.mean(dim=1) + + + def freeze_layers(self, ): + # to avoid overfitting the feature extractor is frozen + self.wav2vec.feature_extractor.requires_grad_(False) + # it is possible to freeze the encoder as well + # note that extract_features is using the encoder + if self.freeze_encoder: + for param in self.wav2vec.encoder.parameters(): + param.requires_grad = False diff --git a/soundbay/utils/files_handler.py b/soundbay/utils/files_handler.py new file mode 100644 index 00000000..22144f50 --- /dev/null +++ b/soundbay/utils/files_handler.py @@ -0,0 +1,40 @@ +import json +import yaml +import urllib + +def load_config(filepath: str): + assert filepath.endswith(("json", "yaml", "yml")), "Only json and yaml files are supported." + + is_url = filepath.startswith("http") + is_json = filepath.endswith("json") + is_yaml = filepath.endswith(("yaml", "yml")) + + if is_url and is_json: + return load_json_from_url(filepath) + elif is_url and is_yaml: + return load_yaml_from_url(filepath) + elif is_json: + return load_json(filepath) + elif is_yaml: + return load_yaml(filepath) + else: + raise ValueError("File format not supported.") + + +def load_json_from_url(url: str): + with urllib.request.urlopen(url) as url: + return json.load(url) + +def load_yaml_from_url(url: str): + with urllib.request.urlopen(url) as url: + return yaml.safe_load(url) + +def load_json(filepath: str): + with open(filepath, "r") as file: + return json.load(file) + +def load_yaml(filepath: str): + with open(filepath, "r") as file: + return yaml.safe_load(file) + + diff --git a/soundbay/utils/logging.py b/soundbay/utils/logging.py index e51d3260..d9bd710d 100644 --- a/soundbay/utils/logging.py +++ b/soundbay/utils/logging.py @@ -172,22 +172,25 @@ def upload_artifacts(self, audio: torch.Tensor, label: torch.Tensor, raw_wav: to artifact_wav = torch.squeeze(raw_wav).detach().cpu().numpy() artifact_wav = artifact_wav / np.expand_dims(np.abs(artifact_wav).max(axis=1) + 1e-8, 1) * 0.5 # gain -6dB list_of_wavs_objects = [wandb.Audio(data_or_path=wav, caption=f'{flag}_label{lab}_i{ind}_{round(b_t/data_sample_rate,2)}sec_{f_n}', sample_rate=sample_rate) for wav, ind, lab, b_t, f_n in zip(artifact_wav,idx, label, meta['begin_time'], meta['org_file'])] - - # Spectrograms batch - artifact_spec = torch.squeeze(audio).detach().cpu().numpy() - specs = [] - for artifact_id in range(artifact_spec.shape[0]): - ax = plt.subplots(nrows=1, ncols=1) - specs.append(librosa.display.specshow(artifact_spec[artifact_id,...], ax=ax[1])) - plt.close('all') - del ax - list_of_specs_objects = [wandb.Image(data_or_path=spec, caption=f'{flag}_label{lab}_i{ind}_{round(b_t/data_sample_rate,2)}sec_{f_n}') for spec, ind, lab, b_t, f_n in zip(specs,idx, label, meta['begin_time'], meta['org_file'])] log_wavs = {f'First batch {flag} original wavs': list_of_wavs_objects} - log_specs = {f'First batch {flag} augmented spectrogram\'s': list_of_specs_objects} - # Upload to W&B + # Spectrograms batch + if audio.dim() >= 4: # In case that spectrogram preprocessing was not applied the dimension is 3. + artifact_spec = torch.squeeze(audio).detach().cpu().numpy() + specs = [] + for artifact_id in range(artifact_spec.shape[0]): + ax = plt.subplots(nrows=1, ncols=1) + specs.append(librosa.display.specshow(artifact_spec[artifact_id,...], ax=ax[1])) + plt.close('all') + del ax + list_of_specs_objects = [wandb.Image(data_or_path=spec, caption=f'{flag}_label{lab}_i{ind}_{round(b_t/data_sample_rate,2)}sec_{f_n}') for spec, ind, lab, b_t, f_n in zip(specs,idx, label, meta['begin_time'], meta['org_file'])] + log_specs = {f'First batch {flag} augmented spectrogram\'s': list_of_specs_objects} + # Upload spectrograms to W&B + wandb.log(log_specs, commit=False) + + # Upload WAVs to W&B wandb.log(log_wavs, commit=False) - wandb.log(log_specs, commit=False) + @staticmethod def get_metrics_dict(label_list: Union[list, np.ndarray], pred_list: Union[list, np.ndarray], From 379d7b024f9ab13133cc23f912524ca636eb60a0 Mon Sep 17 00:00:00 2001 From: Danielle Hausler <74451897+danielle-hausler@users.noreply.github.com> Date: Tue, 3 Dec 2024 21:34:11 +0200 Subject: [PATCH 2/4] Update soundbay/models.py Co-authored-by: Tomer Nahshon <33577556+Z30G0D@users.noreply.github.com> --- soundbay/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/soundbay/models.py b/soundbay/models.py index a173cfdf..40d34d10 100644 --- a/soundbay/models.py +++ b/soundbay/models.py @@ -405,7 +405,7 @@ def extract_features(self, x): x = torch.squeeze(x, dim=1) x = self.wav2vec.extract_features(x)[0] # mean pooling over the layers: [layers, batch, time, features] -> [batch, time, features] - x = torch.stack(x, dim=0).mean(dim=0) + x = torch.stack(x, dim=0).mean(dim=(0,2)) # mean pooling over the time: [batch, time, features] -> [batch, features] return x.mean(dim=1) From 58816940c2b1c4936f42f0bf76b765caecd2427d Mon Sep 17 00:00:00 2001 From: Danielle Hausler <74451897+danielle-hausler@users.noreply.github.com> Date: Tue, 3 Dec 2024 21:38:00 +0200 Subject: [PATCH 3/4] Update models.py --- soundbay/models.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/soundbay/models.py b/soundbay/models.py index 40d34d10..32c925c9 100644 --- a/soundbay/models.py +++ b/soundbay/models.py @@ -9,7 +9,6 @@ from torchvision.models.vgg import VGG from torchvision.models import squeezenet, ResNet18_Weights import torchvision.models as models -from torchaudio.models import wav2vec2_model from soundbay.utils.files_handler import load_config @@ -388,7 +387,7 @@ def __init__( embedding_dim = config['encoder_embed_dim'] self.freeze_encoder = freeze_encoder - self.wav2vec = wav2vec2_model(**config) + self.wav2vec = torchaudio.models.wav2vec2_model(**config) if pretrained: # Load a pre-trained WAV2VEC2 self.wav2vec.load_state_dict(torch.hub.load_state_dict_from_url(path)) @@ -404,10 +403,9 @@ def extract_features(self, x): if len(x.shape) > 2: x = torch.squeeze(x, dim=1) x = self.wav2vec.extract_features(x)[0] - # mean pooling over the layers: [layers, batch, time, features] -> [batch, time, features] + # mean pooling over the layers and time: [layers, batch, time, features] -> [batch, features] x = torch.stack(x, dim=0).mean(dim=(0,2)) - # mean pooling over the time: [batch, time, features] -> [batch, features] - return x.mean(dim=1) + return x def freeze_layers(self, ): From 92d4ea24a886f1b16c5fe83c78b15d5b81864b34 Mon Sep 17 00:00:00 2001 From: Danielle Hausler <74451897+danielle-hausler@users.noreply.github.com> Date: Tue, 3 Dec 2024 22:08:05 +0200 Subject: [PATCH 4/4] Update models.py fixed indentation --- soundbay/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/soundbay/models.py b/soundbay/models.py index cdd3b70c..4c3c46ca 100644 --- a/soundbay/models.py +++ b/soundbay/models.py @@ -416,7 +416,7 @@ def forward(self, x): return self.efficientnet(x) - class WAV2VEC2(nn.Module): +class WAV2VEC2(nn.Module): def __init__( self, num_classes: int = 2,