Skip to content

Commit

Permalink
Merge branch '#1823_whisper_transcription'
Browse files Browse the repository at this point in the history
  • Loading branch information
lfcnassif committed May 25, 2024
2 parents 2d3bcf8 + f8b3f5f commit 982622c
Show file tree
Hide file tree
Showing 11 changed files with 697 additions and 349 deletions.
2 changes: 1 addition & 1 deletion iped-app/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
<artifactItem>
<groupId>org.python</groupId>
<artifactId>python-jep-dlib</artifactId>
<version>3.9.12-4.0.3-19.23.1</version>
<version>3.9.12-4.0.3-19.23.1-2</version>
<type>zip</type>
<overWrite>false</overWrite>
<outputDirectory>${release.dir}</outputDirectory>
Expand Down
10 changes: 6 additions & 4 deletions iped-app/resources/config/IPEDConfig.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,12 @@ enableMinIO = false
enableOCR = false

# Enable audio transcription.
# Default implementation uses VOSK transcription on local CPU (slow and not good accuracy).
# You can change it to a local Facebook Wav2Vec2 implementation (slower on CPU, faster on GPU and good accuracy)
# or remote Microsoft Azure or Google Cloud services (faster and good accuracy).
# Configure it in conf/AudioTranscriptConfig.txt
# Default implementation uses VOSK transcription on local CPU (faster but bad accuracy).
# You can change the algorithm into conf/AudioTranscriptConfig.txt:
# - Wav2Vec2 algorithm (slower and good accuracy)
# - Whisper algorithm (much slower but better accuracy)
# - Google Cloud (about $1.00 per hour cost)
# - Microsoft Azure (about $1.00 per hour cost)
enableAudioTranscription = false

# Enables carving. "addUnallocated" must be enabled to scan unallocated space.
Expand Down
48 changes: 39 additions & 9 deletions iped-app/resources/config/conf/AudioTranscriptConfig.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,23 @@
# you should download it from https://alphacephei.com/vosk/models and put in 'models/vosk/[lang]' folder.
implementationClass = iped.engine.task.transcript.VoskTranscriptTask

# Uses a local/remote wav2vec2 implementation for transcription. Accuracy is much better than most Vosk models.
# The local impl is AT LEAST 1 order of magnitude slower than Vosk on high end CPUs. Using a good GPU is highly recommended!
# The remote impl is useful if you have a central server/cluster with many GPUs to be shared among processing nodes.
# For both the local or remote options, please check the installation steps: https://github.com/sepinf-inc/IPED/wiki/User-Manual#wav2vec2
# If you use the local implementation, you must set 'huggingFaceModel' param below.
# If you use the remote implementation, you must set 'wav2vec2Service' param below.
# Uses a local wav2vec2 implementation for transcription. Accuracy is much better than most Vosk models.
# This is up to 10x slower than Vosk on high end CPUs. Using a good GPU is highly recommended!
# Please check the installation steps: https://github.com/sepinf-inc/IPED/wiki/User-Manual#wav2vec2
# If you enable this, you must set 'huggingFaceModel' param below.
#implementationClass = iped.engine.task.transcript.Wav2Vec2TranscriptTask
#implementationClass = iped.engine.task.transcript.RemoteWav2Vec2TranscriptTask

# Uses a local Whisper implementation for transcription. Accuracy is better than wav2vec2 depending on the model.
# This is up to 4x slower than wav2vec2 depending on compared models. Using a high end GPU is strongly recommended!
# Please check the installation steps: https://github.com/sepinf-inc/IPED/wiki/User-Manual#whisper
# If you enable this, you must set 'whisperModel' param below.
#implementationClass = iped.engine.task.transcript.WhisperTranscriptTask

# Uses a remote service for transcription.
# The remote service is useful if you have a central server/cluster with many GPUs to be shared among processing nodes.
# Please check steps on https://github.com/sepinf-inc/IPED/wiki/User-Manual#remote-transcription
# If you enable this, you must set 'remoteServiceAddress' param below.
#implementationClass = iped.engine.task.transcript.RemoteTranscriptionTask

# If you want to use the Microsoft Azure service implementation, comment above and uncomment below.
# You MUST include Microsoft client-sdk.jar into plugins folder.
Expand Down Expand Up @@ -91,11 +100,32 @@ minWordScore = 0.5
# huggingFaceModel = jonatasgrosman/wav2vec2-xls-r-1b-french

#########################################
# RemoteWav2Vec2TranscriptTask options
# Local WhisperTranscriptTask options
#########################################

# Possible values: tiny, base, small, medium, large-v3, dwhoelz/whisper-large-pt-cv11-ct2
# large-v3 is much better than medium, but 2x slower and uses 2x more memory.
# If you know the language you want to transcribe, please set the 'language' option above.
# 'language = auto' uses the 'locale' set on LocalConfig.txt
# 'language = detect' uses auto detection, but it can cause mistakes
whisperModel = medium

