Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added wav2vec2 model, and related conf/utils files. #107

Merged
merged 5 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions soundbay/conf/augmentations/wav2vec2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# @package _global_
_augmentations:
time_stretch:
_target_: audiomentations.TimeStretch
min_rate: 0.9
max_rate: 1.1
p: 0
time_masking:
_target_: audiomentations.TimeMask
min_band_part: 0.05
max_band_part: 0.2
p: 0
frequency_masking:
_target_: audiomentations.BandStopFilter

min_center_freq: ${data.min_freq}
max_center_freq: ${data.max_freq}
min_bandwidth_fraction: 0.05
max_bandwidth_fraction: 0.2
p: 0
9 changes: 9 additions & 0 deletions soundbay/conf/model/wav2vec2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# @package _global_
model:
criterion:
_target_: torch.nn.CrossEntropyLoss
model:
_target_: models.WAV2VEC2
num_classes: 2
pretrained: True
freeze_encoder: False
10 changes: 10 additions & 0 deletions soundbay/conf/optim/wav2vec2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# @package _global_
optim:
epochs: 100
optimizer:
_target_: torch.optim.Adam
lr: 0.001
scheduler:
_target_: torch.optim.lr_scheduler.ExponentialLR
gamma: 0.995
freeze_layers_for_finetune: True
2 changes: 2 additions & 0 deletions soundbay/conf/preprocessors/wav2vec2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# @package _global_
_preprocessors: []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there are no preprocessors why is this file needed and not use the defaults?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wav2vec is working on the raw input (the waveform) and the defaults file has a preprocess of the data into a spectrogram.

