forked from Agerrr/Automated_Music_Transcription
-
Notifications
You must be signed in to change notification settings - Fork 0
/
highest_peak_method.py
196 lines (160 loc) · 6.42 KB
/
highest_peak_method.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import sys
import wave
import math
import scipy
import pylab
import scipy.io.wavfile as wav
import numpy
def getDuration(sound_file):
"""
Returns the duration of a given sound file.
"""
wr = wave.open(sound_file, 'r')
nchannels, sampwidth, framerate, nframes, comptype, compname = wr.getparams()
return nframes / float(framerate)
def getFrameRate(sound_file):
wr = wave.open(sound_file, 'r')
nchannels, sampwidth, framerate, nframes, comptype, compname = wr.getparams()
return framerate
def is_Prime(n):
"""
Check if a number is prime.
"""
# make sure n is a positive integer
n = abs(int(n))
# 0 and 1 are not primes
if n < 2:
return False
# 2 is the only even prime number
if n == 2:
return True
# all other even numbers are not primes
if not n & 1:
return False
# range starts with 3 and only needs to go up the squareroot of n
# for all odd numbers
for x in range(3, int(n ** 0.5) + 1, 2):
if n % x == 0:
return False
return True
def get_next_power_2(n):
"""
Returns the closest number that is smaller than n that is a power of 2.
"""
power = 1
while (power < n):
power *= 2
if power > 1:
return power / 2
else:
return 1
class Highest_Peaks_MIDI_Detector(object):
"""
Class for MIDI notes detection given a .wav file.
"""
def __init__(self, wav_file):
self.wav_file = wav_file
# before: 0.005e+13 twinkle: 0.002e+14 scale: 0.005e+16
self.THRESHOLD = 0.005e+13
self.HAN_WINDOW = 0.093
self.HOP_SIZE = 0.00928
self.minFreqConsidered = 27.0
self.maxFreqConsidered = 2093
def detect_MIDI_notes(self):
"""
The algorithm for calculating midi notes from a given wav file.
"""
(framerate, sample) = wav.read(self.wav_file)
# We need to change the 2 channels into one because STFT works only
# for 1 channel. We could also do STFT for each channel separately.
monoChannel = sample.mean(axis=1)
duration = getDuration(self.wav_file)
midi_notes = []
# Consider only files with a duration longer than 0.2 seconds.
if duration > 0.18:
frequency_power = self.calculateFFT(duration, framerate, monoChannel)
filtered_frequencies = [f for (f, p) in frequency_power]
#self.plot_power_spectrum(frequency_power)
#self.plot_power_spectrum_dB(frequency_power)
f0_candidates = self.get_pitch_candidates_remove_highest_peak(frequency_power)
midi_notes = self.matchWithMIDINotes(f0_candidates)
return midi_notes
def get_pitch_candidates_remove_highest_peak(self, frequency_power):
peak_frequencies = []
while len(frequency_power) > 0:
# sort the frequency_power by power (highest power first)
sorted_frequency_power = sorted(frequency_power, key=lambda power: power[1], reverse=True)
peak_frequency = sorted_frequency_power[0][0]
peak_frequencies.append(peak_frequency)
frequency_power = self.filterOutHarmonics(frequency_power, peak_frequency)
return peak_frequencies
def plot_power_spectrum(self, frequency_power):
T = int(600)
pylab.figure('Power spectrum')
frequencies = [f[0] for f in frequency_power]
powers = [p[1] for p in frequency_power]
pylab.plot(frequencies[:T], powers[:T],)
pylab.xlabel('Frequency [Hz]')
pylab.ylabel('Power spectrum []')
pylab.show()
def plot_power_spectrum_dB(self, frequency_power):
T = int(600)
pylab.figure('Power spectrum')
frequencies = [f[0] for f in frequency_power]
powers = [p[1] for p in frequency_power]
dBs = [10 * math.log10(power) if power > 0 else 0 for power in powers]
pylab.plot(frequencies[:T], dBs[:T],)
pylab.xlabel('Frequency [Hz]')
pylab.ylabel('Power spectrum [dB]')
pylab.show()
def calculateFFT(self, duration, framerate, sample):
"""
Calculates FFT for a given sound wave.
Considers only frequencies with the magnitudes higher than
a given threshold.
"""
fft_length = int(duration * framerate)
fft_length = get_next_power_2(fft_length)
FFT = numpy.fft.fft(sample, n=fft_length)
''' ADJUSTING THRESHOLD '''
threshold = 0
power_spectra = []
for i in range(len(FFT) / 2):
power_spectrum = scipy.absolute(FFT[i]) * scipy.absolute(FFT[i])
if power_spectrum > threshold:
threshold = power_spectrum
power_spectra.append(power_spectrum)
threshold *= 0.1
binResolution = float(framerate) / float(fft_length)
frequency_power = []
# For each bin calculate the corresponding frequency.
for k in range(len(FFT) / 2):
binFreq = k * binResolution
if binFreq > self.minFreqConsidered and binFreq < self.maxFreqConsidered:
power_spectrum = power_spectra[k]
#dB = 10*math.log10(power_spectrum)
if power_spectrum > threshold:
frequency_power.append((binFreq, power_spectrum))
return frequency_power
def filterOutHarmonics(self, frequency_power, f0_candidate):
"""
Given frequency_power pairs and an f0 candidate remove
all possible harmonics of this f0 candidate.
"""
# If an integer frequency is a multiple of another frequency
# then it is its harmonic. This constant was found empirically.
# TODO: This constant may change for inharmonic frequencies!!!
REMAINDER_THRESHOLD = 0.2
def is_multiple(f, f0):
return abs(round(f / f0) - f / f0) < REMAINDER_THRESHOLD
return [(f, p) for (f, p) in frequency_power if not is_multiple(f, f0_candidate)]
def matchWithMIDINotes(self, f0_candidates):
midi_notes = []
for freq in f0_candidates:
#print 'FREQUENCY: ' + str(freq)
midi_notes.append(int(round(69 + 12 * math.log(freq / 440) / math.log(2)))) # Formula for calculating MIDI note number.
return midi_notes
if __name__ == '__main__':
MIDI_detector = Highest_Peaks_MIDI_Detector(sys.argv[1])
midi_notes = MIDI_detector.detect_MIDI_notes()
print midi_notes