# Compute type precision. This affects accuracy, speed and memory usage.
# Possible values: float32 (better), float16 (recommended for GPU), int8 (faster)
precision = int8

# Batch size (number of parallel transcriptions). If you have a GPU with enough memory,
# increasing this value to e.g. 16 can speed up transcribing long audios up to 10x.
# Test what is the better value for your GPU before hitting OOM.
# This works just if you are using whisperx library instead of faster_whisper
batchSize = 1

#########################################
# RemoteAudioTranscriptTask options
#########################################

# IP:PORT of the service/central node used by the RemoteWav2Vec2TranscriptTask implementation.
# wav2vec2Service = 127.0.0.1:11111
# remoteServiceAddress = 127.0.0.1:11111

#########################################
# MicrosoftTranscriptTask options
Expand Down
112 changes: 112 additions & 0 deletions iped-app/resources/scripts/tasks/WhisperProcess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import sys
import numpy
stdout = sys.stdout
sys.stdout = sys.stderr

terminate = 'terminate_process'
model_loaded = 'model_loaded'
library_loaded = 'library_loaded'
finished = 'transcription_finished'
ping = 'ping'

def main():

modelName = sys.argv[1]
deviceNum = int(sys.argv[2])
threads = int(sys.argv[3])
language = sys.argv[4]
compute_type = sys.argv[5]
batch_size = int(sys.argv[6])

if language == 'detect':
language = None

try:
import whisperx
whisperx_found = True
except:
import faster_whisper
whisperx_found = False

print(library_loaded, file=stdout, flush=True)

import GPUtil
cudaCount = len(GPUtil.getGPUs())

print(str(cudaCount), file=stdout, flush=True)

if cudaCount > 0:
deviceId = 'cuda'
else:
deviceId = 'cpu'
deviceNum = 0

try:
if whisperx_found:
model = whisperx.load_model(modelName, device=deviceId, device_index=deviceNum, threads=threads, compute_type=compute_type, language=language)
else:
model = faster_whisper.WhisperModel(modelName, device=deviceId, device_index=deviceNum, cpu_threads=threads, compute_type=compute_type)

except Exception as e:
if deviceId != 'cpu':
# loading on GPU failed (OOM?), try on CPU
print('FAILED to load model on GPU, OOM? Fallbacking to CPU...', file=sys.stderr)
deviceId = 'cpu'
if compute_type == 'float16': # not supported on CPU
compute_type = 'int8'
if whisperx_found:
model = whisperx.load_model(modelName, device=deviceId, device_index=deviceNum, threads=threads, compute_type=compute_type, language=language)
else:
model = faster_whisper.WhisperModel(modelName, device=deviceId, cpu_threads=threads, compute_type=compute_type)
else:
raise e

print(model_loaded, file=stdout, flush=True)
print(deviceId, file=stdout, flush=True)

while True:

line = input()

if line == terminate:
break
if line == ping:
print(ping, file=stdout, flush=True)
continue

transcription = ''
logprobs = []
try:
if whisperx_found:
audio = whisperx.load_audio(line)
result = model.transcribe(audio, batch_size=batch_size, language=language)
for segment in result['segments']:
transcription += segment['text']
if 'avg_logprob' in segment:
logprobs.append(segment['avg_logprob'])
else:
segments, info = model.transcribe(audio=line, language=language, beam_size=5, vad_filter=True)
for segment in segments:
transcription += segment.text
logprobs.append(segment.avg_logprob)

except Exception as e:
msg = repr(e).replace('\n', ' ').replace('\r', ' ')
print(msg, file=stdout, flush=True)
continue

text = transcription.replace('\n', ' ').replace('\r', ' ')

if len(logprobs) == 0:
finalScore = 0
else:
finalScore = numpy.mean(numpy.exp(logprobs))

print(finished, file=stdout, flush=True)
print(str(finalScore), file=stdout, flush=True)
print(text, file=stdout, flush=True)

return

