Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multiple updates #153

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions padertorch/contrib/je/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from math import ceil
from pathlib import Path
from typing import Callable
from functools import partial

import dataclasses
import numpy as np
Expand Down Expand Up @@ -33,6 +34,7 @@ class AudioReader:
storage_dir: str = None
preemphasis_factor: float = 0.
alignment_keys: list = None
cutoff_length: int = None

def __post_init__(self):
self.norm = None
Expand Down Expand Up @@ -66,10 +68,14 @@ def _load_source(self, filepath, start_sample=0, stop_sample=None):

def load(self, filepath, start_sample=0, stop_sample=None):
x, sr = self._load_source(filepath, start_sample, stop_sample)
# print(start_sample, stop_sample, len(x[0]))
if self.target_sample_rate != sr:
x = samplerate.resample(
x.T, self.target_sample_rate / sr, "sinc_fastest"
).T
if self.cutoff_length is not None:
x = x[..., :int(self.cutoff_length*self.target_sample_rate)]
# print(x.shape)
return x

def _prenormalize(self, audio):
Expand Down Expand Up @@ -168,15 +174,8 @@ def add_start_stop_samples(self, example):
if self.alignment_keys is not None:
for ali_key in self.alignment_keys:
if f'{ali_key}_start_times' in example or f'{ali_key}_stop_times' in example:
assert f'{ali_key}_start_times' in example and f'{ali_key}_stop_times' in example, example.keys()
example[f'{ali_key}_start_samples'] = [
int(self.target_sample_rate*t)
for t in example[f'{ali_key}_start_times']
]
example[f'{ali_key}_stop_samples'] = [
int(self.target_sample_rate*t)
for t in example[f'{ali_key}_stop_times']
]
example[f'{ali_key}_start_samples'] = (np.asanyarray(example[f'{ali_key}_start_times'])*self.target_sample_rate).astype(int)
example[f'{ali_key}_stop_samples'] = (np.asanyarray(example[f'{ali_key}_stop_times'])*self.target_sample_rate).astype(int)
return example

def __call__(self, example):
Expand Down Expand Up @@ -518,7 +517,7 @@ def stack(self, batch):

@dataclasses.dataclass
class ConcatenateArrays:
axis: int
axis: int = -1

