Skip to content

Commit

Permalink
Added type hints.
Browse files Browse the repository at this point in the history
  • Loading branch information
hendriks73 committed Oct 17, 2024
1 parent 82d9ef8 commit ab832fc
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 48 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@ jobs:
sudo apt-get install libsndfile1
sudo apt-get install ffmpeg
python -m pip install --upgrade pip setuptools wheel
pip install ruff pytest
pip install ruff pytest mypy
pip install .[testing]
- name: Lint with ruff
run: |
ruff check tempocnn test
ruff format --check tempocnn test
- name: Type check with mypy
run: |
mypy --ignore-missing-imports --check-untyped-defs tempocnn test
- name: Test with pytest
run: |
coverage run --source ./tempocnn -m pytest --verbose --junitxml=pytest_report${{ matrix.python-version }}.xml
Expand Down
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Changes
- Moved to TensorFlow 2.17.0 and Python 3.9/3.10/3.11.
- Made local cache version dependent.
- Migrated code to Pathlib.
- Added type hints.

0.0.7:
- Added DOIs to bibtex entries.
Expand Down
49 changes: 29 additions & 20 deletions tempocnn/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import urllib.request
from pathlib import Path
from typing import Optional
from urllib.error import HTTPError

import numpy as np
Expand All @@ -12,7 +13,7 @@
logger = logging.getLogger("tempocnn.classifier")


def std_normalizer(data):
def std_normalizer(data: np.ndarray) -> np.ndarray:
"""
Normalizes data to zero mean and unit variance.
Used by Mazurka models.
Expand All @@ -29,7 +30,7 @@ def std_normalizer(data):
return data.astype(np.float16)


def max_normalizer(data):
def max_normalizer(data: np.ndarray) -> np.ndarray:
"""
Divides by max. Used as normalization by older models.
Expand All @@ -47,7 +48,7 @@ class TempoClassifier:
Classifier that can estimate musical tempo in different formats.
"""

def __init__(self, model_name="fcn"):
def __init__(self, model_name: str = "fcn"):
"""
Initializes this classifier with a Keras model.
Expand Down Expand Up @@ -94,7 +95,7 @@ def __init__(self, model_name="fcn"):
logger.debug(f"Loading model {model_name} from {file}")
self.model = load_model(file, compile=False)

def estimate(self, data):
def estimate(self, data: np.ndarray) -> np.ndarray:
"""
Estimate a tempo distribution.
Probabilities are indexed, starting with 30 BPM and ending with 286 BPM.
Expand All @@ -118,7 +119,9 @@ def estimate(self, data):
return self.model.predict(norm_data, norm_data.shape[0])

@staticmethod
def quad_interpol_argmax(y, x=None):
def quad_interpol_argmax(
y: np.ndarray, x: Optional[int] = None
) -> tuple[float, float]:
"""
Find argmax for quadratic interpolation around argmax of y.
Expand All @@ -127,16 +130,16 @@ def quad_interpol_argmax(y, x=None):
:return: float (index) of interpolated max, strength
"""
if x is None:
x = np.argmax(y)
x = int(np.argmax(y))
if x == 0 or x == y.shape[0] - 1:
return x, y[x]
return float(x), float(y[x])
z = np.polyfit([x - 1, x, x + 1], [y[x - 1], y[x], y[x + 1]], 2)
# find (float) x value for max
argmax = -z[1] / (2.0 * z[0])
height = z[2] - (z[1] ** 2.0) / (4.0 * z[0])
return argmax, height
return float(argmax), float(height)

def estimate_tempo(self, data, interpolate=False):
def estimate_tempo(self, data: np.ndarray, interpolate: bool = False) -> float:
"""
Estimates the pre-dominant global tempo.
Expand All @@ -150,10 +153,12 @@ def estimate_tempo(self, data, interpolate=False):
if interpolate:
index, _ = self.quad_interpol_argmax(averaged_prediction)
else:
index = np.argmax(averaged_prediction)
index = int(np.argmax(averaged_prediction))
return self.to_bpm(index)

