Skip to content

Commit

Permalink
Migrated to Pathlib.
Browse files Browse the repository at this point in the history
  • Loading branch information
hendriks73 committed Oct 16, 2024
1 parent 50e0b7d commit 82d9ef8
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 42 deletions.
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Changes
0.0.8:
- Moved to TensorFlow 2.17.0 and Python 3.9/3.10/3.11.
- Made local cache version dependent.
- Migrated code to Pathlib.

0.0.7:
- Added DOIs to bibtex entries.
Expand Down
38 changes: 15 additions & 23 deletions tempocnn/commands.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse
import sys
from os import listdir, makedirs
from os.path import isfile, join, basename, exists, splitext, dirname
from pathlib import Path

import jams
import librosa
Expand Down Expand Up @@ -192,27 +191,24 @@ def _write_tempo_result(
:param create_jam: create JAM or not
"""

file_dir = dirname(input_file)
file_name = basename(input_file)
if output_dir is not None:
file_dir = output_dir
file_dir = Path(input_file).parent if output_dir is None else Path(output_dir)

# determine output_file name
output_file = None
if create_jam:
base, file_extension = splitext(file_name)
output_file = join(file_dir, base + ".jams")
output_file = file_dir / Path(input_file).with_suffix(".jams").name
elif append_extension is not None:
output_file = join(file_dir, file_name + append_extension)
output_file = file_dir / Path(f"{input_file}{append_extension}").name
elif replace_extension is not None:
base, file_extension = splitext(file_name)
output_file = join(file_dir, base + replace_extension)
if not replace_extension.startswith("."):
replace_extension = f".{replace_extension}"
output_file = file_dir / Path(input_file).with_suffix(replace_extension).name
elif output_list is not None and index < len(output_list):
output_file = output_list[index]
output_file = Path(output_list[index])

# actually writing the output
if create_jam:
result.save(output_file)
result.save(str(output_file))
elif output_file is None:
print("\n" + result)
else:
Expand All @@ -225,7 +221,7 @@ def _create_tempo_jam(input_file, model, s1, t1, t2):
y, sr = librosa.load(input_file)
track_duration = librosa.get_duration(y=y, sr=sr)
result.file_metadata.duration = track_duration
result.file_metadata.identifiers = {"file": basename(input_file)}
result.file_metadata.identifiers = {"file": Path(input_file).name}
tempo_a = jams.Annotation(namespace="tempo", time=0, duration=track_duration)
tempo_a.annotation_metadata = jams.AnnotationMetadata(
version=package_version,
Expand Down Expand Up @@ -536,9 +532,9 @@ def greekfolk():
# parse arguments
args = parser.parse_args()

if not exists(args.output):
if not Path(args.output).exists():
print("Creating output dir: " + args.output)
makedirs(args.output)
Path(args.output).mkdir(parents=True, exist_ok=True)

# load models
print("Loading models...")
Expand All @@ -547,22 +543,18 @@ def greekfolk():

print("Processing file(s)...")

wav_files = [
join(args.input, f)
for f in listdir(args.input)
if f.endswith(".wav") and isfile(join(args.input, f))
]
wav_files = [f for f in Path(args.input).glob("*.wav") if f.is_file()]
if len(wav_files) == 0:
print("No .wav files found in " + args.input)
for input_file in wav_files:
print("Analyzing: " + input_file)
print(f"Analyzing: {input_file}")
meter_features = read_features(input_file, frames=512, hop_length=256)
meter_result = str(meter_classifier.estimate_meter(meter_features))

tempo_features = read_features(input_file)
tempo_result = str(round(tempo_classifier.estimate_tempo(tempo_features), 1))

output_file = join(args.output, basename(input_file).replace(".wav", ".txt"))
output_file = Path(args.output, input_file.with_suffix(".txt").name)
with open(output_file, mode="w") as f:
f.write(tempo_result + "\t" + meter_result + "\n")
print("\nDone")
Expand Down
6 changes: 3 additions & 3 deletions test/test_classifier.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
# import os
from pathlib import Path
from unittest.mock import patch

import numpy as np
Expand All @@ -24,8 +25,7 @@ def test_data():

@pytest.fixture
def test_track():
dir = os.path.dirname(os.path.abspath(__file__))
return os.path.join(dir, "data", "drumtrack.mp3")
return Path(__file__).absolute().parent / "data" / "drumtrack.mp3"


@pytest.mark.parametrize(
Expand Down
25 changes: 12 additions & 13 deletions test/test_commands.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from os.path import dirname
from pathlib import Path

import jams

import pytest
Expand All @@ -11,8 +11,7 @@

@pytest.fixture
def test_track():
dir = os.path.dirname(os.path.abspath(__file__))
return os.path.join(dir, "data", "drumtrack.mp3")
return Path(__file__).absolute().parent / "data" / "drumtrack.mp3"


@pytest.mark.parametrize("entry_point", entry_points)
Expand Down Expand Up @@ -45,7 +44,7 @@ def test_tempogram(script_runner, test_track):
ret = script_runner.run(["tempogram", "-p", test_track])
assert ret.success
assert "Loading model" in ret.stdout
assert os.path.exists(test_track + ".png")
assert Path(f"{test_track}.png").exists()


def test_meter(script_runner, test_track):
Expand All @@ -56,7 +55,7 @@ def test_meter(script_runner, test_track):


def test_greekfolk(script_runner, tmpdir, test_track):
ret = script_runner.run(["greekfolk", dirname(test_track), str(tmpdir)])
ret = script_runner.run(["greekfolk", str(test_track.parent), str(tmpdir)])
assert ret.success
assert "No .wav files found in" in ret.stdout

Expand Down Expand Up @@ -87,9 +86,9 @@ def test_tempo_jams(script_runner, test_track):
assert ret.success
assert "Loading model" in ret.stdout
assert "Processing file" in ret.stdout
jams_file = test_track.replace(".mp3", ".jams")
assert os.path.exists(jams_file)
jam = jams.load(jams_file)
jams_file = test_track.with_suffix(".jams")
assert jams_file.exists()
jam = jams.load(str(jams_file))

annotation = jam.annotations[0]
assert annotation.duration == pytest.approx(15.046, abs=0.001)
Expand All @@ -114,17 +113,17 @@ def test_tempo_jams_and_mirex(script_runner, test_track):
def test_tempo_extension(script_runner, test_track):
ret = script_runner.run(["tempo", "-e", ".fancy_pants", "-i", test_track])
assert ret.success
extension_name = test_track + ".fancy_pants"
assert os.path.exists(extension_name)
extension_name = Path(f"{test_track}.fancy_pants")
assert extension_name.exists()
with open(extension_name, "r") as f:
assert "100" in f.read()


def test_tempo_replace_extension(script_runner, test_track):
ret = script_runner.run(["tempo", "-re", ".fancy_pants", "-i", test_track])
assert ret.success
extension_name = test_track.replace(".mp3", ".fancy_pants")
assert os.path.exists(extension_name)
extension_name = test_track.with_suffix(".fancy_pants")
assert extension_name.exists()
with open(extension_name, "r") as f:
assert "100" in f.read()

Expand Down
5 changes: 2 additions & 3 deletions test/test_feature.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
from pathlib import Path

import librosa
import pytest
Expand All @@ -8,8 +8,7 @@

@pytest.fixture
def test_track():
dir = os.path.dirname(os.path.abspath(__file__))
return os.path.join(dir, "data", "drumtrack.mp3")
return Path(__file__).absolute().parent / "data" / "drumtrack.mp3"


def test_init(test_track):
Expand Down

0 comments on commit 82d9ef8

Please sign in to comment.