Skip to content

Commit

Permalink
First version of the remote transcription using whisper
Browse files Browse the repository at this point in the history
  • Loading branch information
hauck-jvsh committed Jul 8, 2024
1 parent bb92776 commit 72ab684
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 43 deletions.
48 changes: 27 additions & 21 deletions iped-app/resources/scripts/tasks/WhisperProcess.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import sys
import sys
import numpy
stdout = sys.stdout
sys.stdout = sys.stderr
Expand All @@ -10,7 +10,6 @@
ping = 'ping'

def main():

modelName = sys.argv[1]
deviceNum = int(sys.argv[2])
threads = int(sys.argv[3])
Expand Down Expand Up @@ -74,38 +73,45 @@ def main():
if line == ping:
print(ping, file=stdout, flush=True)
continue

transcription = ''

files=line.split(",")
transcription = []
logprobs = []
for file in files:
transcription.append("")
logprobs.append([])
try:
if whisperx_found:
audio = whisperx.load_audio(line)
result = model.transcribe(audio, batch_size=batch_size, language=language)
#audio = whisperx.load_audio(line)
result = model.transcribe(files, batch_size=batch_size, language=language)
for segment in result['segments']:
transcription += segment['text']
idx=segment["audio"]
transcription[idx] += segment['text']
if 'avg_logprob' in segment:
logprobs.append(segment['avg_logprob'])
logprobs[idx].append(segment['avg_logprob'])
else:
segments, info = model.transcribe(audio=line, language=language, beam_size=5, vad_filter=True)
for segment in segments:
transcription += segment.text
logprobs.append(segment.avg_logprob)
for idx in range(len(files)):
segments, info = model.transcribe(audio=files[idx], language=language, beam_size=5, vad_filter=True)
for segment in segments:
transcription[idx] += segment.text
logprobs[idx].append(segment.avg_logprob)

except Exception as e:
msg = repr(e).replace('\n', ' ').replace('\r', ' ')
print(msg, file=stdout, flush=True)
continue

text = transcription.replace('\n', ' ').replace('\r', ' ')

if len(logprobs) == 0:
finalScore = 0
else:
finalScore = numpy.mean(numpy.exp(logprobs))

print(finished, file=stdout, flush=True)
print(str(finalScore), file=stdout, flush=True)
print(text, file=stdout, flush=True)

for idx in range(len(files)):
text = transcription[idx].replace('\n', ' ').replace('\r', ' ')

if len(logprobs[idx]) == 0:
finalScore = 0
else:
finalScore = numpy.mean(numpy.exp(logprobs[idx]))
print(str(finalScore), file=stdout, flush=True)
print(text, file=stdout, flush=True)

return

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Deque;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.ExecutorService;
Expand Down Expand Up @@ -51,12 +53,23 @@ static enum MESSAGES {
VERSION_1_0,
PING
}

