Skip to content

Commit

Permalink
redid maestro tests to mock all data, removed midi and wav files from…
Browse files Browse the repository at this point in the history
… repo. Also removed metadata csv file as it wasn't being used in the code.
  • Loading branch information
bgenchel-avail committed Jul 26, 2024
1 parent 62634f9 commit f2f9a39
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 1,297 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ test = [
"coverage>=5.0.2",
"pytest>=6.1.1",
"pytest-mock",
"wave",
"mido"
]
tf = [
"tensorflow>=2.4.1,<2.15.1; platform_system != 'Darwin'",
Expand Down
74 changes: 60 additions & 14 deletions tests/data/test_maestro.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
import numpy as np
import os
import pathlib
from typing import List
import wave

from mido import MidiFile, MidiTrack, Message
from typing import List

import apache_beam as beam
from apache_beam.testing.test_pipeline import TestPipeline

Expand All @@ -37,8 +39,7 @@
TRAIN_TRACK_ID = "2004/MIDI-Unprocessed_SMF_05_R1_2004_01_ORIG_MID--AUDIO_05_R1_2004_03_Track03_wav"
VALID_TRACK_ID = "2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_06_Track06_wav"
TEST_TRACK_ID = "2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_08_Track08_wav"

MOCK_15M_TRACK_ID = "2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav"
GT_15M_TRACK_ID = "2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav"


def create_mock_wav(output_fpath: str, duration_min: int) -> None:
Expand All @@ -61,31 +62,75 @@ def create_mock_wav(output_fpath: str, duration_min: int) -> None:
logging.info(f"Mock {duration_min}-minute WAV file '{output_fpath}' created successfully.")


def create_mock_midi(output_fpath: str) -> None:
# Create a new MIDI file with one track
mid = MidiFile()
track = MidiTrack()
mid.tracks.append(track)

# Define a sequence of notes (time, type, note, velocity)
notes = [
(0, "note_on", 60, 64), # C4
(500, "note_off", 60, 64),
(0, "note_on", 62, 64), # D4
(500, "note_off", 62, 64),
]

# Add the notes to the track
for time, type, note, velocity in notes:
track.append(Message(type, note=note, velocity=velocity, time=time))

# Save the MIDI file
mid.save(output_fpath)

logging.info(f"Mock MIDI file '{output_fpath}' created successfully.")


def test_maestro_to_tf_example(tmpdir: str) -> None:
mock_maestro_home = pathlib.Path(tmpdir) / "maestro"
mock_maestro_ext = mock_maestro_home / "2004"
mock_maestro_ext.mkdir(parents=True, exist_ok=True)

create_mock_wav(str(mock_maestro_ext / (TRAIN_TRACK_ID.split("/")[1] + ".wav")), 3)
create_mock_midi(str(mock_maestro_ext / (TRAIN_TRACK_ID.split("/")[1] + ".midi")))

output_dir = pathlib.Path(tmpdir) / "outputs"
output_dir.mkdir(parents=True, exist_ok=True)

input_data: List[str] = [TRAIN_TRACK_ID]
with TestPipeline() as p:
(
p
| "Create PCollection of track IDs" >> beam.Create([input_data])
| "Create tf.Example" >> beam.ParDo(MaestroToTfExample(str(MAESTRO_TEST_DATA_PATH), download=False))
| "Write to tfrecord" >> beam.ParDo(WriteBatchToTfRecord(tmpdir))
| "Create tf.Example" >> beam.ParDo(MaestroToTfExample(str(mock_maestro_home), download=False))
| "Write to tfrecord" >> beam.ParDo(WriteBatchToTfRecord(str(output_dir)))
)

assert len(os.listdir(tmpdir)) == 1
assert os.path.splitext(os.listdir(tmpdir)[0])[-1] == ".tfrecord"
with open(os.path.join(tmpdir, os.listdir(tmpdir)[0]), "rb") as fp:
assert len(os.listdir(str(output_dir))) == 1
print("PASSED THIS POINT")
assert os.path.splitext(os.listdir(str(output_dir))[0])[-1] == ".tfrecord"
print("PASSED THIS OTHER POINT")
with open(os.path.join(str(output_dir), os.listdir(str(output_dir))[0]), "rb") as fp:
data = fp.read()
assert len(data) != 0


def test_maestro_invalid_tracks(tmpdir: str) -> None:
mock_maestro_home = pathlib.Path(tmpdir) / "maestro"
mock_maestro_ext = mock_maestro_home / "2004"
mock_maestro_ext.mkdir(parents=True, exist_ok=True)

create_mock_wav(str(mock_maestro_ext / (TRAIN_TRACK_ID.split("/")[1] + ".wav")), 3)
create_mock_wav(str(mock_maestro_ext / (VALID_TRACK_ID.split("/")[1] + ".wav")), 3)
create_mock_wav(str(mock_maestro_ext / (TEST_TRACK_ID.split("/")[1] + ".wav")), 3)

input_data = [(TRAIN_TRACK_ID, "train"), (VALID_TRACK_ID, "validation"), (TEST_TRACK_ID, "test")]
split_labels = set([e[1] for e in input_data])
with TestPipeline() as p:
splits = (
p
| "Create PCollection" >> beam.Create(input_data)
| "Tag it" >> beam.ParDo(MaestroInvalidTracks(str(MAESTRO_TEST_DATA_PATH))).with_outputs(*split_labels)
| "Tag it" >> beam.ParDo(MaestroInvalidTracks(str(mock_maestro_home))).with_outputs(*split_labels)
)

for split in split_labels:
Expand All @@ -106,16 +151,19 @@ def test_maestro_invalid_tracks_over_15_min(tmpdir: str) -> None:
not to store a large file in git, hence the variable name.
"""

mock_fpath = MAESTRO_TEST_DATA_PATH / "2004" / (MOCK_15M_TRACK_ID.split("/")[1] + ".wav")
mock_maestro_home = pathlib.Path(tmpdir) / "maestro"
mock_maestro_ext = mock_maestro_home / "2004"
mock_maestro_ext.mkdir(parents=True, exist_ok=True)
mock_fpath = mock_maestro_ext / (GT_15M_TRACK_ID.split("/")[1] + ".wav")
create_mock_wav(str(mock_fpath), 16)

input_data = [(MOCK_15M_TRACK_ID, "train")]
input_data = [(GT_15M_TRACK_ID, "train")]
split_labels = set([e[1] for e in input_data])
with TestPipeline() as p:
splits = (
p
| "Create PCollection" >> beam.Create(input_data)
| "Tag it" >> beam.ParDo(MaestroInvalidTracks(str(MAESTRO_TEST_DATA_PATH))).with_outputs(*split_labels)
| "Tag it" >> beam.ParDo(MaestroInvalidTracks(str(mock_maestro_home))).with_outputs(*split_labels)
)

for split in split_labels:
Expand All @@ -129,8 +177,6 @@ def test_maestro_invalid_tracks_over_15_min(tmpdir: str) -> None:
with open(os.path.join(tmpdir, f"output_{split}.txt"), "r") as fp:
assert fp.read().strip() == ""

os.remove(mock_fpath)


def test_maestro_create_input_data() -> None:
data = create_input_data(str(MAESTRO_TEST_DATA_PATH))
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit f2f9a39

Please sign in to comment.