diff --git a/iped-app/resources/scripts/tasks/WhisperProcess.py b/iped-app/resources/scripts/tasks/WhisperProcess.py index a53203bc47..6c669920b2 100644 --- a/iped-app/resources/scripts/tasks/WhisperProcess.py +++ b/iped-app/resources/scripts/tasks/WhisperProcess.py @@ -1,4 +1,4 @@ -import sys +import sys import numpy stdout = sys.stdout sys.stdout = sys.stderr @@ -10,7 +10,6 @@ ping = 'ping' def main(): - modelName = sys.argv[1] deviceNum = int(sys.argv[2]) threads = int(sys.argv[3]) @@ -74,38 +73,44 @@ def main(): if line == ping: print(ping, file=stdout, flush=True) continue - - transcription = '' + + files=line.split(",") + transcription = [] logprobs = [] + for file in files: + transcription.append("") + logprobs.append([]) try: if whisperx_found: - audio = whisperx.load_audio(line) - result = model.transcribe(audio, batch_size=batch_size, language=language) + result = model.transcribe(files, batch_size=batch_size, language=language,wav=True) for segment in result['segments']: - transcription += segment['text'] + idx=segment["audio"] + transcription[idx] += segment['text'] if 'avg_logprob' in segment: - logprobs.append(segment['avg_logprob']) + logprobs[idx].append(segment['avg_logprob']) else: - segments, info = model.transcribe(audio=line, language=language, beam_size=5, vad_filter=True) - for segment in segments: - transcription += segment.text - logprobs.append(segment.avg_logprob) + for idx in range(len(files)): + segments, info = model.transcribe(audio=files[idx], language=language, beam_size=5, vad_filter=True) + for segment in segments: + transcription[idx] += segment.text + logprobs[idx].append(segment.avg_logprob) except Exception as e: msg = repr(e).replace('\n', ' ').replace('\r', ' ') print(msg, file=stdout, flush=True) continue - text = transcription.replace('\n', ' ').replace('\r', ' ') - - if len(logprobs) == 0: - finalScore = 0 - else: - finalScore = numpy.mean(numpy.exp(logprobs)) - print(finished, file=stdout, flush=True) - print(str(finalScore), file=stdout, flush=True) - print(text, file=stdout, flush=True) + + for idx in range(len(files)): + text = transcription[idx].replace('\n', ' ').replace('\r', ' ') + + if len(logprobs[idx]) == 0: + finalScore = 0 + else: + finalScore = numpy.mean(numpy.exp(logprobs[idx])) + print(str(finalScore), file=stdout, flush=True) + print(text, file=stdout, flush=True) return diff --git a/iped-engine/src/main/java/iped/engine/task/transcript/RemoteTranscriptionService.java b/iped-engine/src/main/java/iped/engine/task/transcript/RemoteTranscriptionService.java index 39f974ddf1..b1873d51d0 100644 --- a/iped-engine/src/main/java/iped/engine/task/transcript/RemoteTranscriptionService.java +++ b/iped-engine/src/main/java/iped/engine/task/transcript/RemoteTranscriptionService.java @@ -16,6 +16,8 @@ import java.nio.file.Files; import java.nio.file.Path; import java.text.DecimalFormat; +import java.util.ArrayList; +import java.util.Deque; import java.util.LinkedList; import java.util.List; import java.util.concurrent.ExecutorService; @@ -51,12 +53,24 @@ static enum MESSAGES { VERSION_1_0, PING } - + static class TranscribeRequest { + File wavAudio; + TextAndScore result=null; + Exception error=null; + + + public TranscribeRequest(File wavAudio) { + this.wavAudio=wavAudio; + } + } static class OpenConnectons { Socket conn; BufferedInputStream bis; PrintWriter writer; Thread t; + File wavAudio; + TextAndScore result=null; + public OpenConnectons(Socket conn, BufferedInputStream bis, PrintWriter writer, Thread t) { this.conn = conn; @@ -92,6 +106,8 @@ public void sendBeacon() { * Control number of simultaneous audio conversions to WAV. */ private static Semaphore wavConvSemaphore; + + private static int BATCH_SIZE=1; private static final AtomicLong audiosTranscripted = new AtomicLong(); private static final AtomicLong audiosDuration = new AtomicLong(); @@ -99,7 +115,9 @@ public void sendBeacon() { private static final AtomicLong transcriptionTime = new AtomicLong(); private static final AtomicLong requestsReceived = new AtomicLong(); private static final AtomicLong requestsAccepted = new AtomicLong(); - private static List beaconQueq = new LinkedList<>(); + private static final List beaconQueq = new LinkedList<>(); + private static final Deque toTranscribe = new LinkedList<>(); + private static Logger logger; @@ -149,7 +167,7 @@ public static void main(String[] args) throws Exception { AbstractTranscriptTask task = (AbstractTranscriptTask) Class.forName(audioConfig.getClassName()).getDeclaredConstructor().newInstance(); audioConfig.setEnabled(true); task.init(cm); - + BATCH_SIZE=audioConfig.getBatchSize(); int numConcurrentTranscriptions = Wav2Vec2TranscriptTask.getNumConcurrentTranscriptions(); int numLogicalCores = Runtime.getRuntime().availableProcessors(); @@ -175,6 +193,10 @@ public static void main(String[] args) throws Exception { startSendStatsThread(discoveryIp, discoveryPort, localPort, numConcurrentTranscriptions, numLogicalCores); startBeaconThread(); + for(int i=0;i transcribeRequests = new ArrayList<>(); + ArrayList files = new ArrayList(); + + if (executor.isShutdown()) { + throw new Exception("Shutting down service instance..."); + } + + synchronized (toTranscribe) { + if (toTranscribe.size() == 0) + return; + while (toTranscribe.size() > 0 && transcribeRequests.size() < BATCH_SIZE) { + TranscribeRequest req = toTranscribe.poll(); + transcribeRequests.add(req); + files.add(req.wavAudio); + } + } + logger.info("inicio da transcricao de " + files.size() + " audios"); + + long t2 = System.currentTimeMillis(); + + boolean batchTrancribe = (task instanceof WhisperTranscriptTask); + if (batchTrancribe) { + try { + List results = ((WhisperTranscriptTask) task).transcribeAudios(files); + for (int i = 0; i < results.size(); i++) { + transcribeRequests.get(i).result = results.get(i); + } + } catch (Exception e) {// case fail, try each audio individually + batchTrancribe = false; + logger.error("Error while doing batch transcribe " + e.toString()); + } + } + if (!batchTrancribe) {// try each audio individually + for (int i = 0; i < files.size(); i++) { + try { + transcribeRequests.get(i).result = task.transcribeAudio(files.get(i)); + } catch (Exception e2) { + transcribeRequests.get(i).result = null; + transcribeRequests.get(i).error = e2; + logger.error("Error while transcribing"); + } + } + } + + for (int i = 0; i < transcribeRequests.size(); i++) { + synchronized (transcribeRequests.get(i)) { + transcribeRequests.get(i).notifyAll(); + } + } + long t3 = System.currentTimeMillis(); + transcriptionTime.addAndGet(t3 - t2); + } private static void registerThis(String discoveryIp, int discoveryPort, int localPort, int concurrentJobs, int concurrentWavConvs) throws Exception { try (Socket client = new Socket(discoveryIp, discoveryPort); @@ -312,29 +387,24 @@ public void run() { logger.info(prefix + "Accepted connection."); - int min = Math.min(MESSAGES.AUDIO_SIZE.toString().length(), - MESSAGES.VERSION_1_1.toString().length()); - bis.mark(min + 1); - byte[] bytes = bis.readNBytes(min); - String cmd = new String(bytes); - if (!MESSAGES.AUDIO_SIZE.toString().startsWith(cmd)) { - bis.reset(); - bytes = bis.readNBytes(MESSAGES.VERSION_1_1.toString().length()); - protocol = new String(bytes); - bis.mark(min + 1); - synchronized (beaconQueq) { - opc = new OpenConnectons(client, bis, writer, this); - beaconQueq.add(opc); - } + byte[] bytes = bis.readNBytes(MESSAGES.VERSION_1_2.toString().length()); + protocol = new String(bytes); + synchronized (beaconQueq) { + opc = new OpenConnectons(client, bis, writer, this); + beaconQueq.add(opc); } + + logger.info("Protocol Version {}", protocol); + if (protocol.compareTo(MESSAGES.VERSION_1_2.toString()) < 0) { + throw new Exception("Procol version " + protocol + " not supported"); + } // read the audio_size message - bis.reset(); bytes = bis.readNBytes(MESSAGES.AUDIO_SIZE.toString().length()); - cmd = new String(bytes); + String cmd = new String(bytes); if (!MESSAGES.AUDIO_SIZE.toString().equals(cmd)) { error = true; @@ -343,11 +413,8 @@ public void run() { DataInputStream dis = new DataInputStream(bis); long size; - if (protocol.compareTo(MESSAGES.VERSION_1_2.toString()) >= 0) { - size = dis.readLong(); - } else { - size = dis.readInt(); - } + size = dis.readLong(); + if (size < 0) { error = true; try { @@ -408,17 +475,22 @@ public void run() { } long durationMillis = 1000 * wavFile.length() / (16000 * 2); - TextAndScore result; + TextAndScore result=null; long t2, t3; try { - transcriptSemaphore.acquire(); - if (executor.isShutdown()) { - error = true; - throw new Exception("Shutting down service instance..."); + TranscribeRequest req=new TranscribeRequest(wavFile); + synchronized (toTranscribe) { + toTranscribe.add(req); + } + synchronized(req) { + req.wait(); } - t2 = System.currentTimeMillis(); - result = task.transcribeAudio(wavFile); - t3 = System.currentTimeMillis(); + result=req.result; + if(result==null) { + error = false; + throw new Exception("Error processing the audio", req.error); + } + } catch (ProcessCrashedException e) { // retry audio error = true; @@ -429,14 +501,12 @@ public void run() { executor.shutdown(); server.close(); throw e; - } finally { - transcriptSemaphore.release(); - } + } audiosTranscripted.incrementAndGet(); audiosDuration.addAndGet(durationMillis); conversionTime.addAndGet(t1 - t0); - transcriptionTime.addAndGet(t3 - t2); + logger.info(prefix + "Transcritpion done."); // removes from the beacon queue to prevent beacons in the middle of the @@ -453,17 +523,8 @@ public void run() { String errorMsg = "Exception while transcribing"; logger.warn(errorMsg, e); if (writer != null) { - if (e.getMessage() != null && e.getMessage().startsWith("Invalid file size:") - && protocol.compareTo(MESSAGES.VERSION_1_2.toString()) < 0) { - writer.println("0"); - writer.println( - "Audios longer than 2GB are not supported by old clients, please update your client version!"); - writer.println(MESSAGES.DONE); - } else { - writer.println(error ? MESSAGES.ERROR : MESSAGES.WARN); - writer.println( - errorMsg + ": " + e.toString().replace('\n', ' ').replace('\r', ' ')); - } + writer.println(error ? MESSAGES.ERROR : MESSAGES.WARN); + writer.println(errorMsg + ": " + e.toString().replace('\n', ' ').replace('\r', ' ')); writer.flush(); } } finally { @@ -508,5 +569,39 @@ public void run() { } }); } + + private static void startTrancribeThreads(AbstractTranscriptTask task) { + executor.execute(new Runnable() { + @Override + public void run() { + while (true) { + Boolean empty = true; + synchronized (toTranscribe) { + empty = toTranscribe.isEmpty(); + } + if (empty) { + try { + Thread.sleep(100); + + } catch (Exception e) { + // TODO: handle exception + } + continue; + } + try { + transcriptSemaphore.acquire(); + transcribeAudios(task); + + } catch (Exception e) { + e.printStackTrace(); + } finally { + transcriptSemaphore.release(); + } + + } + } + }); + } + } diff --git a/iped-engine/src/main/java/iped/engine/task/transcript/Wav2Vec2TranscriptTask.java b/iped-engine/src/main/java/iped/engine/task/transcript/Wav2Vec2TranscriptTask.java index 84a92dca8a..803fee7b12 100644 --- a/iped-engine/src/main/java/iped/engine/task/transcript/Wav2Vec2TranscriptTask.java +++ b/iped-engine/src/main/java/iped/engine/task/transcript/Wav2Vec2TranscriptTask.java @@ -27,18 +27,18 @@ public class Wav2Vec2TranscriptTask extends AbstractTranscriptTask { private static Logger logger = LogManager.getLogger(Wav2Vec2TranscriptTask.class); private static final String SCRIPT_PATH = "/scripts/tasks/Wav2Vec2Process.py"; - private static final String TRANSCRIPTION_FINISHED = "transcription_finished"; + protected static final String TRANSCRIPTION_FINISHED = "transcription_finished"; private static final String MODEL_LOADED = "wav2vec2_model_loaded"; private static final String HUGGINGSOUND_LOADED = "huggingsound_loaded"; private static final String TERMINATE = "terminate_process"; private static final String PING = "ping"; - private static final int MAX_TRANSCRIPTIONS = 100000; - private static final byte[] NEW_LINE = "\n".getBytes(); + protected static final int MAX_TRANSCRIPTIONS = 100000; + protected static final byte[] NEW_LINE = "\n".getBytes(); protected static volatile Integer numProcesses; - private static LinkedBlockingDeque deque = new LinkedBlockingDeque<>(); + protected static LinkedBlockingDeque deque = new LinkedBlockingDeque<>(); protected static volatile Level logLevel = Level.forName("MSG", 250); @@ -199,7 +199,7 @@ public void finish() throws Exception { deque.clear(); } - private void terminateServer(Server server) throws InterruptedException { + protected void terminateServer(Server server) throws InterruptedException { Process process = server.process; try { process.getOutputStream().write(TERMINATE.getBytes(Charsets.UTF8_CHARSET)); @@ -216,7 +216,7 @@ private void terminateServer(Server server) throws InterruptedException { } } - private boolean ping(Server server) { + protected boolean ping(Server server) { try { server.process.getOutputStream().write(PING.getBytes(Charsets.UTF8_CHARSET)); server.process.getOutputStream().write(NEW_LINE); diff --git a/iped-engine/src/main/java/iped/engine/task/transcript/WhisperTranscriptTask.java b/iped-engine/src/main/java/iped/engine/task/transcript/WhisperTranscriptTask.java index 697273efb2..f34893a203 100644 --- a/iped-engine/src/main/java/iped/engine/task/transcript/WhisperTranscriptTask.java +++ b/iped-engine/src/main/java/iped/engine/task/transcript/WhisperTranscriptTask.java @@ -5,6 +5,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; @@ -17,6 +18,8 @@ import iped.engine.config.AudioTranscriptConfig; import iped.engine.config.Configuration; import iped.engine.config.ConfigurationManager; +import iped.engine.task.transcript.AbstractTranscriptTask.TextAndScore; +import iped.engine.task.transcript.Wav2Vec2TranscriptTask.Server; import iped.exception.IPEDException; public class WhisperTranscriptTask extends Wav2Vec2TranscriptTask { @@ -127,6 +130,60 @@ protected Server startServer0(int device) throws IOException { protected TextAndScore transcribeAudio(File tmpFile) throws Exception { return transcribeWavPart(tmpFile); } + + protected List transcribeAudios(ArrayList tmpFiles) throws Exception { + + ArrayList textAndScores = new ArrayList<>(); + for(int i=0;i= MAX_TRANSCRIPTIONS) { + terminateServer(server); + server = startServer(server.device); + } + + StringBuilder filePaths = new StringBuilder(); + for(int i=0;i0) { + filePaths.append(","); + } + filePaths.append(tmpFiles.get(i).getAbsolutePath().replace('\\', '/')); + + } + server.process.getOutputStream().write(filePaths.toString().getBytes("UTF-8")); + server.process.getOutputStream().write(NEW_LINE); + server.process.getOutputStream().flush(); + + String line; + while (!TRANSCRIPTION_FINISHED.equals(line = server.reader.readLine())) { + if (line == null) { + throw new ProcessCrashedException(); + } else { + throw new RuntimeException("Transcription failed, returned: " + line); + } + } + for(int i=0;i