From 72ab684024c875036627a82f2b5c543ce555ff93 Mon Sep 17 00:00:00 2001 From: Hauck Date: Mon, 8 Jul 2024 00:27:47 -0300 Subject: [PATCH] First version of the remote transcription using whisper --- .../resources/scripts/tasks/WhisperProcess.py | 48 +++---- .../RemoteTranscriptionService.java | 118 +++++++++++++++--- .../transcript/Wav2Vec2TranscriptTask.java | 12 +- .../transcript/WhisperTranscriptTask.java | 57 +++++++++ 4 files changed, 192 insertions(+), 43 deletions(-) diff --git a/iped-app/resources/scripts/tasks/WhisperProcess.py b/iped-app/resources/scripts/tasks/WhisperProcess.py index a53203bc47..66f92d84e0 100644 --- a/iped-app/resources/scripts/tasks/WhisperProcess.py +++ b/iped-app/resources/scripts/tasks/WhisperProcess.py @@ -1,4 +1,4 @@ -import sys +import sys import numpy stdout = sys.stdout sys.stdout = sys.stderr @@ -10,7 +10,6 @@ ping = 'ping' def main(): - modelName = sys.argv[1] deviceNum = int(sys.argv[2]) threads = int(sys.argv[3]) @@ -74,38 +73,45 @@ def main(): if line == ping: print(ping, file=stdout, flush=True) continue - - transcription = '' + + files=line.split(",") + transcription = [] logprobs = [] + for file in files: + transcription.append("") + logprobs.append([]) try: if whisperx_found: - audio = whisperx.load_audio(line) - result = model.transcribe(audio, batch_size=batch_size, language=language) + #audio = whisperx.load_audio(line) + result = model.transcribe(files, batch_size=batch_size, language=language) for segment in result['segments']: - transcription += segment['text'] + idx=segment["audio"] + transcription[idx] += segment['text'] if 'avg_logprob' in segment: - logprobs.append(segment['avg_logprob']) + logprobs[idx].append(segment['avg_logprob']) else: - segments, info = model.transcribe(audio=line, language=language, beam_size=5, vad_filter=True) - for segment in segments: - transcription += segment.text - logprobs.append(segment.avg_logprob) + for idx in range(len(files)): + segments, info = model.transcribe(audio=files[idx], language=language, beam_size=5, vad_filter=True) + for segment in segments: + transcription[idx] += segment.text + logprobs[idx].append(segment.avg_logprob) except Exception as e: msg = repr(e).replace('\n', ' ').replace('\r', ' ') print(msg, file=stdout, flush=True) continue - text = transcription.replace('\n', ' ').replace('\r', ' ') - - if len(logprobs) == 0: - finalScore = 0 - else: - finalScore = numpy.mean(numpy.exp(logprobs)) - print(finished, file=stdout, flush=True) - print(str(finalScore), file=stdout, flush=True) - print(text, file=stdout, flush=True) + + for idx in range(len(files)): + text = transcription[idx].replace('\n', ' ').replace('\r', ' ') + + if len(logprobs[idx]) == 0: + finalScore = 0 + else: + finalScore = numpy.mean(numpy.exp(logprobs[idx])) + print(str(finalScore), file=stdout, flush=True) + print(text, file=stdout, flush=True) return diff --git a/iped-engine/src/main/java/iped/engine/task/transcript/RemoteTranscriptionService.java b/iped-engine/src/main/java/iped/engine/task/transcript/RemoteTranscriptionService.java index 39f974ddf1..9b3df80047 100644 --- a/iped-engine/src/main/java/iped/engine/task/transcript/RemoteTranscriptionService.java +++ b/iped-engine/src/main/java/iped/engine/task/transcript/RemoteTranscriptionService.java @@ -16,6 +16,8 @@ import java.nio.file.Files; import java.nio.file.Path; import java.text.DecimalFormat; +import java.util.ArrayList; +import java.util.Deque; import java.util.LinkedList; import java.util.List; import java.util.concurrent.ExecutorService; @@ -51,12 +53,23 @@ static enum MESSAGES { VERSION_1_0, PING } - + static class TranscribeRequest { + File wavAudio; + TextAndScore result=null; + + + public TranscribeRequest(File wavAudio) { + this.wavAudio=wavAudio; + } + } static class OpenConnectons { Socket conn; BufferedInputStream bis; PrintWriter writer; Thread t; + File wavAudio; + TextAndScore result=null; + public OpenConnectons(Socket conn, BufferedInputStream bis, PrintWriter writer, Thread t) { this.conn = conn; @@ -92,6 +105,8 @@ public void sendBeacon() { * Control number of simultaneous audio conversions to WAV. */ private static Semaphore wavConvSemaphore; + + private static int BATCH_SIZE=1; private static final AtomicLong audiosTranscripted = new AtomicLong(); private static final AtomicLong audiosDuration = new AtomicLong(); @@ -99,7 +114,9 @@ public void sendBeacon() { private static final AtomicLong transcriptionTime = new AtomicLong(); private static final AtomicLong requestsReceived = new AtomicLong(); private static final AtomicLong requestsAccepted = new AtomicLong(); - private static List beaconQueq = new LinkedList<>(); + private static final List beaconQueq = new LinkedList<>(); + private static final Deque toTranscribe = new LinkedList<>(); + private static Logger logger; @@ -149,7 +166,7 @@ public static void main(String[] args) throws Exception { AbstractTranscriptTask task = (AbstractTranscriptTask) Class.forName(audioConfig.getClassName()).getDeclaredConstructor().newInstance(); audioConfig.setEnabled(true); task.init(cm); - + BATCH_SIZE=audioConfig.getBatchSize(); int numConcurrentTranscriptions = Wav2Vec2TranscriptTask.getNumConcurrentTranscriptions(); int numLogicalCores = Runtime.getRuntime().availableProcessors(); @@ -175,6 +192,10 @@ public static void main(String[] args) throws Exception { startSendStatsThread(discoveryIp, discoveryPort, localPort, numConcurrentTranscriptions, numLogicalCores); startBeaconThread(); + for(int i=0;i transcribeRequests=new ArrayList<>(); + ArrayList files=new ArrayList(); + + if (executor.isShutdown()) { + throw new Exception("Shutting down service instance..."); + } + + synchronized (toTranscribe) { + if(toTranscribe.size()==0) + return; + while(toTranscribe.size()>0 && transcribeRequests.size() results = ((WhisperTranscriptTask)task).transcribeAudios(files); + for(int i=0;i deque = new LinkedBlockingDeque<>(); + protected static LinkedBlockingDeque deque = new LinkedBlockingDeque<>(); protected static volatile Level logLevel = Level.forName("MSG", 250); @@ -199,7 +199,7 @@ public void finish() throws Exception { deque.clear(); } - private void terminateServer(Server server) throws InterruptedException { + protected void terminateServer(Server server) throws InterruptedException { Process process = server.process; try { process.getOutputStream().write(TERMINATE.getBytes(Charsets.UTF8_CHARSET)); @@ -216,7 +216,7 @@ private void terminateServer(Server server) throws InterruptedException { } } - private boolean ping(Server server) { + protected boolean ping(Server server) { try { server.process.getOutputStream().write(PING.getBytes(Charsets.UTF8_CHARSET)); server.process.getOutputStream().write(NEW_LINE); diff --git a/iped-engine/src/main/java/iped/engine/task/transcript/WhisperTranscriptTask.java b/iped-engine/src/main/java/iped/engine/task/transcript/WhisperTranscriptTask.java index 697273efb2..f34893a203 100644 --- a/iped-engine/src/main/java/iped/engine/task/transcript/WhisperTranscriptTask.java +++ b/iped-engine/src/main/java/iped/engine/task/transcript/WhisperTranscriptTask.java @@ -5,6 +5,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; @@ -17,6 +18,8 @@ import iped.engine.config.AudioTranscriptConfig; import iped.engine.config.Configuration; import iped.engine.config.ConfigurationManager; +import iped.engine.task.transcript.AbstractTranscriptTask.TextAndScore; +import iped.engine.task.transcript.Wav2Vec2TranscriptTask.Server; import iped.exception.IPEDException; public class WhisperTranscriptTask extends Wav2Vec2TranscriptTask { @@ -127,6 +130,60 @@ protected Server startServer0(int device) throws IOException { protected TextAndScore transcribeAudio(File tmpFile) throws Exception { return transcribeWavPart(tmpFile); } + + protected List transcribeAudios(ArrayList tmpFiles) throws Exception { + + ArrayList textAndScores = new ArrayList<>(); + for(int i=0;i= MAX_TRANSCRIPTIONS) { + terminateServer(server); + server = startServer(server.device); + } + + StringBuilder filePaths = new StringBuilder(); + for(int i=0;i0) { + filePaths.append(","); + } + filePaths.append(tmpFiles.get(i).getAbsolutePath().replace('\\', '/')); + + } + server.process.getOutputStream().write(filePaths.toString().getBytes("UTF-8")); + server.process.getOutputStream().write(NEW_LINE); + server.process.getOutputStream().flush(); + + String line; + while (!TRANSCRIPTION_FINISHED.equals(line = server.reader.readLine())) { + if (line == null) { + throw new ProcessCrashedException(); + } else { + throw new RuntimeException("Transcription failed, returned: " + line); + } + } + for(int i=0;i