Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement dataset and update tests #5

Merged
merged 1 commit into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added amt/assets/mel_filters.npz
Binary file not shown.
178 changes: 178 additions & 0 deletions amt/audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""Contains code taken from https://github.com/openai/whisper"""

import os
import torch
import numpy as np
import torch.nn.functional as F

from functools import lru_cache
from subprocess import CalledProcessError, run
from typing import Optional, Union

from amt.config import load_config

# hard-coded audio hyperparameters
config = load_config()["audio"]
SAMPLE_RATE = config["sample_rate"]
N_FFT = config["n_fft"]
HOP_LENGTH = config["hop_len"]
CHUNK_LENGTH = config["chunk_len"]
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
N_FRAMES = N_SAMPLES // HOP_LENGTH # 3000 frames in a mel spectrogram input
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token


def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Open an audio file and read as mono waveform, resampling as necessary

Parameters
----------
file: str
The audio file to open

sr: int
The sample rate to resample the audio if necessary

Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""

# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# fmt: off
cmd = [
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(sr),
"-"
]

# chat-gpt says that this will work for reading mp3 ?? not tested
# cmd = [
# "ffmpeg",
# "-nostdin",
# "-threads", "0",
# "-i", file,
# "-ac", "1",
# "-ar", str(sr),
# "-"
# ]

# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e

return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0


def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
"""
if torch.is_tensor(array):
if array.shape[axis] > length:
array = array.index_select(
dim=axis, index=torch.arange(length, device=array.device)
)

if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = F.pad(
array, [pad for sizes in pad_widths[::-1] for pad in sizes]
)
else:
if array.shape[axis] > length:
array = array.take(indices=range(length), axis=axis)

if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)

return array


@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:

np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
)
"""
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"

filters_path = os.path.join(
os.path.dirname(__file__), "assets", "mel_filters.npz"
)
with np.load(filters_path, allow_pickle=False) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)


