Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NAS KWS model, Dynamic Augmentation and Automated Evaluation Notebook #280

Merged
merged 19 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 165 additions & 96 deletions datasets/kws20.py

Large diffs are not rendered by default.

226 changes: 97 additions & 129 deletions datasets/msnoise.py

Large diffs are not rendered by default.

136 changes: 136 additions & 0 deletions datasets/signalmixer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#
# Copyright (c) 2018 Intel Corporation
# Portions Copyright (C) 2019-2023 Maxim Integrated Products, Inc.
# Portions Copyright (C) 2023-2024 Analog Devices, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Classes and functions used to create noisy keyword spotting dataset.
"""
import numpy as np
import torch


class signalmixer:
"""
Signal mixer dataloader to create datasets with specified
length using a noise dataset and a speech dataset and a specified SNR level.

Args:
signal_dataset(object): KWS dataset object.
snr(int): SNR level to be created in the mixed dataset.
noise_kind(string): Noise kind that will be applied to the speech dataset.
noise_dataset(object, optional): MSnoise dataset object.
"""

def __init__(self, signal_dataset, snr, noise_kind, noise_dataset=None):

self.signal_data = signal_dataset.data
self.signal_targets = signal_dataset.targets

if noise_kind != 'WhiteNoise':
self.noise_data = noise_dataset.data
self.noise_targets = noise_dataset.targets

# using getitem to reach the noise test data
self.noise_dataset_float = next(iter(torch.utils.data.DataLoader(
noise_dataset, batch_size=noise_dataset.dataset_len)))[0]

self.noise_rms = noise_dataset.rms

self.snr = snr
self.noise_kind = noise_kind

# using getitem to reach the speech test data
self.test_dataset_float = next(iter(torch.utils.data.DataLoader(
signal_dataset, batch_size=signal_dataset.data.shape[0])))[0]

if noise_kind == 'WhiteNoise':
self.mixed_signal = self.white_noise_mixer()
else:
self.mixed_signal = self.snr_mixer()

def __getitem__(self, index):

inp = self.mixed_signal[index].type(torch.FloatTensor)
target = int(self.signal_targets[index])
return inp, target

def __len__(self):
return len(self.mixed_signal)

def snr_mixer(self):
''' creates mixed signal dataset using the SNR level and the noise dataset
'''
clean = self.test_dataset_float
noise = self.noise_dataset_float

idx = np.random.randint(0, noise.shape[0], clean.shape[0])
noise = noise[idx]
rms_noise = self.noise_rms[idx]

snr = self.snr

rmsclean = torch.sqrt(torch.mean(clean.reshape(
clean.shape[0], -1)**2, 1, keepdims=True)).unsqueeze(1)
scalarclean = 1 / rmsclean
clean = clean * scalarclean

scalarnoise = 1 / rms_noise.reshape(-1, 1, 1)
noise = noise * scalarnoise

cleanfactor = 10**(snr/20)
noisyspeech = cleanfactor * clean + noise
noisyspeech = noisyspeech / (torch.tensor(scalarnoise) + cleanfactor * scalarclean)

# 16384 --> (noisyspeech[0].shape[0])*(noisyspeech[0].shape[1])
speech_shape = noisyspeech[0].shape[0]*noisyspeech[0].shape[1]
max_mixed = torch.max(abs(noisyspeech.reshape(
noisyspeech.shape[0], speech_shape)), 1, keepdims=True).values

noisyspeech = noisyspeech * (1 / max_mixed).unsqueeze(1)
return noisyspeech

def white_noise_mixer(self):

'''creates mixed signal dataset using the SNR level and white noise
'''
clean = self.test_dataset_float
snr = self.snr

mean = 0
std = 1
noise = np.random.normal(mean, std, clean.shape)
noise = torch.tensor(noise, dtype=torch.float32)

rmsclean = (torch.mean(clean.reshape(
clean.shape[0], -1)**2, 1, keepdims=True)**0.5).unsqueeze(1)
scalarclean = 1 / rmsclean
clean = clean * scalarclean

rmsnoise = (torch.mean(noise.reshape(
noise.shape[0], -1)**2, 1, keepdims=True)**0.5).unsqueeze(1)
scalarnoise = 1 / rmsnoise
noise = noise * scalarnoise

cleanfactor = 10**(snr/20)
noisyspeech = cleanfactor * clean + noise
noisyspeech = noisyspeech / (scalarnoise + cleanfactor * scalarclean)

# scaling to ~[-1,1]
max_mixed = torch.max(abs(noisyspeech.reshape(
noisyspeech.shape[0], 16384)), 1, keepdims=True).values
noisyspeech = noisyspeech * (1 / max_mixed).unsqueeze(1)

return noisyspeech
107 changes: 107 additions & 0 deletions models/ai85net-kws20-nas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
###################################################################################################
#
# Copyright (C) 2023-2024 Analog Devices, Inc. All Rights Reserved.
#
# Analog Devices, Inc. Default Copyright Notice:
# https://www.analog.com/en/about-adi/legal-and-risk-oversight/intellectual-property/copyright-notice.html
#
###################################################################################################
#
# Copyright (C) 2021-2023 Maxim Integrated Products, Inc. All Rights Reserved.
#
# Maxim Integrated Products, Inc. Default Copyright Notice:
# https://www.maximintegrated.com/en/aboutus/legal/copyrights.html
#
###################################################################################################
"""
Keyword spotting network for AI85
"""
from torch import nn

import ai8x


class AI85KWS20NetNAS(nn.Module):
"""
KWS20 NAS Audio net, found via Neural Architecture Search
It significantly outperforms earlier networks (v1, v2, v3), though with a higher
parameter count and slightly increased latency.
"""

# num_classes = n keywords + 1 unknown
def __init__(
self,
num_classes=21,
num_channels=128,
dimensions=(128, 1), # pylint: disable=unused-argument
bias=True,
**kwargs
):
super().__init__()
self.conv1_1 = ai8x.FusedConv1dBNReLU(num_channels, 128, 1, stride=1, padding=0,
bias=bias, batchnorm="NoAffine", **kwargs)
self.conv1_2 = ai8x.FusedConv1dBNReLU(128, 64, 3, stride=1, padding=1,
bias=bias, batchnorm="NoAffine", **kwargs)
self.conv1_3 = ai8x.FusedConv1dBNReLU(64, 128, 3, stride=1, padding=1,
bias=bias, batchnorm="NoAffine", **kwargs)
self.conv2_1 = ai8x.FusedMaxPoolConv1dBNReLU(128, 128, 3, stride=1, padding=1,
bias=bias, batchnorm="NoAffine", **kwargs)
self.conv2_2 = ai8x.FusedConv1dBNReLU(128, 64, 1, stride=1, padding=0,
bias=bias, batchnorm="NoAffine", **kwargs)
self.conv2_3 = ai8x.FusedConv1dBNReLU(64, 128, 1, stride=1, padding=0,
bias=bias, batchnorm="NoAffine", **kwargs)
self.conv3_1 = ai8x.FusedMaxPoolConv1dBNReLU(128, 128, 3, stride=1, padding=1,
bias=bias, batchnorm="NoAffine", **kwargs)
self.conv3_2 = ai8x.FusedConv1dBNReLU(128, 64, 5, stride=1, padding=2,
bias=bias, batchnorm="NoAffine", **kwargs)
self.conv4_1 = ai8x.FusedMaxPoolConv1dBNReLU(64, 128, 5, stride=1, padding=2,
bias=bias, batchnorm="NoAffine", **kwargs)
self.conv4_2 = ai8x.FusedConv1dBNReLU(128, 128, 1, stride=1, padding=0,
bias=bias, batchnorm="NoAffine", **kwargs)
self.conv5_1 = ai8x.FusedMaxPoolConv1dBNReLU(128, 128, 5, stride=1, padding=2,
bias=bias, batchnorm="NoAffine", **kwargs)
self.conv5_2 = ai8x.FusedConv1dBNReLU(128, 64, 3, stride=1, padding=1,
bias=bias, batchnorm="NoAffine", **kwargs)
self.conv6_1 = ai8x.FusedMaxPoolConv1dBNReLU(64, 64, 5, stride=1, padding=2,
bias=bias, batchnorm="NoAffine", **kwargs)
self.conv6_2 = ai8x.FusedConv1dBNReLU(64, 128, 1, stride=1, padding=0,
bias=bias, batchnorm="NoAffine", **kwargs)
self.fc = ai8x.Linear(512, num_classes, bias=bias, wide=True, **kwargs)

def forward(self, x): # pylint: disable=arguments-differ
"""Forward prop"""
# Run CNN
x = self.conv1_1(x)
x = self.conv1_2(x)
x = self.conv1_3(x)
x = self.conv2_1(x)
x = self.conv2_2(x)
x = self.conv2_3(x)
x = self.conv3_1(x)
x = self.conv3_2(x)
x = self.conv4_1(x)
x = self.conv4_2(x)
x = self.conv5_1(x)
x = self.conv5_2(x)
x = self.conv6_1(x)
x = self.conv6_2(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x


def ai85kws20netnas(pretrained=False, **kwargs):
"""
Constructs a AI85KWS20NetNAS model.
"""
assert not pretrained
return AI85KWS20NetNAS(**kwargs)


models = [
{
'name': 'ai85kws20netnas',
'min_input': 1,
'dim': 1,
},
]
6 changes: 4 additions & 2 deletions notebooks/Bayer2RGB_Evaluation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
"source": [
"###################################################################################################\n",
"#\n",
"# Copyright © 2023 Analog Devices, Inc. All Rights Reserved.\n",
"# This software is proprietary and confidential to Analog Devices, Inc. and its licensors.\n",
"# Copyright (C) 2023-2024 Analog Devices, Inc. All Rights Reserved.\n",
"#\n",
"# Analog Devices, Inc. Default Copyright Notice:\n",
"# https://www.analog.com/en/about-adi/legal-and-risk-oversight/intellectual-property/copyright-notice.html\n",
"#\n",
"###################################################################################################import cv2\n",
"import importlib\n",
Expand Down
Loading