Skip to content

Commit

Permalink
add utils
Browse files Browse the repository at this point in the history
  • Loading branch information
matiaslindgren committed Nov 7, 2020
1 parent 438d266 commit abc2a43
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 6 deletions.
31 changes: 28 additions & 3 deletions lidbox/meta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Dataset metadata parsing/loading/preprocessing.
"""
from concurrent.futures import ThreadPoolExecutor
import collections
import itertools
import os

Expand All @@ -10,13 +11,26 @@
import pandas as pd


REQUIRED_META_COLUMNS = (
"path",
"label",
"split",
)


def verify_integrity(meta, use_threads=True):
"""
Check that
1. There are no NaN values.
2. All paths exist on disk.
3. All splits are disjoint by speaker id.
1. The metadata table contains all required columns.
2. There are no NaN values.
3. All audio filepaths exist on disk.
4. All splits/buckets are disjoint by speaker id.
This function throws an exception if verification fails, otherwise completes silently.
"""
missing_columns = set(REQUIRED_META_COLUMNS) - set(meta.columns)
assert missing_columns == set(), "{} missing columns in metadata: {}".format(len(missing_columns), sorted(missing_columns))

assert not meta.isna().any(axis=None), "NaNs in metadata"

if use_threads:
Expand Down Expand Up @@ -91,3 +105,14 @@ def update_sample_id(row):
samples.append(sample)

return pd.concat(samples).set_index("id", drop=True, verify_integrity=True)


def generate_label2target(meta):
"""
Generate a unique label-to-target mapping,
where integer targets are the enumeration of labels in lexicographic order.
"""
label2target = collections.OrderedDict(
(l, t) for t, l in enumerate(sorted(meta.label.unique())))
meta["target"] = np.array([label2target[l] for l in meta.label], np.int32)
return meta, label2target
60 changes: 57 additions & 3 deletions lidbox/util.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
import itertools
"""
High-level utilities and wrappers on top of high-level APIs of other libraries.
"""
import numpy as np
import pandas as pd
import sklearn.metrics
import sklearn.preprocessing

import lidbox.metrics

def predict_with_model(model, test_ds, batch_size=1):

def predict_with_model(model, ds):
"""
Map callable model over all batches in ds, predicting values for each element at key 'input'.
"""
ids = []
predictions = []

for x in test_ds.batch(batch_size).as_numpy_iterator():
for x in ds.as_numpy_iterator():
ids.extend(id.decode("utf-8") for id in x["id"])
predictions.extend(p.numpy() for p in model(x["input"], training=False))

Expand All @@ -28,6 +37,51 @@ def merge_chunk_predictions(chunk_predictions, merge_fn=np.mean):
.agg(lambda row: merge_fn(np.array(row), axis=0)))


def classification_report(true_sparse, pred_dense, label2target, dense2sparse_fn=None, num_cavg_thresholds=100):
if dense2sparse_fn is None:
dense2sparse_fn = lambda pred: pred.argmax(axis=1)
pred_sparse = dense2sparse_fn(pred_dense)

report = sklearn.metrics.classification_report(
true_sparse,
pred_sparse,
labels=list(range(len(label2target))),
target_names=label2target,
output_dict=True,
zero_division=0)

cavg_thresholds = np.linspace(
pred_dense.min(),
pred_dense.max(),
num_cavg_thresholds)
cavg = lidbox.metrics.SparseAverageDetectionCost(len(label2target), cavg_thresholds)
cavg.update_state(true_sparse, pred_dense)
report["avg_detection_cost"] = float(cavg.result().numpy())

def to_dense(target):
v = np.zeros(len(label2target))
v[target] = 1
return v

true_dense = np.stack([to_dense(t) for t in true_sparse])

eer = np.zeros(len(label2target))
for l, label in enumerate(label2target):
# https://stackoverflow.com/a/46026962
fpr, tpr, _ = sklearn.metrics.roc_curve(true_dense[:,l], pred_dense[:,l])
fnr = 1 - tpr
eer[l] = fpr[np.nanargmin(np.absolute(fnr - fpr))]

report["avg_equal_error_rate"] = float(eer.mean())
for label, i in label2target.items():
report[label]["equal_error_rate"] = eer[i]

report["confusion_matrix"] = sklearn.metrics.confusion_matrix(true_sparse, pred_sparse)

# TODO convert to multi-level pandas.DataFrame by separating language metrics from summary metrics
return report


# TODO
# 1. load metadata
# 2. merge, preprocess, update, prepare, meta
Expand Down

0 comments on commit abc2a43

Please sign in to comment.