Skip to content

Commit

Permalink
Merge branch 'master' into ci-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
domkirke authored Dec 18, 2023
2 parents 3df70c9 + e9e00f9 commit 2a342af
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 21 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
include rave/configs/*.gin
include rave/configs/augmentations/*.gin
include requirements.txt
16 changes: 7 additions & 9 deletions rave/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import os
import pathlib
from pathlib import Path

import cached_conv as cc
import gin
import torch

from .version import __version__

gin.add_config_file_search_path(os.path.dirname(__file__))
gin.add_config_file_search_path(
os.path.join(
os.path.dirname(__file__),
'configs',
))
BASE_PATH: Path = Path(__file__).parent

gin.add_config_file_search_path(BASE_PATH)
gin.add_config_file_search_path(BASE_PATH.joinpath('configs'))
gin.add_config_file_search_path(BASE_PATH.joinpath('configs', 'augmentations'))


def __safe_configurable(name):
try:
Expand Down
20 changes: 11 additions & 9 deletions rave/descript_discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def WNConv2d(*args, **kwargs):

class MPD(nn.Module):

def __init__(self, period):
def __init__(self, period, n_channels: int = 1):
super().__init__()
self.period = period
self.convs = nn.ModuleList([
WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
WNConv2d(n_channels, 32, (5, 1), (3, 1), padding=(2, 0)),
WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
Expand Down Expand Up @@ -68,10 +68,10 @@ def forward(self, x):

class MSD(nn.Module):

def __init__(self, scale: int):
def __init__(self, scale: int, n_channels: int = 1):
super().__init__()
self.convs = nn.ModuleList([
WNConv1d(1, 16, 15, 1, padding=7),
WNConv1d(n_channels, 16, 15, 1, padding=7),
WNConv1d(16, 64, 41, 4, groups=4, padding=20),
WNConv1d(64, 256, 41, 4, groups=16, padding=20),
WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
Expand Down Expand Up @@ -123,6 +123,7 @@ def __init__(
hop_factor: float = 0.25,
sample_rate: int = 44100,
bands: list = BANDS,
n_channels: int = 1
):
super().__init__()

Expand All @@ -136,7 +137,7 @@ def __init__(

ch = 32
convs = lambda: nn.ModuleList([
WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
WNConv2d(2 * n_channels, ch, (3, 9), (1, 1), padding=(1, 4)),
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
Expand All @@ -160,7 +161,7 @@ def __init__(

def spectrogram(self, x):
x = torch.view_as_real(self.stft(x))
x = rearrange(x, "b 1 f t c -> (b 1) c t f")
x = rearrange(x, "b c f t p -> b (c p) t f")
# Split into bands
x_bands = [x[..., b[0]:b[1]] for b in self.bands]
return x_bands
Expand Down Expand Up @@ -192,13 +193,14 @@ def __init__(
fft_sizes: list = [2048, 1024, 512],
sample_rate: int = 44100,
bands: list = BANDS,
n_channels: int = 1,
):
super().__init__()
discs = []
discs += [MPD(p) for p in periods]
discs += [MSD(r, sample_rate=sample_rate) for r in rates]
discs += [MPD(p, n_channels=n_channels) for p in periods]
discs += [MSD(r, sample_rate=sample_rate, n_channels=n_channels) for r in rates]
discs += [
MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes
MRD(f, sample_rate=sample_rate, bands=bands, n_channels=n_channels) for f in fft_sizes
]
self.discriminators = nn.ModuleList(discs)

Expand Down
4 changes: 2 additions & 2 deletions rave/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def __call__(self, x: torch.Tensor):
gain_factor = np.random.rand(1)[None, None][0] * (self.gain_range[1] - self.gain_range[0]) + self.gain_range[0]
amp_factor = np.power(10, gain_factor / 20)
x_amp = x * amp_factor
if (self.limit) and (x_amp.abs().max() > 1):
x_amp = x_amp / x_amp.abs().max()
if (self.limit) and (np.abs(x_amp).max() > 1):
x_amp = x_amp / np.abs(x_amp).max()
return x
else:
return x
Expand Down
2 changes: 1 addition & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def main(argv):
)

# create model
model = rave.RAVE()
model = rave.RAVE(n_channels=FLAGS.channels)
if FLAGS.derivative:
model.integrator = rave.dataset.get_derivator_integrator(model.sr)[1]

Expand Down

0 comments on commit 2a342af

Please sign in to comment.