static class TranscribeRequest {
File wavAudio;
TextAndScore result=null;


public TranscribeRequest(File wavAudio) {
this.wavAudio=wavAudio;
}
}
static class OpenConnectons {
Socket conn;
BufferedInputStream bis;
PrintWriter writer;
Thread t;
File wavAudio;
TextAndScore result=null;


public OpenConnectons(Socket conn, BufferedInputStream bis, PrintWriter writer, Thread t) {
this.conn = conn;
Expand Down Expand Up @@ -92,14 +105,18 @@ public void sendBeacon() {
* Control number of simultaneous audio conversions to WAV.
*/
private static Semaphore wavConvSemaphore;

private static int BATCH_SIZE=1;

private static final AtomicLong audiosTranscripted = new AtomicLong();
private static final AtomicLong audiosDuration = new AtomicLong();
private static final AtomicLong conversionTime = new AtomicLong();
private static final AtomicLong transcriptionTime = new AtomicLong();
private static final AtomicLong requestsReceived = new AtomicLong();
private static final AtomicLong requestsAccepted = new AtomicLong();
private static List<OpenConnectons> beaconQueq = new LinkedList<>();
private static final List<OpenConnectons> beaconQueq = new LinkedList<>();
private static final Deque<TranscribeRequest> toTranscribe = new LinkedList<>();


private static Logger logger;

Expand Down Expand Up @@ -149,7 +166,7 @@ public static void main(String[] args) throws Exception {
AbstractTranscriptTask task = (AbstractTranscriptTask) Class.forName(audioConfig.getClassName()).getDeclaredConstructor().newInstance();
audioConfig.setEnabled(true);
task.init(cm);

BATCH_SIZE=audioConfig.getBatchSize();
int numConcurrentTranscriptions = Wav2Vec2TranscriptTask.getNumConcurrentTranscriptions();
int numLogicalCores = Runtime.getRuntime().availableProcessors();

Expand All @@ -175,6 +192,10 @@ public static void main(String[] args) throws Exception {
startSendStatsThread(discoveryIp, discoveryPort, localPort, numConcurrentTranscriptions, numLogicalCores);

startBeaconThread();
for(int i=0;i<numConcurrentTranscriptions;i++) {
startTrancribeThreads(task);
}


waitRequests(server, task, discoveryIp);

Expand Down Expand Up @@ -208,7 +229,48 @@ public void run() {
});
}


private static void transcribeAudios(AbstractTranscriptTask task ) throws Exception {
ArrayList<TranscribeRequest> transcribeRequests=new ArrayList<>();
ArrayList<File> files=new ArrayList<File>();

if (executor.isShutdown()) {
throw new Exception("Shutting down service instance...");
}

synchronized (toTranscribe) {
if(toTranscribe.size()==0)
return;
while(toTranscribe.size()>0 && transcribeRequests.size() <BATCH_SIZE) {
TranscribeRequest req=toTranscribe.poll();
transcribeRequests.add(req);
files.add(req.wavAudio);
}
}
logger.info("inicio da transcricao de "+files.size()+" audios");


long t2 = System.currentTimeMillis();


if(task instanceof WhisperTranscriptTask) {
List<TextAndScore> results = ((WhisperTranscriptTask)task).transcribeAudios(files);
for(int i=0;i<results.size();i++) {
transcribeRequests.get(i).result=results.get(i);
synchronized(transcribeRequests.get(i)) {
transcribeRequests.get(i).notifyAll();
}
}
}else {
for(int i=0;i<files.size();i++) {
transcribeRequests.get(i).result=task.transcribeAudio(files.get(i));
synchronized(transcribeRequests.get(i)) {
transcribeRequests.get(i).notifyAll();
}
}
}
long t3 = System.currentTimeMillis();
transcriptionTime.addAndGet(t3 - t2);
}

private static void registerThis(String discoveryIp, int discoveryPort, int localPort, int concurrentJobs, int concurrentWavConvs) throws Exception {
try (Socket client = new Socket(discoveryIp, discoveryPort);
Expand Down Expand Up @@ -408,17 +470,18 @@ public void run() {
}
long durationMillis = 1000 * wavFile.length() / (16000 * 2);

TextAndScore result;
TextAndScore result=null;
long t2, t3;
try {
transcriptSemaphore.acquire();
if (executor.isShutdown()) {
error = true;
throw new Exception("Shutting down service instance...");
TranscribeRequest req=new TranscribeRequest(wavFile);
synchronized (toTranscribe) {
toTranscribe.add(req);
}
synchronized(req) {
req.wait();
}
t2 = System.currentTimeMillis();
result = task.transcribeAudio(wavFile);
t3 = System.currentTimeMillis();
result=req.result;

} catch (ProcessCrashedException e) {
// retry audio
error = true;
Expand All @@ -429,14 +492,12 @@ public void run() {
executor.shutdown();
server.close();
throw e;
} finally {
transcriptSemaphore.release();
}
}

audiosTranscripted.incrementAndGet();
audiosDuration.addAndGet(durationMillis);
conversionTime.addAndGet(t1 - t0);
transcriptionTime.addAndGet(t3 - t2);

logger.info(prefix + "Transcritpion done.");

// removes from the beacon queue to prevent beacons in the middle of the
Expand Down Expand Up @@ -508,5 +569,30 @@ public void run() {
}
});
}


private static void startTrancribeThreads(AbstractTranscriptTask task) {
executor.execute(new Runnable() {
@Override
public void run() {
while (true) {
try {
Thread.sleep(100);
transcriptSemaphore.acquire();
transcribeAudios(task);

} catch (Exception e) {
e.printStackTrace();
}finally {
transcriptSemaphore.release();
}


}
}
});
}



}
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,18 @@ public class Wav2Vec2TranscriptTask extends AbstractTranscriptTask {
private static Logger logger = LogManager.getLogger(Wav2Vec2TranscriptTask.class);

private static final String SCRIPT_PATH = "/scripts/tasks/Wav2Vec2Process.py";
private static final String TRANSCRIPTION_FINISHED = "transcription_finished";
protected static final String TRANSCRIPTION_FINISHED = "transcription_finished";
private static final String MODEL_LOADED = "wav2vec2_model_loaded";
private static final String HUGGINGSOUND_LOADED = "huggingsound_loaded";
private static final String TERMINATE = "terminate_process";
private static final String PING = "ping";

private static final int MAX_TRANSCRIPTIONS = 100000;
private static final byte[] NEW_LINE = "\n".getBytes();
protected static final int MAX_TRANSCRIPTIONS = 100000;
protected static final byte[] NEW_LINE = "\n".getBytes();

protected static volatile Integer numProcesses;

private static LinkedBlockingDeque<Server> deque = new LinkedBlockingDeque<>();
protected static LinkedBlockingDeque<Server> deque = new LinkedBlockingDeque<>();

protected static volatile Level logLevel = Level.forName("MSG", 250);

Expand Down Expand Up @@ -199,7 +199,7 @@ public void finish() throws Exception {
deque.clear();
}

private void terminateServer(Server server) throws InterruptedException {
protected void terminateServer(Server server) throws InterruptedException {
Process process = server.process;
try {
process.getOutputStream().write(TERMINATE.getBytes(Charsets.UTF8_CHARSET));
Expand All @@ -216,7 +216,7 @@ private void terminateServer(Server server) throws InterruptedException {
}
}

private boolean ping(Server server) {
protected boolean ping(Server server) {
try {
server.process.getOutputStream().write(PING.getBytes(Charsets.UTF8_CHARSET));
server.process.getOutputStream().write(NEW_LINE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
Expand All @@ -17,6 +18,8 @@
import iped.engine.config.AudioTranscriptConfig;
import iped.engine.config.Configuration;
import iped.engine.config.ConfigurationManager;
import iped.engine.task.transcript.AbstractTranscriptTask.TextAndScore;
import iped.engine.task.transcript.Wav2Vec2TranscriptTask.Server;
import iped.exception.IPEDException;

public class WhisperTranscriptTask extends Wav2Vec2TranscriptTask {
Expand Down Expand Up @@ -127,6 +130,60 @@ protected Server startServer0(int device) throws IOException {
protected TextAndScore transcribeAudio(File tmpFile) throws Exception {
return transcribeWavPart(tmpFile);
}

protected List<TextAndScore> transcribeAudios(ArrayList<File> tmpFiles) throws Exception {

ArrayList<TextAndScore> textAndScores = new ArrayList<>();
for(int i=0;i<tmpFiles.size();i++) {
textAndScores.add(null);
}

Server server = deque.take();
try {
if (!ping(server) || server.transcriptionsDone >= MAX_TRANSCRIPTIONS) {
terminateServer(server);
server = startServer(server.device);
}

StringBuilder filePaths = new StringBuilder();
for(int i=0;i<tmpFiles.size();i++) {
if(i>0) {
filePaths.append(",");
}
filePaths.append(tmpFiles.get(i).getAbsolutePath().replace('\\', '/'));

}
server.process.getOutputStream().write(filePaths.toString().getBytes("UTF-8"));
server.process.getOutputStream().write(NEW_LINE);
server.process.getOutputStream().flush();

String line;
while (!TRANSCRIPTION_FINISHED.equals(line = server.reader.readLine())) {
if (line == null) {
throw new ProcessCrashedException();
} else {
throw new RuntimeException("Transcription failed, returned: " + line);
}
}
for(int i=0;i<tmpFiles.size();i++) {
Double score = Double.valueOf(server.reader.readLine());
String text = server.reader.readLine();

TextAndScore textAndScore = new TextAndScore();
textAndScore.text = text;
textAndScore.score = score;
textAndScores.set(i, textAndScore);
server.transcriptionsDone++;
}

} finally {
deque.add(server);
}

return textAndScores;
}



@Override
protected void logInputStream(InputStream is) {
Expand Down

0 comments on commit 72ab684

Please sign in to comment.