Skip to content

Commit

Permalink
rpi 4 kws
Browse files Browse the repository at this point in the history
  • Loading branch information
roatienza committed Apr 4, 2022
1 parent f0f9116 commit 15d6285
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions versions/2022/supervised/python/kws-infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,19 @@
To use microphone input, run:
python3 kws-infer.py --gui
On RPi 4:
python3 kws-infer.py --rpi --gui
Dependencies:
pip3 install pysimplegui
pip3 install sounddevice
pip3 install librosa
sudo apt-get install libasound2-dev libportaudio2
Inference time:
0.03 sec Quad Core Intel i7 2.3GHz
0.09 sec on RPi 4
'''


Expand All @@ -38,6 +45,7 @@ def get_args():
parser.add_argument("--wav-file", type=str, default=None)
parser.add_argument("--model-path", type=str, default="resnet18-kws-best-acc.pt")
parser.add_argument("--gui", default=False, action="store_true")
parser.add_argument("--rpi", default=False, action="store_true")
args = parser.parse_args()
return args

Expand Down Expand Up @@ -82,18 +90,6 @@ def get_args():

if not args.gui:
waveform, sample_rate = torchaudio.load(wav_file)

#wav = np.expand_dims(waveform.numpy().squeeze(), axis=1)

#print("Shape:", wav.shape)
#sd.default.samplerate = sample_rate
#sd.default.channels = 1
#sd.play(wav, blocking=True)

#print("Shape:", waveform.shape)
#sd.fs = sample_rate
#sd.channel_count = waveform.shape[1]
#sd.play(waveform.numpy(), sample_rate)

transform = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate,
n_fft=args.n_fft,
Expand Down Expand Up @@ -127,8 +123,17 @@ def get_args():
if waveform.max() > 1.0:
continue
start_time = time.time()
waveform = torch.from_numpy(waveform).unsqueeze(0)
mel = ToTensor()(librosa.power_to_db(transform(waveform).squeeze().numpy(), ref=np.max))
if args.rpi:
waveform = torch.FloatTensor(waveform.tolist())
mel = np.array(transform(waveform).squeeze().tolist())
mel = librosa.power_to_db(mel, ref=np.max).tolist()

mel = torch.FloatTensor(mel)
mel = mel.unsqueeze(0)

else:
waveform = torch.from_numpy(waveform).unsqueeze(0)
mel = ToTensor()(librosa.power_to_db(transform(waveform).squeeze().numpy(), ref=np.max))
mel = mel.unsqueeze(0)
pred = scripted_module(mel)
max_prob = pred.max()
Expand Down

0 comments on commit 15d6285

Please sign in to comment.