def estimate_mirex(self, data, interpolate=False):
def estimate_mirex(
self, data: np.ndarray, interpolate: bool = False
) -> tuple[float, float, float]:
"""
Estimates the two dominant tempi along with a salience value.
Expand All @@ -165,8 +170,8 @@ def estimate_mirex(self, data, interpolate=False):

prediction = self.estimate(data)

def find_index_peaks(distribution):
p = []
def find_index_peaks(distribution: np.ndarray) -> list[tuple[float, float]]:
p: list[tuple[float, float]] = []
last_index = 0
for index in range(256):
height = distribution[index]
Expand All @@ -181,14 +186,18 @@ def find_index_peaks(distribution):
) = self.quad_interpol_argmax(distribution, x=index)
p.append((interpolated_index, interpolated_height))
else:
p.append((index, height))
p.append((float(index), float(height)))
last_index = index
# sort peaks by height, descending
return sorted(p, key=lambda element: element[1], reverse=True)

averaged_prediction = np.average(prediction, axis=0)
peaks = find_index_peaks(averaged_prediction)

s1: float
t1: float
t2: float

if len(peaks) == 0:
s1 = 1.0
t1 = 0.0
Expand Down Expand Up @@ -226,7 +235,7 @@ class MeterClassifier:
Classifier that can estimate musical meter
"""

def __init__(self, model_name="fcn"):
def __init__(self, model_name: str = "fcn"):
"""
Initializes this classifier with a Keras model.
Expand Down Expand Up @@ -254,7 +263,7 @@ def __init__(self, model_name="fcn"):
raise e
self.model = load_model(file)

def estimate(self, data):
def estimate(self, data: np.ndarray) -> np.ndarray:
"""
Estimate a meter distribution.
Probabilities are indexed, starting with 2. Only the meter numerator is given (e.g. 2 for 2/4).
Expand All @@ -277,7 +286,7 @@ def estimate(self, data):
norm_data = self.normalize(data)
return self.model.predict(norm_data, norm_data.shape[0])

