From 9bd7383f59cb15420f698b0bbce33b123cc0946d Mon Sep 17 00:00:00 2001 From: nicolas-f <1382241+nicolas-f@users.noreply.github.com> Date: Thu, 8 Feb 2024 15:15:46 +0100 Subject: [PATCH] Before keeping audio data, scan the whole audio with yamnet if there is banned sound source, and do not keep audio if it was found in audio --- services/zero_trigger.py | 297 +++++++++++++++++++++++---------------- 1 file changed, 174 insertions(+), 123 deletions(-) diff --git a/services/zero_trigger.py b/services/zero_trigger.py index 2cb9a2b..dd9ef1d 100644 --- a/services/zero_trigger.py +++ b/services/zero_trigger.py @@ -72,6 +72,7 @@ from importlib.resources import files import struct + class Params: """ Yamnet settings @@ -117,7 +118,8 @@ def encrypt(audio_data, ssh_file): aes_cipher = AES.new(aes_key, AES.MODE_CBC, iv) # pad audio data if len(audio_data) % AES.block_size > 0: - audio_data = audio_data.ljust(len(audio_data) + AES.block_size - len(audio_data) % AES.block_size, b'\0') + audio_data = audio_data.ljust( + len(audio_data) + AES.block_size - len(audio_data) % AES.block_size, b'\0') # Write AES data output_encrypted.write(aes_cipher.encrypt(audio_data)) return output_encrypted.getvalue() @@ -132,9 +134,11 @@ def __init__(self, trigger_processor, config): def run(self): while self.config.running: record_time = str(datetime.timedelta(seconds= - round(self.trigger_processor.total_read / self.config.sample_rate))) - print("%s samples read: %ld (%s)" % (datetime.datetime.now().replace(microsecond=0).isoformat(), - self.trigger_processor.total_read, record_time)) + round( + self.trigger_processor.total_read / self.config.sample_rate))) + print("%s samples read: %ld (%s)" % ( + datetime.datetime.now().replace(microsecond=0).isoformat(), + self.trigger_processor.total_read, record_time)) time.sleep(self.config.delay_print_samples) @@ -142,7 +146,8 @@ def read_yamnet_class_and_threshold(class_map_csv): with open(class_map_csv) as csv_file: reader = csv.reader(csv_file) next(reader) # Skip header - names, threshold = zip(*[[display_name.strip(), float(threshold)] for _, _, display_name, threshold in reader]) + names, threshold = zip( + *[[display_name.strip(), float(threshold)] for _, _, display_name, threshold in reader]) return names, np.array(threshold, dtype=float) @@ -169,6 +174,7 @@ class TriggerProcessor: def __init__(self, config): self.frame_time = 0 + self.processing_time = 0 self.config = config self.total_read = 0 # Total audio samples read self.sample_rate = self.config.sample_rate @@ -197,7 +203,7 @@ def __init__(self, config): self.tensors.spectrogram_output_index = output_details[2]['index'] print("Init tensors done") self.yamnet_samples = np.zeros((int(config.yamnet_window_time * - self.yamnet_config.sample_rate)), + self.yamnet_config.sample_rate)), dtype=np.float32) # where to place new samples self.yamnet_samples_index = 0 @@ -207,7 +213,7 @@ def __init__(self, config): self.yamnet_classes = read_yamnet_class_and_threshold(yamnet_class_map) if self.config.yamnet_cutoff_frequency > 0: self.sos = butter_highpass(self.config.yamnet_cutoff_frequency, - self.yamnet_config.sample_rate) + self.yamnet_config.sample_rate) else: self.sos = None @@ -223,24 +229,24 @@ def init_socket(self): self.socket_out = context.socket(zmq.PUB) self.socket_out.bind(self.config.output_address) - def process_tags(self): + def process_tags(self, samples): # check for sound recognition tags # filter and normalize signal if self.config.yamnet_cutoff_frequency > 0: - waveform = self.butter_highpass_filter(self.yamnet_samples) + samples = self.butter_highpass_filter(samples) if self.config.yamnet_max_gain > 0: # apply gain - max_value = max(1e-12, float(np.max(np.abs(self.yamnet_samples)))) + max_value = max(1e-12, float(np.max(np.abs(samples)))) gain = 10 * math.log10(1 / max_value) gain = min(self.config.yamnet_max_gain, gain) - self.yamnet_samples *= 10 ** (gain / 10.0) + samples *= 10 ** (gain / 10.0) # Predict YAMNet classes. self.yamnet_interpreter.resize_tensor_input( self.tensors.waveform_input_index, - [len(self.yamnet_samples)], strict=True) + [len(samples)], strict=True) self.yamnet_interpreter.allocate_tensors() self.yamnet_interpreter.set_tensor(self.tensors.waveform_input_index, - self.yamnet_samples) + samples) self.yamnet_interpreter.invoke() scores, embeddings, spectrogram = ( self.yamnet_interpreter.get_tensor(self.tensors.scores_output_index), @@ -259,22 +265,77 @@ def fetch_audio_data(self, feed_cache=True): if feed_cache: self.samples_stack.append(audio_data_samples) # will keep keep_only_samples samples, and drop older stack elements - keep_only_samples = max(self.config.cached_length, self.yamnet_config.patch_window_seconds) * \ + keep_only_samples = max(self.config.cached_length, + self.yamnet_config.patch_window_seconds) * \ self.config.sample_rate - while sum([len(s) for s in self.samples_stack]) > keep_only_samples + len(audio_data_samples): + while sum([len(s) for s in self.samples_stack]) > keep_only_samples + len( + audio_data_samples): self.samples_stack.popleft() return audio_data_samples + def generate_yamnet_document(self, samples, add_spectrogram: bool): + """ + @param samples: Audio samples in 16khz sample rate + @param add_spectrogram: add spectrogram in dictionary + @return: dict + """ + deb = time.time() + scores, embeddings, spectrogram = self.process_tags(samples) + self.processing_time += time.time() - deb + # Take maximum found prediction (was avg in the ref) + prediction = np.max(scores, axis=0) + # filter out classes that are below threshold values + filter_pred = (prediction > self.yamnet_classes[1]) + filter_pred = filter_pred.nonzero()[0] + classes_threshold_index = list(map(int, filter_pred)) + if len(classes_threshold_index) == 0: + return {} + classification_tag = [self.yamnet_classes[0][i] + for i in classes_threshold_index] + print("%s tags:%s \n processed in %.3f seconds for " + "%.1f seconds of audio." % + (time.strftime("%Y-%m-%d %H:%M:%S"), + ",".join(classification_tag), self.processing_time, + len(samples) / + self.yamnet_config.sample_rate)) + self.processing_time = 0 + # Sort by score + classes_threshold_index = [classes_threshold_index[j] for j in + np.argsort([prediction[i] - self.yamnet_classes[1][i] + for i in classes_threshold_index])[::-1]] + # Compute a score between 0-100% from threshold to 1.0 + scores_percentage = {self.yamnet_classes[0][i]: round( + float(((prediction[i] - self.yamnet_classes[1][i]) / ( + 1 - self.yamnet_classes[1][i])) * 100)) for + i in classes_threshold_index} + document_scores = {self.yamnet_classes[0][i]: round(float(prediction[i]), 2) for i in + classes_threshold_index} + # threshold_time is the score over the time + # there is 2x more cells because there is 50% overlap + threshold_time = { + self.yamnet_classes[0][i]: [self.yamnet_classes[1][i], + np.round(scores[:, i], + 3).tolist()] + for i in classes_threshold_index} + document = {"scores": document_scores, + "scores_perc": scores_percentage, + "scores_time": threshold_time} + if add_spectrogram: + document["spectrogram"] = base64.b64encode( + spectrogram.astype(np.float16). + tobytes()).decode("UTF-8") + return document, classification_tag + def run(self): reference_pressure = 1 / 10 ** ( - (94 - self.config.sensitivity) / 20.0) + (94 - self.config.sensitivity) / 20.0) status = "wait_trigger" last_day_of_year = datetime.datetime.now().timetuple().tm_yday self.init_socket() document = {} - processing_time = 0 + self.processing_time = 0 while True: - if last_day_of_year != datetime.datetime.now().timetuple().tm_yday\ + if last_day_of_year != datetime.datetime.now().timetuple().tm_yday \ and "trigger_count" in self.config: # reset trigger counter each day print("Reset trigger counter") @@ -298,7 +359,7 @@ def run(self): waveform = waveform[:len_to_extract] self.yamnet_samples[start_index:end_index] = waveform self.yamnet_samples_index += len_to_extract - processing_time += time.time() - deb + self.processing_time += time.time() - deb if self.yamnet_samples_index < len(self.yamnet_samples): # window is not complete so wait for more samples continue @@ -309,29 +370,8 @@ def run(self): if leq >= self.config.min_leq: print("Leq: %.2f dB > %.2f dB, so now try to recognize" " sound source " % (leq, self.config.min_leq)) - deb = time.time() - scores, embeddings, spectrogram = self.process_tags() - processing_time += time.time() - deb - # Take maximum found prediction (was avg in the ref) - prediction = np.max(scores, axis=0) - # filter out classes that are below threshold values - filter_pred = (prediction > self.yamnet_classes[1]) - filter_pred = filter_pred.nonzero()[0] - classes_threshold_index = list(map(int, filter_pred)) - if len(classes_threshold_index) == 0: - # classifier rejected all known classes - print("No classes found above yamnet threshold") - status = "wait_trigger" - continue - classification_tag = [self.yamnet_classes[0][i] - for i in classes_threshold_index] - print("%s tags:%s \n processed in %.3f seconds for " - "%.1f seconds of audio." % - (time.strftime("%Y-%m-%d %H:%M:%S"), - ",".join(classification_tag), processing_time, - len(self.yamnet_samples) / - self.yamnet_config.sample_rate)) - processing_time = 0 + document, classification_tag = self.generate_yamnet_document( + self.yamnet_samples, self.config.add_spectrogram) # If trigger_tag defined, we process only if one of the # tag is specified keep_classification = len(self.config.trigger_tag) == 0 @@ -346,48 +386,12 @@ def run(self): " document") status = "wait_trigger" continue - # Sort by score - classes_threshold_index = [classes_threshold_index[j] for j in - np.argsort([prediction[i]-self.yamnet_classes[1][i] - for i in classes_threshold_index])[::-1]] - # Compute a score between 0-100% from threshold to 1.0 - scores_percentage = {self.yamnet_classes[0][i]: round( - float(((prediction[i] - self.yamnet_classes[1][i]) / ( - 1 - self.yamnet_classes[1][i])) * 100)) for - i in classes_threshold_index} - document_scores = {self.yamnet_classes[0][i]: - round(float(prediction[i]), 2) for i in - classes_threshold_index} - # threshold_time is the score over the time - # there is 2x more cells because there is 50% overlap - threshold_time = { - self.yamnet_classes[0][i]: [self.yamnet_classes[1][i], - np.round(scores[:, i], - 3).tolist()] - for i in classes_threshold_index} - document = {"scores": document_scores, - "scores_perc": scores_percentage, - "scores_time": threshold_time, - "leq": round(leq, 2), - "date": epoch_to_elasticsearch_date( - self.frame_time)} - if self.config.add_spectrogram: - document["spectrogram"] = base64.b64encode( - spectrogram.astype(np.float16). - tobytes()).decode("UTF-8") + document["leq"] = round(leq, 2) + document["date"] = epoch_to_elasticsearch_date(self.frame_time) if self.remaining_triggers >= 0: print(" Remaining triggers for today %d" % self.remaining_triggers) - # check for audio storage tag exception - banned = False - if len(self.config.trigger_ban) > 0: - for banned_tag in self.config.trigger_ban: - if banned_tag in document["scores"].keys(): - print("Do not record audio because %s has" - " been detected" % banned_tag) - banned = True - break - if not banned and self.config.total_length > 0 and ( + if self.config.total_length > 0 and ( self.remaining_triggers > 0 or self.remaining_triggers == -1): # requesting audio data into the json file, so now record audio @@ -398,7 +402,7 @@ def run(self): status = "wait_trigger" self.socket_out.send_json(document) continue - processing_time = 0 # yamnet window has been rejected + self.processing_time = 0 # yamnet window has been rejected elif status == "record": if self.remaining_triggers > 0: self.remaining_triggers -= 1 @@ -422,29 +426,55 @@ def run(self): audio_processing_start = time.time() # Compress audio samples output = io.BytesIO() - data, samplerate = sf.read(samples_trigger, - format='RAW', channels=1, - samplerate=int( - self.config. - sample_rate), subtype= - 'FLOAT') - channels = 1 - with sf.SoundFile(output, 'w', samplerate, - channels, format='FLAC', - subtype='PCM_24') as f: - f.write(data) - f.flush() - audio_data_encrypt = base64.b64encode(encrypt( - output.getvalue(), ssh_file)).decode("UTF-8") - print("raw %d array %d bytes b64 ogg: %d bytes" - " in %.3f seconds" % (samples_trigger.tell(), - data.shape[0], - len(audio_data_encrypt), - time.time() - - audio_processing_start)) - info = sf.info(io.BytesIO(output.getvalue())) - print("Audio duration %.2f s, remaining triggers %d" % (info.duration, self.remaining_triggers)) - document["encrypted_audio"] = audio_data_encrypt + keep_audio = True + if len(self.config.trigger_ban) > 0: + # There is banned sound source, analyze all the recorded audio to check + # if there is banned tags into it + samples = np.frombuffer(samples_trigger.getbuffer(), dtype=np.float32) + if self.config.sample_rate != self.yamnet_config.sample_rate: + # resample if necessary + samples = resampy.resample(samples, + self.config.sample_rate, + self.yamnet_config.sample_rate, + filter=self.config. + resample_method) + document, classification_tag = self.generate_yamnet_document( + samples, self.config.add_spectrogram) + del samples + if len(self.config.trigger_ban) > 0: + for banned_tag in self.config.trigger_ban: + if banned_tag in document["scores"].keys(): + print("Do not keep audio because %s has" + " been detected" % banned_tag) + keep_audio = False + break + if keep_audio: + data, samplerate = sf.read(samples_trigger, + format='RAW', channels=1, + samplerate=int( + self.config. + sample_rate), subtype= + 'FLOAT') + channels = 1 + with sf.SoundFile(output, 'w', samplerate, + channels, format='FLAC', + subtype='PCM_24') as f: + f.write(data) + f.flush() + del data + audio_data_encrypt = base64.b64encode(encrypt( + output.getvalue(), ssh_file)).decode("UTF-8") + print("raw %d array %d bytes b64 ogg: %d bytes" + " in %.3f seconds" % (samples_trigger.tell(), + data.shape[0], + len(audio_data_encrypt), + time.time() - + audio_processing_start)) + info = sf.info(io.BytesIO(output.getvalue())) + del output + print("Audio duration %.2f s, remaining triggers %d" % ( + info.duration, self.remaining_triggers)) + document["encrypted_audio"] = audio_data_encrypt self.socket_out.send_json(document) self.samples_stack.clear() status = "wait_trigger" @@ -456,32 +486,53 @@ def unix_time(self): if __name__ == "__main__": required_actions = [] - parser = argparse.ArgumentParser(description='This program read audio stream from zeromq and publish noise events', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument("--trigger_count", help="limit the number of embedding of audio by day (-1 unlimited)", default=-1, type=int) - parser.add_argument("-b", "--trigger_ban", help="Remove storage of audio if one of the following audio recognition tag is detected", + parser = argparse.ArgumentParser( + description='This program read audio stream from zeromq and publish noise events', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--trigger_count", + help="limit the number of embedding of audio by day (-1 unlimited)", + default=-1, type=int) + parser.add_argument("-b", "--trigger_ban", + help="Remove storage of audio if one of the following audio recognition tag is detected", action='append', type=str) - parser.add_argument("-c", "--configuration_file", help="Provide json configuration file instead of arguments", default="", type=str) + parser.add_argument("-c", "--configuration_file", + help="Provide json configuration file instead of arguments", default="", + type=str) parser.add_argument("-t", "--trigger_tag", help="Send json only if this tags are detected", action='append', type=str, default=[]) parser.add_argument("--min_leq", help="minimum leq to trigger an event", default=30, type=float) - parser.add_argument("--total_length", help="record length total in seconds to be embedded into the output json", default=10, type=float) - parser.add_argument("--cached_length", help="record length before the trigger", default=5, type=float) + parser.add_argument("--total_length", + help="record length total in seconds to be embedded into the output json", + default=10, type=float) + parser.add_argument("--cached_length", help="record length before the trigger", default=5, + type=float) parser.add_argument("--sample_rate", help="audio sample rate", default=48000, type=int) - parser.add_argument("--resample_method", help="Resampling method as Yamnet is requiring 16 KHz", default='kaiser_fast', type=str) - parser.add_argument("--ssh_file", help="public key file for audio encryption", default="~/.ssh/id_rsa.pub") - parser.add_argument("--input_address", help="Address for zero_record samples", default="tcp://127.0.0.1:10001") + parser.add_argument("--resample_method", help="Resampling method as Yamnet is requiring 16 KHz", + default='kaiser_fast', type=str) + parser.add_argument("--ssh_file", help="public key file for audio encryption", + default="~/.ssh/id_rsa.pub") + parser.add_argument("--input_address", help="Address for zero_record samples", + default="tcp://127.0.0.1:10001") parser.add_argument("--output_address", help="Address for publishing JSON of sound recognition", default="tcp://*:10002") - required_actions.append(parser.add_argument("--yamnet_class_map", help="Yamnet CSV path yamnet_class_threshold_map.csv", type=str)) - required_actions.append(parser.add_argument("--yamnet_weights", help="Yamnet .tflite model download at https://tfhub.dev/google/lite-model/yamnet/tflite/1", type=str)) - parser.add_argument("--yamnet_cutoff_frequency", help="Yamnet highpass filter frequency", default=0, type=float) - parser.add_argument("--yamnet_max_gain", help="Yamnet maximum gain in dB", default=8.0, type=float) - parser.add_argument("--yamnet_window_time", help="Sound source recognition time in seconds", default=5.0, + required_actions.append(parser.add_argument("--yamnet_class_map", + help="Yamnet CSV path yamnet_class_threshold_map.csv", + type=str)) + required_actions.append(parser.add_argument("--yamnet_weights", + help="Yamnet .tflite model download at https://tfhub.dev/google/lite-model/yamnet/tflite/1", + type=str)) + parser.add_argument("--yamnet_cutoff_frequency", help="Yamnet highpass filter frequency", + default=0, type=float) + parser.add_argument("--yamnet_max_gain", help="Yamnet maximum gain in dB", default=8.0, + type=float) + parser.add_argument("--yamnet_window_time", help="Sound source recognition time in seconds", + default=5.0, type=int) - parser.add_argument("--sensitivity", help="Microphone sensitivity in dBFS at 94 dB 1 kHz", default=-28.34, + parser.add_argument("--sensitivity", help="Microphone sensitivity in dBFS at 94 dB 1 kHz", + default=-28.34, type=float) - parser.add_argument("--delay_print_samples", help="Delay in second between each print of number of samples read", + parser.add_argument("--delay_print_samples", + help="Delay in second between each print of number of samples read", default=0, type=float) parser.add_argument("--add_spectrogram", help="Add spectrogram float16 array in base 64 in" @@ -499,7 +550,7 @@ def unix_time(self): cfg = json.load(fp) args = types.SimpleNamespace(**cfg) print("Configuration:\n" + json.dumps(vars(args), - sort_keys=False, indent=2)) + sort_keys=False, indent=2)) trigger = TriggerProcessor(args) args.running = True status_thread = StatusThread(trigger, args)