8 changes: 8 additions & 0 deletions soundbay/conf/runs/wav2vec2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# @package _global_
defaults:
- ../data: defaults
- ../augmentations: wav2vec2
- ../preprocessors: wav2vec2
- ../model: wav2vec2
- ../optim: wav2vec2
- ../experiment: defaults
5 changes: 3 additions & 2 deletions soundbay/conf_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
-------
These dicts describe the allowed values of the soundbay framework
'''
from soundbay.models import ResNet1Channel, GoogleResNet50withPCEN, ChristophCNN, ResNet182D, Squeezenet2D
from soundbay.models import ResNet1Channel, GoogleResNet50withPCEN, ChristophCNN, ResNet182D, Squeezenet2D, WAV2VEC2
from soundbay.data import ClassifierDataset, InferenceDataset, NoBackGroundDataset
import torch
from audiomentations import PitchShift, BandStopFilter, TimeMask, TimeStretch
Expand All @@ -12,7 +12,8 @@
'models.GoogleResNet50withPCEN': GoogleResNet50withPCEN,
'models.ResNet182D': ResNet182D,
'models.Squeezenet2D': Squeezenet2D,
'models.ChristophCNN': ChristophCNN}
'models.ChristophCNN': ChristophCNN,
'models.WAV2VEC2': WAV2VEC2}

datasets_dict = {'soundbay.data.ClassifierDataset': ClassifierDataset,
'soundbay.data.NoBackGroundDataset': NoBackGroundDataset,
Expand Down
54 changes: 54 additions & 0 deletions soundbay/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import importlib
from typing import Union

import torch
import torchaudio
import torch.nn as nn
from torch import Tensor
from torchvision.models.resnet import ResNet, BasicBlock, conv3x3, Bottleneck
from torchvision.models.vgg import VGG
from torchvision.models import squeezenet, ResNet18_Weights
import torchvision.models as models
from torchaudio.models import wav2vec2_model
danielle-hausler marked this conversation as resolved.
Show resolved Hide resolved

from soundbay.utils.files_handler import load_config


class ResNet1Channel(ResNet):
Expand Down Expand Up @@ -364,3 +370,51 @@ def __init__(self, num_classes=2, pretrained=True):
def forward(self, x):
x = x.repeat(1, 3, 1, 1)
return self.resnet(x)


class WAV2VEC2(nn.Module):
def __init__(
self,
num_classes: int = 2,
config: Union[str, dict] = torchaudio.pipelines.WAV2VEC2_BASE._params,
path: str = f'https://download.pytorch.org/torchaudio/models/{torchaudio.pipelines.WAV2VEC2_BASE._path}',
pretrained: bool = True,
freeze_encoder: bool = False
):
super(WAV2VEC2, self).__init__()
if isinstance(config, str):
config = load_config(config)
config['aux_num_out'] = config.get('aux_num_out', None)
embedding_dim = config['encoder_embed_dim']

self.freeze_encoder = freeze_encoder
self.wav2vec = wav2vec2_model(**config)
if pretrained:
# Load a pre-trained WAV2VEC2
self.wav2vec.load_state_dict(torch.hub.load_state_dict_from_url(path))
self.fc = nn.Linear(in_features=embedding_dim, out_features=num_classes)

def forward(self, x):
x = self.extract_features(x)
return self.fc(x)

def extract_features(self, x):
# this is separated from forward to allow feature extraction.
# sometimes for a batch of samples our raw input is [batch, 1, time]
if len(x.shape) > 2:
x = torch.squeeze(x, dim=1)
x = self.wav2vec.extract_features(x)[0]
# mean pooling over the layers: [layers, batch, time, features] -> [batch, time, features]
x = torch.stack(x, dim=0).mean(dim=0)
danielle-hausler marked this conversation as resolved.
Show resolved Hide resolved
# mean pooling over the time: [batch, time, features] -> [batch, features]
return x.mean(dim=1)


def freeze_layers(self, ):
# to avoid overfitting the feature extractor is frozen
self.wav2vec.feature_extractor.requires_grad_(False)
# it is possible to freeze the encoder as well
# note that extract_features is using the encoder
if self.freeze_encoder:
for param in self.wav2vec.encoder.parameters():
param.requires_grad = False
40 changes: 40 additions & 0 deletions soundbay/utils/files_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import json
import yaml
import urllib

def load_config(filepath: str):
assert filepath.endswith(("json", "yaml", "yml")), "Only json and yaml files are supported."

is_url = filepath.startswith("http")
is_json = filepath.endswith("json")
is_yaml = filepath.endswith(("yaml", "yml"))

if is_url and is_json:
return load_json_from_url(filepath)
elif is_url and is_yaml:
return load_yaml_from_url(filepath)
elif is_json:
return load_json(filepath)
elif is_yaml:
return load_yaml(filepath)
else:
raise ValueError("File format not supported.")


def load_json_from_url(url: str):
with urllib.request.urlopen(url) as url:
return json.load(url)

def load_yaml_from_url(url: str):
with urllib.request.urlopen(url) as url:
return yaml.safe_load(url)

def load_json(filepath: str):
with open(filepath, "r") as file:
return json.load(file)

def load_yaml(filepath: str):
with open(filepath, "r") as file:
return yaml.safe_load(file)


29 changes: 16 additions & 13 deletions soundbay/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,22 +172,25 @@ def upload_artifacts(self, audio: torch.Tensor, label: torch.Tensor, raw_wav: to
artifact_wav = torch.squeeze(raw_wav).detach().cpu().numpy()
artifact_wav = artifact_wav / np.expand_dims(np.abs(artifact_wav).max(axis=1) + 1e-8, 1) * 0.5 # gain -6dB
list_of_wavs_objects = [wandb.Audio(data_or_path=wav, caption=f'{flag}_label{lab}_i{ind}_{round(b_t/data_sample_rate,2)}sec_{f_n}', sample_rate=sample_rate) for wav, ind, lab, b_t, f_n in zip(artifact_wav,idx, label, meta['begin_time'], meta['org_file'])]

# Spectrograms batch
artifact_spec = torch.squeeze(audio).detach().cpu().numpy()
specs = []
for artifact_id in range(artifact_spec.shape[0]):
ax = plt.subplots(nrows=1, ncols=1)
specs.append(librosa.display.specshow(artifact_spec[artifact_id,...], ax=ax[1]))
plt.close('all')
del ax
list_of_specs_objects = [wandb.Image(data_or_path=spec, caption=f'{flag}_label{lab}_i{ind}_{round(b_t/data_sample_rate,2)}sec_{f_n}') for spec, ind, lab, b_t, f_n in zip(specs,idx, label, meta['begin_time'], meta['org_file'])]
log_wavs = {f'First batch {flag} original wavs': list_of_wavs_objects}
log_specs = {f'First batch {flag} augmented spectrogram\'s': list_of_specs_objects}

# Upload to W&B
# Spectrograms batch
if audio.dim() >= 4: # In case that spectrogram preprocessing was not applied the dimension is 3.
artifact_spec = torch.squeeze(audio).detach().cpu().numpy()
specs = []
for artifact_id in range(artifact_spec.shape[0]):
ax = plt.subplots(nrows=1, ncols=1)
specs.append(librosa.display.specshow(artifact_spec[artifact_id,...], ax=ax[1]))
plt.close('all')
del ax
list_of_specs_objects = [wandb.Image(data_or_path=spec, caption=f'{flag}_label{lab}_i{ind}_{round(b_t/data_sample_rate,2)}sec_{f_n}') for spec, ind, lab, b_t, f_n in zip(specs,idx, label, meta['begin_time'], meta['org_file'])]
log_specs = {f'First batch {flag} augmented spectrogram\'s': list_of_specs_objects}
# Upload spectrograms to W&B
wandb.log(log_specs, commit=False)

# Upload WAVs to W&B
wandb.log(log_wavs, commit=False)
wandb.log(log_specs, commit=False)


@staticmethod
def get_metrics_dict(label_list: Union[list, np.ndarray], pred_list: Union[list, np.ndarray],
Expand Down