def __call__(self, example):
if isinstance(example, dict):
Expand Down Expand Up @@ -569,14 +568,30 @@ class Collate:
[1., 1.],
[1., 1.]]), 'b': ['0', '1']}
"""
leaf_op: callable = StackArrays()
leaf_op: dict = "stack"

def __call__(self, example):
example = nested_op(self.collate, *example, sequence_type=())
if isinstance(self.leaf_op, dict):
example = nested_op(lambda *x: list(x), *example, sequence_type=())
for key, leaf_op in self.leaf_op.items():
if leaf_op is None:
continue
leaf_op = self._get_leaf_op(leaf_op)
example[key] = nested_op(leaf_op, example[key], sequence_type=())
else:
leaf_op = partial(self._collate, leaf_op=self._get_leaf_op(self.leaf_op))
example = nested_op(leaf_op, *example, sequence_type=())
return example

def collate(self, *batch):
def _get_leaf_op(self, leaf_op):
if leaf_op == "stack":
leaf_op = StackArrays()
elif leaf_op == "concat" or leaf_op == "concatenate":
leaf_op = ConcatenateArrays()
return leaf_op

def _collate(self, *batch, leaf_op):
batch = list(batch)
if self.leaf_op is not None:
batch = self.leaf_op(batch)
if leaf_op is not None:
batch = leaf_op(batch)
return batch
153 changes: 119 additions & 34 deletions padertorch/contrib/je/modules/augment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from math import ceil
import torch
import torch.nn.functional as F
from padertorch.contrib.je.modules.conv import Pad
Expand Down Expand Up @@ -125,7 +126,7 @@ def forward(self, *tensors, seq_len):
class Superpose(nn.Module):
"""
>>> x = torch.cumsum(torch.ones((3, 4, 5)), 0)
>>> y = torch.eye(3).float()
>>> y = np.array([[1,0,0],[1,1,1],[1,2,2]])
>>> superpose = Superpose(p=1.)
>>> superpose(x, seq_len=[3,4,5], targets=y)
"""
Expand All @@ -134,51 +135,59 @@ def __init__(self, p, scale_fn=None):
self.p = p
self.scale_fn = scale_fn

def forward(self, x, targets, seq_len=None):
def forward(self, x, seq_len=None, labels=None):
if self.training:
B = x.shape[0]
shuffle_idx = np.roll(np.arange(B, dtype=np.int), np.random.choice(B-1)+1)
lambda_ = np.random.binomial(1, self.p, B)
if seq_len is not None:
seq_len = np.maximum(seq_len, (lambda_ > 0.) * np.array(seq_len)[shuffle_idx])
lambda_ = torch.from_numpy(lambda_).float().to(x.device)
assert all(lambda_ >= 0.) and all(lambda_ <= 1.)
lambda_x = lambda_[(...,) + (x.dim() - 1) * (None,)]
lambda_ = torch.tensor(lambda_, dtype=torch.bool, device=x.device)
lambda_x = lambda_[(...,) + (x.dim() - 1) * (None,)].float()
x_shuffled = x[shuffle_idx]
if self.scale_fn is not None:
x = self.scale_fn(x)
x_shuffled = self.scale_fn(x_shuffled)
x = x + lambda_x * x_shuffled
if isinstance(targets, (list, tuple)):
targets = list(targets)
for i in range(len(targets)):
lambda_t = lambda_[(...,) + (targets[i].dim() - 1) * (None,)]
targets[i] = targets[i] + lambda_t * (1 - targets[i]) * targets[i][shuffle_idx]
elif targets is not None:
lambda_t = lambda_[(...,) + (targets.dim() - 1) * (None,)]
targets = targets + lambda_t * (1 - targets) * targets[shuffle_idx]
return x, targets, seq_len
# ToDo: fix for sparse targets
if isinstance(labels, (list, tuple)):
raise NotImplementedError
# targets = list(targets)
# for i in range(len(targets)):
# lambda_t = lambda_[(...,) + (targets[i].dim() - 1) * (None,)]
# targets[i] = targets[i] | (lambda_t & targets[i][shuffle_idx])
elif labels is not None:
raise NotImplementedError
# lambda_t = lambda_[(...,) + (targets.dim() - 1) * (None,)]
# targets = targets | (lambda_t & targets[shuffle_idx])
return x, seq_len, labels


class Mixup(nn.Module):
"""
>>> x = torch.cumsum(torch.ones((3, 4, 5)), 0)
>>> y = torch.eye(3).float()
>>> mixup = Mixup(p=1.)
>>> mixup(x, seq_len=[3,4,5], targets=y)
>>> y_sparse = torch.sparse_coo_tensor([[0,1,2],[0,1,2]],[1,1,1],(3,3)).float()
>>> mixup = Mixup(p=1., target_threshold=0.3)
>>> mixup.roll_targets(y,1)
>>> mixup.roll_targets(y_sparse,1).to_dense()
>>> mixup(x, seq_len=[3,4,5], targets=y_sparse)
"""
def __init__(self, p, alpha=1.):
def __init__(self, p, alpha=2., beta=1., target_threshold=None):
super().__init__()
self.p = p
self.alpha = alpha
self.beta = beta
self.target_threshold = target_threshold

def forward(self, x, targets=None, seq_len=None):
def forward(self, x, seq_len=None, targets=None):
if self.training:
B = x.shape[0]
shuffle_idx = np.roll(np.arange(B, dtype=np.int), np.random.choice(B))
shift = 1+np.random.choice(B-1)
shuffle_idx = np.roll(np.arange(B, dtype=np.int), shift)
lambda_ = np.maximum(
np.random.binomial(1, 1 - self.p, B),
np.random.beta(self.alpha, self.alpha, B)
np.random.beta(self.alpha, self.beta, B),
)
if seq_len is not None:
seq_len = np.maximum(seq_len, (lambda_ < 1.) * np.array(seq_len)[shuffle_idx])
Expand All @@ -189,29 +198,103 @@ def forward(self, x, targets=None, seq_len=None):
if isinstance(targets, (list, tuple)):
targets = list(targets)
for i in range(len(targets)):
targets[i] = targets[i].float()
rolled_targets = self.roll_targets(targets[i], shift)
lambda_t = lambda_[(...,) + (targets[i].dim() - 1) * (None,)]
targets[i] = lambda_t * targets[i] + (1. - lambda_t) * targets[i][shuffle_idx]
targets[i] = lambda_t * targets[i] + (1. - lambda_t) * rolled_targets
if self.target_threshold is not None:
targets[i] = self.threshold_targets(targets[i])
elif targets is not None:
targets = targets.float()
rolled_targets = self.roll_targets(targets, shift)
lambda_t = lambda_[(...,) + (targets.dim() - 1) * (None,)]
targets = lambda_t * targets + (1. - lambda_t) * targets[shuffle_idx]
return x, targets, seq_len
targets = lambda_t * targets + (1. - lambda_t) * rolled_targets
if self.target_threshold is not None:
targets = self.threshold_targets(targets)
return x, seq_len, targets

@staticmethod
def roll_targets(targets, shift):
B = targets.shape[0]
if targets.is_sparse:
targets = targets.coalesce()
targets_indices = targets.indices()
targets_indices[0] = (targets_indices[0]+shift)%B
return torch.sparse_coo_tensor(
indices=targets_indices,
values=targets.values(),
size=targets.shape,
device=targets.device
)
return torch.roll(targets, shift, 0)

def threshold_targets(self, targets):
if targets.is_sparse:
targets = targets.coalesce()
return torch.sparse_coo_tensor(
indices=targets.indices()[..., targets.values() > self.target_threshold],
values=torch.ones_like(targets.values()[targets.values() > self.target_threshold]),
size=targets.shape,
device=targets.device
)
return targets > self.target_threshold


class MixBack(nn.Module):
"""
>>> mixback = MixBack(.5)
>>> x1 = torch.ones((4,6,10))*torch.tensor([[1],[-1],[1],[-1],[1],[-1]])
>>> mixback(x1)
>>> x2 = torch.ones((3,6,9))*torch.tensor([[1],[-1],[1],[-1],[1],[-1]])
>>> mixback(x2)
"""
def __init__(self, max_mixback_scale, buffer_size=1, norm_axes=(-2,-1)):
super().__init__()
self.max_mixback_scale = max_mixback_scale
self.buffer_size = buffer_size
self.norm_axes = norm_axes
self._buffer = []

def reset(self):
self._buffer = []

def forward(self, x_input):
if self.training:
self._buffer = [x_input.detach().cpu()] + self._buffer
if len(self._buffer) > self.buffer_size:
b_in, *_, t_in = x_input.shape
x_mixback = self._buffer[-1][:b_in, ..., :t_in].to(x_input.device)
b_mix, *_, t_mix = x_mixback.shape
if b_mix < b_in or t_mix < t_in:
reps = len(x_mixback.shape) * [1]
reps[0] = ceil(b_in/b_mix)
reps[-1] = ceil(t_in/t_mix)
x_mixback = x_mixback.repeat(*reps)[:b_in, ..., :t_in]
x_mixback = (x_mixback - x_mixback.mean(self.norm_axes, keepdim=True)) / (x_mixback.std(self.norm_axes, keepdim=True)+1e-3)
scale = self.max_mixback_scale * torch.rand((b_in,) + (x_input.dim()-1)*(1,), device=x_input.device)
x_input = x_input + scale * x_mixback
self._buffer = self._buffer[:self.buffer_size]
return x_input


class Mask(nn.Module):
"""
>>> x = torch.ones((3, 4, 5))
>>> x = Mask(axis=-1, max_masked_rate=1., max_masked_steps=10)(x, seq_len=[1,2,3])
>>> x, _ = Mask(axis=-1, max_masked_rate=1., max_masked_steps=10, n_masks=2)(x, seq_len=[1,2,3])
>>> x.shape
"""
def __init__(self, axis, n_masks=1, max_masked_steps=None, max_masked_rate=1.):
def __init__(self, axis, n_masks=1, max_masked_steps=None, max_masked_rate=1., min_masked_steps=0, min_masked_rate=0.):
super().__init__()
self.axis = axis
self.n_masks = n_masks
self.max_masked_values = max_masked_steps
self.max_masked_rate = max_masked_rate
self.min_masked_values = min_masked_steps
self.min_masked_rate = min_masked_rate

def __call__(self, x, seq_len=None):
def __call__(self, x, seq_len=None, rng=None):
if not self.training:
return x
return x, torch.ones_like(x)
mask = torch.ones_like(x)
idx = torch.arange(x.shape[self.axis]).float()
axis = self.axis
Expand All @@ -223,19 +306,21 @@ def __call__(self, x, seq_len=None):
seq_len = x.shape[axis] * torch.ones(x.shape[0])
else:
seq_len = torch.Tensor(seq_len)
max_width = self.max_masked_rate/self.n_masks * seq_len
max_width = self.max_masked_rate * seq_len
if self.max_masked_values is not None:
max_width = torch.min(self.max_masked_values*torch.ones_like(max_width)/self.n_masks, max_width)
max_width = torch.min(self.max_masked_values*torch.ones_like(max_width), max_width)
max_width = torch.floor(max_width)
min_width = torch.min(self.min_masked_rate * seq_len, self.min_masked_values * torch.ones_like(max_width))
min_width = torch.floor(min_width)
width = min_width + torch.rand(x.shape[0], generator=rng) * (max_width - min_width + 1)
width = torch.floor(width/self.n_masks)
for i in range(self.n_masks):
width = torch.floor(torch.rand(x.shape[0]) * (max_width + 1))
max_onset = seq_len - width
onset = torch.floor(torch.rand(x.shape[0]) * (max_onset + 1))
width = width[(...,) + (x.dim()-1)*(None,)]
onset = torch.floor(torch.rand(x.shape[0], generator=rng) * (max_onset + 1))
onset = onset[(...,) + (x.dim()-1)*(None,)]
offset = onset + width
offset = onset + width[(...,) + (x.dim()-1)*(None,)]
mask = mask * ((idx < onset) + (idx >= offset)).float().to(x.device)
return x * mask
return x * mask, mask


class AdditiveNoise(nn.Module):
Expand Down
Loading