Skip to content

Commit

Permalink
'#1539: code formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
lfcnassif committed Oct 8, 2024
1 parent 97e8d38 commit c51db60
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 87 deletions.
4 changes: 2 additions & 2 deletions iped-app/resources/scripts/tasks/WhisperProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def main():
print(ping, file=stdout, flush=True)
continue

files=line.split(",")
files = line.split(",")
transcription = []
logprobs = []
for file in files:
Expand All @@ -84,7 +84,7 @@ def main():
if whisperx_found:
result = model.transcribe(files, batch_size=batch_size, language=language,wav=True)
for segment in result['segments']:
idx=segment["audio"]
idx = segment["audio"]
transcription[idx] += segment['text']
if 'avg_logprob' in segment:
logprobs[idx].append(segment['avg_logprob'])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public abstract class AbstractTranscriptTask extends AbstractTask {
private static final int MAX_WAV_SIZE = 16000 * 2 * MAX_WAV_TIME;

protected AudioTranscriptConfig transcriptConfig;

// Variables to store some statistics
private static final AtomicLong wavTime = new AtomicLong();
private static final AtomicLong transcriptionTime = new AtomicLong();
Expand All @@ -91,8 +91,7 @@ public boolean isEnabled() {

protected boolean isToProcess(IItem evidence) {

if (evidence.getLength() == null || evidence.getLength() == 0 || !evidence.isToAddToCase()
|| evidence.getMetadata().get(ExtraProperties.TRANSCRIPT_ATTR) != null) {
if (evidence.getLength() == null || evidence.getLength() == 0 || !evidence.isToAddToCase() || evidence.getMetadata().get(ExtraProperties.TRANSCRIPT_ATTR) != null) {
return false;
}
if (transcriptConfig.getSkipKnownFiles() && evidence.getExtraAttribute(HashDBLookupTask.STATUS_ATTRIBUTE) != null) {
Expand Down Expand Up @@ -192,8 +191,7 @@ public void init(ConfigurationManager configurationManager) throws Exception {

}

public static TextAndScore transcribeWavBreaking(File tmpFile, String itemPath,
Function<File, TextAndScore> transcribeWavPart) throws Exception {
public static TextAndScore transcribeWavBreaking(File tmpFile, String itemPath, Function<File, TextAndScore> transcribeWavPart) throws Exception {
if (tmpFile.length() <= MAX_WAV_SIZE) {
return transcribeWavPart.apply(tmpFile);
} else {
Expand Down Expand Up @@ -316,7 +314,7 @@ public void finish() throws Exception {
conn.close();
conn = null;
}

long totWavConversions = wavSuccess.longValue() + wavFail.longValue();
if (totWavConversions != 0) {
LOGGER.info("Total conversions to WAV: " + totWavConversions);
Expand All @@ -340,8 +338,7 @@ public void finish() throws Exception {
}
}

protected File getTempFileToTranscript(IItem evidence, TemporaryResources tmp)
throws IOException, InterruptedException {
protected File getTempFileToTranscript(IItem evidence, TemporaryResources tmp) throws IOException, InterruptedException {
long t = System.currentTimeMillis();
File tempWav = null;
try {
Expand Down Expand Up @@ -373,8 +370,7 @@ protected void process(IItem evidence) throws Exception {
return;
}

if (evidence.getMetadata().get(ExtraProperties.TRANSCRIPT_ATTR) != null
&& evidence.getMetadata().get(ExtraProperties.CONFIDENCE_ATTR) != null)
if (evidence.getMetadata().get(ExtraProperties.TRANSCRIPT_ATTR) != null && evidence.getMetadata().get(ExtraProperties.CONFIDENCE_ATTR) != null)
return;

TextAndScore prevResult = getTextFromDb(evidence.getHash());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,38 +41,28 @@ public class RemoteTranscriptionService {
// 30 minutos
private static final int MAX_WAV_TIME = 30 * 60;
private static final int MAX_WAV_SIZE = 16000 * 2 * MAX_WAV_TIME;

static enum MESSAGES {
ACCEPTED,
AUDIO_SIZE,
BUSY,
DISCOVER,
DONE,
ERROR,
REGISTER,
STATS,
WARN, VERSION_1_1,
VERSION_1_2,
VERSION_1_0,
PING
ACCEPTED, AUDIO_SIZE, BUSY, DISCOVER, DONE, ERROR, REGISTER, STATS, WARN, VERSION_1_1, VERSION_1_2, VERSION_1_0, PING
}

static class TranscribeRequest {
File wavAudio;
TextAndScore result=null;
Exception error=null;

TextAndScore result = null;
Exception error = null;

public TranscribeRequest(File wavAudio) {
this.wavAudio=wavAudio;
this.wavAudio = wavAudio;
}
}

static class OpenConnectons {
Socket conn;
BufferedInputStream bis;
PrintWriter writer;
Thread t;
File wavAudio;
TextAndScore result=null;

TextAndScore result = null;

public OpenConnectons(Socket conn, BufferedInputStream bis, PrintWriter writer, Thread t) {
this.conn = conn;
Expand Down Expand Up @@ -108,8 +98,8 @@ public void sendBeacon() {
* Control number of simultaneous audio conversions to WAV.
*/
private static Semaphore wavConvSemaphore;
private static int BATCH_SIZE=1;

private static int BATCH_SIZE = 1;

private static final AtomicLong audiosTranscripted = new AtomicLong();
private static final AtomicLong audiosDuration = new AtomicLong();
Expand All @@ -119,16 +109,11 @@ public void sendBeacon() {
private static final AtomicLong requestsAccepted = new AtomicLong();
private static final List<OpenConnectons> beaconQueq = new LinkedList<>();
private static final Deque<TranscribeRequest> toTranscribe = new LinkedList<>();


private static Logger logger;

private static void printHelpAndExit() {
System.out.println(
"Params: IP:Port [LocalPort]\n"
+ "IP:Port IP and port of the naming node.\n"
+ "LocalPort [optional] local port to listen for connections.\n"
+ " If not provided, a random port will be used.");
System.out.println("Params: IP:Port [LocalPort]\n" + "IP:Port IP and port of the naming node.\n" + "LocalPort [optional] local port to listen for connections.\n" + " If not provided, a random port will be used.");
System.exit(1);
}

Expand Down Expand Up @@ -169,7 +154,7 @@ public static void main(String[] args) throws Exception {
AbstractTranscriptTask task = (AbstractTranscriptTask) Class.forName(audioConfig.getClassName()).getDeclaredConstructor().newInstance();
audioConfig.setEnabled(true);
task.init(cm);
BATCH_SIZE=audioConfig.getBatchSize();
BATCH_SIZE = audioConfig.getBatchSize();
int numConcurrentTranscriptions = Wav2Vec2TranscriptTask.getNumConcurrentTranscriptions();
int numLogicalCores = Runtime.getRuntime().availableProcessors();

Expand All @@ -195,10 +180,9 @@ public static void main(String[] args) throws Exception {
startSendStatsThread(discoveryIp, discoveryPort, localPort, numConcurrentTranscriptions, numLogicalCores);

startBeaconThread();
for(int i=0;i<numConcurrentTranscriptions;i++) {
for (int i = 0; i < numConcurrentTranscriptions; i++) {
startTrancribeThreads(task);
}


waitRequests(server, task, discoveryIp);

Expand All @@ -219,7 +203,7 @@ public void run() {
Thread.sleep(60000);
logger.info("Send beacons to {} clients", beaconQueq.size());
synchronized (beaconQueq) {
for( var cliente:beaconQueq) {
for (var cliente : beaconQueq) {
cliente.sendBeacon();
}
}
Expand Down Expand Up @@ -291,8 +275,7 @@ private static void registerThis(String discoveryIp, int discoveryPort, int loca
try (Socket client = new Socket(discoveryIp, discoveryPort);
InputStream is = client.getInputStream();
BufferedReader reader = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8));
PrintWriter writer = new PrintWriter(
new OutputStreamWriter(client.getOutputStream(), StandardCharsets.UTF_8), true)) {
PrintWriter writer = new PrintWriter(new OutputStreamWriter(client.getOutputStream(), StandardCharsets.UTF_8), true)) {

client.setSoTimeout(10000);
writer.println(MESSAGES.REGISTER);
Expand Down Expand Up @@ -371,11 +354,9 @@ public void run() {
try {
client.setSoTimeout(CLIENT_TIMEOUT_MILLIS);
bis = new BufferedInputStream(client.getInputStream());
writer = new PrintWriter(
new OutputStreamWriter(client.getOutputStream(), StandardCharsets.UTF_8), true);
writer = new PrintWriter(new OutputStreamWriter(client.getOutputStream(), StandardCharsets.UTF_8), true);

String clientName = "Client " + client.getInetAddress().getHostAddress() + ":"
+ client.getPort();
String clientName = "Client " + client.getInetAddress().getHostAddress() + ":" + client.getPort();
String prefix = clientName + " - ";
writer.println(MESSAGES.ACCEPTED);

Expand All @@ -391,16 +372,13 @@ public void run() {

logger.info(prefix + "Accepted connection.");



byte[] bytes = bis.readNBytes(MESSAGES.VERSION_1_2.toString().length());
protocol = new String(bytes);
synchronized (beaconQueq) {
opc = new OpenConnectons(client, bis, writer, this);
beaconQueq.add(opc);
}


logger.info("Protocol Version {}", protocol);
if (protocol.compareTo(MESSAGES.VERSION_1_2.toString()) < 0) {
throw new Exception("Procol version " + protocol + " not supported");
Expand Down Expand Up @@ -480,19 +458,18 @@ public void run() {
long durationMillis = 1000 * wavFile.length() / (16000 * 2);

TextAndScore result = new TextAndScore();
result.text="";
result.score=0;
result.text = "";
result.score = 0;
try {
reqs = new ArrayList<TranscribeRequest>();
TranscribeRequest last=null;
TranscribeRequest last = null;
if (wavFile.length() <= MAX_WAV_SIZE) {
TranscribeRequest req = new TranscribeRequest(wavFile);
reqs.add(req);

} else {

for (File wavPart : AbstractTranscriptTask.getAudioSplits(wavFile,
wavFile.getPath(), MAX_WAV_TIME)) {
for (File wavPart : AbstractTranscriptTask.getAudioSplits(wavFile, wavFile.getPath(), MAX_WAV_TIME)) {
TranscribeRequest req = new TranscribeRequest(wavPart);
reqs.add(req);
}
Expand All @@ -501,15 +478,15 @@ public void run() {

}
wavFile = null;

// dispatch all parts to be executed
for (TranscribeRequest req : reqs) {
synchronized (toTranscribe) {
toTranscribe.add(req);
}
last=req;
last = req;
}

// wait until the last wav part is transcribed
synchronized (last) {
last.wait();
Expand All @@ -521,7 +498,7 @@ public void run() {
error = false;
throw new Exception("Error processing the audio", req.error);
}

if (result.score > 0)
result.text += " ";
result.text += partResult.text;
Expand All @@ -530,7 +507,6 @@ public void run() {

}
result.score /= reqs.size();


} catch (ProcessCrashedException e) {
// retry audio
Expand All @@ -542,12 +518,12 @@ public void run() {
executor.shutdown();
server.close();
throw e;
}
}

audiosTranscripted.incrementAndGet();
audiosDuration.addAndGet(durationMillis);
conversionTime.addAndGet(t1 - t0);

logger.info(prefix + "Transcritpion done.");

// removes from the beacon queue to prevent beacons in the middle of the
Expand Down Expand Up @@ -617,7 +593,7 @@ public void run() {
}
});
}

private static void startTrancribeThreads(AbstractTranscriptTask task) {
executor.execute(new Runnable() {
@Override
Expand All @@ -628,12 +604,12 @@ public void run() {
empty = toTranscribe.isEmpty();
}
if (empty) {
try {
try {
Thread.sleep(100);

} catch (Exception e) {
// TODO: handle exception
}
}
continue;
}
try {
Expand All @@ -646,10 +622,9 @@ public void run() {
transcriptSemaphore.release();
}

}
}
}
});
}


}
Loading

0 comments on commit c51db60

Please sign in to comment.