diff --git a/amt/assets/mel_filters.npz b/amt/assets/mel_filters.npz new file mode 100644 index 0000000..28ea269 Binary files /dev/null and b/amt/assets/mel_filters.npz differ diff --git a/amt/audio.py b/amt/audio.py new file mode 100644 index 0000000..f9bc27e --- /dev/null +++ b/amt/audio.py @@ -0,0 +1,178 @@ +"""Contains code taken from https://github.com/openai/whisper""" + +import os +import torch +import numpy as np +import torch.nn.functional as F + +from functools import lru_cache +from subprocess import CalledProcessError, run +from typing import Optional, Union + +from amt.config import load_config + +# hard-coded audio hyperparameters +config = load_config()["audio"] +SAMPLE_RATE = config["sample_rate"] +N_FFT = config["n_fft"] +HOP_LENGTH = config["hop_len"] +CHUNK_LENGTH = config["chunk_len"] +N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk +N_FRAMES = N_SAMPLES // HOP_LENGTH # 3000 frames in a mel spectrogram input +N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 +FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame +TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token + + +def load_audio(file: str, sr: int = SAMPLE_RATE): + """ + Open an audio file and read as mono waveform, resampling as necessary + + Parameters + ---------- + file: str + The audio file to open + + sr: int + The sample rate to resample the audio if necessary + + Returns + ------- + A NumPy array containing the audio waveform, in float32 dtype. + """ + + # This launches a subprocess to decode audio while down-mixing + # and resampling as necessary. Requires the ffmpeg CLI in PATH. + # fmt: off + cmd = [ + "ffmpeg", + "-nostdin", + "-threads", "0", + "-i", file, + "-f", "s16le", + "-ac", "1", + "-acodec", "pcm_s16le", + "-ar", str(sr), + "-" + ] + + # chat-gpt says that this will work for reading mp3 ?? not tested + # cmd = [ + # "ffmpeg", + # "-nostdin", + # "-threads", "0", + # "-i", file, + # "-ac", "1", + # "-ar", str(sr), + # "-" + # ] + + # fmt: on + try: + out = run(cmd, capture_output=True, check=True).stdout + except CalledProcessError as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + + return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 + + +def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): + """ + Pad or trim the audio array to N_SAMPLES, as expected by the encoder. + """ + if torch.is_tensor(array): + if array.shape[axis] > length: + array = array.index_select( + dim=axis, index=torch.arange(length, device=array.device) + ) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = F.pad( + array, [pad for sizes in pad_widths[::-1] for pad in sizes] + ) + else: + if array.shape[axis] > length: + array = array.take(indices=range(length), axis=axis) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = np.pad(array, pad_widths) + + return array + + +@lru_cache(maxsize=None) +def mel_filters(device, n_mels: int) -> torch.Tensor: + """ + load the mel filterbank matrix for projecting STFT into a Mel spectrogram. + Allows decoupling librosa dependency; saved using: + + np.savez_compressed( + "mel_filters.npz", + mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), + mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), + ) + """ + assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" + + filters_path = os.path.join( + os.path.dirname(__file__), "assets", "mel_filters.npz" + ) + with np.load(filters_path, allow_pickle=False) as f: + return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) + + +def log_mel_spectrogram( + audio: Union[str, np.ndarray, torch.Tensor], + n_mels: int = 80, + padding: int = 0, + device: Optional[Union[str, torch.device]] = None, +): + """ + Compute the log-Mel spectrogram of + + Parameters + ---------- + audio: Union[str, np.ndarray, torch.Tensor], shape = (*) + The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz + + n_mels: int + The number of Mel-frequency filters, only 80 is supported + + padding: int + Number of zero samples to pad to the right + + device: Optional[Union[str, torch.device]] + If given, the audio tensor is moved to this device before STFT + + Returns + ------- + torch.Tensor, shape = (80, n_frames) + A Tensor that contains the Mel spectrogram + """ + if not torch.is_tensor(audio): + if isinstance(audio, str): + audio = load_audio(audio) + audio = torch.from_numpy(audio) + + if device is not None: + audio = audio.to(device) + if padding > 0: + audio = F.pad(audio, (0, padding)) + window = torch.hann_window(N_FFT).to(audio.device) + stft = torch.stft( + audio, N_FFT, HOP_LENGTH, window=window, return_complex=True + ) + magnitudes = stft[..., :-1].abs() ** 2 + + filters = mel_filters(audio.device, n_mels) + mel_spec = filters @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + + return log_spec diff --git a/amt/data.py b/amt/data.py index e69de29..648fdcd 100644 --- a/amt/data.py +++ b/amt/data.py @@ -0,0 +1,140 @@ +import mmap +import os +import json +import jsonlines +import torch + +from typing import Callable +from multiprocessing import Pool + +from aria.data.midi import MidiDict +from amt.tokenizer import AmtTokenizer +from amt.config import load_config +from amt.audio import ( + log_mel_spectrogram, + pad_or_trim, + N_FRAMES, +) + +config = load_config()["data"] +STRIDE_FACTOR = config["stride_factor"] + + +def get_features(audio_path: str, mid_path: str): + """This function yields tuples of matched log mel spectrograms and + tokenized sequences (np.array, list). + """ + tokenizer = AmtTokenizer() + + if not os.path.isfile(audio_path) or not os.path.isfile(mid_path): + return None + + try: + midi_dict = MidiDict.from_midi(mid_path) + log_spec = log_mel_spectrogram(audio=audio_path) + except Exception as e: + print("Failed to convert files into features") + return None + + _, total_frames = log_spec.shape + res = [] + for start_frame in range(0, total_frames, N_FRAMES // STRIDE_FACTOR): + audio_feature = pad_or_trim(log_spec[:, start_frame:], length=N_FRAMES) + mid_feature = tokenizer._tokenize_midi_dict( + midi_dict=midi_dict, + start_ms=start_frame * 10, + end_ms=(start_frame + N_FRAMES) * 10, + ) + res.append((audio_feature, mid_feature)) + + return res + + +def get_features_mp(args): + """Multiprocessing wrapper for get_features""" + res = get_features(*args) + if res is None: + return False, None + else: + return True, res + + +class AmtDataset(torch.utils.data.Dataset): + def __init__(self, load_path: str): + self.tokenizer = AmtTokenizer(return_tensors=True) + self.aug_fn = self.tokenizer.export_msg_mixup() + self.file_buff = open(load_path, mode="r") + self.file_mmap = mmap.mmap( + self.file_buff.fileno(), 0, access=mmap.ACCESS_READ + ) + self.index = self._build_index() + + def close(self): + if self.file_buff: + self.file_buff.close() + if self.file_mmap: + self.file_mmap.close() + + def __del__(self): + self.close() + + def __len__(self): + return len(self.index) + + def __getitem__(self, idx: int): + def _format(tok): + # This is required because json formats tuples into lists + if isinstance(tok, list): + return tuple(tok) + return tok + + self.file_mmap.seek(self.index[idx]) + + # This isn't going to load properly + spec, seq = json.loads(self.file_mmap.readline()) # Load data from line + + spec = torch.tensor(spec) # Format spectrogram into tensor + seq = [_format(tok) for tok in seq] # Format seq + seq = self.aug_fn(seq) # Data augmentation + + src = seq + tgt = seq[1:] + [self.tokenizer.pad_tok] + + return spec, self.tokenizer.encode(src), self.tokenizer.encode(tgt) + + def _build_index(self): + self.file_mmap.seek(0) + index = [] + while True: + pos = self.file_mmap.tell() + line_buffer = self.file_mmap.readline() + if line_buffer == b"": + break + else: + index.append(pos) + + return index + + @classmethod + def build( + cls, + matched_load_paths: list[tuple[str, str]], + save_path: str, + audio_aug_hook: Callable | None = None, + ): + def _get_features(_matched_load_paths: list): + with Pool(4) as pool: + results = pool.imap(get_features_mp, _matched_load_paths) + num_paths = len(_matched_load_paths) + for idx, (success, res) in enumerate(results): + if idx % 50 == 0 and idx != 0: + print(f"Processed audio-mid pairs: {idx}/{num_paths}") + + if success == False: + continue + for _audio_feature, _mid_feature in res: + yield _audio_feature.tolist(), _mid_feature + + with jsonlines.open(save_path, mode="w") as writer: + for audio_feature, mid_feature in _get_features(matched_load_paths): + writer.write([audio_feature, mid_feature]) diff --git a/amt/tokenizer.py b/amt/tokenizer.py index dcc8671..1af661c 100644 --- a/amt/tokenizer.py +++ b/amt/tokenizer.py @@ -115,7 +115,7 @@ def _tokenize_midi_dict( if note_end_ms <= start_ms or note_start_ms >= end_ms: # Skip continue elif ( - note_start_ms <= start_ms and _pitch not in prev_notes + note_start_ms < start_ms and _pitch not in prev_notes ): # Add to prev notes prev_notes.append(_pitch) if note_end_ms < end_ms: @@ -123,9 +123,16 @@ def _tokenize_midi_dict( ("off", _pitch, rel_note_end_ms_q, None) ) else: # Add to on_off_msgs - # Skip notes with no duration + # Skip notes with no duration or duplicate notes if rel_note_start_ms_q == rel_note_end_ms_q: continue + elif ( + "on", + _pitch, + rel_note_start_ms_q, + velocity_q, + ) in on_off_notes: + continue on_off_notes.append( ("on", _pitch, rel_note_start_ms_q, velocity_q) @@ -190,7 +197,7 @@ def _detokenize_midi_dict(self, tokenized_seq: list, len_ms: int): for tok_1, tok_2, tok_3 in zip( tokenized_seq[:], tokenized_seq[1:], - tokenized_seq[2:], + tokenized_seq[2:] + [(None, None)], ): tok_1_type, tok_1_data = tok_1 tok_2_type, tok_2_data = tok_2 @@ -210,7 +217,7 @@ def _detokenize_midi_dict(self, tokenized_seq: list, len_ms: int): # Process note and add to note msgs note_to_close = notes_to_close.pop(tok_1_data, None) if note_to_close is None: - print("No 'on' token corresponding to 'off' token") + print(f"No 'on' token corresponding to 'off' token") continue else: _pitch = tok_1_data @@ -267,6 +274,7 @@ def export_msg_mixup(self): def msg_mixup(src: list): # Reorder prev tokens res = [] + idx = 0 for idx, tok in enumerate(src): tok_type, tok_data = tok if tok_type != "prev": @@ -279,7 +287,7 @@ def msg_mixup(src: list): for tok_1, tok_2, tok_3 in zip( src[idx:], src[idx + 1 :], - src[idx + 2 :], + src[idx + 2 :] + [(None, None)], ): tok_1_type, tok_1_data = tok_1 tok_2_type, tok_2_data = tok_2 diff --git a/config/config.json b/config/config.json index 25742d7..eec37f6 100644 --- a/config/config.json +++ b/config/config.json @@ -8,5 +8,14 @@ "num_steps": 3000, "step": 10 } + }, + "audio": { + "sample_rate": 16000, + "n_fft": 400, + "hop_len": 160, + "chunk_len": 30 + }, + "data": { + "stride_factor": 1 } -} \ No newline at end of file +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 9fcef77..571a4d0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,4 @@ torchaudio accelerate mido tqdm -jsonlines \ No newline at end of file +jsonlines diff --git a/tests/test_data.py b/tests/test_data.py new file mode 100644 index 0000000..cdfa39f --- /dev/null +++ b/tests/test_data.py @@ -0,0 +1,45 @@ +import unittest +import logging +import os + +from amt.data import get_features, AmtDataset +from amt.tokenizer import AmtTokenizer +from aria.data.midi import MidiDict + + +logging.basicConfig(level=logging.INFO) +if os.path.isdir("tests/test_results") is False: + os.mkdir("tests/test_results") + + +# Need to test this properly, have issues turning mel_spec back into audio +class TestDataGen(unittest.TestCase): + def test_feature_gen(self): + for log_spec, seq in get_features( + audio_path="tests/test_data/147.wav", + mid_path="tests/test_data/147.mid", + ): + print(log_spec.shape, len(seq)) + + +class TestAmtDataset(unittest.TestCase): + def test_build(self): + matched_paths = [("tests/test_data/147.wav", "tests/test_data/147.mid")] + AmtDataset.build( + matched_load_paths=matched_paths, + save_path="tests/test_results/dataset.jsonl", + ) + + dataset = AmtDataset("tests/test_results/dataset.jsonl") + tokenizer = AmtTokenizer() + for idx, (spec, src, tgt) in enumerate(dataset): + print(spec.shape, src.shape, tgt.shape) + decoded = tokenizer.decode(src) + mid = tokenizer._detokenize_midi_dict( + decoded, len_ms=30000 + ).to_midi() + mid.save(f"tests/test_results/trunc_{idx}.mid") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_data/147.mid b/tests/test_data/147.mid new file mode 100644 index 0000000..8cecba3 Binary files /dev/null and b/tests/test_data/147.mid differ diff --git a/tests/test_data/147.mp3 b/tests/test_data/147.mp3 new file mode 100644 index 0000000..4fd9489 Binary files /dev/null and b/tests/test_data/147.mp3 differ diff --git a/tests/test_data/147.wav b/tests/test_data/147.wav new file mode 100644 index 0000000..fd566a0 Binary files /dev/null and b/tests/test_data/147.wav differ diff --git a/tests/test_data/beethoven_sonata.mid b/tests/test_data/beethoven_sonata.mid deleted file mode 100644 index 923a4d5..0000000 Binary files a/tests/test_data/beethoven_sonata.mid and /dev/null differ diff --git a/tests/test_data/expressive.mid b/tests/test_data/expressive.mid deleted file mode 100644 index 40d9e84..0000000 Binary files a/tests/test_data/expressive.mid and /dev/null differ diff --git a/tests/test_data/pop.mid b/tests/test_data/pop.mid deleted file mode 100644 index be83c69..0000000 Binary files a/tests/test_data/pop.mid and /dev/null differ diff --git a/tests/test_data/pop_copy.mid b/tests/test_data/pop_copy.mid deleted file mode 100644 index be83c69..0000000 Binary files a/tests/test_data/pop_copy.mid and /dev/null differ diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index 02b6bb1..e46e263 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -5,6 +5,7 @@ from amt.tokenizer import AmtTokenizer from aria.data.midi import MidiDict +logging.basicConfig(level=logging.INFO) if os.path.isdir("tests/test_results") is False: os.mkdir("tests/test_results") @@ -34,40 +35,64 @@ def _tokenize_detokenize(mid_name: str): for msg in _midi_dict.note_msgs: logging.info(msg) - # _tokenize_detokenize(mid_name="arabesque.mid") - # _tokenize_detokenize(mid_name="bach.mid") - # _tokenize_detokenize(mid_name="beethoven_moonlight.mid") + _tokenize_detokenize(mid_name="basic.mid") + _tokenize_detokenize(mid_name="147.mid") + _tokenize_detokenize(mid_name="beethoven_moonlight.mid") def test_aug(self): - START = 5000 - END = 15000 + def aug(_midi_dict: MidiDict, _start_ms: int, _end_ms: int): + _tokenized_seq = tokenizer._tokenize_midi_dict( + midi_dict=_midi_dict, + start_ms=_start_ms, + end_ms=_end_ms, + ) + + aug_fn = tokenizer.export_msg_mixup() + _aug_tokenized_seq = aug_fn(_tokenized_seq) + self.assertEqual(len(_tokenized_seq), len(_aug_tokenized_seq)) + return _tokenized_seq, _aug_tokenized_seq + + DELTA_MS = 5000 tokenizer = AmtTokenizer() midi_dict = MidiDict.from_midi("tests/test_data/bach.mid") - tokenized_seq = tokenizer._tokenize_midi_dict( - midi_dict=midi_dict, - start_ms=START, - end_ms=END, - ) + __end_ms = midi_dict.note_msgs[-1]["data"]["end"] + + for idx, __start_ms in enumerate(range(0, __end_ms, DELTA_MS)): + tokenized_seq, aug_tokenized_seq = aug( + midi_dict, __start_ms, __start_ms + DELTA_MS + ) + + self.assertEqual( + len( + tokenizer._detokenize_midi_dict( + tokenized_seq, DELTA_MS + ).note_msgs + ), + len( + tokenizer._detokenize_midi_dict( + aug_tokenized_seq, DELTA_MS + ).note_msgs + ), + ) - aug_fn = tokenizer.export_msg_mixup() - aug_tokenized_seq = aug_fn(tokenized_seq) - logging.info(f"msg mixup: {tokenized_seq} \n -> {aug_tokenized_seq}") + if idx == 0: + logging.info( + f"msg mixup: {tokenized_seq} ->\n{aug_tokenized_seq}" + ) - _midi_dict = tokenizer._detokenize_midi_dict(tokenized_seq, END - START) - _mid = _midi_dict.to_midi() - _mid.save(f"tests/test_results/bach_orig.mid") + _midi_dict = tokenizer._detokenize_midi_dict( + tokenized_seq, DELTA_MS + ) + _mid = _midi_dict.to_midi() + _mid.save(f"tests/test_results/bach_orig.mid") - _midi_dict = tokenizer._detokenize_midi_dict( - aug_tokenized_seq, END - START - ) - _mid = _midi_dict.to_midi() - _mid.save(f"tests/test_results/bach_aug.mid") + _midi_dict = tokenizer._detokenize_midi_dict( + aug_tokenized_seq, DELTA_MS + ) + _mid = _midi_dict.to_midi() + _mid.save(f"tests/test_results/bach_aug.mid") if __name__ == "__main__": - if os.path.isdir("tests/test_results") is False: - os.mkdir("tests/test_results") - - logging.basicConfig(level=logging.INFO) unittest.main()