Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisperx Optimization #2258

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 26 additions & 21 deletions iped-app/resources/scripts/tasks/WhisperProcess.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import sys
import sys
import numpy
stdout = sys.stdout
sys.stdout = sys.stderr
Expand All @@ -10,7 +10,6 @@
ping = 'ping'

def main():

modelName = sys.argv[1]
deviceNum = int(sys.argv[2])
threads = int(sys.argv[3])
Expand Down Expand Up @@ -74,38 +73,44 @@ def main():
if line == ping:
print(ping, file=stdout, flush=True)
continue

transcription = ''

files=line.split(",")
transcription = []
logprobs = []
for file in files:
transcription.append("")
logprobs.append([])
try:
if whisperx_found:
audio = whisperx.load_audio(line)
result = model.transcribe(audio, batch_size=batch_size, language=language)
result = model.transcribe(files, batch_size=batch_size, language=language,wav=True)
for segment in result['segments']:
transcription += segment['text']
idx=segment["audio"]
transcription[idx] += segment['text']
if 'avg_logprob' in segment:
logprobs.append(segment['avg_logprob'])
logprobs[idx].append(segment['avg_logprob'])
else:
segments, info = model.transcribe(audio=line, language=language, beam_size=5, vad_filter=True)
for segment in segments:
transcription += segment.text
logprobs.append(segment.avg_logprob)
for idx in range(len(files)):
segments, info = model.transcribe(audio=files[idx], language=language, beam_size=5, vad_filter=True)
for segment in segments:
transcription[idx] += segment.text
logprobs[idx].append(segment.avg_logprob)

except Exception as e:
msg = repr(e).replace('\n', ' ').replace('\r', ' ')
print(msg, file=stdout, flush=True)
continue

text = transcription.replace('\n', ' ').replace('\r', ' ')

if len(logprobs) == 0:
finalScore = 0
else:
finalScore = numpy.mean(numpy.exp(logprobs))

print(finished, file=stdout, flush=True)
print(str(finalScore), file=stdout, flush=True)
print(text, file=stdout, flush=True)

for idx in range(len(files)):
text = transcription[idx].replace('\n', ' ').replace('\r', ' ')

if len(logprobs[idx]) == 0:
finalScore = 0
else:
finalScore = numpy.mean(numpy.exp(logprobs[idx]))
print(str(finalScore), file=stdout, flush=True)
print(text, file=stdout, flush=True)

return

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Deque;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.ExecutorService;
Expand Down Expand Up @@ -51,12 +53,24 @@ static enum MESSAGES {
VERSION_1_0,
PING
}

