-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
207 lines (169 loc) · 6.01 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# -*-coding:utf8;-*-
import math
import numpy as np
import matplotlib.pyplot as plt
from scipy.sparse import (
load_npz,
isspmatrix_dok,
save_npz
)
from constants import (
FILES_PATH,
INDEX_TYPES,
MATCHING_ALGORITHMS,
MANHATTAN_DISTANCE,
METHODS,
REQUIRE_INDEX_TYPE,
SEARCH_METHODS,
SERIALIZE_PITCH_VECTORS,
THRESHOLD_FILENAME
)
from messages import (
log_bare_exception_error,
log_impossible_serialize_option_error,
log_invalid_index_type_error,
log_invalid_matching_algorithm_error,
log_invalid_method_error,
log_no_confidence_measurement_found_error,
log_wrong_confidence_measurement_error
)
def get_confidence_measurement():
'''
Gets confidence measurement from its file, generated by train_confidence.
'''
threshold = None
try:
with open(THRESHOLD_FILENAME, 'r') as file:
content = file.read()
threshold = float(content)
except FileNotFoundError as not_found_err:
log_no_confidence_measurement_found_error()
except ValueError as value_err:
log_wrong_confidence_measurement_error(content)
except Exception as err:
log_bare_exception_error(err)
return threshold
def is_create_index_or_search_method(args):
'''
Says if passed method is creation or search of any index
'''
is_index_method = any([
method
for method in args
if method in REQUIRE_INDEX_TYPE
])
return is_index_method
def is_serialize_pitches_method(args):
return SERIALIZE_PITCH_VECTORS in args
def load_sparse_matrix(structure_name):
"""Loads a sparse matrix from a file in .npz format."""
filename = f'{FILES_PATH}/{structure_name}.npz'
matrix = load_npz(filename)
return matrix
def percent(part, whole):
'''
Given a percent and a whole part, calculates its real value.
Ex:
percent(10, 1000) # Ten percent of a thousand
> 100
'''
return float(whole) / 100 * float(part)
def save_graphic(values, xlabel, ylabel, title, show=False):
values_as_nparray = np.array(values)
histogram, bins, patches = plt.hist(
x=values_as_nparray,
bins='auto',
histtype='stepfilled',
color='#0504aa',
alpha=0.7,
rwidth=0.85
)
plt.grid(axis='y', alpha=0.75)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
max_frequency = histogram.max()
# Set a clean upper y-axis limit.
if max_frequency % 10:
y_max = np.ceil(max_frequency / 10) * 10
else:
y_max = max_frequency + 10
plt.ylim(ymax=y_max)
if show:
plt.show()
plt.savefig(FILES_PATH + f"/{title}.png")
def print_confidence_measurements(confidence_measurements):
'''
Prints confidence measurements of all queries.
'''
print('*' * 80)
for query_name, candidates_and_measures in confidence_measurements.items():
print('Query: ', query_name)
pluralize = '' if len(candidates_and_measures) == 1 else 's'
print('Candidate{0} confidence measurement{0}:'.format(pluralize))
for candidate_and_measure in candidates_and_measures:
print('\t', candidate_and_measure)
print('*' * 80)
def print_results(matching_algorithm, index_type, results, show_top_x):
print('*' * 80)
print(f'Results found by {matching_algorithm} in {index_type}')
for query_name, result in results.items():
print('Query: ', query_name)
print('Results:')
bounded_result = result# result[:show_top_x]
for position, r in enumerate(bounded_result, start=1):
print('\t{:03}. {}'.format(position, r))
print('*' * 80)
def save_sparse_matrix(structure, structure_name):
"""Save a sparse matrix to a file using .npz format. If the matrix is
dok-like its converted to csr and dok type is NOT restaured in load phase.
"""
if isspmatrix_dok(structure):
# save_npz does not support dok matrix
structure = structure.tocsr()
filename = f'{FILES_PATH}/{structure_name}.npz'
save_npz(filename, structure)
def train_confidence(all_confidence_measurements, results_mapping):
confidence_training_data = []
for query_name, candidates_and_measures in all_confidence_measurements.items():
correct_result = results_mapping[query_name]
first_candidate_name, first_candidate_measure = candidates_and_measures[0]
if first_candidate_name != correct_result:
confidence_training_data.append(first_candidate_measure)
threshold = max(confidence_training_data)
print(
f'Max confidence measure is: {threshold}.\n',
f'Saving in file {THRESHOLD_FILENAME}'
)
with open(THRESHOLD_FILENAME, 'w') as file:
file.write(str(threshold))
print("WARN: Exiting program because 'train_confidence' is True")
exit(0)
def unzip_pitch_contours(pitch_contour_segmentations):
"""
Extracts audio path and pitch vector for application of matching algorithms.
"""
pitch_vectors = []
for pitch_contour_segmentation in pitch_contour_segmentations:
audio_path, pitch_vector, onsets, durations = pitch_contour_segmentation
pitch_vectors.append((audio_path, pitch_vector))
return np.array(pitch_vectors)
def validate_program_args(**kwargs):
"""
Validates the list of program args. If any of them is invalid, logs an
error message and exists program.
Arguments:
kwargs {dict} -- Dict of program args
"""
method_name = kwargs['method_name']
serialize_options = kwargs['serialize_options']
is_training_confidence = kwargs['is_training_confidence']
invalid_confidence_measurement = False
if not is_training_confidence and method_name in SEARCH_METHODS:
confidence_measurement = get_confidence_measurement()
invalid_confidence_measurement = confidence_measurement is None
if invalid_confidence_measurement:
exit(1)
if method_name == SERIALIZE_PITCH_VECTORS and not serialize_options:
log_impossible_serialize_option_error()
exit(1)