if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,14 @@ public class AudioTranscriptConfig extends AbstractTaskPropertiesConfig {
private static final String MAX_REQUESTS_KEY = "maxConcurrentRequests";
private static final String MIN_WORD_SCORE = "minWordScore";
public static final String HUGGING_FACE_MODEL = "huggingFaceModel";
public static final String WHISPER_MODEL = "whisperModel";
public static final String WAV2VEC2_SERVICE = "wav2vec2Service";
public static final String REMOTE_SERVICE = "remoteServiceAddress";
private static final String GOOGLE_MODEL = "googleModel";
private static final String LANG_AUTO_VAL = "auto";
private static final String SKIP_KNOWN_FILES = "skipKnownFiles";
private static final String PRECISION = "precision";
private static final String BATCH_SIZE = "batchSize";

private List<String> languages = new ArrayList<>();
private List<String> mimesToProcess = new ArrayList<>();
Expand All @@ -43,9 +47,20 @@ public class AudioTranscriptConfig extends AbstractTaskPropertiesConfig {
private int maxConcurrentRequests;
private float minWordScore = 0.7f;
private String huggingFaceModel;
private String wav2vec2Service;
private String whisperModel;
private String remoteService;
private String googleModel;
private boolean skipKnownFiles = true;
private String precision = "int8";
private int batchSize = 1;

public String getPrecision() {
return precision;
}

public int getBatchSize() {
return batchSize;
}

public boolean getSkipKnownFiles() {
return this.skipKnownFiles;
Expand Down Expand Up @@ -109,8 +124,12 @@ public String getHuggingFaceModel() {
return huggingFaceModel;
}

public String getWav2vec2Service() {
return wav2vec2Service;
public String getWhisperModel() {
return whisperModel;
}

public String getRemoteService() {
return remoteService;
}

public String getGoogleModel() {
Expand Down Expand Up @@ -144,9 +163,17 @@ public void processProperties(UTF8Properties properties) {
if (huggingFaceModel != null) {
huggingFaceModel = huggingFaceModel.trim();
}
wav2vec2Service = properties.getProperty(WAV2VEC2_SERVICE);
if (wav2vec2Service != null) {
wav2vec2Service = wav2vec2Service.trim();
whisperModel = properties.getProperty(WHISPER_MODEL);
if (whisperModel != null) {
whisperModel = whisperModel.strip();
}

remoteService = properties.getProperty(REMOTE_SERVICE);
if (remoteService == null) {
remoteService = properties.getProperty(WAV2VEC2_SERVICE);
}
if (remoteService != null) {
remoteService = remoteService.trim();
}
googleModel = properties.getProperty(GOOGLE_MODEL);
if (googleModel != null) {
Expand All @@ -165,6 +192,14 @@ public void processProperties(UTF8Properties properties) {
if (value != null) {
timeoutPerSec = Integer.valueOf(value.trim());
}
value = properties.getProperty(PRECISION);
if (value != null) {
precision = value.trim();
}
value = properties.getProperty(BATCH_SIZE);
if (value != null) {
batchSize = Integer.parseInt(value.trim());
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;

import iped.engine.task.transcript.RemoteWav2Vec2Service.MESSAGES;
import iped.engine.task.transcript.RemoteTranscriptionService.MESSAGES;

public class RemoteWav2Vec2Discovery {
public class RemoteTranscriptionDiscovery {

private static final File statsFile = new File(System.getProperty("user.home"), "transcription.stats");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import iped.io.URLUtil;
import iped.utils.IOUtil;

public class RemoteWav2Vec2Service {
public class RemoteTranscriptionService {

static enum MESSAGES {
ACCEPTED,
Expand Down Expand Up @@ -131,22 +131,22 @@ public static void main(String[] args) throws Exception {
printHelpAndExit();
}

File jar = new File(URLUtil.getURL(RemoteWav2Vec2Service.class).toURI());
File jar = new File(URLUtil.getURL(RemoteTranscriptionService.class).toURI());
File root = jar.getParentFile().getParentFile();

System.setProperty("org.apache.logging.log4j.level", "INFO");
logger = LoggerFactory.getLogger(RemoteWav2Vec2Service.class);
logger = LoggerFactory.getLogger(RemoteTranscriptionService.class);

Configuration.getInstance().loadConfigurables(root.getAbsolutePath());
ConfigurationManager cm = ConfigurationManager.get();
AudioTranscriptConfig audioConfig = new AudioTranscriptConfig();
LocalConfig localConfig = new LocalConfig();
cm.addObject(audioConfig);
cm.addObject(localConfig);
cm.loadConfig(audioConfig);
cm.loadConfig(localConfig);
cm.loadConfig(audioConfig);

Wav2Vec2TranscriptTask task = new Wav2Vec2TranscriptTask();
AbstractTranscriptTask task = (AbstractTranscriptTask) Class.forName(audioConfig.getClassName()).getDeclaredConstructor().newInstance();
audioConfig.setEnabled(true);
task.init(cm);

Expand Down Expand Up @@ -261,7 +261,7 @@ private static void removeFrombeaconQueq(OpenConnectons opc) {
}
}

private static void waitRequests(ServerSocket server, Wav2Vec2TranscriptTask task, String discoveryIp) {
private static void waitRequests(ServerSocket server, AbstractTranscriptTask task, String discoveryIp) {
AtomicInteger jobs = new AtomicInteger();
while (true) {
try {
Expand Down
Loading

0 comments on commit 982622c

Please sign in to comment.