static class TranscribeRequest {
File wavAudio;
TextAndScore result=null;
Exception error=null;


public TranscribeRequest(File wavAudio) {
this.wavAudio=wavAudio;
}
}
static class OpenConnectons {
Socket conn;
BufferedInputStream bis;
PrintWriter writer;
Thread t;
File wavAudio;
TextAndScore result=null;


public OpenConnectons(Socket conn, BufferedInputStream bis, PrintWriter writer, Thread t) {
this.conn = conn;
Expand Down Expand Up @@ -92,14 +106,18 @@ public void sendBeacon() {
* Control number of simultaneous audio conversions to WAV.
*/
private static Semaphore wavConvSemaphore;

private static int BATCH_SIZE=1;

private static final AtomicLong audiosTranscripted = new AtomicLong();
private static final AtomicLong audiosDuration = new AtomicLong();
private static final AtomicLong conversionTime = new AtomicLong();
private static final AtomicLong transcriptionTime = new AtomicLong();
private static final AtomicLong requestsReceived = new AtomicLong();
private static final AtomicLong requestsAccepted = new AtomicLong();
private static List<OpenConnectons> beaconQueq = new LinkedList<>();
private static final List<OpenConnectons> beaconQueq = new LinkedList<>();
private static final Deque<TranscribeRequest> toTranscribe = new LinkedList<>();


private static Logger logger;

Expand Down Expand Up @@ -149,7 +167,7 @@ public static void main(String[] args) throws Exception {
AbstractTranscriptTask task = (AbstractTranscriptTask) Class.forName(audioConfig.getClassName()).getDeclaredConstructor().newInstance();
audioConfig.setEnabled(true);
task.init(cm);

BATCH_SIZE=audioConfig.getBatchSize();
int numConcurrentTranscriptions = Wav2Vec2TranscriptTask.getNumConcurrentTranscriptions();
int numLogicalCores = Runtime.getRuntime().availableProcessors();

Expand All @@ -175,6 +193,10 @@ public static void main(String[] args) throws Exception {
startSendStatsThread(discoveryIp, discoveryPort, localPort, numConcurrentTranscriptions, numLogicalCores);

startBeaconThread();
for(int i=0;i<numConcurrentTranscriptions;i++) {
startTrancribeThreads(task);
}


waitRequests(server, task, discoveryIp);

Expand Down Expand Up @@ -208,7 +230,60 @@ public void run() {
});
}

private static void transcribeAudios(AbstractTranscriptTask task) throws Exception {
ArrayList<TranscribeRequest> transcribeRequests = new ArrayList<>();
ArrayList<File> files = new ArrayList<File>();

if (executor.isShutdown()) {
throw new Exception("Shutting down service instance...");
}

synchronized (toTranscribe) {
if (toTranscribe.size() == 0)
return;
while (toTranscribe.size() > 0 && transcribeRequests.size() < BATCH_SIZE) {
TranscribeRequest req = toTranscribe.poll();
transcribeRequests.add(req);
files.add(req.wavAudio);
}
}
logger.info("inicio da transcricao de " + files.size() + " audios");

long t2 = System.currentTimeMillis();

boolean batchTrancribe = (task instanceof WhisperTranscriptTask);
if (batchTrancribe) {
try {
List<TextAndScore> results = ((WhisperTranscriptTask) task).transcribeAudios(files);
for (int i = 0; i < results.size(); i++) {
transcribeRequests.get(i).result = results.get(i);
}
} catch (Exception e) {// case fail, try each audio individually
batchTrancribe = false;
logger.error("Error while doing batch transcribe " + e.toString());
}
}
if (!batchTrancribe) {// try each audio individually
for (int i = 0; i < files.size(); i++) {
try {
transcribeRequests.get(i).result = task.transcribeAudio(files.get(i));
} catch (Exception e2) {
transcribeRequests.get(i).result = null;
transcribeRequests.get(i).error = e2;
logger.error("Error while transcribing");
}
}
}

for (int i = 0; i < transcribeRequests.size(); i++) {
synchronized (transcribeRequests.get(i)) {
transcribeRequests.get(i).notifyAll();
}
}

long t3 = System.currentTimeMillis();
transcriptionTime.addAndGet(t3 - t2);
}

private static void registerThis(String discoveryIp, int discoveryPort, int localPort, int concurrentJobs, int concurrentWavConvs) throws Exception {
try (Socket client = new Socket(discoveryIp, discoveryPort);
Expand Down Expand Up @@ -312,29 +387,24 @@ public void run() {

logger.info(prefix + "Accepted connection.");

int min = Math.min(MESSAGES.AUDIO_SIZE.toString().length(),
MESSAGES.VERSION_1_1.toString().length());

bis.mark(min + 1);
byte[] bytes = bis.readNBytes(min);
String cmd = new String(bytes);
if (!MESSAGES.AUDIO_SIZE.toString().startsWith(cmd)) {
bis.reset();
bytes = bis.readNBytes(MESSAGES.VERSION_1_1.toString().length());
protocol = new String(bytes);
bis.mark(min + 1);
synchronized (beaconQueq) {
opc = new OpenConnectons(client, bis, writer, this);
beaconQueq.add(opc);
}

byte[] bytes = bis.readNBytes(MESSAGES.VERSION_1_2.toString().length());
protocol = new String(bytes);
synchronized (beaconQueq) {
opc = new OpenConnectons(client, bis, writer, this);
beaconQueq.add(opc);
}


logger.info("Protocol Version {}", protocol);
if (protocol.compareTo(MESSAGES.VERSION_1_2.toString()) < 0) {
throw new Exception("Procol version " + protocol + " not supported");
}

// read the audio_size message
bis.reset();
bytes = bis.readNBytes(MESSAGES.AUDIO_SIZE.toString().length());
cmd = new String(bytes);
String cmd = new String(bytes);

if (!MESSAGES.AUDIO_SIZE.toString().equals(cmd)) {
error = true;
Expand All @@ -343,11 +413,8 @@ public void run() {

DataInputStream dis = new DataInputStream(bis);
long size;
if (protocol.compareTo(MESSAGES.VERSION_1_2.toString()) >= 0) {
size = dis.readLong();
} else {
size = dis.readInt();
}
size = dis.readLong();

if (size < 0) {
error = true;
try {
Expand Down Expand Up @@ -408,17 +475,22 @@ public void run() {
}
long durationMillis = 1000 * wavFile.length() / (16000 * 2);

TextAndScore result;
TextAndScore result=null;
long t2, t3;
try {
transcriptSemaphore.acquire();
if (executor.isShutdown()) {
error = true;
throw new Exception("Shutting down service instance...");
TranscribeRequest req=new TranscribeRequest(wavFile);
synchronized (toTranscribe) {
toTranscribe.add(req);
}
synchronized(req) {
req.wait();
}
t2 = System.currentTimeMillis();
result = task.transcribeAudio(wavFile);
t3 = System.currentTimeMillis();
result=req.result;
if(result==null) {
error = false;
throw new Exception("Error processing the audio", req.error);
}

} catch (ProcessCrashedException e) {
// retry audio
error = true;
Expand All @@ -429,14 +501,12 @@ public void run() {
executor.shutdown();
server.close();
throw e;
} finally {
transcriptSemaphore.release();
}
}

audiosTranscripted.incrementAndGet();
audiosDuration.addAndGet(durationMillis);
conversionTime.addAndGet(t1 - t0);
transcriptionTime.addAndGet(t3 - t2);

logger.info(prefix + "Transcritpion done.");

// removes from the beacon queue to prevent beacons in the middle of the
Expand All @@ -453,17 +523,8 @@ public void run() {
String errorMsg = "Exception while transcribing";
logger.warn(errorMsg, e);
if (writer != null) {
if (e.getMessage() != null && e.getMessage().startsWith("Invalid file size:")
&& protocol.compareTo(MESSAGES.VERSION_1_2.toString()) < 0) {
writer.println("0");
writer.println(
"Audios longer than 2GB are not supported by old clients, please update your client version!");
writer.println(MESSAGES.DONE);
} else {
writer.println(error ? MESSAGES.ERROR : MESSAGES.WARN);
writer.println(
errorMsg + ": " + e.toString().replace('\n', ' ').replace('\r', ' '));
}
writer.println(error ? MESSAGES.ERROR : MESSAGES.WARN);
writer.println(errorMsg + ": " + e.toString().replace('\n', ' ').replace('\r', ' '));
writer.flush();
}
} finally {
Expand Down Expand Up @@ -508,5 +569,39 @@ public void run() {
}
});
}

private static void startTrancribeThreads(AbstractTranscriptTask task) {
executor.execute(new Runnable() {
@Override
public void run() {
while (true) {
Boolean empty = true;
synchronized (toTranscribe) {
empty = toTranscribe.isEmpty();
}
if (empty) {
try {
Thread.sleep(100);

} catch (Exception e) {
// TODO: handle exception
}
continue;
}
try {
transcriptSemaphore.acquire();
transcribeAudios(task);

} catch (Exception e) {
e.printStackTrace();
} finally {
transcriptSemaphore.release();
}

}
}
});
}


}
Loading
Loading