def estimate_meter(self, data):
def estimate_meter(self, data: np.ndarray) -> int:
"""
Estimates the pre-dominant global meter.
Expand All @@ -290,7 +299,7 @@ def estimate_meter(self, data):
return self._to_meter(index)


def _to_model_resource(model_name):
def _to_model_resource(model_name: str) -> str:
file = model_name
if not model_name.endswith(".h5"):
file = file + ".h5"
Expand All @@ -299,7 +308,7 @@ def _to_model_resource(model_name):
return file


def _extract_from_package(resource):
def _extract_from_package(resource: str) -> str:
# check local cache
cache_path = Path(Path.home(), ".tempocnn", package_version, resource)
if cache_path.exists():
Expand All @@ -326,7 +335,7 @@ def _extract_from_package(resource):
return str(cache_path)


def _load_model_from_github(resource):
def _load_model_from_github(resource: str):
url = f"https://raw.githubusercontent.com/hendriks73/tempo-cnn/main/tempocnn/{resource}"
logger.info(f"Attempting to download model file from main branch {url}")
try:
Expand Down
51 changes: 28 additions & 23 deletions tempocnn/commands.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import sys
from pathlib import Path
from typing import Union, Optional

import jams
import librosa
Expand Down Expand Up @@ -98,10 +99,10 @@ def tempo():

output_format = parser.add_mutually_exclusive_group()
output_format.add_argument(
"--mirex", help="use MIREX format for output", action="store_true"
"--mirex", help="use MIREX format for output", action="store_true", type=bool
)
output_format.add_argument(
"--jams", help="use JAMS format for output", action="store_true"
"--jams", help="use JAMS format for output", action="store_true", type=bool
)

parser.add_argument(
Expand Down Expand Up @@ -132,6 +133,7 @@ def tempo():
create_jam = args.jams
create_mirex = args.mirex

result: Union[str, jams.JAMS]
if create_mirex or create_jam:
t1, t2, s1 = classifier.estimate_mirex(
features, interpolate=args.interpolate
Expand Down Expand Up @@ -169,14 +171,14 @@ def tempo():


def _write_tempo_result(
result,
input_file=None,
output_dir=None,
output_list=None,
index=0,
append_extension=None,
replace_extension=None,
create_jam=False,
result: Union[str, jams.JAMS],
input_file: str,
output_dir: Optional[str] = None,
output_list: Optional[list[str]] = None,
index: int = 0,
append_extension: Optional[str] = None,
replace_extension: Optional[str] = None,
create_jam: bool = False,
):
"""
Write the tempo analysis results to a file.
Expand Down Expand Up @@ -207,7 +209,7 @@ def _write_tempo_result(
output_file = Path(output_list[index])

# actually writing the output
if create_jam:
if create_jam and isinstance(result, jams.JAMS):
result.save(str(output_file))
elif output_file is None:
print("\n" + result)
Expand All @@ -216,7 +218,9 @@ def _write_tempo_result(
file_name.write(result + "\n")


def _create_tempo_jam(input_file, model, s1, t1, t2):
def _create_tempo_jam(
input_file: Union[str, Path], model: str, s1: float, t1: float, t2: float
) -> jams.JAMS:
result = jams.JAMS()
y, sr = librosa.load(input_file)
track_duration = librosa.get_duration(y=y, sr=sr)
Expand Down Expand Up @@ -378,7 +382,8 @@ def tempogram():
frame_length = (fft_hop_length / sr) * hop_length

fig = plt.figure()
fig.canvas.manager.set_window_title("tempogram: " + file)
if fig.canvas.manager is not None:
fig.canvas.manager.set_window_title("tempogram: " + file)
if args.png:
fig.set_size_inches(5, 2)

Expand Down Expand Up @@ -422,7 +427,7 @@ def tempogram():
print("\nDone")


def _norm_tempogram_frames(predictions=None, norm_frame=None):
def _norm_tempogram_frames(predictions: np.ndarray, norm_frame: str) -> np.ndarray:
norm_order = np.inf
if "max" == norm_frame.lower():
norm_order = np.inf
Expand All @@ -439,14 +444,14 @@ def _norm_tempogram_frames(predictions=None, norm_frame=None):


def _write_tempogram_as_csv(
predictions=None,
classifier=None,
file=None,
frame_length=None,
log_scale=False,
min_bpm=30,
max_bpm=256,
sharpen=False,
predictions: np.ndarray,
classifier: TempoClassifier,
file: str,
frame_length: int,
log_scale: bool = False,
min_bpm: int = 30,
max_bpm: int = 256,
sharpen: bool = False,
):
csv_file_name = file + ".csv"
if sharpen:
Expand Down Expand Up @@ -479,7 +484,7 @@ def _write_tempogram_as_csv(
)


def _get_tempogram_limits(log_scale):
def _get_tempogram_limits(log_scale: bool) -> tuple[int, int, int]:
if log_scale:
min_bpm = 50
max_bpm = 500
Expand Down
23 changes: 19 additions & 4 deletions tempocnn/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,24 @@
20 to 5000 Hz.
"""

import os
from pathlib import Path
from typing import Union, Any, BinaryIO

import audioread
import librosa as librosa
import numpy as np
import soundfile as sf


def read_features(file, frames=256, hop_length=128, zero_pad=False):
def read_features(
file: Union[
str, Path, os.PathLike[Any], sf.SoundFile, audioread.AudioFile, BinaryIO
],
frames: int = 256,
hop_length: int = 128,
zero_pad: bool = False,
) -> np.ndarray:
"""
Resample file to 11025 Hz, then transform using STFT with length 1024
and hop size 512. Convert resulting linear spectrum to mel spectrum
Expand Down Expand Up @@ -56,21 +69,23 @@ def read_features(file, frames=256, hop_length=128, zero_pad=False):
return _to_sliding_window(data, frames, hop_length)


def _ensure_length(data, length):
def _ensure_length(data: np.ndarray, length: int) -> np.ndarray:
padded_data = np.zeros((1, data.shape[1], length, 1), dtype=data.dtype)
padded_data[0, :, 0 : data.shape[2], 0] = data[0, :, :, 0]
return padded_data


def _add_zeros(data, zeros):
def _add_zeros(data: np.ndarray, zeros: int) -> np.ndarray:
padded_data = np.zeros(
(1, data.shape[1], data.shape[2] + zeros, 1), dtype=data.dtype
)
padded_data[0, :, zeros // 2 : data.shape[2] + (zeros // 2), 0] = data[0, :, :, 0]
return padded_data


def _to_sliding_window(data, window_length, hop_length):
def _to_sliding_window(
data: np.ndarray, window_length: int, hop_length: int
) -> np.ndarray:
total_frames = data.shape[2]
windowed_data = []
for offset in range(
Expand Down

0 comments on commit ab832fc

Please sign in to comment.