diff --git a/runtime/server/diarization_gpu/Dockerfile/dockerfile.client b/runtime/server/diarization_gpu/Dockerfile/dockerfile.client new file mode 100644 index 00000000..a7f8219d --- /dev/null +++ b/runtime/server/diarization_gpu/Dockerfile/dockerfile.client @@ -0,0 +1,33 @@ +################################################################################################### +# +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted +# provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright notice, this list of +# conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, this list of +# conditions and the following disclaimer in the documentation and/or other materials +# provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific prior written +# permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################### + +FROM nvcr.io/nvidia/tritonserver:22.07-py3-sdk +LABEL maintainer="NVIDIA" +LABEL repository="tritonserver" + +RUN apt-get update && apt-get install -y libsndfile1 +RUN pip3 install soundfile kaldiio +WORKDIR /workspace diff --git a/runtime/server/diarization_gpu/Dockerfile/dockerfile.server b/runtime/server/diarization_gpu/Dockerfile/dockerfile.server new file mode 100644 index 00000000..510593c6 --- /dev/null +++ b/runtime/server/diarization_gpu/Dockerfile/dockerfile.server @@ -0,0 +1,38 @@ +################################################################################################### +# +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted +# provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright notice, this list of +# conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, this list of +# conditions and the following disclaimer in the documentation and/or other materials +# provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific prior written +# permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################### +FROM nvcr.io/nvidia/tritonserver:22.07-py3 +LABEL maintainer="NVIDIA" +LABEL repository="tritonserver" + +RUN apt-get update && apt-get -y install swig && apt-get -y install python3-dev && apt-get install -y cmake +RUN pip3 install torch==1.10.0+cu113 torchvision==0.11.1+cu113 torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html +RUN pip3 install -v kaldifeat +RUN python3 -m pip install cupy +RUN python3 -m pip install soundfile +RUN pip install cudf-cu11 dask-cudf-cu11 --extra-index-url=https://pypi.ngc.nvidia.com +RUN pip install cuml-cu11 --extra-index-url=https://pypi.ngc.nvidia.com +RUN pip install cugraph-cu11 --extra-index-url=https://pypi.ngc.nvidia.com +WORKDIR /workspace diff --git a/runtime/server/diarization_gpu/README.md b/runtime/server/diarization_gpu/README.md new file mode 100644 index 00000000..fefd14ff --- /dev/null +++ b/runtime/server/diarization_gpu/README.md @@ -0,0 +1,179 @@ +# Best Practice for Deploying a WeSpeaker diarization service using Triton + +In this best practice, we'll go through how to deploy a WeSpeaker diarization pipeline in GPU by using NVIDIA [Triton Inference Server](https://github.com/triton-inference-server/server), which contains several modules including SAD, Speaker Embedding Extraction, Clustering and etc. + +We will use [Triton Business Logic Scripting](https://github.com/triton-inference-server/python_backend#business-logic-scripting) (BLS) to implement this pipeline. + +## Table of Contents + +- [Preparation](#preparation) + - [Prepare Environment](#prepare-environment) + - [Prepare Models](#prepare-models) + - [Preapare Test Data](#prepare-test-data) +- [Triton Inference Server](#triton-inference-server) + - [Quick Start](#quick-start) + - [Business Logic Scripting](#bls) +- [Inference Client](#inference-client) + - [Quick Start](#quick-start-1) + - [Compute Metrics](#compute-metrics) +- [Benchmark](#benchmark) + + +## Preparation + +Let's prepare enrivonments, models and data first. + +### Prepare Environment + +Clone the repository: + +```bash +# Clond WeSpeaker repo +git clone https://github.com/wenet-e2e/wespeaker.git +export WeSpeaker=$PWD/wespeaker/ +cd runtime/server/diarization_gpu +export PROJECT_DIR=$PWD + +``` + +### Prepare Models + +To depoloy this pipeline, first we should obtain SAD and Speaker models. + +#### Speaker Models + +You can refer to [voxceleb sv recipe](https://github.com/wenet-e2e/wespeaker/tree/master/examples/voxceleb/v2) to train a WeSpeaker model or use a pre-trained model: + +```bash +export SPK_MODEL_DIR=/workspace/pretrained_models +mkdir -p ${SPK_MODEL_DIR} +wget -c https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34_LM.onnx -O ${SPK_MODEL_DIR}/voxceleb_resnet34_LM.onnx +``` + +Then you can follow the best practice of [GPU deployment](https://github.com/wenet-e2e/wespeaker/tree/master/runtime/server/x86_gpu) to deploy the WeSpeaker model in Triton. +After that, speaker models will be avaliable in `wespeaker/runtime/server/x86_gpu/model_repo/` directory. + +```bash +export SPK_MODEL_REPO="wespeaker/runtime/server/x86_gpu/model_repo/" +``` + +#### SAD Models + +Speaker activity detection model: system SAD (VAD model pretrained by [silero](https://github.com/snakers4/silero-vad)). + +```bash +export SAD_DIR=/workspace/SAD +wget -c https://github.com/snakers4/silero-vad/archive/refs/tags/v3.1.zip -O external_tools/silero-vad-v3.1.zip +unzip -o external_tools/silero-vad-v3.1.zip -d external_tools +cp external_tools/silero-vad-3.1/files/silero_vad.jit $SAD_DIR/ +``` + +### Prepare Test Data + +You can use the following command to access the evluation datas from voxconverse: + +```bash +bash $WeSpeaker/examples/voxconverse/v1/run.sh --stage 2 --stop_stage 2 +``` + +If you are using your own data, you can evaluate the audio one by one. Or you should preapre a `wav.scp`, which contains a list of audios. For example, + +``` +abjxc abjxc.wav +afjiv afjiv.wav +``` + +## Triton Inference Server + +[Triton Inference Server](https://github.com/triton-inference-server/server) can help with the most of serving work for us and handles requests/results sending and receiving, request scheduling, load balance, and inference execution. In this section, we will use Triton to depoy the diarization pipeline. + +![Pipeline](./bls.png) + +Build the server docker image: +``` +docker build . -f Dockerfile/dockerfile.server -t wespeaker_server:latest --network host +``` + +Build the client docker image: +``` +docker build . -f Dockerfile/dockerfile.client -t wespeaker_client:latest --network host +``` + +Run the following commands to put the pretrained SAD and Speaker models into current `model_repo` directory. + +```bash +cd ${PROJECT_DIR} +mkdir -p model_repo/run/1 +cp -r $SPK_MODEL_REPO/* model_repo/ + +``` + +### Quick Start + +Now start server: + +```bash +# Start the docker container +docker run --gpus all -v $PWD/model_repo:/workspace/model_repo -v $SAD_DIR:/workspace/triton/ --net host --shm-size=1g --ulimit memlock=-1 -p 8000:8000 -p 8001:8001 -p 8002:8002 --ulimit stack=67108864 -it wespeaker_server:latest + +# Inside the docker container +tritonserver --model-repository=/workspace/model_repo + +``` + +### Business Logic Scripting + +Business Logic Scripting (BLS) can execute inference requests on other models being served by Triton as a part of executing one Python model. + + +## Inference Client + +In this section, we will show how to send requests to our deployed SD service, and receive the RTTM results. + + +### Quick Start + +Run, + +```bash +AUDIO_DATA= +docker run -ti --net host -v $PWD/client:/ws/client -v $AUDIO_DATA:/ws/test_data wespeaker_client:latest +cd /ws/client +``` + +In the docker container, run the client script to do the whole pipeline inference. + +```bash +# Test one audio +export output_directory="output" +mkdir -p $output_directory +python client.py --url=localhost:8001 --audio_file=/ws/test_data/abjxc.wav --output_directory=$output_directory +cat $output_directory/rttm* > $output_directory/rttm +``` + +The above command sends a single audio `abjxc.wav` to the server and get the result. `--url` option specifies the IP and port of the server, in our example, we set the server and client on the same machine, therefore IP is `localhost`, and we use port `8001` since it is the default port for gRPC in Triton. But if your client is not on the same machine as the server, you should change this option. + +You can also test specify the path of `wav.scp` with `--wavscp` option, then the client will test the audio files in the `wav.scp`. + +```bash +# Test a bunch of audios +export wav_scp_dir=/ws/test_data +python client.py --url=localhost:8001 --wavscp=$wav_scp_dir/wav.scp --output_directory="outp" +cat $output_directory/rttm* > $output_directory/rttm +``` + +Finally, you can get the RTTM information in `$output_directory/rttm`. + +### Compute Metrics + +If you want to test the performances of our SD pipeline, you can run: + +```bash +perl external_tools/SCTK-2.4.12/src/md-eval/md-eval.pl \ + -c 0.25 \ + -r <(cat data/voxconverse-master/dev/*.rttm) \ + -s $output_directory/rttm +``` + +## Benchmark (TODO) + diff --git a/runtime/server/diarization_gpu/bls.png b/runtime/server/diarization_gpu/bls.png new file mode 100644 index 00000000..3dd15d49 Binary files /dev/null and b/runtime/server/diarization_gpu/bls.png differ diff --git a/runtime/server/diarization_gpu/client/client.py b/runtime/server/diarization_gpu/client/client.py new file mode 100644 index 00000000..24723e32 --- /dev/null +++ b/runtime/server/diarization_gpu/client/client.py @@ -0,0 +1,154 @@ +# -*- encoding: utf-8 -*- +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +from multiprocessing import Pool + +import tritonclient.grpc as grpcclient +from tritonclient.utils import np_to_triton_dtype +import numpy as np +import soundfile +import argparse +import os + + +class SpeakerClient(object): + def __init__(self, triton_client, model_name, protocol_client): + self.triton_client = triton_client + self.protocol_client = protocol_client + self.model_name = model_name + + def recognize(self, wav_path, client_index): + # We send batchsize=1 data to server + # BatchSize > 1 is also ok but you need to take care of + # padding. + waveform, sample_rate = soundfile.read(wav_path) + cur_length = len(waveform) + input = np.zeros((1, cur_length), dtype=np.float32) + input[0][0:cur_length] = waveform[0:cur_length] + inputs = [self.protocol_client.InferInput("input", input.shape, + np_to_triton_dtype(input.dtype))] + inputs[0].set_data_from_numpy(input) + outputs = [grpcclient.InferRequestedOutput("LABELS")] + response = self.triton_client.infer(self.model_name, + inputs, + request_id=str(client_index), + outputs=outputs) + result = response.as_numpy("LABELS")[0] + return [result] + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-v', + '--verbose', + action="store_true", + required=False, + default=False, + help='Enable verbose output') + parser.add_argument('-u', + '--url', + type=str, + required=False, + default='localhost:8001', + help='Inference server URL. Default is ' + 'localhost:8001.') + parser.add_argument('--model_name', + required=False, + default='run', + help='the model to send request to') + parser.add_argument('--wavscp', + type=str, + required=False, + default=None, + help='audio_id \t absolute_wav_path') + parser.add_argument('--output_directory', + type=str, + required=False, + default=None, + help='the path to save the segment files') + parser.add_argument('--data_dir', + type=str, + required=False, + default=None, + help='data dir will be append to audio file if given') + parser.add_argument('--audio_file', + type=str, + required=False, + default=None, + help='single wav file') + FLAGS = parser.parse_args() + + # load data + audio_wavpath = [] + if FLAGS.audio_file is not None: + path = FLAGS.audio_file + if FLAGS.data_dir: + path = os.path.join(FLAGS.data_dir, path) + if os.path.exists(path): + audio_wavpath = [(FLAGS.audio_file, path)] + elif FLAGS.wavscp is not None: + with open(FLAGS.wavscp, "r", encoding="utf-8") as f: + for line in f: + aid, path = line.strip().split(' ') + audio_wavpath.append((aid, path)) + + num_workers = multiprocessing.cpu_count() // 2 + + def single_job(li): + idx, audio_files = li + dir_name = os.path.dirname(FLAGS.output_directory) # get the path + if not os.path.exists(dir_name) and (dir_name != ''): + os.makedirs(dir_name) + seg_writer = open(os.path.join(FLAGS.output_directory, + 'rttm' + str(idx)), 'w', encoding="utf-8") + + with grpcclient.InferenceServerClient(url=FLAGS.url, + verbose=FLAGS.verbose) as triton_client: + protocol_client = grpcclient + speech_client = SpeakerClient(triton_client, FLAGS.model_name, + protocol_client) + + predictions = {} + + for li in audio_files: + utt, wavpath = li + rttms = speech_client.recognize(wavpath, idx)[0] + spec = "SPEAKER {} {} {:.3f} {:.3f} {} " + for i in range(0, rttms.shape[0]): + begin = rttms[i][0] + end = rttms[i][1] + label = int(rttms[i][2]) + channel = 1 + seg_writer.write(spec.format(utt, + channel, + begin, + end - begin, + label) + '\n') + seg_writer.flush() + return predictions + + # start to do inference + # Group requests in batches + predictions = [] + tasks = [] + splits = np.array_split(audio_wavpath, num_workers) + + for idx, per_split in enumerate(splits): + cur_files = per_split.tolist() + tasks.append((idx, cur_files)) + + with Pool(processes=num_workers) as pool: + prediction = pool.map(single_job, tasks) diff --git a/runtime/server/diarization_gpu/model_repo/clusterer/1/model.py b/runtime/server/diarization_gpu/model_repo/clusterer/1/model.py new file mode 100644 index 00000000..0b879bd3 --- /dev/null +++ b/runtime/server/diarization_gpu/model_repo/clusterer/1/model.py @@ -0,0 +1,163 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import triton_python_backend_utils as pb_utils +from torch.utils.dlpack import from_dlpack +import json +import cupy as cp +import numpy as np +from cuml.cluster import KMeans as cuKM + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance + * device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + self.model_config = model_config = json.loads(args['model_config']) + self.max_batch_size = max(model_config["max_batch_size"], 1) + + if "GPU" in model_config["instance_group"][0]["kind"]: + self.device = "cuda" + else: + self.device = "cpu" + + # Get OUTPUT0 configuration + output0_config = pb_utils.get_output_config_by_name( + model_config, "LABELS") + # Convert Triton types to numpy types + self.output0_dtype = pb_utils.triton_string_to_numpy( + output0_config['data_type']) + + def cluster_gpu(self, embeddings, p=.01, num_spks=None, + min_num_spks=1, max_num_spks=20): + # Define utility functions + def cosine_similarity(M): + M = M / cp.linalg.norm(M, axis=1, keepdims=True) + return 0.5 * (1.0 + cp.dot(M, M.T)) + + def prune(M, p): + m = M.shape[0] + if m < 1000: + n = max(m - 10, 2) + else: + n = int((1.0 - p) * m) + for i in range(m): + indexes = cp.argsort(M[i, :]) + low_indexes, high_indexes = indexes[0:n], indexes[n:m] + M[i, low_indexes] = 0.0 + M[i, high_indexes] = 1.0 + return 0.5 * (M + M.T) + + def laplacian(M): + M[cp.diag_indices(M.shape[0])] = 0.0 + D = cp.diag(cp.sum(cp.abs(M), axis=1)) + return D - M + + def spectral(M, num_spks, min_num_spks, max_num_spks): + eig_values, eig_vectors = cp.linalg.eigh(M) + num_spks = num_spks if num_spks is not None \ + else cp.argmax(cp.diff(eig_values[:max_num_spks + 1])) + 1 + num_spks = max(num_spks, min_num_spks) + return eig_vectors[:, :num_spks] + + def kmeans(data): + k = data.shape[1] + kmeans_float = cuKM(n_clusters=k, n_init=10, random_state=10) + kmeans_float.fit(cp.asarray(data)) + return kmeans_float.labels_ + + # Fallback for trivial cases + if len(embeddings) <= 2: + return [0] * len(embeddings) + + # Compute similarity matrix + similarity_matrix = cosine_similarity(embeddings) + # Prune matrix with p interval + pruned_similarity_matrix = prune(similarity_matrix, p) + # Compute Laplacian + laplacian_matrix = laplacian(pruned_similarity_matrix) + # Compute spectral embeddings + spectral_embeddings = spectral(laplacian_matrix, num_spks, + min_num_spks, max_num_spks) + # Assign class labels + labels = kmeans(spectral_embeddings) + + return labels + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. + + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + + Returns + ------- + list + A list of pb_utils.InferenceResponse. + The length of this list must be the same as `requests` + """ + batch_count = [] + total_embd = [] + + responses = [] + for request in requests: + # the requests will all have the same shape + # different shape request will be + # separated by triton inference server + input0 = pb_utils.get_input_tensor_by_name(request, "EMBEDDINGS") + cur_b_embd = from_dlpack(input0.to_dlpack()) + cur_batch = cur_b_embd.shape[0] + batch_count.append(cur_batch) + + for embds in cur_b_embd: + total_embd.append(embds.to(self.device)) + + labels_list = [] + for embds in total_embd: + res = self.cluster_gpu(cp.asarray(embds)) + labels_list.append(cp.asnumpy(res)) + + idx = 0 + for b in batch_count: + batch_labels = np.array(labels_list[idx:idx + b]) + idx += b + out0 = pb_utils.Tensor("LABELS", + batch_labels.astype(self.output0_dtype)) + inference_response = pb_utils.InferenceResponse( + output_tensors=[out0]) + responses.append(inference_response) + return responses diff --git a/runtime/server/diarization_gpu/model_repo/clusterer/config.pbtxt b/runtime/server/diarization_gpu/model_repo/clusterer/config.pbtxt new file mode 100644 index 00000000..87f310cd --- /dev/null +++ b/runtime/server/diarization_gpu/model_repo/clusterer/config.pbtxt @@ -0,0 +1,43 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "clusterer" +backend: "python" +max_batch_size: 256 + +input [ + { + name: "EMBEDDINGS" + data_type: TYPE_FP32 + dims: [ -1, 256 ] # embedding dim + } +] + +output [ + { + name: "LABELS" + data_type: TYPE_INT32 + dims: [ -1 ] + } +] + +dynamic_batching { + preferred_batch_size: [ 16, 32 ] + } +instance_group [ + { + count: 2 + kind: KIND_GPU + } +] diff --git a/runtime/server/diarization_gpu/model_repo/run/1/model.py b/runtime/server/diarization_gpu/model_repo/run/1/model.py new file mode 100644 index 00000000..805cc591 --- /dev/null +++ b/runtime/server/diarization_gpu/model_repo/run/1/model.py @@ -0,0 +1,374 @@ +import triton_python_backend_utils as pb_utils +from torch.utils.dlpack import to_dlpack, from_dlpack +import torch +import numpy as np +import json +import asyncio + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance + device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + self.model_config = model_config = json.loads(args['model_config']) + self.max_batch_size = max(model_config["max_batch_size"], 1) + + if "GPU" in model_config["instance_group"][0]["kind"]: + self.device = "cuda" + else: + self.device = "cpu" + + # Get OUTPUT0 configuration + output0_config = pb_utils.get_output_config_by_name(model_config, + "LABELS") + # Convert Triton types to numpy types + self.output0_dtype = pb_utils.triton_string_to_numpy( + output0_config['data_type']) + + self.init_jit_model("/workspace/triton/silero_vad.jit") + + def init_jit_model(self, model_path): + torch.set_grad_enabled(False) + self.sad_model = torch.jit.load(model_path, map_location=self.device) + self.sad_model.eval() + + def prepare_chunks(self, + wav, + audio_length_samples, + sr: int = 16000, + window_size_samples: int = 1536): + chunks = [] + self.sad_model.reset_states() + + for current_start_sample in range(0, audio_length_samples, + window_size_samples): + chunk = wav[current_start_sample: + current_start_sample + window_size_samples] + if len(chunk) < window_size_samples: + chunk = torch.nn.functional.pad( + chunk, (0, int(window_size_samples - len(chunk)))) + speech_prob = self.sad_model(chunk, 16000) + chunks.append(speech_prob) + return chunks + + def get_timestamps(self, speech_probs, audio_length_samples, + sr: int = 16000, + threshold: float = 0.5, + min_duration: float = 0.255, + min_speech_duration_ms: int = 250, + min_silence_duration_ms: int = 100, + window_size_samples: int = 1536, + speech_pad_ms: int = 30): + triggered = False + speeches = [] + current_speech = {} + neg_threshold = threshold - 0.15 + temp_end = 0 + + min_speech_samples = sr * min_speech_duration_ms / 1000 + min_silence_samples = sr * min_silence_duration_ms / 1000 + speech_pad_samples = sr * speech_pad_ms / 1000 + + for i, speech_prob in enumerate(speech_probs): + if (speech_prob >= threshold) and temp_end: + temp_end = 0 + + if (speech_prob >= threshold) and not triggered: + triggered = True + current_speech['start'] = window_size_samples * i + continue + + if (speech_prob < neg_threshold) and triggered: + if not temp_end: + temp_end = window_size_samples * i + if (window_size_samples * i) - temp_end < min_silence_samples: + continue + else: + current_speech['end'] = temp_end + if (current_speech['end'] - + current_speech['start']) > min_speech_samples: + speeches.append(current_speech) + temp_end = 0 + current_speech = {} + triggered = False + continue + if current_speech: + current_speech['end'] = audio_length_samples + speeches.append(current_speech) + + for i, speech in enumerate(speeches): + if i == 0: + speech['start'] = int(max(0, + speech['start'] - speech_pad_samples)) + if i != len(speeches) - 1: + silence_duration = speeches[i + 1]['start'] - speech['end'] + if silence_duration < 2 * speech_pad_samples: + speech['end'] += int(silence_duration // 2) + speeches[i + 1]['start'] = int( + max(0, speeches[i + 1]['start'] - silence_duration // 2)) + else: + speech['end'] += int(speech_pad_samples) + else: + speech['end'] = int(min(audio_length_samples, speech['end'] + + speech_pad_samples)) + vad_result = [] + for item in speeches: + begin = item['start'] / sr + end = item['end'] / sr + if end - begin >= min_duration: + item['start'] = begin + item['end'] = end + vad_result.append(item) + return vad_result + + def subsegment(self, wav, segments, wav_idx, + window_fs: float = 1.50, + period_fs: float = 0.75, + sr: int = 16000, + frame_shift: int = 10): + def repeat_to_fill(x, window_fs): + length = x.size(0) + num = (window_fs + length - 1) // length + + x = x.repeat(1, num)[0][:window_fs] + input = torch.zeros((1, window_fs), device=self.device) + input[0] = x + return input + + subsegs = [] + subseg_signals = [] + + seg_idx = 0 + + window_fs = int(window_fs * sr) + period_fs = int(period_fs * sr) + for segment in segments: + seg_begin, seg_end = int(segment['start'] * sr) + seg_end = int(segment['end'] * sr) + seg_signal = wav[seg_begin: seg_end + 1] + seg_length = seg_end - seg_begin + + if seg_length <= window_fs: + subseg = [wav_idx, seg_idx, + segment['start'], segment['end'], 0, + int(seg_length / sr * 1000 // frame_shift)] + subseg_signal = repeat_to_fill(seg_signal, window_fs) + + subsegs.append(subseg) + subseg_signals.append(subseg_signal) + seg_idx += 1 + else: + max_subseg_begin = seg_length - window_fs + period_fs + for subseg_begin in range(0, max_subseg_begin, period_fs): + subseg_end = min(subseg_begin + window_fs, seg_length) + subseg = [wav_idx, seg_idx, + segment['start'], segment['end'], + int(subseg_begin / sr * 1000 / frame_shift), + int(subseg_end / sr * 1000 / frame_shift)] + subseg_signal = repeat_to_fill( + seg_signal[subseg_begin: subseg_end + 1], window_fs) + + subsegs.append(subseg) + subseg_signals.append(subseg_signal) + seg_idx += 1 + + return subsegs, subseg_signals + + def read_labels(self, subseg_ids, label, frame_shift=10): + utt_to_subseg_labels = [] + new_sort = {} + for i, subseg in enumerate(subseg_ids): + (utt, seg_idx, begin_ms, end_ms, begin_frames, end_frames) = subseg + begin = (int(begin_ms) + int(begin_frames) * frame_shift) / 1000.0 + end = (int(begin_ms) + int(end_frames) * frame_shift) / 1000.0 + new_sort[seg_idx] = (begin, end, label[i]) + utt_to_subseg_labels = list(dict(sorted(new_sort.items())).values()) + return utt_to_subseg_labels + + def merge_segments(self, subseg_to_labels): + merged_segment_to_labels = [] + + if len(subseg_to_labels) == 0: + return merged_segment_to_labels + + (begin, end, label) = subseg_to_labels[0] + for (b, e, la) in subseg_to_labels[1:]: + if b <= end and la == label: + end = e + elif b > end: + merged_segment_to_labels.append((begin, end, label)) + begin, end, label = b, e, la + elif b <= end and la != label: + pivot = (b + end) / 2.0 + merged_segment_to_labels.append((begin, pivot, label)) + begin, end, label = pivot, e, la + else: + raise ValueError + merged_segment_to_labels.append((begin, e, label)) + + return merged_segment_to_labels + + async def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + batch_count = [] + batch_len = [] + + total_wavs = [] + total_lens = [] + responses = [] + + for request in requests: + input0 = pb_utils.get_input_tensor_by_name(request, "input") + + cur_b_wav = from_dlpack(input0.to_dlpack()) + cur_batch = cur_b_wav.shape[0] + cur_len = cur_b_wav.shape[1] + batch_count.append(cur_batch) + batch_len.append(cur_len) + + for wav in cur_b_wav: + total_lens.append(len(wav)) + total_wavs.append(wav.to(self.device)) + + speech_shapes = [] + all_probs = [] + + for wav, lens in zip(total_wavs, total_lens): + chunks = self.prepare_chunks(wav, lens) + speech_shapes.append(len(chunks)) + all_probs.append(chunks) + reshape_probs = [] + idx = 0 + for i in range(0, len(speech_shapes)): + cur_speech = [] + for j in range(0, speech_shapes[i]): + cur_speech.append(all_probs[i][j]) + idx += 1 + reshape_probs.append(cur_speech) + + out_segs = [] + for speech_prob, speech_len in zip(reshape_probs, total_lens): + segments = self.get_timestamps(speech_prob, + speech_len, threshold=0.36) + out_segs.append(segments) + + total_subsegments = [] + total_subsegment_ids = [] + total_embds = [] + + wav_idx = 0 + for waveform, segments in zip(total_wavs, out_segs): + subsegs, subseg_signals = self.subsegment(waveform, + segments, + wav_idx) + total_subsegments.extend(subseg_signals) + total_subsegment_ids.extend(subsegs) + wav_idx += 1 + + inference_response_awaits = [] + for wavs in total_subsegments: + input_tensor_spk0 = pb_utils.Tensor.from_dlpack("WAV", + to_dlpack(wavs)) + + input_tensors_spk = [input_tensor_spk0] + inference_request = pb_utils.InferenceRequest(model_name='speaker', + requested_output_names=['EMBEDDINGS'], + inputs=input_tensors_spk) + inference_response_awaits.append(inference_request.async_exec()) + + inference_responses = await asyncio.gather( + *inference_response_awaits) + + for inference_response in inference_responses: + if inference_response.has_error(): + raise pb_utils.TritonModelException(inference_response. + error().message()) + else: + batched_result = pb_utils.get_output_tensor_by_name(inference_response, + 'EMBEDDINGS') + total_embds.extend(from_dlpack(batched_result.to_dlpack())) + + out_embds = list() + out_time_info = list() + for i in range(0, len(total_wavs)): + out_embds.append(list()) + out_time_info.append(list()) + + for subseg_idx, embds in zip(total_subsegment_ids, total_embds): + wav_idx = subseg_idx[0] + out_embds[wav_idx].append(embds) + out_time_info[wav_idx].append(subseg_idx) + + # Begin clustering + inference_response_awaits = [] + for i, embd in enumerate(out_embds): + embd = torch.stack(embd) + input_tensor_embds0 = pb_utils.Tensor.from_dlpack( + "EMBEDDINGS", to_dlpack(torch.unsqueeze(embd, 0))) + + input_tensors_spk = [input_tensor_embds0] + inference_request = pb_utils.InferenceRequest(model_name='clusterer', + requested_output_names=['LABELS'], + request_id=str(i), + inputs=input_tensors_spk) + inference_response_awaits.append(inference_request.async_exec()) + + inference_responses = await asyncio.gather( + *inference_response_awaits) + + i = 0 + results = [] + for inference_response in inference_responses: + if inference_response.has_error(): + raise pb_utils.TritonModelException(inference_response. + error().message()) + else: + result = pb_utils.get_output_tensor_by_name(inference_response, + 'LABELS').as_numpy()[0] + utt_to_subseg_labels = self.read_labels(out_time_info[i], + result) + i += 1 + rttm = self.merge_segments(utt_to_subseg_labels) + if len(rttm) > 0: + results.append(rttm) + + # Return the batched resoponse + st = 0 + for b in batch_count: + sents = np.array(results[st:st + b]) + out0 = pb_utils.Tensor("LABELS", sents.astype(self.output0_dtype)) + inference_response = pb_utils.InferenceResponse(output_tensors=[out0]) + responses.append(inference_response) + st += b + return responses diff --git a/runtime/server/diarization_gpu/model_repo/run/config.pbtxt b/runtime/server/diarization_gpu/model_repo/run/config.pbtxt new file mode 100644 index 00000000..5a51c6ed --- /dev/null +++ b/runtime/server/diarization_gpu/model_repo/run/config.pbtxt @@ -0,0 +1,43 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "run" +backend: "python" +max_batch_size: 128 + +input [ + { + name: "input" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +output [ + { + name: "LABELS" + data_type: TYPE_FP32 + dims: [ -1, 3 ] + } +] + +dynamic_batching { + preferred_batch_size: [ 16, 32 ] + } +instance_group [ + { + count: 2 + kind: KIND_GPU + } +]