def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int = 80,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):
"""
Compute the log-Mel spectrogram of

Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz

n_mels: int
The number of Mel-frequency filters, only 80 is supported

padding: int
Number of zero samples to pad to the right

device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT

Returns
-------
torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)

if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(
audio, N_FFT, HOP_LENGTH, window=window, return_complex=True
)
magnitudes = stft[..., :-1].abs() ** 2

filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes

log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0

return log_spec
140 changes: 140 additions & 0 deletions amt/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import mmap
import os
import json
import jsonlines
import torch

from typing import Callable
from multiprocessing import Pool

from aria.data.midi import MidiDict
from amt.tokenizer import AmtTokenizer
from amt.config import load_config
from amt.audio import (
log_mel_spectrogram,
pad_or_trim,
N_FRAMES,
)

config = load_config()["data"]
STRIDE_FACTOR = config["stride_factor"]


def get_features(audio_path: str, mid_path: str):
"""This function yields tuples of matched log mel spectrograms and
tokenized sequences (np.array, list).
"""
tokenizer = AmtTokenizer()

if not os.path.isfile(audio_path) or not os.path.isfile(mid_path):
return None

try:
midi_dict = MidiDict.from_midi(mid_path)
log_spec = log_mel_spectrogram(audio=audio_path)
except Exception as e:
print("Failed to convert files into features")
return None

_, total_frames = log_spec.shape
res = []
for start_frame in range(0, total_frames, N_FRAMES // STRIDE_FACTOR):
audio_feature = pad_or_trim(log_spec[:, start_frame:], length=N_FRAMES)
mid_feature = tokenizer._tokenize_midi_dict(
midi_dict=midi_dict,
start_ms=start_frame * 10,
end_ms=(start_frame + N_FRAMES) * 10,
)
res.append((audio_feature, mid_feature))

return res


def get_features_mp(args):
"""Multiprocessing wrapper for get_features"""
res = get_features(*args)
if res is None:
return False, None
else:
return True, res


class AmtDataset(torch.utils.data.Dataset):
def __init__(self, load_path: str):
self.tokenizer = AmtTokenizer(return_tensors=True)
self.aug_fn = self.tokenizer.export_msg_mixup()
self.file_buff = open(load_path, mode="r")
self.file_mmap = mmap.mmap(
self.file_buff.fileno(), 0, access=mmap.ACCESS_READ
)
self.index = self._build_index()

def close(self):
if self.file_buff:
self.file_buff.close()
if self.file_mmap:
self.file_mmap.close()

def __del__(self):
self.close()

def __len__(self):
return len(self.index)

def __getitem__(self, idx: int):
def _format(tok):
# This is required because json formats tuples into lists
if isinstance(tok, list):
return tuple(tok)
return tok

self.file_mmap.seek(self.index[idx])

# This isn't going to load properly
spec, seq = json.loads(self.file_mmap.readline()) # Load data from line

spec = torch.tensor(spec) # Format spectrogram into tensor
seq = [_format(tok) for tok in seq] # Format seq
seq = self.aug_fn(seq) # Data augmentation

src = seq
tgt = seq[1:] + [self.tokenizer.pad_tok]

return spec, self.tokenizer.encode(src), self.tokenizer.encode(tgt)

def _build_index(self):
self.file_mmap.seek(0)
index = []
while True:
pos = self.file_mmap.tell()
line_buffer = self.file_mmap.readline()
if line_buffer == b"":
break
else:
index.append(pos)

return index

@classmethod
def build(
cls,
matched_load_paths: list[tuple[str, str]],
save_path: str,
audio_aug_hook: Callable | None = None,
):
def _get_features(_matched_load_paths: list):
with Pool(4) as pool:
results = pool.imap(get_features_mp, _matched_load_paths)
num_paths = len(_matched_load_paths)
for idx, (success, res) in enumerate(results):
if idx % 50 == 0 and idx != 0:
print(f"Processed audio-mid pairs: {idx}/{num_paths}")

if success == False:
continue
for _audio_feature, _mid_feature in res:
yield _audio_feature.tolist(), _mid_feature

with jsonlines.open(save_path, mode="w") as writer:
for audio_feature, mid_feature in _get_features(matched_load_paths):
writer.write([audio_feature, mid_feature])
18 changes: 13 additions & 5 deletions amt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,17 +115,24 @@ def _tokenize_midi_dict(
if note_end_ms <= start_ms or note_start_ms >= end_ms: # Skip
continue
elif (
note_start_ms <= start_ms and _pitch not in prev_notes
note_start_ms < start_ms and _pitch not in prev_notes
): # Add to prev notes
prev_notes.append(_pitch)
if note_end_ms < end_ms:
on_off_notes.append(
("off", _pitch, rel_note_end_ms_q, None)
)
else: # Add to on_off_msgs
# Skip notes with no duration
# Skip notes with no duration or duplicate notes
if rel_note_start_ms_q == rel_note_end_ms_q:
continue
elif (
"on",
_pitch,
rel_note_start_ms_q,
velocity_q,
) in on_off_notes:
continue

on_off_notes.append(
("on", _pitch, rel_note_start_ms_q, velocity_q)
Expand Down Expand Up @@ -190,7 +197,7 @@ def _detokenize_midi_dict(self, tokenized_seq: list, len_ms: int):
for tok_1, tok_2, tok_3 in zip(
tokenized_seq[:],
tokenized_seq[1:],
tokenized_seq[2:],
tokenized_seq[2:] + [(None, None)],
):
tok_1_type, tok_1_data = tok_1
tok_2_type, tok_2_data = tok_2
Expand All @@ -210,7 +217,7 @@ def _detokenize_midi_dict(self, tokenized_seq: list, len_ms: int):
# Process note and add to note msgs
note_to_close = notes_to_close.pop(tok_1_data, None)
if note_to_close is None:
print("No 'on' token corresponding to 'off' token")
print(f"No 'on' token corresponding to 'off' token")
continue
else:
_pitch = tok_1_data
Expand Down Expand Up @@ -267,6 +274,7 @@ def export_msg_mixup(self):
def msg_mixup(src: list):
# Reorder prev tokens
res = []
idx = 0
for idx, tok in enumerate(src):
tok_type, tok_data = tok
if tok_type != "prev":
Expand All @@ -279,7 +287,7 @@ def msg_mixup(src: list):
for tok_1, tok_2, tok_3 in zip(
src[idx:],
src[idx + 1 :],
src[idx + 2 :],
src[idx + 2 :] + [(None, None)],
):
tok_1_type, tok_1_data = tok_1
tok_2_type, tok_2_data = tok_2
Expand Down
11 changes: 10 additions & 1 deletion config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,14 @@
"num_steps": 3000,
"step": 10
}
},
"audio": {
"sample_rate": 16000,
"n_fft": 400,
"hop_len": 160,
"chunk_len": 30
},
"data": {
"stride_factor": 1
}
}
}
Loading
Loading