diff --git a/src/cudadecoder/Makefile b/src/cudadecoder/Makefile index 362eb703a3d..278faa5ce67 100644 --- a/src/cudadecoder/Makefile +++ b/src/cudadecoder/Makefile @@ -13,8 +13,12 @@ endif TESTFILES = -OBJFILES = batched-threaded-nnet3-cuda-pipeline.o decodable-cumatrix.o \ - cuda-decoder.o cuda-decoder-kernels.o cuda-fst.o +OBJFILES = cuda-decoder.o cuda-decoder-kernels.o cuda-fst.o \ + batched-threaded-nnet3-cuda-online-pipeline.o \ + batched-threaded-nnet3-cuda-pipeline.o \ + batched-threaded-nnet3-cuda-pipeline2.o \ + batched-static-nnet3.o batched-static-nnet3-kernels.o \ + decodable-cumatrix.o LIBNAME = kaldi-cudadecoder diff --git a/src/cudadecoder/batched-static-nnet3-kernels.cu b/src/cudadecoder/batched-static-nnet3-kernels.cu new file mode 100644 index 00000000000..f02a78ed1af --- /dev/null +++ b/src/cudadecoder/batched-static-nnet3-kernels.cu @@ -0,0 +1,208 @@ +// cudadecoder/batched-static-nnet3-kernels.cu +// +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Hugo Braun +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cudadecoder/batched-static-nnet3-kernels.h" + +#include +namespace kaldi { +namespace cuda_decoder { + +__global__ void build_batch_with_context_kernel( + BatchedStaticNnet3KernelParams params) { + for (int batch_slot = blockIdx.z; batch_slot < params.batch_size; + batch_slot += gridDim.z) { + BatchSlotAssignment batch_assign = + params.d_batch_slot_assignement[batch_slot]; + const BaseFloat *d_batch_slot_features = batch_assign.d_features; + BaseFloat *d_channel_context = + ¶ms + .d_all_context_frames[batch_assign.ichannel * + params.d_all_context_frames_channel_stride]; + BaseFloat *d_batch_slot_with_context = + ¶ms.d_batch_with_context[params.d_batch_with_context_batch_stride * + batch_slot]; + + int n_frames_available = + batch_assign.n_frames_already_in_context + batch_assign.n_new_frames; + int n_frames_to_set = n_frames_available; + int n_left_context_frames_from_frame0 = 0; + if (batch_assign.n_frames_already_in_context == 0) { + // First chunk for that utterance. Generating left context by duplicating + // frame0 + n_frames_to_set += params.total_nnet_left_context; + n_left_context_frames_from_frame0 = params.total_nnet_left_context; + } + + for (int iframe = blockIdx.y; iframe < n_frames_to_set; + iframe += gridDim.y) { + for (int idim = threadIdx.x; idim < params.input_dim; + idim += blockDim.x) { + if (iframe < n_left_context_frames_from_frame0) { + d_batch_slot_with_context + [iframe * params.d_batch_with_context_frame_stride + idim] = + d_batch_slot_features[0 + idim]; // frame 0 + } else if (iframe < (n_left_context_frames_from_frame0 + + batch_assign.n_frames_already_in_context)) { + // Those are the frames coming from context + int src_iframe_in_saved_context = + iframe - n_left_context_frames_from_frame0; + d_batch_slot_with_context[iframe * + params + .d_batch_with_context_frame_stride + + idim] = + d_channel_context[src_iframe_in_saved_context * + params.d_all_context_frames_frame_stride + + idim]; + } else { + // Now we are moving the frames coming from the new chunk + int src_iframe_in_new_chunk = + iframe - n_left_context_frames_from_frame0 - + batch_assign.n_frames_already_in_context; + d_batch_slot_with_context + [iframe * params.d_batch_with_context_frame_stride + idim] = + d_batch_slot_features[src_iframe_in_new_chunk * + params.d_features_frame_stride + + idim]; + } + } + + if (iframe == 0 && + params.d_batch_ivectors) { // one CTA moves the ivectors + for (int idim = threadIdx.x; idim < params.ivector_dim; + idim += blockDim.x) { + params.d_batch_ivectors[batch_slot * params.d_batch_ivectors_stride + + idim] = batch_assign.d_ivectors[idim]; + } + } + } + } +} + +void BuildBatchWithContextKernel(const dim3 &grid, const dim3 &block, + const cudaStream_t &stream, + const BatchedStaticNnet3KernelParams ¶ms) { + build_batch_with_context_kernel<<>>(params); +} + +__global__ void build_batch_with_context_context_flush_kernel( + BatchedStaticNnet3KernelParams params) { + for (int batch_slot = blockIdx.z; batch_slot < params.batch_size; + batch_slot += gridDim.z) { + BatchSlotAssignment batch_assign = + params.d_batch_slot_assignement[batch_slot]; + BaseFloat *d_channel_context = + ¶ms + .d_all_context_frames[batch_assign.ichannel * + params.d_all_context_frames_channel_stride]; + BaseFloat *d_batch_slot_with_context = + ¶ms.d_batch_with_context[params.d_batch_with_context_batch_stride * + batch_slot]; + + int n_frames_in_context = batch_assign.n_frames_already_in_context; + int n_frames_to_set = n_frames_in_context + params.total_nnet_right_context; + + for (int iframe = blockIdx.y; iframe < n_frames_to_set; + iframe += gridDim.y) { + for (int idim = threadIdx.x; idim < params.input_dim; + idim += blockDim.x) { + if (iframe < n_frames_in_context) { + d_batch_slot_with_context + [iframe * params.d_batch_with_context_frame_stride + + idim] = d_channel_context + [iframe * params.d_all_context_frames_frame_stride + idim]; + } else if (iframe < n_frames_to_set) { + // Generating right context from last frame + int src_iframe_in_saved_context = n_frames_in_context - 1; + d_batch_slot_with_context[iframe * + params + .d_batch_with_context_frame_stride + + idim] = + d_channel_context[src_iframe_in_saved_context * + params.d_all_context_frames_frame_stride + + idim]; + } + } + + if (iframe == 0 && + params.d_batch_ivectors) { // one CTA moves the ivectors + for (int idim = threadIdx.x; idim < params.ivector_dim; + idim += blockDim.x) { + params.d_batch_ivectors[batch_slot * params.d_batch_ivectors_stride + + idim] = batch_assign.d_ivectors[idim]; + } + } + } + } +} + +void BuildBatchWithContextKernelContextFlush( + const dim3 &grid, const dim3 &block, const cudaStream_t &stream, + const BatchedStaticNnet3KernelParams ¶ms) { + build_batch_with_context_context_flush_kernel<<>>( + params); +} + +__global__ void save_context_from_batch_kernel( + BatchedStaticNnet3KernelParams params) { + for (int batch_slot = blockIdx.z; batch_slot < params.batch_size; + batch_slot += gridDim.z) { + BatchSlotAssignment batch_assign = + params.d_batch_slot_assignement[batch_slot]; + + // Real frames : does not include frame0 copies for left context + int n_real_frames_available = + batch_assign.n_frames_already_in_context + batch_assign.n_new_frames; + // total frames : includes frame0 copies + int total_frames_in_batch_slot = n_real_frames_available; + if (batch_assign.n_frames_already_in_context == 0) { + // First chunk for that utterance. We generated left context by + // duplicating frame0 + total_frames_in_batch_slot += params.total_nnet_left_context; + } + // total frames : includes frame0 copies + int n_to_copy = min(total_frames_in_batch_slot, params.total_nnet_context); + int copy_from_frame = total_frames_in_batch_slot - n_to_copy; + BaseFloat *d_batch_slot_with_context = + ¶ms.d_batch_with_context[params.d_batch_with_context_batch_stride * + batch_slot]; + BaseFloat *d_channel_context = + ¶ms + .d_all_context_frames[batch_assign.ichannel * + params.d_all_context_frames_channel_stride]; + + for (int dst_iframe = blockIdx.y; dst_iframe < n_to_copy; + dst_iframe += gridDim.y) { + int src_iframe = copy_from_frame + dst_iframe; + for (int idim = threadIdx.x; idim < params.input_dim; + idim += blockDim.x) { + d_channel_context[dst_iframe * + params.d_all_context_frames_frame_stride + + idim] = d_batch_slot_with_context + [src_iframe * params.d_batch_with_context_frame_stride + idim]; + } + } + } +} + +void SaveContextFromBatchKernel(const dim3 &grid, const dim3 &block, + const cudaStream_t &stream, + const BatchedStaticNnet3KernelParams ¶ms) { + save_context_from_batch_kernel<<>>(params); +} + +} // namespace cuda_decoder +} // namespace kaldi diff --git a/src/cudadecoder/batched-static-nnet3-kernels.h b/src/cudadecoder/batched-static-nnet3-kernels.h new file mode 100644 index 00000000000..45064e15071 --- /dev/null +++ b/src/cudadecoder/batched-static-nnet3-kernels.h @@ -0,0 +1,87 @@ +// cudadecoder/batched-static-nnet3-kernels.h +// +// Copyright (c) 2019; NVIDIA CORPORATION. All rights reserved. +// Hugo Braun +// +// Licensed under the Apache License; Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing; software +// distributed under the License is distributed on an "AS IS" BASIS; +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND; either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if HAVE_CUDA == 1 + +#include +#include "base/kaldi-types.h" + +#ifndef KALDI_CUDA_DECODER_BATCHED_STATIC_NNET3_KERNELS_H_ +#define KALDI_CUDA_DECODER_BATCHED_STATIC_NNET3_KERNELS_H_ + +namespace kaldi { +namespace cuda_decoder { + +// Describe what each batch slot is made of. Used by the context switch kernels +struct BatchSlotAssignment { + BaseFloat *d_features; + BaseFloat *d_ivectors; + int ichannel; + int n_frames_already_in_context; + int n_new_frames; +}; + +struct BatchedStaticNnet3KernelParams { + const BaseFloat *d_all_new_features; + const BatchSlotAssignment *d_batch_slot_assignement; + BaseFloat *d_all_context_frames; + BaseFloat *d_batch_with_context; + BaseFloat *d_batch_ivectors; + int d_batch_ivectors_stride; + int batch_size; + int d_features_frame_stride; + int d_ivectors_frame_stride; + int d_all_context_frames_frame_stride; + int d_batch_with_context_frame_stride; + int d_all_context_frames_channel_stride; + int d_batch_with_context_batch_stride; + int input_dim; + int ivector_dim; + int total_nnet_context; + int total_nnet_left_context; + int total_nnet_right_context; + int input_frames_per_chunk_with_context; +}; + +// Takes as a input strided new chunks ptrs [chk0, chk1, chk2..] +// associated to channels [ch0, ch1, ch2...] +// And build a continuous batch such as: +// Batch with context: +// row0: [left_context(ch0), chk0] +// row0: [left_context(ch1), chk1] +// row0: [left_context(ch2), chk2] +// With left context being either part of a previous chunk for that channel, or +// just duplications of frame0 if this is the first chunk for that channel The +// end of each chunk for each row will then be used as a right context +void BuildBatchWithContextKernel(const dim3 &grid, const dim3 &block, + const cudaStream_t &stream, + const BatchedStaticNnet3KernelParams ¶ms); + +// Same thing than BuildBatchWithContextKernelContextFlush, except that the +// final frame is replicated to create the right context +void BuildBatchWithContextKernelContextFlush( + const dim3 &grid, const dim3 &block, const cudaStream_t &stream, + const BatchedStaticNnet3KernelParams ¶ms); +void SaveContextFromBatchKernel(const dim3 &grid, const dim3 &block, + const cudaStream_t &stream, + const BatchedStaticNnet3KernelParams ¶ms); + +} // namespace cuda_decoder +} // namespace kaldi + +#endif // KALDI_CUDA_DECODER_BATCHED_STATIC_NNET3_KERNELS_H_ +#endif // HAVE_CUDA diff --git a/src/cudadecoder/batched-static-nnet3.cc b/src/cudadecoder/batched-static-nnet3.cc new file mode 100644 index 00000000000..87736a12bd0 --- /dev/null +++ b/src/cudadecoder/batched-static-nnet3.cc @@ -0,0 +1,393 @@ +// cudadecoder/batched-static-nnet3.cc +// +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Hugo Braun +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if HAVE_CUDA == 1 + +#include "cudadecoder/batched-static-nnet3.h" +#include "nnet3/nnet-utils.h" + +namespace kaldi { +namespace cuda_decoder { + +void BatchedStaticNnet3::ReadParametersFromModelAndConfig() { + input_frames_per_chunk_ = config_.compute_opts.frames_per_chunk; + int32 nnet_left_context, nnet_right_context; + nnet3::ComputeSimpleNnetContext(am_nnet_.GetNnet(), &nnet_left_context, + &nnet_right_context); + total_nnet_left_context_ = + nnet_left_context + config_.compute_opts.extra_left_context; + total_nnet_right_context_ = + nnet_right_context + config_.compute_opts.extra_right_context; + total_nnet_context_ = total_nnet_left_context_ + total_nnet_right_context_; + subsampling_factor_ = config_.compute_opts.frame_subsampling_factor, + input_frames_per_chunk_ = config_.compute_opts.frames_per_chunk; + input_frames_per_chunk_with_context_ = input_frames_per_chunk_ + + total_nnet_left_context_ + + total_nnet_right_context_; + output_frames_per_chunk_ = + (subsampling_factor_ - 1 + input_frames_per_chunk_) / subsampling_factor_; + KALDI_ASSERT(output_frames_per_chunk_ > 0); + + input_dim_ = am_nnet_.InputDim(); + if (has_ivector_) ivector_dim_ = am_nnet_.IvectorDim(); +} + +void BatchedStaticNnet3::PresetKernelParams() { + // context_switch_kernel_params_.d_all_new_features; <- To be set when + // called + context_switch_kernel_params_.d_batch_slot_assignement = + d_batch_slot_assignement_; + context_switch_kernel_params_.d_all_context_frames = + d_all_context_frames_.Data(); + context_switch_kernel_params_.d_all_context_frames_frame_stride = + d_all_context_frames_.Stride(); + context_switch_kernel_params_.d_all_context_frames_channel_stride = + d_all_context_frames_.Stride() * total_nnet_context_; + // context_switch_kernel_params_.d_batch_with_context = <- To be set + // when called + // d_batch_with_context_.Data(); + // context_switch_kernel_params_.batch_size; <- To be set when called + // context_switch_kernel_params_.d_all_new_features_stride = <- To be + // set when called + + context_switch_kernel_params_.input_dim = input_dim_; + context_switch_kernel_params_.ivector_dim = ivector_dim_; + context_switch_kernel_params_.total_nnet_context = total_nnet_context_; + context_switch_kernel_params_.total_nnet_left_context = + total_nnet_left_context_; + context_switch_kernel_params_.total_nnet_right_context = + total_nnet_right_context_; + context_switch_kernel_params_.input_frames_per_chunk_with_context = + input_frames_per_chunk_with_context_; +} + +void BatchedStaticNnet3::Allocate() { + cudaEventCreate(&batch_slot_assignement_copy_evt_); + d_all_context_frames_.Resize(nchannels_ * total_nnet_context_, input_dim_); + d_batch_with_context_.Resize( + max_batch_size_ * input_frames_per_chunk_with_context_, input_dim_); + if (has_ivector_) d_batch_ivectors_.Resize(max_batch_size_, ivector_dim_); + cudaMalloc(&d_batch_slot_assignement_, + max_batch_size_ * sizeof(*d_batch_slot_assignement_)); + cudaMallocHost(&h_batch_slot_assignement_, + max_batch_size_ * sizeof(*h_batch_slot_assignement_)); + channel_n_frames_in_context_.resize(nchannels_, -1); + st_ = cudaStreamPerThread; + PresetKernelParams(); +} + +void BatchedStaticNnet3::Deallocate() { + cudaFreeHost(h_batch_slot_assignement_); + cudaFreeHost(d_batch_slot_assignement_); + cudaEventDestroy(batch_slot_assignement_copy_evt_); +} + +void BatchedStaticNnet3::CompileNnet3() { + SetComputationRequest(); + config_.compute_opts.compiler_config.cache_capacity += + max_batch_size_ * input_frames_per_chunk_; + compiler_.reset(new nnet3::CachingOptimizingCompiler( + am_nnet_.GetNnet(), config_.compute_opts.compiler_config)); + computation_ = compiler_->Compile(request_); +} + +void BatchedStaticNnet3::SetComputationRequest() { + request_.need_model_derivative = false; + request_.store_component_stats = false; + request_.inputs.reserve(2); + + int32 num_input_frames = input_frames_per_chunk_ + total_nnet_left_context_ + + total_nnet_right_context_; + int32 first_input_t = 0 - total_nnet_left_context_; + int32 num_output_frames = output_frames_per_chunk_; + int32 output_t_stride = subsampling_factor_; + + std::vector input_indexes, ivector_indexes, output_indexes; + input_indexes.reserve(nnet3_batch_size_ * num_input_frames); + output_indexes.reserve(nnet3_batch_size_ * num_output_frames); + if (has_ivector_) ivector_indexes.reserve(nnet3_batch_size_); + for (int32 n = 0; n < nnet3_batch_size_; n++) { + for (int32 t = first_input_t; t < first_input_t + num_input_frames; t++) { + input_indexes.push_back(nnet3::Index(n, t, 0)); + } + if (config_.has_ivector) ivector_indexes.push_back(nnet3::Index(n, 0, 0)); + for (int32 t = 0; t < num_output_frames; t++) + output_indexes.push_back(nnet3::Index(n, t * output_t_stride, 0)); + } + request_.inputs.push_back(nnet3::IoSpecification("input", input_indexes)); + if (has_ivector_) + request_.inputs.push_back( + nnet3::IoSpecification("ivector", ivector_indexes)); + request_.outputs.push_back(nnet3::IoSpecification("output", output_indexes)); +} + +void BatchedStaticNnet3::BatchContextSwitch( + const std::vector &channels, + const std::vector &d_features, const int features_frame_stride, + const std::vector &d_ivectors, + const std::vector &n_input_frames_valid, bool flush_eos_context, + std::vector *n_output_frames_valid) { + int batch_size = channels.size(); + + // AcceptInput destroys input, resizing + d_batch_with_context_.Resize( + max_batch_size_ * input_frames_per_chunk_with_context_, input_dim_); + if (has_ivector_) d_batch_ivectors_.Resize(max_batch_size_, ivector_dim_); + + n_output_frames_valid->resize(batch_size); + + cudaEventSynchronize( + batch_slot_assignement_copy_evt_); // reusing same pinned memory + for (int i = 0; i < channels.size(); ++i) { + int channel = channels[i]; + int nframes_in_context = channel_n_frames_in_context_[channel]; + int ninput_frames = n_input_frames_valid[i]; + + KALDI_ASSERT(ninput_frames <= input_frames_per_chunk_); + h_batch_slot_assignement_[i].d_features = d_features[i]; + h_batch_slot_assignement_[i].d_ivectors = + has_ivector_ ? d_ivectors[i] : NULL; + h_batch_slot_assignement_[i].ichannel = channel; + h_batch_slot_assignement_[i].n_frames_already_in_context = + nframes_in_context; + h_batch_slot_assignement_[i].n_new_frames = ninput_frames; + + // Left context will be generated as necessary (copying first + // frame) However we must have a full right context to start + // decoding frames + KALDI_ASSERT(!flush_eos_context || ninput_frames == 0); + int nframes_in_batch = ninput_frames; + if (nframes_in_context == 0) + nframes_in_batch += total_nnet_left_context_; // using frame0 as left + // context + else + nframes_in_batch += nframes_in_context; + if (flush_eos_context) + nframes_in_batch += total_nnet_right_context_; // using last frame as + // right context + KALDI_ASSERT( + "Please set --frames-per-chunk at least as large as the neural net " + "right context" && + input_frames_per_chunk_ >= total_nnet_right_context_); + + channel_n_frames_in_context_[channel] = + std::min(nframes_in_batch, total_nnet_context_); + + // Computing number of output frames + int total_nframes_minus_context = + std::max(0, nframes_in_batch - total_nnet_context_); + int total_output_nframes = + (total_nframes_minus_context + subsampling_factor_ - 1) / + subsampling_factor_; + (*n_output_frames_valid)[i] = total_output_nframes; + } + context_switch_kernel_params_.batch_size = batch_size; + context_switch_kernel_params_.d_features_frame_stride = features_frame_stride; + context_switch_kernel_params_.d_batch_with_context = + d_batch_with_context_.Data(); + context_switch_kernel_params_.d_batch_with_context_frame_stride = + d_batch_with_context_.Stride(); + context_switch_kernel_params_.d_batch_ivectors = + has_ivector_ ? d_batch_ivectors_.Data() : NULL; + context_switch_kernel_params_.d_batch_ivectors_stride = + has_ivector_ ? d_batch_ivectors_.Stride() : 0; + context_switch_kernel_params_.d_batch_with_context_batch_stride = + d_batch_with_context_.Stride() * input_frames_per_chunk_with_context_; + + cudaMemcpyAsync(d_batch_slot_assignement_, h_batch_slot_assignement_, + batch_size * sizeof(*d_batch_slot_assignement_), + cudaMemcpyHostToDevice, st_); + cudaEventRecord(batch_slot_assignement_copy_evt_, st_); + + dim3 grid = {1, + static_cast(input_frames_per_chunk_with_context_), + static_cast(batch_size)}; + dim3 block = { + 64, 1, + 1}; // Expecting chunks in the order of magnitude of 64 frames. It will + // still work with any numbers of frames per chunk, this only impacts + // performance. This kernel is not a bottleneck anyway + if (flush_eos_context) { + BuildBatchWithContextKernelContextFlush(grid, block, st_, + context_switch_kernel_params_); + } else { + BuildBatchWithContextKernel(grid, block, st_, + context_switch_kernel_params_); + SaveContextFromBatchKernel(grid, block, st_, context_switch_kernel_params_); + } +} + +void BatchedStaticNnet3::RunNnet3(CuMatrix *d_all_log_posteriors, + int batch_size) { + for (int off = 0; off < batch_size; off += nnet3_batch_size_) { + // Nnet3 destroys input, resizing + d_nnet3_input_.Resize( + nnet3_batch_size_ * input_frames_per_chunk_with_context_, input_dim_); + if (has_ivector_) d_nnet3_ivectors_.Resize(nnet3_batch_size_, ivector_dim_); + + int minibatch_size = std::min(nnet3_batch_size_, batch_size - off); + { + // Copy minibatch from batch : mfcc + int frames_per_minibatch = + minibatch_size * input_frames_per_chunk_with_context_; + CuSubMatrix dst = + d_nnet3_input_.RowRange(0, frames_per_minibatch); + CuSubMatrix src = d_batch_with_context_.RowRange( + off * input_frames_per_chunk_with_context_, frames_per_minibatch); + dst.CopyFromMat(src); + } + + if (has_ivector_) { + // Copy minibatch from batch : ivectors + CuSubMatrix dst = + d_nnet3_ivectors_.RowRange(0, minibatch_size); + CuSubMatrix src = + d_batch_ivectors_.RowRange(off, minibatch_size); + dst.CopyFromMat(src); + } + + // Using pre-compiled computation_ + nnet3::NnetComputer computer(config_.compute_opts.compute_config, + *computation_, am_nnet_.GetNnet(), NULL); + + computer.AcceptInput("input", &d_nnet3_input_); + if (has_ivector_) computer.AcceptInput("ivector", &d_nnet3_ivectors_); + computer.Run(); + + d_nnet3_output_ = computer.GetOutput("output"); + + { + int output_rows_per_minibatch = minibatch_size * output_frames_per_chunk_; + + // Copy nnet3 minibatch output to batch + CuSubMatrix src = + d_nnet3_output_.RowRange(0, output_rows_per_minibatch); + CuSubMatrix dst = d_all_log_posteriors->RowRange( + off * output_frames_per_chunk_, output_rows_per_minibatch); + dst.CopyFromMat(src); + } + } + + // Postprocessing of the loglikehoods + if (log_priors_.Dim() != 0) + d_all_log_posteriors->AddVecToRows(-1.0, log_priors_); + if (config_.compute_opts.acoustic_scale != 1.0f) + d_all_log_posteriors->Scale(config_.compute_opts.acoustic_scale); +} + +void BatchedStaticNnet3::RunBatch( + const std::vector &channels, + const std::vector &d_features, const int features_stride, + const std::vector &d_ivectors, + const std::vector &n_input_frames_valid, + const std::vector &is_first_chunk, + const std::vector &is_last_chunk, + CuMatrix *d_all_log_posteriors, + std::vector>> + *all_frames_log_posteriors_ptrs) { + KALDI_ASSERT(d_features.size() == channels.size()); + KALDI_ASSERT(is_last_chunk.size() == channels.size()); + KALDI_ASSERT(is_first_chunk.size() == channels.size()); + if (has_ivector_) { + KALDI_ASSERT(d_ivectors.size() == channels.size()); + } + // Initializing the new channels + for (size_t i = 0; i < is_first_chunk.size(); ++i) { + if (is_first_chunk[i]) InitChannel(channels[i]); + } + + all_frames_log_posteriors_ptrs + ->clear(); // will start setting output frames now + + // + // Step1: Processing chunks in d_features + // + + // Building a continuous execution batch made of the current assignements, + // while adding left and right context to the chunks + BatchContextSwitch(channels, d_features, features_stride, d_ivectors, + n_input_frames_valid, false, &n_output_frames_valid_); + // Running this batch + RunNnet3(d_all_log_posteriors, channels.size()); + // Building the list of pointers to output frames. Will be used by the decoder + FormatOutputPtrs(channels, d_all_log_posteriors, + all_frames_log_posteriors_ptrs, n_output_frames_valid_); + + // + // Step2: Flushing context for chunks with is_last_chunk set + // + + eos_channels_.clear(); + d_eos_features_.clear(); + d_eos_ivectors_.clear(); + eos_n_input_frames_valid_.clear(); + eos_n_output_frames_offset_.clear(); + for (int i = 0; i < channels.size(); ++i) { + if (!is_last_chunk[i]) continue; + eos_channels_.push_back(channels[i]); + d_eos_features_.push_back(NULL); // the context will serve as features + if (has_ivector_) d_eos_ivectors_.push_back(d_ivectors[i]); + eos_n_input_frames_valid_.push_back(0); + eos_n_output_frames_offset_.push_back( + n_output_frames_valid_[i]); // append to previously generated frames + // (in step1) + } + + if (!eos_channels_.empty()) { + BatchContextSwitch(eos_channels_, d_eos_features_, 0, d_eos_ivectors_, + eos_n_input_frames_valid_, /* flush context */ true, + &eos_n_output_frames_valid_); + d_all_eos_log_posteriors_.Resize(d_all_log_posteriors->NumRows(), + d_all_log_posteriors->NumCols()); + RunNnet3(&d_all_eos_log_posteriors_, eos_channels_.size()); + FormatOutputPtrs(eos_channels_, &d_all_eos_log_posteriors_, + all_frames_log_posteriors_ptrs, eos_n_output_frames_valid_, + &eos_n_output_frames_offset_); + } +} + +void BatchedStaticNnet3::FormatOutputPtrs( + const std::vector &channels, CuMatrix *d_all_log_posteriors, + std::vector>> + *all_frames_log_posteriors_ptrs, + const std::vector &n_output_frames_valid, + const std::vector *n_output_frames_valid_offset) { + // Build the list of pointers to output frames. Will be used by the decoder + KALDI_ASSERT(channels.size() == n_output_frames_valid.size()); + for (int i = 0; i < channels.size(); ++i) { + int ichannel = channels[i]; + int offset = + (n_output_frames_valid_offset) ? (*n_output_frames_valid_offset)[i] : 0; + int total_output_nframes = offset + n_output_frames_valid[i]; + if (all_frames_log_posteriors_ptrs->size() < total_output_nframes) + all_frames_log_posteriors_ptrs->resize(total_output_nframes); + for (int iframe = offset; iframe < total_output_nframes; ++iframe) { + std::vector> &this_frame = + (*all_frames_log_posteriors_ptrs)[iframe]; + int local_iframe = iframe - offset; + CuSubVector out = d_all_log_posteriors->Row( + i * output_frames_per_chunk_ + local_iframe); + BaseFloat *frame = out.Data(); + this_frame.push_back({ichannel, frame}); + } + } +} + +} // namespace cuda_decoder +} // namespace kaldi + +#endif // HAVE_CUDA diff --git a/src/cudadecoder/batched-static-nnet3.h b/src/cudadecoder/batched-static-nnet3.h new file mode 100644 index 00000000000..df03e924854 --- /dev/null +++ b/src/cudadecoder/batched-static-nnet3.h @@ -0,0 +1,227 @@ +// cudadecoder/batched-static-nnet3.h +// +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Hugo Braun +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if HAVE_CUDA == 1 + +#ifndef KALDI_CUDA_DECODER_BATCHED_STATIC_NNET3_H_ +#define KALDI_CUDA_DECODER_BATCHED_STATIC_NNET3_H_ + +// Following define is NOT an upper bound for max_batch_size +// It only concerns the nnet3 compiled computation +// If we use a batch size > MAX_COMPUTE_BATCH_SIZE, we will run nnet3 +// multiple times, each computing minibatches of size MAX_COMPUTE_BATCH_SIZE +// MAX_COMPUTE_BATCH_SIZE is defined to be big enough to hide kernel launch +// latency and increase the arithmetic intensity of the GEMMs +// not not bigger so that running partial batches is faster +// (e.g. running a batch size = 72 with max_batch_size_=512) +#define MAX_COMPUTE_BATCH_SIZE 64 + +#include "cudadecoder/batched-static-nnet3-kernels.h" +#include "nnet3/am-nnet-simple.h" +#include "nnet3/nnet-am-decodable-simple.h" +#include "nnet3/nnet-compute.h" +#include "nnet3/nnet-optimize.h" + +namespace kaldi { +namespace cuda_decoder { + +struct BatchedStaticNnet3Config { + BatchedStaticNnet3Config() + : max_batch_size(200), nchannels(-1), has_ivector(false) {} + nnet3::NnetSimpleComputationOptions compute_opts; + int max_batch_size; + int nchannels; + bool has_ivector; // probably can be deducted from am_nnet? +}; + +// Light driver for Nnet3. Compiles the nnet only once and reuse it. +// It is cheaper to waste some computation by adding partial chunks to a batch +// than recompiling a nnet3 computation just for that chunk (and running smaller +// batches, because each batch would be specialized to a specific chunk +// size/batch size) +// Also takes care of storing/restoring left/right context, generating initial +// context/final context, flushing this context. +// Supports context switch with ivectors +class BatchedStaticNnet3 { + public: + BatchedStaticNnet3(const BatchedStaticNnet3Config &config, + const nnet3::AmNnetSimple &am_nnet) + : config_(config), + am_nnet_(am_nnet), + max_batch_size_(config.max_batch_size), + has_ivector_(config.has_ivector), + log_priors_(am_nnet.Priors()) { + nchannels_ = (config.nchannels != -1) ? config.nchannels : max_batch_size_; + KALDI_ASSERT(max_batch_size_ > 0); + nnet3_batch_size_ = std::min(max_batch_size_, MAX_COMPUTE_BATCH_SIZE); + KALDI_ASSERT(nchannels_ >= max_batch_size_); + ReadParametersFromModelAndConfig(); + CompileNnet3(); + Allocate(); + } + + virtual ~BatchedStaticNnet3() { Deallocate(); } + + // Receives a batch with a set of chunks (at most one chunk per channel). + // Restore contextes, run nnet3, save the context for next RunBatch. + // Pointers to the output frames are set in all_frames_log_posteriors + // + // For each batch slot i: + // - channels[i] is the associated channel. + // - d_features[i] points to a submatrix of features. It is made of + // mfcc_dim*n_input_frames_valid[i] BaseFloats + // - d_ivectors[i] is the ivector to use for this nnet3 run, if ivectors + // are available. + // - n_input_frames_valid[i] how many frames can be read from d_features. + // It can be strictly less than frames_per_chunk, for instance for the last + // chunk + // - is_first_chunk[i] set <=> first chunk for that channel. Will reset + // left context + // - is_last_chunk[i] set <=> last chunk for that channel. Will flush right + // context + // - d_all_log_posteriors where to store the output frames. Could be owned + // by that class (the decoder is supposed to access those frames through + // all_frames_log_posteriors + // - all_frames_log_posteriors. For each output frame index (dim1), list + // all the channels which have a valid frame, and the corresponding pointer + // in memory. + // + // E.g.: We called RunBatch with channels = {4,7} Channels 4 has + // 2 valid output frames, channel 7 has 3 valid output frames. + // all_frames_log_posteriors = [ + // [[4,ptr0,4],[7,ptr0,7]], + // [[4,ptr1,4],[7,ptr1,7]], + // [[7,ptr2,7]], + // ] + // with ptri,j the pointer to the output frame i for channel j. + // frame i is a local indexing: the first frame for channel j for this + // RunBatch call will always be 0, even if other output frames have already + // been generated for that channel in previous RunBatch calls. + void RunBatch(const std::vector &channels, + const std::vector &d_features, + const int features_stride, + const std::vector &d_ivectors, + const std::vector &n_input_frames_valid, + const std::vector &is_first_chunk, + const std::vector &is_last_chunk, + CuMatrix *d_all_log_posteriors, + std::vector>> + *all_frames_log_posteriors); + + // Nnet3 puts the output frames in the matrix all_frames_log_posteriors_ptrs + // However, we still have to only consider "valid" output frames. + // See RunBatch comments for a description of the output + // n_output_frames_valid_offset describes how many valid output frames we + // already have in all_frames_log_posteriors_ptrs for each channel + void FormatOutputPtrs( + const std::vector &channels, + CuMatrix *d_all_log_posteriors, + std::vector>> + *all_frames_log_posteriors_ptrs, + const std::vector &n_output_frames_valid, + const std::vector *n_output_frames_valid_offset = NULL); + + int GetNOutputFramesPerChunk() { return output_frames_per_chunk_; } + int GetTotalNnet3RightContext() { return total_nnet_right_context_; } + + private: + // Compiling nnet3 using that computation request + void ReadParametersFromModelAndConfig(); + // Define the computation request for nnet3 based on parameters + void SetComputationRequest(); + void Allocate(); + void PresetKernelParams(); + void Deallocate(); + void CompileNnet3(); + // Run Nnet3 itself. Divides the execution batch into smaller nnet3 batches + // That nnet3 batch size is choosen so that we saturate the GPU, but we still + // keep the smallest batch size possible to have a better granularity with + // partial batches + void RunNnet3(CuMatrix *d_all_log_posteriors, int batch_size); + void BatchContextSwitch(const std::vector &channels, + const std::vector &d_features, + const int features_stride, + const std::vector &d_ivectors, + const std::vector &n_input_frames_valid, + bool flush_eos_context, + std::vector *n_output_frames_valid); + void InitChannel(int32 ichannel) { + KALDI_ASSERT(ichannel < nchannels_); + channel_n_frames_in_context_[ichannel] = 0; + } + + BatchedStaticNnet3Config config_; + cudaStream_t st_; + nnet3::AmNnetSimple am_nnet_; + int max_batch_size_; + int nnet3_batch_size_; // Cf RunNnet3. Batch size for the execution for nnet3 + int nchannels_; // Number of possible channels. Each channel owns a context. + bool has_ivector_; + CuVector log_priors_; + + // Extracted from config or models + int input_dim_; // mfcc dim + int ivector_dim_; // ivector dim + int input_frames_per_chunk_; + int input_frames_per_chunk_with_context_; // input_frames_per_chunk_ with + // left and right context + int total_nnet_left_context_; + int total_nnet_right_context_; + int total_nnet_context_; + int output_frames_per_chunk_; + int subsampling_factor_; + + // Storing frames which will be used in future context + // If the channel has just been resetted, those frames are empty. + // Otherwise, it contains at most total_nnet_context_ frames + CuMatrix d_all_context_frames_; + CuMatrix d_batch_with_context_; + CuMatrix d_nnet3_input_; + CuMatrix d_nnet3_ivectors_; + CuMatrix d_nnet3_output_; + CuMatrix d_batch_ivectors_; + CuMatrix d_all_log_posteriors_; + CuMatrix d_all_eos_log_posteriors_; + // batch slot assignement. Size [max_batch_size] + BatchSlotAssignment *d_batch_slot_assignement_; + BatchSlotAssignment *h_batch_slot_assignement_; + BatchedStaticNnet3KernelParams context_switch_kernel_params_; + cudaEvent_t batch_slot_assignement_copy_evt_; + // Number of frames already stored in context + // Size [nchannels] + // If channel not initialized, equals to -1 + std::vector channel_n_frames_in_context_; + std::vector n_output_frames_valid_; + + // Used to flush context at eos (end of sequence) + std::vector eos_channels_; + std::vector d_eos_features_; + std::vector d_eos_ivectors_; + std::vector eos_n_input_frames_valid_; + std::vector eos_n_output_frames_valid_; + std::vector eos_n_output_frames_offset_; + + std::unique_ptr compiler_; + std::shared_ptr + computation_; // shared because returned as shared by compiler + nnet3::ComputationRequest request_; +}; +} // namespace cuda_decoder +} // namespace kaldi + +#endif // KALDI_CUDA_DECODER_BATCHED_STATIC_NNET3_H_ +#endif // HAVE_CUDA diff --git a/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.cc b/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.cc new file mode 100644 index 00000000000..6fe87ee3dc7 --- /dev/null +++ b/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.cc @@ -0,0 +1,476 @@ +// cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.cc +// +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Hugo Braun +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if HAVE_CUDA == 1 + +#define KALDI_CUDA_DECODER_WAIT_FOR_CALLBACKS_US 10000 +#define KALDI_CUDA_DECODER_WAIT_FOR_CPU_FEATURES_THREADS_US 1000 +#define KALDI_CUDA_DECODER_WAIT_FOR_AVAILABLE_CHANNEL_US 1000 + +#include "cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h" +#include +#include "feat/feature-window.h" +#include "lat/lattice-functions.h" +#include "nnet3/nnet-utils.h" + +namespace kaldi { +namespace cuda_decoder { +void BatchedThreadedNnet3CudaOnlinePipeline::Initialize( + const fst::Fst &decode_fst) { + ReadParametersFromModel(); + AllocateAndInitializeData(decode_fst); +} + +void BatchedThreadedNnet3CudaOnlinePipeline::AllocateAndInitializeData( + const fst::Fst &decode_fst) { + d_all_features_.Resize(max_batch_size_ * input_frames_per_chunk_, input_dim_, + kUndefined, kStrideEqualNumCols); + + if (config_.use_gpu_feature_extraction) { + h_all_waveform_.Resize(max_batch_size_, samples_per_chunk_, kUndefined, + kStrideEqualNumCols); + cudaHostRegister(h_all_waveform_.Data(), h_all_waveform_.SizeInBytes(), + cudaHostRegisterDefault); + d_all_waveform_.Resize(max_batch_size_, samples_per_chunk_, kUndefined, + kStrideEqualNumCols); + } else { + h_all_features_.Resize(max_batch_size_ * input_frames_per_chunk_, + input_dim_, kUndefined, kStrideEqualNumCols); + } + + if (use_ivectors_) { + d_all_ivectors_.Resize(max_batch_size_ * ivector_dim_, kSetZero); + h_all_ivectors_.Resize(max_batch_size_, ivector_dim_, kSetZero, + kStrideEqualNumCols); + } + + d_all_log_posteriors_.Resize(max_batch_size_ * output_frames_per_chunk_, + trans_model_->NumPdfs(), kUndefined); + available_channels_.resize(config_.num_channels); + channels_callbacks_.resize(config_.num_channels); + std::iota(available_channels_.begin(), available_channels_.end(), + 0); // 0,1,2,3.. + corr_id2channel_.reserve(config_.num_channels); + channel_frame_offset_.resize(config_.num_channels, 0); + + // Feature extraction + if (config_.use_gpu_feature_extraction) { + gpu_feature_pipeline_.reset(new OnlineBatchedFeaturePipelineCuda( + config_.feature_opts, samples_per_chunk_, config_.max_batch_size, + config_.num_channels)); + } else { + feature_pipelines_.resize(config_.num_channels); + } + + // Decoder + cuda_fst_ = std::make_shared(); + cuda_fst_->Initialize(decode_fst, trans_model_); + cuda_decoder_.reset(new CudaDecoder(*cuda_fst_, config_.decoder_opts, + max_batch_size_, config_.num_channels)); + if (config_.num_decoder_copy_threads > 0) { + cuda_decoder_->SetThreadPoolAndStartCPUWorkers( + thread_pool_.get(), config_.num_decoder_copy_threads); + } + n_samples_valid_.resize(max_batch_size_); + n_input_frames_valid_.resize(max_batch_size_); + n_lattice_callbacks_not_done_.store(0); +} + +void BatchedThreadedNnet3CudaOnlinePipeline::SetLatticeCallback( + CorrelationID corr_id, + const std::function &callback) { + auto it = corr_id2channel_.find(corr_id); + KALDI_ASSERT(it != corr_id2channel_.end()); + ChannelId ichannel = it->second; + channels_callbacks_[ichannel].reset( + new std::function(callback)); +} + +bool BatchedThreadedNnet3CudaOnlinePipeline::TryInitCorrID( + CorrelationID corr_id, int wait_for) { + bool inserted; + decltype(corr_id2channel_.end()) it; + std::tie(it, inserted) = corr_id2channel_.insert({corr_id, -1}); + int32 ichannel; + if (inserted) { + // The corr_id was not in use + std::unique_lock lk(available_channels_m_); + bool channel_available = (available_channels_.size() > 0); + if (!channel_available) { + // We cannot use that corr_id + int waited_for = 0; + while (waited_for < wait_for) { + lk.unlock(); + usleep(KALDI_CUDA_DECODER_WAIT_FOR_AVAILABLE_CHANNEL_US); + waited_for += KALDI_CUDA_DECODER_WAIT_FOR_AVAILABLE_CHANNEL_US; + lk.lock(); + channel_available = (available_channels_.size() > 0); + if (channel_available) break; + } + + // If still not available return + if (!channel_available) { + corr_id2channel_.erase(it); + return false; + } + } + + ichannel = available_channels_.back(); + available_channels_.pop_back(); + it->second = ichannel; + } else { + // This corr id was already in use but not closed + // It can happen if for instance a channel lost connection and + // did not send its last chunk Cleaning up + KALDI_WARN << "This corr_id was already in use"; + ichannel = it->second; + } + channels_callbacks_[ichannel].reset(); + + if (!config_.use_gpu_feature_extraction) { + KALDI_ASSERT(!feature_pipelines_[ichannel]); + feature_pipelines_[ichannel].reset( + new OnlineNnet2FeaturePipeline(*feature_info_)); + } + + channel_frame_offset_[ichannel] = 0; + return true; +} // namespace cuda_decoder + +void BatchedThreadedNnet3CudaOnlinePipeline::ComputeGPUFeatureExtraction( + const std::vector &channels, + const std::vector> &wave_samples, + const std::vector &is_first_chunk, + const std::vector &is_last_chunk) { + for (int i = 0; i < wave_samples.size(); ++i) { + const SubVector &src = wave_samples[i]; + int size = src.Dim(); + n_samples_valid_[i] = size; + const BaseFloat *wave_src = src.Data(); + BaseFloat *wave_dst = h_all_waveform_.RowData(i); + std::memcpy(wave_dst, wave_src, size * sizeof(BaseFloat)); + } + // CopyFromMat syncs, avoiding it + KALDI_ASSERT(d_all_waveform_.SizeInBytes() == h_all_waveform_.SizeInBytes()); + cudaMemcpyAsync(d_all_waveform_.Data(), h_all_waveform_.Data(), + h_all_waveform_.SizeInBytes(), cudaMemcpyHostToDevice, + cudaStreamPerThread); + + KALDI_ASSERT(channels.size() == is_last_chunk.size()); + KALDI_ASSERT(channels.size() == is_first_chunk.size()); + + KALDI_ASSERT(gpu_feature_pipeline_); + gpu_feature_pipeline_->ComputeFeaturesBatched( + channels.size(), channels, n_samples_valid_, is_first_chunk, + is_last_chunk, model_frequency_, d_all_waveform_, &d_all_features_, + &d_all_ivectors_, &n_input_frames_valid_); +} + +void BatchedThreadedNnet3CudaOnlinePipeline::ComputeCPUFeatureExtraction( + const std::vector &channels, + const std::vector> &wave_samples, + const std::vector &is_last_chunk) { + // Will be used by worker threads to grab work + fe_threads_channels_ = &channels; + fe_threads_wave_samples_ = &wave_samples; + + n_compute_features_not_done_.store(channels.size()); + + for (size_t i = 0; i < channels.size(); ++i) { + thread_pool_->Push( + {&BatchedThreadedNnet3CudaOnlinePipeline::ComputeOneFeatureWrapper, + this, i, 0}); // second argument "0" is not used + } + + while (n_compute_features_not_done_.load(std::memory_order_acquire)) + usleep(KALDI_CUDA_DECODER_WAIT_FOR_CPU_FEATURES_THREADS_US); + + KALDI_ASSERT(d_all_features_.NumRows() == h_all_features_.NumRows() && + d_all_features_.NumCols() == h_all_features_.NumCols()); + cudaMemcpyAsync(d_all_features_.Data(), h_all_features_.Data(), + h_all_features_.SizeInBytes(), cudaMemcpyHostToDevice, + cudaStreamPerThread); + if (use_ivectors_) { + KALDI_ASSERT(d_all_ivectors_.Dim() >= + (h_all_ivectors_.NumRows() * h_all_ivectors_.NumCols())); + cudaMemcpyAsync(d_all_ivectors_.Data(), h_all_ivectors_.Data(), + h_all_ivectors_.SizeInBytes(), cudaMemcpyHostToDevice, + cudaStreamPerThread); + } +} + +void BatchedThreadedNnet3CudaOnlinePipeline::DecodeBatch( + const std::vector &corr_ids, + const std::vector> &wave_samples, + const std::vector &is_first_chunk, + const std::vector &is_last_chunk) { + nvtxRangePushA("DecodeBatch"); + KALDI_ASSERT(corr_ids.size() > 0); + KALDI_ASSERT(corr_ids.size() == wave_samples.size()); + KALDI_ASSERT(corr_ids.size() == is_last_chunk.size()); + + ListIChannelsInBatch(corr_ids, &channels_); + + if (config_.use_gpu_feature_extraction) + ComputeGPUFeatureExtraction(channels_, wave_samples, is_first_chunk, + is_last_chunk); + else + ComputeCPUFeatureExtraction(channels_, wave_samples, is_last_chunk); + + d_features_ptrs_.clear(); + d_ivectors_ptrs_.clear(); + for (int i = 0; i < channels_.size(); ++i) { + d_features_ptrs_.push_back(d_all_features_.Data() + + i * input_frames_per_chunk_ * + d_all_features_.Stride()); + if (use_ivectors_) { + d_ivectors_ptrs_.push_back(d_all_ivectors_.Data() + i * ivector_dim_); + } + } + int features_frame_stride = d_all_features_.Stride(); + DecodeBatch(corr_ids, d_features_ptrs_, features_frame_stride, + n_input_frames_valid_, d_ivectors_ptrs_, is_first_chunk, + is_last_chunk, &channels_); +} + +void BatchedThreadedNnet3CudaOnlinePipeline::DecodeBatch( + const std::vector &corr_ids, + const std::vector &d_features, const int features_frame_stride, + const std::vector &n_input_frames_valid, + const std::vector &d_ivectors, + const std::vector &is_first_chunk, + const std::vector &is_last_chunk, std::vector *channels) { + nvtxRangePushA("DecodeBatch"); + if (!channels) { + channels = &channels_; + ListIChannelsInBatch(corr_ids, channels); + } + + list_channels_first_chunk_.clear(); + for (size_t i = 0; i < is_first_chunk.size(); ++i) { + if (is_first_chunk[i]) list_channels_first_chunk_.push_back((*channels)[i]); + } + if (!list_channels_first_chunk_.empty()) + cuda_decoder_->InitDecoding(list_channels_first_chunk_); + + RunNnet3(*channels, d_features, features_frame_stride, n_input_frames_valid, + is_first_chunk, is_last_chunk, d_ivectors); + RunDecoder(*channels); + + BuildLatticesAndRunCallbacks(corr_ids, *channels, is_last_chunk); + nvtxRangePop(); +} + +void BatchedThreadedNnet3CudaOnlinePipeline::ComputeOneFeature(int element) { + const SubVector &wave_samples = + (*fe_threads_wave_samples_)[element]; + const int ichannel = (*fe_threads_channels_)[element]; + OnlineNnet2FeaturePipeline &feature_pipeline = *feature_pipelines_[ichannel]; + // KALDI_ASSERT("Mismatch sample frequency/model frequency" && + // (model_frequency_ == + // utt_chunk.sample_frequency_)); + KALDI_ASSERT( + "Too many samples for one chunk. Must be <= " + "this.GetNSampsPerChunk()" && + wave_samples.Dim() <= samples_per_chunk_); + int32 start_iframe = feature_pipeline.NumFramesReady(); + feature_pipeline.AcceptWaveform(model_frequency_, wave_samples); + + // All frames should be ready here + int32 end_iframe = feature_pipeline.NumFramesReady(); + int32 nframes = end_iframe - start_iframe; + if (nframes > 0) { + SubMatrix utt_features = + h_all_features_.RowRange(element * input_frames_per_chunk_, nframes); + std::vector frames(nframes); + for (int j = start_iframe; j < end_iframe; ++j) + frames[j - start_iframe] = j; + // + // Copy Features + feature_pipeline.InputFeature()->GetFrames(frames, &utt_features); + + // If available, copy ivectors + if (use_ivectors_) { + SubVector utt_ivector = h_all_ivectors_.Row(element); + feature_pipeline.IvectorFeature()->GetFrame(end_iframe - 1, &utt_ivector); + } + } + n_input_frames_valid_[element] = nframes; + + n_compute_features_not_done_.fetch_sub(1, std::memory_order_release); +} + +void BatchedThreadedNnet3CudaOnlinePipeline::BuildLatticesAndRunCallbacks( + const std::vector &corr_ids, + const std::vector &channels, const std::vector &is_last_chunk) { + list_channels_last_chunk_.clear(); + list_corr_id_last_chunk_.clear(); + for (int i = 0; i < is_last_chunk.size(); ++i) { + if (is_last_chunk[i]) { + list_channels_last_chunk_.push_back(channels[i]); + list_corr_id_last_chunk_.push_back(corr_ids[i]); + } + } + cuda_decoder_->PrepareForGetRawLattice(list_channels_last_chunk_, true); + // Storing number of callbacks not done. Used if + // WaitForLatticeCallbacks() is called + n_lattice_callbacks_not_done_.fetch_add(list_channels_last_chunk_.size(), + std::memory_order_acquire); + + // delete data used for decoding that corr_id + for (int32 i = 0; i < list_channels_last_chunk_.size(); ++i) { + uint64_t ichannel = list_channels_last_chunk_[i]; + CorrelationID corr_id = list_corr_id_last_chunk_[i]; + int32 ndeleted = corr_id2channel_.erase(corr_id); + KALDI_ASSERT(ndeleted == 1); + thread_pool_->Push( + {&BatchedThreadedNnet3CudaOnlinePipeline::FinalizeDecodingWrapper, this, + ichannel, corr_id}); + if (!config_.use_gpu_feature_extraction) { + // Done with this CPU FE pipeline + KALDI_ASSERT(feature_pipelines_[ichannel]); + feature_pipelines_[ichannel].reset(); + } + } + list_channels_last_chunk_.clear(); + list_corr_id_last_chunk_.clear(); +} + +void BatchedThreadedNnet3CudaOnlinePipeline::ListIChannelsInBatch( + const std::vector &corr_ids, std::vector *channels) { + channels->clear(); + list_channels_last_chunk_.clear(); + list_corr_id_last_chunk_.clear(); + for (int i = 0; i < corr_ids.size(); ++i) { + int corr_id = corr_ids[i]; + auto it = corr_id2channel_.find(corr_id); + KALDI_ASSERT(it != corr_id2channel_.end()); + int ichannel = it->second; + channels->push_back(ichannel); + } +} + +void BatchedThreadedNnet3CudaOnlinePipeline::RunNnet3( + const std::vector &channels, + const std::vector &d_features, const int features_stride, + const std::vector &n_input_frames_valid, + const std::vector &is_first_chunk, + const std::vector &is_last_chunk, + const std::vector &d_ivectors) { + cuda_nnet3_->RunBatch(channels, d_features, features_stride, d_ivectors, + n_input_frames_valid, is_first_chunk, is_last_chunk, + &d_all_log_posteriors_, &all_frames_log_posteriors_); +} + +void BatchedThreadedNnet3CudaOnlinePipeline::RunDecoder( + const std::vector &channels) { + for (int iframe = 0; iframe < all_frames_log_posteriors_.size(); ++iframe) { + cuda_decoder_->AdvanceDecoding(all_frames_log_posteriors_[iframe]); + } +} + +void BatchedThreadedNnet3CudaOnlinePipeline::ReadParametersFromModel() { + feature_info_.reset(new OnlineNnet2FeaturePipelineInfo(config_.feature_opts)); + feature_info_->ivector_extractor_info.use_most_recent_ivector = true; + feature_info_->ivector_extractor_info.greedy_ivector_extractor = true; + + OnlineNnet2FeaturePipeline feature(*feature_info_); + use_ivectors_ = (feature.IvectorFeature() != NULL); + input_dim_ = feature.InputFeature()->Dim(); + if (use_ivectors_) ivector_dim_ = feature.IvectorFeature()->Dim(); + model_frequency_ = feature_info_->GetSamplingFrequency(); + BaseFloat frame_shift = feature_info_->FrameShiftInSeconds(); + input_frames_per_chunk_ = config_.compute_opts.frames_per_chunk; + seconds_per_chunk_ = input_frames_per_chunk_ * frame_shift; + int32 samp_per_frame = static_cast(model_frequency_ * frame_shift); + samples_per_chunk_ = input_frames_per_chunk_ * samp_per_frame; + BatchedStaticNnet3Config nnet3_config; + nnet3_config.compute_opts = config_.compute_opts; + nnet3_config.max_batch_size = max_batch_size_; + nnet3_config.nchannels = config_.num_channels; + nnet3_config.has_ivector = (feature.IvectorFeature() != NULL); + + cuda_nnet3_.reset(new BatchedStaticNnet3(nnet3_config, *am_nnet_)); + output_frames_per_chunk_ = cuda_nnet3_->GetNOutputFramesPerChunk(); +} + +void BatchedThreadedNnet3CudaOnlinePipeline::FinalizeDecoding( + int32 ichannel, CorrelationID corr_id) { + Lattice lat; + cuda_decoder_->ConcurrentGetRawLatticeSingleChannel(ichannel, &lat); + + // Getting the channel callback now, we're going to free that channel + std::unique_ptr> callback; + callback = std::move(channels_callbacks_[ichannel]); + // Done with this channel. Making it available again + { + std::lock_guard lk(available_channels_m_); + available_channels_.push_back(ichannel); + } + + // If necessary, determinize the lattice + CompactLattice dlat; + if (config_.determinize_lattice) { + DeterminizeLatticePhonePrunedWrapper(*trans_model_, &lat, + config_.decoder_opts.lattice_beam, + &dlat, config_.det_opts); + } else { + ConvertLattice(lat, &dlat); + } + + if (dlat.NumStates() > 0) { + if (word_syms_) { + CompactLattice best_path_clat; + CompactLatticeShortestPath(dlat, &best_path_clat); + + Lattice best_path_lat; + ConvertLattice(best_path_clat, &best_path_lat); + + std::vector alignment; + std::vector words; + LatticeWeight weight; + GetLinearSymbolSequence(best_path_lat, &alignment, &words, &weight); + std::ostringstream oss; + for (size_t i = 0; i < words.size(); i++) { + std::string s = word_syms_->Find(words[i]); + if (s == "") oss << "Word-id " << words[i] << " not in symbol table."; + oss << s << " "; + } + { + std::lock_guard lk(stdout_m_); + KALDI_LOG << "OUTPUT: " << oss.str(); + } + } + } + + // if ptr set and if callback func callable + if (callback && *callback) { + (*callback)(dlat); + } + + n_lattice_callbacks_not_done_.fetch_sub(1, std::memory_order_release); +} + +void BatchedThreadedNnet3CudaOnlinePipeline::WaitForLatticeCallbacks() { + while (n_lattice_callbacks_not_done_.load() != 0) + usleep(KALDI_CUDA_DECODER_WAIT_FOR_CALLBACKS_US); +} + +} // namespace cuda_decoder +} // namespace kaldi + +#endif // HAVE_CUDA diff --git a/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h b/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h new file mode 100644 index 00000000000..ccb91cb2fc9 --- /dev/null +++ b/src/cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h @@ -0,0 +1,367 @@ +// cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h +// +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Hugo Braun +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if HAVE_CUDA == 1 + +#ifndef KALDI_CUDA_DECODER_BATCHED_THREADED_CUDA_ONLINE_PIPELINE_H_ +#define KALDI_CUDA_DECODER_BATCHED_THREADED_CUDA_ONLINE_PIPELINE_H_ + +#define KALDI_CUDA_DECODER_MIN_NCHANNELS_FACTOR 2 + +#include +#include + +#include "base/kaldi-utils.h" +#include "cudadecoder/batched-static-nnet3.h" +#include "cudadecoder/cuda-decoder.h" +#include "cudadecoder/thread-pool-light.h" +#include "cudafeat/online-batched-feature-pipeline-cuda.h" +#include "feat/wave-reader.h" +#include "lat/determinize-lattice-pruned.h" +#include "nnet3/am-nnet-simple.h" +#include "nnet3/nnet-am-decodable-simple.h" +#include "nnet3/nnet-compute.h" +#include "nnet3/nnet-optimize.h" +#include "online2/online-nnet2-feature-pipeline.h" + +namespace kaldi { +namespace cuda_decoder { + +// +// Online Streaming Batched Pipeline calling feature extraction, CUDA light +// Nnet3 driver and CUDA decoder. Can handle up to num_channels streaming audio +// channels in parallel. Each channel is externally identified by a correlation +// id (corr_id). Receives chunks of audio (up to max_batch_size per DecodeBatch +// call). Will call a callback with the final lattice once the processing of the +// final chunk is done. +// +// For an example on how to use that pipeline, see +// cudadecoderbin/batched-threaded-wav-nnet3-online.cc +// +// Feature extraction can be CUDA or CPU +// (multithreaded). +// Internally reuses the concept of channels and lanes from the CUDA decoder +// + +struct BatchedThreadedNnet3CudaOnlinePipelineConfig { + BatchedThreadedNnet3CudaOnlinePipelineConfig() + : max_batch_size(400), + num_channels(600), + num_worker_threads(-1), + determinize_lattice(true), + num_decoder_copy_threads(2), + use_gpu_feature_extraction(true) {} + void Register(OptionsItf *po) { + po->Register("max-batch-size", &max_batch_size, + "The maximum execution batch size. " + "Larger = Better throughput slower latency."); + po->Register("num-channels", &num_channels, + "The number of parallel audio channels. This is the maximum " + "number of parallel audio channels supported by the pipeline" + ". This should be larger " + "than max_batch_size."); + po->Register("cuda-worker-threads", &num_worker_threads, + "(optional) The total number of CPU threads launched to " + "process CPU tasks. -1 = use std::hardware_concurrency()"); + po->Register("determinize-lattice", &determinize_lattice, + "Determinize the lattice before output."); + po->Register("cuda-decoder-copy-threads", &num_decoder_copy_threads, + "Advanced - Number of worker threads used in the " + "decoder for " + "the host to host copies."); + po->Register("gpu-feature-extract", &use_gpu_feature_extraction, + "Use GPU feature extraction"); + + feature_opts.Register(po); + decoder_opts.Register(po); + det_opts.Register(po); + compute_opts.Register(po); + } + int max_batch_size; + int num_channels; + int num_worker_threads; + bool determinize_lattice; + int num_decoder_copy_threads; + bool use_gpu_feature_extraction; + + OnlineNnet2FeaturePipelineConfig feature_opts; + CudaDecoderConfig decoder_opts; + fst::DeterminizeLatticePhonePrunedOptions det_opts; + nnet3::NnetSimpleComputationOptions compute_opts; + + void CheckAndFixConfigs() { + KALDI_ASSERT(max_batch_size > 0); + // Lower bound on nchannels. + // Using strictly more than max_batch_size because channels are still used + // when the lattice postprocessing is running. We still want to run full + // max_batch_size batches in the meantime + int min_nchannels = + max_batch_size * KALDI_CUDA_DECODER_MIN_NCHANNELS_FACTOR; + num_channels = std::max(num_channels, min_nchannels); + + // If not set use number of physical threads + num_worker_threads = (num_worker_threads > 0) + ? num_worker_threads + : std::thread::hardware_concurrency(); + } +}; + +class BatchedThreadedNnet3CudaOnlinePipeline { + public: + using CorrelationID = uint64_t; + BatchedThreadedNnet3CudaOnlinePipeline( + const BatchedThreadedNnet3CudaOnlinePipelineConfig &config, + const fst::Fst &decode_fst, + const nnet3::AmNnetSimple &am_nnet, const TransitionModel &trans_model) + : config_(config), + max_batch_size_(config.max_batch_size), + trans_model_(&trans_model), + am_nnet_(&am_nnet), + word_syms_(NULL) { + config_.compute_opts.CheckAndFixConfigs(am_nnet_->GetNnet().Modulus()); + config_.CheckAndFixConfigs(); + int num_worker_threads = config_.num_worker_threads; + thread_pool_.reset(new ThreadPoolLight(num_worker_threads)); + + Initialize(decode_fst); + } + + // Called when a new utterance will be decoded w/ correlation id corr_id + // When this utterance will be done (when it will receive a chunk with + // last_chunk=true) + // If no channels are available, will wait for "wait_for" microseconds + // Returns true if a channel was available (eventually after waiting for + // up to wait_for seconds) + bool TryInitCorrID(CorrelationID corr_id, int wait_for = 0); + + // Set the callback function to call with the final lattice for a given + // corr_id + void SetLatticeCallback( + CorrelationID corr_id, + const std::function &callback); + + // Chunk of one utterance. We receive batches of those chunks through + // DecodeBatch + // Contains pointers to that chunk, the corresponding correlation ID, + // and whether that chunk is the last one for that utterance + struct UtteranceChunk { + CorrelationID corr_id; + SubVector wave_samples; + bool last_chunk; // sets to true if last chunk for that + // utterance + }; + + // Receive a batch of chunks. Will decode them, then return. + // If it contains some last chunks for given utterances, it will call + // FinalizeDecoding (building the final lattice, determinize it, etc.) + // asynchronously. The callback for that utterance will then be called + void DecodeBatch(const std::vector &corr_ids, + const std::vector> &wave_samples, + const std::vector &is_first_chunk, + const std::vector &is_last_chunk); + + // Version providing directly the features. Only runs nnet3 & decoder + // Used when we want to provide the final ivectors (offline case) + // channels can be provided if they are known (internal use) + void DecodeBatch(const std::vector &corr_ids, + const std::vector &d_features, + const int features_frame_stride, + const std::vector &n_input_frames_valid, + const std::vector &d_ivectors, + const std::vector &is_first_chunk, + const std::vector &is_last_chunk, + std::vector *channels = NULL); + + void ComputeGPUFeatureExtraction( + const std::vector &channels, + const std::vector> &wave_samples, + const std::vector &is_first_chunk, + const std::vector &is_last_chunk); + + void ComputeCPUFeatureExtraction( + const std::vector &channels, + const std::vector> &wave_samples, + const std::vector &is_last_chunk); + + // Maximum number of samples per chunk + int32 GetNSampsPerChunk() { return samples_per_chunk_; } + int32 GetNInputFramesPerChunk() { return input_frames_per_chunk_; } + float GetModelFrequency() { return model_frequency_; } + int GetTotalNnet3RightContext() { + return cuda_nnet3_->GetTotalNnet3RightContext(); + } + // Maximum number of seconds per chunk + BaseFloat GetSecondsPerChunk() { return seconds_per_chunk_; } + + // Used when debugging. Used to Print the text when a decoding is done + void SetSymbolTable(fst::SymbolTable *word_syms) { word_syms_ = word_syms; } + + // Wait for all lattice callbacks to complete + // Can be called after DecodeBatch + void WaitForLatticeCallbacks(); + + private: + // Initiliaze this object + void Initialize(const fst::Fst &decode_fst); + + // Allocate and initialize data that will be used for computation + void AllocateAndInitializeData(const fst::Fst &decode_fst); + + // Reads what's needed from models, such as left and right context + void ReadParametersFromModel(); + + // Following functions are DecodeBatch's helpers + + // Filling curr_batch_ichannels_ + void ListIChannelsInBatch(const std::vector &corr_ids, + std::vector *channels); + void CPUFeatureExtraction( + const std::vector &channels, + const std::vector> &wave_samples); + + // Compute features and ivectors for the chunk + // curr_batch[element] + // CPU function + void ComputeOneFeature(int element); + static void ComputeOneFeatureWrapper(void *obj, uint64_t element, + uint64_t ignored) { + static_cast(obj) + ->ComputeOneFeature(element); + } + void RunNnet3(const std::vector &channels, + const std::vector &d_features, + const int feature_stride, + const std::vector &n_input_frames_valid, + const std::vector &is_first_chunk, + const std::vector &is_last_chunk, + const std::vector &d_ivectors); + + void RunDecoder(const std::vector &channels); + + void BuildLatticesAndRunCallbacks(const std::vector &corr_ids, + const std::vector &channels, + const std::vector &is_last_chunk); + + // If an utterance is done, we call FinalizeDecoding async on + // the threadpool + // it will call the utterance's callback when done + void FinalizeDecoding(int32 ichannel, CorrelationID corr_id); + // static wrapper for thread pool + static void FinalizeDecodingWrapper(void *obj, uint64_t ichannel64, + uint64_t corr_id) { + int32 ichannel = static_cast(ichannel64); + static_cast(obj) + ->FinalizeDecoding(ichannel, corr_id); + } + // Data members + + BatchedThreadedNnet3CudaOnlinePipelineConfig config_; + int32 max_batch_size_; // extracted from config_ + // Models + const TransitionModel *trans_model_; + const nnet3::AmNnetSimple *am_nnet_; + std::unique_ptr feature_info_; + + // Decoder channels currently available, w/ mutex + std::vector available_channels_; + std::mutex available_channels_m_; + + // corr_id -> decoder channel map + std::unordered_map corr_id2channel_; + + // channels -> callbacks + // the callback is called once the final lattice is ready + std::vector>> + channels_callbacks_; + + // New channels in the current batch. We've just received + // their first batch + std::vector list_channels_first_chunk_; + + std::vector n_samples_valid_, n_input_frames_valid_; + + std::vector>> + all_frames_log_posteriors_; + + // Channels done after current batch. We've just received + // their last chunk + std::vector list_channels_last_chunk_; + std::vector list_corr_id_last_chunk_; + + // Number of frames already computed in channel (before + // curr_batch_) + std::vector channel_frame_offset_; + + // Parameters extracted from the models + int input_frames_per_chunk_; + int output_frames_per_chunk_; + BaseFloat seconds_per_chunk_; + BaseFloat samples_per_chunk_; + BaseFloat model_frequency_; + int32 ivector_dim_, input_dim_; + + // Buffers used during computation + Matrix h_all_features_; + Matrix h_all_waveform_; + CuMatrix d_all_waveform_; + CuMatrix d_all_features_; + Matrix h_all_ivectors_; + CuVector d_all_ivectors_; // gpu pipeline uses a meta vector + CuMatrix d_all_log_posteriors_; + + bool use_ivectors_; + // Used with CPU features extraction. Contains the number of CPU FE tasks + // still running + std::atomic n_compute_features_not_done_; + // Number of CPU lattice postprocessing tasks still running + std::atomic n_lattice_callbacks_not_done_; + + // Current assignement buffers, when DecodeBatch is running + std::vector channels_; + std::vector d_features_ptrs_; + std::vector d_ivectors_ptrs_; + + // Used by CPU FE threads. Could be merged with channels_ + const std::vector *fe_threads_channels_; + const std::vector> *fe_threads_wave_samples_; + + std::unique_ptr gpu_feature_pipeline_; + std::unique_ptr cuda_nnet3_; + + // Feature pipelines, associated to a channel + // Only used if feature extraction is run on the CPU + std::vector> feature_pipelines_; + + // HCLG graph : CudaFst object is a host object, but contains + // data stored in + // GPU memory + std::shared_ptr cuda_fst_; + std::unique_ptr cuda_decoder_; + + std::unique_ptr thread_pool_; + + // Used for debugging + fst::SymbolTable *word_syms_; + // Used when printing to stdout for debugging purposes + std::mutex stdout_m_; +}; + +} // end namespace cuda_decoder +} // end namespace kaldi. + +#endif // KALDI_CUDA_DECODER_BATCHED_THREADED_CUDA_ONLINE_PIPELINE_H_ +#endif // HAVE_CUDA diff --git a/src/cudadecoder/batched-threaded-nnet3-cuda-pipeline.cc b/src/cudadecoder/batched-threaded-nnet3-cuda-pipeline.cc index 0ca64ccc275..87602f0920c 100644 --- a/src/cudadecoder/batched-threaded-nnet3-cuda-pipeline.cc +++ b/src/cudadecoder/batched-threaded-nnet3-cuda-pipeline.cc @@ -23,12 +23,18 @@ #include #include "base/kaldi-utils.h" +// This pipeline is deprecated and will be removed. Please switch to +// batched-threaded-nnet3-cuda-pipeline2 + namespace kaldi { namespace cuda_decoder { void BatchedThreadedNnet3CudaPipeline::Initialize( const fst::Fst &decode_fst, const nnet3::AmNnetSimple &am_nnet, const TransitionModel &trans_model) { + KALDI_LOG << "\n\nIMPORTANT: This pipeline is deprecated. Please switch to " + "cudadecoderbin/batch-wav-nnet3-cuda2 (binary) or " + "cudadecoder/batched-threaded-nnet3-cuda-pipeline2.h (class)\n"; KALDI_LOG << "BatchedThreadedNnet3CudaPipeline Initialize with " << config_.num_control_threads << " control threads, " << config_.num_worker_threads << " worker threads" @@ -47,16 +53,19 @@ void BatchedThreadedNnet3CudaPipeline::Initialize( // create work queue, padding so that we can better detect if this // overflows. this should not happen and is just there as a sanity check - pending_task_queue_ = new TaskState *[config_.max_pending_tasks + - config_.pending_queue_padding]; + pending_task_queue_ = new TaskState + *[config_.max_pending_tasks + config_.pending_queue_padding]; tasks_front_ = 0; tasks_back_ = 0; - // ensure all allocations/kernels above are complete before launching threads - // in different streams. + // ensure all allocations/kernels above are complete before launching + // threads in different streams. cudaStreamSynchronize(cudaStreamPerThread); // Create threadpool for CPU work + // Using the thread pool light for decoder + // This pipeline is deprecated. Using two thread pools is not ideal but + // this pipeline will be removed eventually work_pool_ = new ThreadPool(config_.num_worker_threads); exit_ = false; @@ -68,7 +77,8 @@ void BatchedThreadedNnet3CudaPipeline::Initialize( std::thread(&BatchedThreadedNnet3CudaPipeline::ExecuteWorker, this, i); } - // wait for threads to start to ensure allocation time isn't in the timings + // wait for threads to start to ensure allocation time isn't in the + // timings while (numStarted_ < config_.num_control_threads) kaldi::Sleep(SLEEP_BACKOFF_S); } @@ -148,9 +158,8 @@ void BatchedThreadedNnet3CudaPipeline::WaitForGroup(const std::string &group) { std::unique_lock lk(group_tasks_mutex_); group_done_cv_.wait( lk, [this, &group] { return group_tasks_not_done_[group] == 0; }); - // Safe to delete entry from the map now. If the user creates new task in that - // group, - // the entry will be created once more + // Safe to delete entry from the map now. If the user creates new task + // in that group, the entry will be created once more group_tasks_not_done_.erase(group); } @@ -194,7 +203,7 @@ void BatchedThreadedNnet3CudaPipeline::CloseAllDecodeHandlesForGroup( std::lock_guard lk1(tasks_lookup_mutex_); auto p = tasks_group_lookup_.equal_range(group); for (auto it = p.first; it != p.second; ++it) { - KALDI_ASSERT(it->second->finished==true); + KALDI_ASSERT(it->second->finished == true); tasks_lookup_.erase(it->second->key); } tasks_group_lookup_.erase(p.first, p.second); @@ -297,8 +306,8 @@ bool BatchedThreadedNnet3CudaPipeline::GetRawLattice(const std::string &key, // intervention from the master thread. while (task->finished == false) kaldi::Sleep(SLEEP_BACKOFF_S); - // GetRawLattice on a determinized lattice is not supported (Per email from - // DanP) + // GetRawLattice on a determinized lattice is not supported (Per email + // from DanP) KALDI_ASSERT(task->determinized == false); if (task->error) { @@ -331,9 +340,8 @@ bool BatchedThreadedNnet3CudaPipeline::GetLattice(const std::string &key, return false; } - // if user has not requested a determinized lattice from the decoder then we - // must - // determinize it here since it was done done already. + // if user has not requested a determinized lattice from the decoder + // then we must determinize it here since it was done done already. if (!config_.determinize_lattice && !task->determinized) { // Determinzation was not done by worker threads so do it here DeterminizeOneLattice(task); @@ -349,8 +357,8 @@ void BatchedThreadedNnet3CudaPipeline::AddTaskToPendingTaskQueue( TaskState *task) { std::lock_guard lk(tasks_add_mutex_); if (NumPendingTasks() == config_.max_pending_tasks) { - // task queue is full launch a new thread to add this task and exit to make - // room for other work + // task queue is full launch a new thread to add this task and + // exit to make room for other work work_pool_->enqueue( THREAD_POOL_LOW_PRIORITY, &BatchedThreadedNnet3CudaPipeline::AddTaskToPendingTaskQueue, this, @@ -360,8 +368,8 @@ void BatchedThreadedNnet3CudaPipeline::AddTaskToPendingTaskQueue( // insert into pending task queue pending_task_queue_[tasks_back_] = task; // (int)tasks_back_); - tasks_back_ = (tasks_back_ + 1) % (config_.max_pending_tasks + - config_.pending_queue_padding); + tasks_back_ = (tasks_back_ + 1) % + (config_.max_pending_tasks + config_.pending_queue_padding); KALDI_ASSERT(NumPendingTasks() <= config_.max_pending_tasks); } } @@ -393,8 +401,8 @@ void BatchedThreadedNnet3CudaPipeline::AquireAdditionalTasks( // pending_task_queue_[tasks_front_]); KALDI_ASSERT(NumPendingTasks() > 0); tasks.push_back(pending_task_queue_[tasks_front_]); - tasks_front_ = (tasks_front_ + 1) % (config_.max_pending_tasks - + config_.pending_queue_padding); + tasks_front_ = (tasks_front_ + 1) % (config_.max_pending_tasks + + config_.pending_queue_padding); } } } @@ -410,8 +418,8 @@ void BatchedThreadedNnet3CudaPipeline::AquireAdditionalTasks( ChannelId channel; { std::lock_guard lk(channel_state.free_channels_mutex); - KALDI_ASSERT(free_channels.size() > - 0); // it should always be true (cf std::min above) + KALDI_ASSERT(free_channels.size() > 0); // it should always be true + // (cf std::min above) channel = free_channels.back(); free_channels.pop_back(); } @@ -478,8 +486,8 @@ void BatchedThreadedNnet3CudaPipeline::ComputeBatchNnet( } } - // process all minibatches, we allow partial minibatches but this should only - // occur on the last iteration + // process all minibatches, we allow partial minibatches but this should + // only occur on the last iteration bool allow_partial_minibatch = true; while (computer.Compute(allow_partial_minibatch)) ; @@ -491,12 +499,13 @@ void BatchedThreadedNnet3CudaPipeline::ComputeBatchNnet( CuMatrix &posteriors = task_data->posteriors; MergeTaskOutput(nnet_tasks[i], &posteriors); - // nnet output is no longer necessary as we have copied the output out + // nnet output is no longer necessary as we have copied the + // output out nnet_tasks[i].resize(0); // features are no longer needed so free memory here task_data->ivector_features.Resize(0); - task_data->input_features.Resize(0,0); + task_data->input_features.Resize(0, 0); } nvtxRangePop(); @@ -555,10 +564,11 @@ void BatchedThreadedNnet3CudaPipeline::ComputeBatchFeatures( OnlineCudaFeaturePipeline &feature_pipeline) { KALDI_ASSERT(config_.gpu_feature_extract == true); nvtxRangePushA("CopyBatchWaves"); - // below we will pack waves into a single buffer for efficient transfer across - // device + // below we will pack waves into a single buffer for efficient transfer + // across device - // first count the total number of elements and create a single large vector + // first count the total number of elements and create a single large + // vector int count = 0; for (int i = first; i < tasks.size(); i++) { count += tasks[i]->task_data->wave_samples->Dim(); @@ -570,9 +580,9 @@ void BatchedThreadedNnet3CudaPipeline::ComputeBatchFeatures( thread_local Vector pinned_vector; if (pinned_vector.Dim() < count) { - // WAR: Not pinning memory because it seems to impact correctness - // we are continuing to look into a fix but want to commit this workaround - // as a temporary measure. + // WAR: Not pinning memory because it seems to impact + // correctness we are continuing to look into a fix but want to + // commit this workaround as a temporary measure. if (pinned_vector.Dim() != 0) { cudaHostUnregister(pinned_vector.Data()); } @@ -580,11 +590,11 @@ void BatchedThreadedNnet3CudaPipeline::ComputeBatchFeatures( // allocated array 2x size pinned_vector.Resize(count * 2, kUndefined); cudaHostRegister(pinned_vector.Data(), - pinned_vector.Dim() * sizeof(BaseFloat), 0); + pinned_vector.Dim() * sizeof(BaseFloat), 0); } - // We will launch a thread for each task in order to get better host memory - // bandwidth + // We will launch a thread for each task in order to get better host + // memory bandwidth std::vector> futures; // for syncing // vector copy function for threading below. @@ -611,11 +621,10 @@ void BatchedThreadedNnet3CudaPipeline::ComputeBatchFeatures( } CuVector cu_waves(count, kUndefined); - // copy memory down asynchronously. Vector copy functions are synchronous so - // we do it manually. - // It is important for this to happen asynchrously to help hide launch latency - // of smaller kernels - // that come in the future. + // copy memory down asynchronously. Vector copy functions are + // synchronous so we do it manually. It is important for this to happen + // asynchrously to help hide launch latency of smaller kernels that come + // in the future. cudaMemcpyAsync(cu_waves.Data(), pinned_vector.Data(), cu_waves.Dim() * sizeof(BaseFloat), cudaMemcpyHostToDevice, cudaStreamPerThread); @@ -638,7 +647,8 @@ void BatchedThreadedNnet3CudaPipeline::ComputeBatchFeatures( int32 numFrames = task_data->input_features.NumRows(); if (numFrames == 0) { - // Make this a warning for now. Need to check how this is handled + // Make this a warning for now. Need to check how this + // is handled KALDI_WARN << "Warning empty audio file"; } } @@ -684,7 +694,8 @@ void BatchedThreadedNnet3CudaPipeline::RemoveCompletedChannels( // add channel to free and completed queues completed_channels.push_back(channel); - // this was assigned earlier just making sure it is still consistent + // this was assigned earlier just making sure it is + // still consistent KALDI_ASSERT(tasks[cur]->ichannel == channel); // Rearrange queues, @@ -724,24 +735,25 @@ void BatchedThreadedNnet3CudaPipeline::PostDecodeProcessing( cuda_decoder.PrepareForGetRawLattice(completed_channels, true); // clean up datastructures for completed tasks for (int i = channels.size(); i < tasks.size(); i++) { - tasks[i]->task_data->posteriors.Resize(0,0); + tasks[i]->task_data->posteriors.Resize(0, 0); delete decodables[i]; } // Calling GetRawLattice + Determinize (optional) on a CPU worker thread for (int i = channels.size(); i < tasks.size(); i++) { - // checking that this channel is actually in the completed channels list - // order is reversed because we used push_back into completed_channel list + // checking that this channel is actually in the completed + // channels list order is reversed because we used push_back + // into completed_channel list KALDI_ASSERT(tasks[i]->ichannel == completed_channels[channels.size() + completed_channels.size() - i - 1]); - // enqueue task completion on a worker thread. We do not need to wait - // for sychronization on this thread as the parameters passed to this - // thread are persistent and that thread will return resources to the - // system when they free up. + // enqueue task completion on a worker thread. We do not need + // to wait for sychronization on this thread as the parameters + // passed to this thread are persistent and that thread will + // return resources to the system when they free up. work_pool_->enqueue(THREAD_POOL_NORMAL_PRIORITY, - &BatchedThreadedNnet3CudaPipeline::CompleteTask, - this, &cuda_decoder, &channel_state, tasks[i]); + &BatchedThreadedNnet3CudaPipeline::CompleteTask, this, + &cuda_decoder, &channel_state, tasks[i]); } tasks.resize(channels.size()); @@ -752,11 +764,12 @@ void BatchedThreadedNnet3CudaPipeline::PostDecodeProcessing( void BatchedThreadedNnet3CudaPipeline::CompleteTask(CudaDecoder *cuda_decoder, ChannelState *channel_state, TaskState *task) { - // Calling GetRawLattice for that channel. PrepareForGetRawLattice was already - // called + // Calling GetRawLattice for that channel. PrepareForGetRawLattice was + // already called cuda_decoder->ConcurrentGetRawLatticeSingleChannel(task->ichannel, &task->lat); - // We are done using that channel. Putting it back into the free channels + // We are done using that channel. Putting it back into the free + // channels { std::lock_guard lk(channel_state->free_channels_mutex); channel_state->free_channels.push_back(task->ichannel); @@ -778,7 +791,8 @@ void BatchedThreadedNnet3CudaPipeline::CompleteTask(CudaDecoder *cuda_decoder, std::lock_guard lk(group_tasks_mutex_); --all_group_tasks_not_done_; int32 left_in_group = --group_tasks_not_done_[task->group]; - // std::cout << "left in group " << task->group << " " << left_in_group + // std::cout << "left in group " << task->group << " " << + // left_in_group // << std::endl; if (left_in_group == 0) group_done_cv_.notify_all(); } @@ -800,12 +814,10 @@ void BatchedThreadedNnet3CudaPipeline::ExecuteWorker(int threadId) { KALDI_LOG << "CudaDecoder batch_size=" << config_.max_batch_size << " num_channels=" << config_.num_channels; - // Data structures that are reusable across decodes but unique to each thread + // Data structures that are reusable across decodes but unique to each + // thread CudaDecoder cuda_decoder(cuda_fst_, config_.decoder_opts, config_.max_batch_size, config_.num_channels); - if (config_.num_decoder_copy_threads > 0) - cuda_decoder.SetThreadPoolAndStartCPUWorkers( - work_pool_, config_.num_decoder_copy_threads); nnet3::NnetBatchComputer computer(config_.compute_opts, am_nnet_->GetNnet(), am_nnet_->Priors()); @@ -834,28 +846,26 @@ void BatchedThreadedNnet3CudaPipeline::ExecuteWorker(int threadId) { numStarted_++; // Tell master I have started - // main control loop. At each iteration a thread will see if it has been - // asked to shut - // down. If it has it will exit. This loop condition will only be processed - // if all - // other work assigned to this thread has been processed. + // main control loop. At each iteration a thread will see if it has + // been asked to shut down. If it has it will exit. This loop + // condition will only be processed if all other work assigned to this + // thread has been processed. while (!exit_) { - // main processing loop. At each iteration the thread will do the - // following: - // 1) Attempt to grab more work. - // 2) Initialize any new work - // do - // 3) Process work in a batch - // while(free lanes < drain_count) - // 4) Postprocess any completed work + // main processing loop. At each iteration the thread will do + // the following: 1) Attempt to grab more work. 2) Initialize + // any new work do 3) Process work in a batch while(free lanes < + // drain_count) 4) Postprocess any completed work do { // 1) attempt to fill the batch - if (tasks_front_ != tasks_back_) { // if work is available grab more work + if (tasks_front_ != tasks_back_) { // if work is available grab more + // work - int start = tasks.size(); // Save the current assigned tasks size + int start = tasks.size(); // Save the current assigned + // tasks size AquireAdditionalTasks(cuda_decoder, channel_state, tasks); - // New tasks are now in the in tasks[start,tasks.size()) + // New tasks are now in the in + // tasks[start,tasks.size()) if (start != tasks.size()) { // if there are new tasks if (config_.gpu_feature_extract) ComputeBatchFeatures(start, tasks, feature_pipeline); @@ -865,9 +875,10 @@ void BatchedThreadedNnet3CudaPipeline::ExecuteWorker(int threadId) { } // end if (tasks_front_!=tasks_back_) // check if there is no active work on this thread. - // This can happen if another thread was assigned the work. + // This can happen if another thread was assigned the + // work. if (tasks.size() == 0) { - // Thread is spinning waiting for work. Backoff. + // Thread is spinning waiting for work. Backoff. kaldi::Sleep(SLEEP_BACKOFF_S); break; } @@ -875,45 +886,55 @@ void BatchedThreadedNnet3CudaPipeline::ExecuteWorker(int threadId) { // try/catch to catch and report errors inside decoder. // errors can be recoverable or non-recoverable // unrecoverable errors will assert - // recoverable errors will cancel the batch (output empty lattice) - // and print a warning. - // There should be no errors and this is just a sanity check + // recoverable errors will cancel the batch (output + // empty lattice) and print a warning. There should be + // no errors and this is just a sanity check try { - // This is in a loop in case we want to drain the batch a little. - // Draining the batch will cause initialization tasks to be batched. + // This is in a loop in case we want to drain + // the batch a little. Draining the batch will + // cause initialization tasks to be batched. do { - // 3) Process outstanding work in a batch - // Advance decoding on all open channels + // 3) Process outstanding work in a + // batch Advance decoding on all open + // channels cuda_decoder.AdvanceDecoding(channel_state.channels, decodables); - // Adjust channel state for all completed decodes + // Adjust channel state for all + // completed decodes RemoveCompletedChannels(cuda_decoder, channel_state, decodables, tasks); - // do loop repeates until we meet drain size or run out of work + // do loop repeates until we meet drain + // size or run out of work } while (config_.max_batch_size - channel_state.channels.size() < config_.batch_drain_size && channel_state.channels.size() > 0); - // 4) Post process work. This reorders completed work to the end, - // copies results outs, and cleans up data structures + // 4) Post process work. This reorders + // completed work to the end, copies results + // outs, and cleans up data structures PostDecodeProcessing(cuda_decoder, channel_state, decodables, tasks); } catch (CudaDecoderException e) { - // Code to catch errors. Most errors are unrecoverable but a user can - // mark them - // recoverable which will cancel the entire batch but keep processing. + // Code to catch errors. Most errors are + // unrecoverable but a user can mark them + // recoverable which will cancel the entire + // batch but keep processing. if (!e.recoverable) { bool UNRECOVERABLE_EXCEPTION = false; - KALDI_LOG << "Error unrecoverable cuda decoder error '" << e.what() - << "'\n"; + KALDI_LOG << "Error unrecoverable cuda " + "decoder error '" + << e.what() << "'\n"; KALDI_ASSERT(UNRECOVERABLE_EXCEPTION); } else { - KALDI_LOG << "Error recoverable cuda decoder error '" << e.what() - << "'\n"; - KALDI_LOG << " Aborting batch for recovery. Canceling the " + KALDI_LOG << "Error recoverable cuda " + "decoder error '" + << e.what() << "'\n"; + KALDI_LOG << " Aborting batch for " + "recovery. Canceling the " "following decodes:\n"; // Cancel all outstanding tasks for (int i = 0; i < tasks.size(); i++) { - // move all channels to free channel queue + // move all channels to free + // channel queue ChannelId channel = channel_state.channels[i]; { std::lock_guard lk(channel_state.free_channels_mutex); @@ -929,7 +950,8 @@ void BatchedThreadedNnet3CudaPipeline::ExecuteWorker(int threadId) { // cleanup memory delete decodables[i]; - // notifiy master decode is finished + // notifiy master decode is + // finished task.finished = true; } tasks.resize(0); diff --git a/src/cudadecoder/batched-threaded-nnet3-cuda-pipeline.h b/src/cudadecoder/batched-threaded-nnet3-cuda-pipeline.h index be6443b8a7a..694e4728eeb 100644 --- a/src/cudadecoder/batched-threaded-nnet3-cuda-pipeline.h +++ b/src/cudadecoder/batched-threaded-nnet3-cuda-pipeline.h @@ -15,20 +15,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef KALDI_CUDA_DECODER_BATCHED_THREADED_CUDA_DECODER_H_ -#define KALDI_CUDA_DECODER_BATCHED_THREADED_CUDA_DECODER_H_ +#ifndef KALDI_CUDA_DECODER_BATCHED_THREADED_NNET3_CUDA_PIPELINE_H_ +#define KALDI_CUDA_DECODER_BATCHED_THREADED_NNET3_CUDA_PIPELINE_H_ #include #include #include "cudadecoder/cuda-decoder.h" -#include "decodable-cumatrix.h" +#include "cudadecoder/decodable-cumatrix.h" +#include "cudadecoder/thread-pool.h" +#include "cudafeat/online-cuda-feature-pipeline.h" #include "feat/wave-reader.h" #include "lat/determinize-lattice-pruned.h" #include "nnet3/nnet-batch-compute.h" #include "online2/online-nnet2-feature-pipeline.h" -#include "cudafeat/online-cuda-feature-pipeline.h" -#include "thread-pool.h" + +// This pipeline is deprecated and will be removed. Please switch to +// batched-threaded-nnet3-cuda-pipeline2 // If num_channels sets to automatic, // num_channels = [this define] * max_batch_size @@ -56,7 +59,7 @@ struct BatchedThreadedNnet3CudaPipelineConfig { max_pending_tasks(4000), pending_queue_padding(10), num_decoder_copy_threads(2), - gpu_feature_extract(true) {}; + gpu_feature_extract(true){}; void Register(OptionsItf *po) { po->Register("max-batch-size", &max_batch_size, "The maximum batch size to be used by the decoder. " @@ -77,19 +80,22 @@ struct BatchedThreadedNnet3CudaPipelineConfig { "batches pre/post decode work."); po->Register("cuda-control-threads", &num_control_threads, "The number of pipeline control threads for the CUDA work. " - "e.g. 2 control threads -> 2 independent CUDA pipeline (nnet3 " + "e.g. 2 control threads -> 2 independent CUDA pipeline " + "(nnet3 " "and decoder)."); - po->Register( - "cuda-worker-threads", &num_worker_threads, - "The total number of CPU threads launched to process CPU tasks."); + po->Register("cuda-worker-threads", &num_worker_threads, + "The total number of CPU threads launched to " + "process CPU tasks."); po->Register("determinize-lattice", &determinize_lattice, "Determinize the lattice before output."); po->Register("max-outstanding-queue-length", &max_pending_tasks, - "Number of files to allow to be outstanding at a time. When " + "Number of files to allow to be outstanding at a time. " + "When " "the number of files is larger than this handles will be " "closed before opening new ones in FIFO order."); po->Register("cuda-decoder-copy-threads", &num_decoder_copy_threads, - "Advanced - Number of worker threads used in the decoder for " + "Advanced - Number of worker threads used in the " + "decoder for " "the host to host copies."); po->Register("gpu-feature-extract", &gpu_feature_extract, "Extract features on the GPU. This reduces CPU overhead " @@ -118,179 +124,182 @@ struct BatchedThreadedNnet3CudaPipelineConfig { max_batch_size * KALDI_CUDA_DECODER_CHANNELS_BATCH_SIZE_RATIO; } - OnlineNnet2FeaturePipelineConfig feature_opts; // constant readonly - CudaDecoderConfig decoder_opts; // constant readonly - fst::DeterminizeLatticePhonePrunedOptions det_opts; // constant readonly - nnet3::NnetBatchComputerOptions compute_opts; // constant readonly + OnlineNnet2FeaturePipelineConfig feature_opts; // constant readonly + CudaDecoderConfig decoder_opts; // constant readonly + fst::DeterminizeLatticePhonePrunedOptions det_opts; // constant readonly + nnet3::NnetBatchComputerOptions compute_opts; // constant readonly }; /* - * BatchedThreadedNnet3CudaPipeline uses multiple levels of parallelism in order to - * decode quickly on CUDA GPUs. This is the primary interface for cuda decoding. - * For examples of how to use this decoder see cudadecoder/README and + * BatchedThreadedNnet3CudaPipeline uses multiple levels of parallelism in order + * to decode quickly on CUDA GPUs. This is the primary interface for cuda + * decoding. For examples of how to use this decoder see cudadecoder/README and * cudadecoderbin/batched-wav-nnet3-cuda.cc */ class BatchedThreadedNnet3CudaPipeline { -public: - BatchedThreadedNnet3CudaPipeline( - const BatchedThreadedNnet3CudaPipelineConfig &config) - : config_(config), all_group_tasks_not_done_(0) { - config_.ComputeConfig(); - }; - - // allocates reusable objects that are common across all decodings - void Initialize(const fst::Fst &decode_fst, - const nnet3::AmNnetSimple &nnet, - const TransitionModel &trans_model); - - // deallocates reusable objects - void Finalize(); - - // query a specific key to see if compute on it is complete - bool isFinished(const std::string &key); - - // remove an audio file from the decoding and clean up resources - void CloseDecodeHandle(const std::string &key); - void CloseAllDecodeHandlesForGroup(const std::string &group); - void CloseAllDecodeHandles(); - - // Adds a decoding task to the decoder - // When passing in a vector of data, the caller must ensure the data exists - // until the CloseDecodeHandle/WaitForAllTasks is called - // callback is called once task is done and we pass it the final lattice - // callback can be used to compute lattice rescoring, find best path in - // lattice, writing lattice to disk, etc. - // Important: callback is launched in the threadpool. It must be threadsafe. - // For instance, if writing to disk, or to stdout, - // use a lock: - // e.g. : - // { - // std::lock_guard lock(global_mutex); - // // write lattice to disk - // // lock is released in the destructor of lock_guard<> - // } - void OpenDecodeHandle( - const std::string &key, const WaveData &wave_data, - const std::string &group = std::string(), - const std::function &callback = - std::function()); - // When passing in a vector of data, the caller must ensure the data exists - // until the CloseDecodeHandle is called - void OpenDecodeHandle( - const std::string &key, const VectorBase &wave_data, - float sample_rate, const std::string &group = std::string(), - const std::function &callback = - std::function()); - - // Copies the raw lattice for decoded handle "key" into lat - bool GetRawLattice(const std::string &key, Lattice *lat); - // Determinizes raw lattice and returns a compact lattice - bool GetLattice(const std::string &key, CompactLattice *lat); - - int32 GetNumberOfTasksPending(); - - // Wait for all tasks to complete - void WaitForAllTasks(); - // Wait for all tasks in the group to complete - void WaitForGroup(const std::string &group); - // Check if a group is available. Returns if not. - bool IsGroupCompleted(const std::string &group); - // Wait for any group to complete, then returns which group completed - std::string WaitForAnyGroup(); - // Check if any group is available. If one is available, set its name in *group - bool IsAnyGroupCompleted(std::string *group); - inline int NumPendingTasks() { - return (tasks_back_ - tasks_front_ + config_.max_pending_tasks + - config_.pending_queue_padding) % - (config_.max_pending_tasks + config_.pending_queue_padding); + public: + BatchedThreadedNnet3CudaPipeline( + const BatchedThreadedNnet3CudaPipelineConfig &config) + : config_(config), all_group_tasks_not_done_(0) { + config_.ComputeConfig(); + }; + + // allocates reusable objects that are common across all decodings + void Initialize(const fst::Fst &decode_fst, + const nnet3::AmNnetSimple &nnet, + const TransitionModel &trans_model); + + // deallocates reusable objects + void Finalize(); + + // query a specific key to see if compute on it is complete + bool isFinished(const std::string &key); + + // remove an audio file from the decoding and clean up resources + void CloseDecodeHandle(const std::string &key); + void CloseAllDecodeHandlesForGroup(const std::string &group); + void CloseAllDecodeHandles(); + + // Adds a decoding task to the decoder + // When passing in a vector of data, the caller must ensure the data + // exists until the CloseDecodeHandle/WaitForAllTasks is called callback + // is called once task is done and we pass it the final lattice callback + // can be used to compute lattice rescoring, find best path in lattice, + // writing lattice to disk, etc. Important: callback is launched in the + // threadpool. It must be threadsafe. For instance, if writing to disk, + // or to stdout, use a lock: e.g. : + // { + // std::lock_guard lock(global_mutex); + // // write lattice to disk + // // lock is released in the destructor of lock_guard<> + // } + void OpenDecodeHandle( + const std::string &key, const WaveData &wave_data, + const std::string &group = std::string(), + const std::function &callback = + std::function()); + // When passing in a vector of data, the caller must ensure the data + // exists until the CloseDecodeHandle is called + void OpenDecodeHandle( + const std::string &key, const VectorBase &wave_data, + float sample_rate, const std::string &group = std::string(), + const std::function &callback = + std::function()); + + // Copies the raw lattice for decoded handle "key" into lat + bool GetRawLattice(const std::string &key, Lattice *lat); + // Determinizes raw lattice and returns a compact lattice + bool GetLattice(const std::string &key, CompactLattice *lat); + + int32 GetNumberOfTasksPending(); + + // Wait for all tasks to complete + void WaitForAllTasks(); + // Wait for all tasks in the group to complete + void WaitForGroup(const std::string &group); + // Check if a group is available. Returns if not. + bool IsGroupCompleted(const std::string &group); + // Wait for any group to complete, then returns which group completed + std::string WaitForAnyGroup(); + // Check if any group is available. If one is available, set its name in + // *group + bool IsAnyGroupCompleted(std::string *group); + inline int NumPendingTasks() { + return (tasks_back_ - tasks_front_ + config_.max_pending_tasks + + config_.pending_queue_padding) % + (config_.max_pending_tasks + config_.pending_queue_padding); + }; + + private: + // Task data used during computation + // Is cleared when task is completed + struct TaskData { + Vector raw_data; // Wave input data when wave_reader passed + std::shared_ptr> + wave_samples; // Used as a pointer to either the raw + // data or the samples passed + float sample_frequency; + Vector ivector_features_cpu; + Matrix input_features_cpu; + CuVector ivector_features; + CuMatrix input_features; + CuMatrix posteriors; + + TaskData(const WaveData &wave_data_in) + : wave_samples(NULL), sample_frequency(0) { + int rows = wave_data_in.Data().NumRows(); + int cols = wave_data_in.Data().NumCols(); + int stride = wave_data_in.Data().Stride(); + + raw_data.Resize(rows * cols, kUndefined); + + if (stride == cols) { + // contigious so use one large memory copy + memcpy(raw_data.Data(), wave_data_in.Data().Data(), + rows * cols * sizeof(BaseFloat)); + } else { + // data is not contigious so we need to copy one + // row at a time + for (int i = 0; i < rows; i++) { + memcpy(raw_data.Data() + i * cols, wave_data_in.Data().RowData(i), + cols * sizeof(BaseFloat)); + } + } + wave_samples = + std::make_shared>(raw_data, 0, raw_data.Dim()); + sample_frequency = wave_data_in.SampFreq(); + }; + + // Init when raw data is passed in. This data is shallow + // copied. + TaskData(const VectorBase &wave_data_in, float sample_rate) { + wave_samples = std::make_shared>(wave_data_in, 0, + wave_data_in.Dim()); + sample_frequency = sample_rate; + } }; -private: - // Task data used during computation - // Is cleared when task is completed - struct TaskData { - Vector raw_data; // Wave input data when wave_reader passed - std::shared_ptr> - wave_samples; // Used as a pointer to either the raw - // data or the samples passed - float sample_frequency; - Vector ivector_features_cpu; - Matrix input_features_cpu; - CuVector ivector_features; - CuMatrix input_features; - CuMatrix posteriors; - - TaskData(const WaveData &wave_data_in) - : wave_samples(NULL), sample_frequency(0) { - int rows = wave_data_in.Data().NumRows(); - int cols = wave_data_in.Data().NumCols(); - int stride = wave_data_in.Data().Stride(); - - raw_data.Resize(rows * cols, kUndefined); - - if (stride == cols) { - // contigious so use one large memory copy - memcpy(raw_data.Data(), wave_data_in.Data().Data(), - rows * cols * sizeof(BaseFloat)); - } else { - // data is not contigious so we need to copy one row at a time - for (int i = 0; i < rows; i++) { - memcpy(raw_data.Data() + i * cols, wave_data_in.Data().RowData(i), - cols * sizeof(BaseFloat)); - } - } - wave_samples = - std::make_shared>(raw_data, 0, raw_data.Dim()); - sample_frequency = wave_data_in.SampFreq(); - }; - - // Init when raw data is passed in. This data is shallow copied. - TaskData(const VectorBase &wave_data_in, float sample_rate) { - wave_samples = std::make_shared>(wave_data_in, 0, - wave_data_in.Dim()); - sample_frequency = sample_rate; - } - }; - - // State needed for each decode task. - // This state can be passed around by reference or pointer safely - // and provides a convieniet way to store all decoding state. - struct TaskState { - std::string key; - std::string group; // group for that task. "" is default - bool error; - std::string error_string; - - std::unique_ptr task_data; - - int32 ichannel; // associated CudaDecoder channel - Lattice lat; // Raw Lattice output - CompactLattice dlat; // Determinized lattice output. Only set if - // determinize-lattice=true - std::atomic finished; // Tells master thread if task has finished - // execution - - bool determinized; - - // (optional) callback is called task is finished and we have a lattice - // ready - // that way we can compute all CPU tasks in the threadpool (lattice - // rescoring, find best path in lattice, etc.) - std::function callback; - - TaskState() : error(false), finished(false), determinized(false) {} - - // Init when wave data is passed directly in. This data is deep copied. - void Init(const std::string &key_in, const WaveData &wave_data_in) { - task_data.reset(new TaskData(wave_data_in)); - key = key_in; - }; - // Init when raw data is passed in. This data is shallow copied. - void Init(const std::string &key_in, - const VectorBase &wave_data_in, float sample_rate) { - task_data.reset(new TaskData(wave_data_in, sample_rate)); - key = key_in; - } + // State needed for each decode task. + // This state can be passed around by reference or pointer safely + // and provides a convieniet way to store all decoding state. + struct TaskState { + std::string key; + std::string group; // group for that task. "" is default + bool error; + std::string error_string; + + std::unique_ptr task_data; + + int32 ichannel; // associated CudaDecoder channel + Lattice lat; // Raw Lattice output + CompactLattice dlat; // Determinized lattice output. Only set + // if determinize-lattice=true + std::atomic finished; // Tells master thread if task has + // finished execution + + bool determinized; + + // (optional) callback is called task is finished and we have a + // lattice ready that way we can compute all CPU tasks in the + // threadpool (lattice rescoring, find best path in lattice, + // etc.) + std::function callback; + + TaskState() : error(false), finished(false), determinized(false) {} + + // Init when wave data is passed directly in. This data is deep + // copied. + void Init(const std::string &key_in, const WaveData &wave_data_in) { + task_data.reset(new TaskData(wave_data_in)); + key = key_in; + }; + // Init when raw data is passed in. This data is shallow + // copied. + void Init(const std::string &key_in, + const VectorBase &wave_data_in, float sample_rate) { + task_data.reset(new TaskData(wave_data_in, sample_rate)); + key = key_in; + } }; // Creating a new task in the hashmaps @@ -307,8 +316,8 @@ class BatchedThreadedNnet3CudaPipeline { // Adds task to the PendingTaskQueue void AddTaskToPendingTaskQueue(TaskState *task); - // Attempts to fill the batch from the task queue. May not fully fill the - // batch. + // Attempts to fill the batch from the task queue. May not fully fill + // the batch. void AquireAdditionalTasks(CudaDecoder &cuda_decoder, ChannelState &channel_state, std::vector &tasks); @@ -317,8 +326,7 @@ class BatchedThreadedNnet3CudaPipeline { void ComputeOneFeatureCPU(TaskState *task); // Computes features across the tasks[first,tasks.size() - void ComputeBatchFeatures(int32 first, - std::vector &tasks, + void ComputeBatchFeatures(int32 first, std::vector &tasks, OnlineCudaFeaturePipeline &feature_pipeline); // Computes Nnet across the current decode batch @@ -332,11 +340,10 @@ class BatchedThreadedNnet3CudaPipeline { // Removes all completed channels from the channel list. // Also enqueues up work for post processing - void - RemoveCompletedChannels(CudaDecoder &cuda_decoder, - ChannelState &channel_state, - std::vector &decodables, - std::vector &tasks); + void RemoveCompletedChannels( + CudaDecoder &cuda_decoder, ChannelState &channel_state, + std::vector &decodables, + std::vector &tasks); // For each completed decode perform post processing work and clean up void PostDecodeProcessing(CudaDecoder &cuda_decoder, @@ -351,8 +358,8 @@ class BatchedThreadedNnet3CudaPipeline { // Determinize one lattice void DeterminizeOneLattice(TaskState *task); - // Thread execution function. This is a single worker thread which processes - // input. + // Thread execution function. This is a single worker thread which + // processes input. void ExecuteWorker(int threadId); BatchedThreadedNnet3CudaPipelineConfig config_; @@ -363,19 +370,19 @@ class BatchedThreadedNnet3CudaPipeline { nnet3::DecodableNnetSimpleLoopedInfo *decodable_info_; OnlineNnet2FeaturePipelineInfo *feature_info_; - std::mutex tasks_mutex_; // protects tasks_front_ and pending_task_queue_ for - // workers - std::mutex tasks_add_mutex_; // protect OpenDecodeHandle if multiple threads - // access - std::mutex tasks_lookup_mutex_; // protext tasks_lookup map + std::mutex tasks_mutex_; // protects tasks_front_ and + // pending_task_queue_ for workers + std::mutex tasks_add_mutex_; // protect OpenDecodeHandle if multiple + // threads access + std::mutex tasks_lookup_mutex_; // protext tasks_lookup map std::condition_variable tasks_lookup_cv_; std::atomic tasks_front_, tasks_back_; TaskState **pending_task_queue_; - std::atomic exit_; // signals threads to exit - std::atomic numStarted_; // signals master how many threads have started + std::atomic exit_; // signals threads to exit + std::atomic numStarted_; // signals master how many threads have started - ThreadPool *work_pool_; // thread pool for CPU work + ThreadPool *work_pool_; // thread pool for CPU work std::map group_tasks_not_done_; int32 all_group_tasks_not_done_; std::mutex group_tasks_mutex_; @@ -383,12 +390,12 @@ class BatchedThreadedNnet3CudaPipeline { std::unordered_multimap tasks_group_lookup_; // group -> list of tasks std::unordered_map - tasks_lookup_; // Contains a map of - // utterance to TaskState - std::vector thread_contexts_; // A list of thread contexts + tasks_lookup_; // Contains a map of + // utterance to TaskState + std::vector thread_contexts_; // A list of thread contexts }; } // end namespace cuda_decoder -} // end namespace kaldi. +} // end namespace kaldi. #endif // KALDI_CUDA_DECODER_BATCHED_THREADED_CUDA_DECODER_H_ diff --git a/src/cudadecoder/batched-threaded-nnet3-cuda-pipeline2.cc b/src/cudadecoder/batched-threaded-nnet3-cuda-pipeline2.cc new file mode 100644 index 00000000000..95171f32fb4 --- /dev/null +++ b/src/cudadecoder/batched-threaded-nnet3-cuda-pipeline2.cc @@ -0,0 +1,296 @@ +// cudadecoder/batched-threaded-nnet3-cuda-pipeline2.cc +// +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Hugo Braun +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if HAVE_CUDA == 1 + +#define KALDI_CUDA_DECODER_WAIT_FOR_TASKS_US 10000 +#define KALDI_CUDA_DECODER_WAIT_FOR_NEW_TASKS_US 100 + +#include "cudadecoder/batched-threaded-nnet3-cuda-pipeline2.h" +#include + +namespace kaldi { +namespace cuda_decoder { + +void BatchedThreadedNnet3CudaPipeline2::BuildBatchFromCurrentTasks() { + batch_corr_ids_.clear(); + batch_is_last_chunk_.clear(); + batch_is_first_chunk_.clear(); + if (use_online_features_) { + batch_wave_samples_.clear(); + } else { + batch_features_.clear(); + batch_ivectors_.clear(); + batch_n_input_frames_valid_.clear(); + } + for (size_t task_id = 0; task_id < current_tasks_.size();) { + UtteranceTask &task = current_tasks_[task_id]; + int32 total_n_input; + if (use_online_features_) { + KALDI_ASSERT(task.h_wave); + SubVector &h_wave = *task.h_wave; + total_n_input = h_wave.Dim(); + } else { + total_n_input = task.d_features->NumRows(); + } + + int32 samp_offset = task.samp_offset; + int32 samp_remaining = total_n_input - samp_offset; + int32 num_samp = std::min(n_input_per_chunk_, samp_remaining); + KALDI_ASSERT(num_samp > 0); + bool is_last_chunk = (samp_remaining == num_samp); + bool is_first_chunk = (task.samp_offset == 0); + CorrelationID corr_id = task.corr_id; + task.samp_offset += num_samp; + + batch_corr_ids_.push_back(corr_id); + batch_is_last_chunk_.push_back(is_last_chunk); + batch_is_first_chunk_.push_back(is_first_chunk); + + if (use_online_features_) { + SubVector &h_wave = *task.h_wave; + SubVector wave_part(h_wave, samp_offset, num_samp); + batch_wave_samples_.push_back(wave_part); + } else { + batch_features_.push_back(task.d_features->Data() + + samp_offset * task.d_features->Stride()); + if (task_id == 0) + batch_features_frame_stride_ = task.d_features->Stride(); + else + KALDI_ASSERT(batch_features_frame_stride_ == task.d_features->Stride()); + batch_ivectors_.push_back(task.d_ivectors->Data()); + batch_n_input_frames_valid_.push_back(num_samp); + } + + // If last chunk, moving the task to tasks_last_chunk_ + if (is_last_chunk) { + tasks_last_chunk_.push_back(std::move(task)); + size_t last_task_id = current_tasks_.size() - 1; + current_tasks_[task_id] = std::move(current_tasks_[last_task_id]); + current_tasks_.pop_back(); + } else { + // If it was the last chunk, we replaced the current + // task with another one we must process that task_id + // again (because it is now another task) If it was not + // the last chunk, then we must take care of the next + // task_id + ++task_id; + } + } +} + +void BatchedThreadedNnet3CudaPipeline2::WaitForAllTasks() { + while (n_tasks_not_done_.load() != 0) { + usleep(KALDI_CUDA_DECODER_WAIT_FOR_TASKS_US); + } +} + +void BatchedThreadedNnet3CudaPipeline2::CreateTaskGroup( + const std::string &group) { + std::lock_guard lk(n_group_tasks_not_done_m_); + bool inserted; + std::unique_ptr> group_cnt; + group_cnt.reset(new std::atomic(0)); + std::tie(std::ignore, inserted) = + n_group_tasks_not_done_.emplace(group, std::move(group_cnt)); + KALDI_ASSERT("Group is already in use" && inserted); +} + +void BatchedThreadedNnet3CudaPipeline2::DestroyTaskGroup( + const std::string &group) { + std::lock_guard lk(n_group_tasks_not_done_m_); + int nerased = n_group_tasks_not_done_.erase(group); + KALDI_ASSERT("Group does not exist" && (nerased == 1)); +} + +void BatchedThreadedNnet3CudaPipeline2::WaitForGroup(const std::string &group) { + std::atomic *n_not_done; + { + std::lock_guard lk(n_group_tasks_not_done_m_); + auto it = n_group_tasks_not_done_.find(group); + KALDI_ASSERT("Group does not exist. Call CreateTaskGroup() first" && + (it != n_group_tasks_not_done_.end())); + n_not_done = it->second.get(); + } + + while (n_not_done->load(std::memory_order_consume) != 0) + usleep(KALDI_CUDA_DECODER_WAIT_FOR_TASKS_US); +} + +void BatchedThreadedNnet3CudaPipeline2::DecodeWithCallback( + const std::string &key, const std::shared_ptr &wave_data, + std::unique_ptr> &&h_wave, + const std::function &callback, + const std::string &group) { + if (wave_data) { + KALDI_ASSERT( + "Mismatch in model and utt frequency" && + (wave_data->SampFreq() == cuda_online_pipeline_.GetModelFrequency())); + } + + UtteranceTask task; + if (wave_data) task.wave_data = wave_data; + if (h_wave) { + task.h_wave = std::move(h_wave); + } else { + KALDI_ASSERT(wave_data); + task.h_wave.reset(new SubVector(wave_data->Data(), 0)); + } + + if (task.h_wave->Dim() == 0) return; // nothing to do + n_tasks_not_done_.fetch_add(1); + task.key = key; + task.samp_offset = 0; + task.corr_id = corr_id_cnt_.fetch_add( + 1); // at 5000 files/s, expected to overflow in ~116 million years + task.callback = callback; + + if (!group.empty()) { + // Need to add it to group + std::lock_guard lk(n_group_tasks_not_done_m_); + auto it = n_group_tasks_not_done_.find(group); + KALDI_ASSERT("Group does not exist. Call CreateTaskGroup() first" && + (it != n_group_tasks_not_done_.end())); + it->second->fetch_add(1); // adding current task + task.group_cnt = it->second.get(); // will be used to --cnt + } else { + task.group_cnt = NULL; + } + + if (use_online_features_) { + // If we use online ivectors, we can just add it to the + // outstanding queue. ivectors and mfcc will be computed in the + // online pipeline + std::lock_guard lk(outstanding_utt_m_); + outstanding_utt_.push(std::move(task)); + } else { + // Otherwise we first need to compute ivectors and mfcc for the + // full audio file Adding it to the preprocessing queue + std::lock_guard lk(preprocessing_utt_queue_m_); + preprocessing_utt_queue_.push(std::move(task)); + } +} + +void BatchedThreadedNnet3CudaPipeline2::ComputeOfflineFeatures() { + bool iterate = true; + do { + UtteranceTask task; + { + std::lock_guard lk(preprocessing_utt_queue_m_); + if (preprocessing_utt_queue_.empty()) { + iterate = false; + break; + } + + task = std::move(preprocessing_utt_queue_.front()); + preprocessing_utt_queue_.pop(); + } + KALDI_ASSERT(task.h_wave); + SubVector &h_wave = *task.h_wave; + int32 nsamp = h_wave.Dim(); + + cudaEventSynchronize(wave_buffer_->evt); + if (nsamp > wave_buffer_->size) { + wave_buffer_->Reallocate(nsamp); + } + std::memcpy(wave_buffer_->h_data, h_wave.Data(), + h_wave.Dim() * sizeof(BaseFloat)); + cudaMemcpyAsync(wave_buffer_->d_data, wave_buffer_->h_data, + sizeof(BaseFloat) * nsamp, cudaMemcpyHostToDevice, + cudaStreamPerThread); + + task.d_features.reset(new CuMatrix()); + task.d_ivectors.reset(new CuVector()); + CuSubVector wrapper(wave_buffer_->d_data, nsamp); + cuda_features_->ComputeFeatures( + wrapper, cuda_online_pipeline_.GetModelFrequency(), + task.d_features.get(), task.d_ivectors.get()); + cudaEventRecord(wave_buffer_->evt, cudaStreamPerThread); + std::swap(wave_buffer_, next_wave_buffer_); + if (task.wave_data) task.wave_data.reset(); // delete wave samples on host + { + std::lock_guard lk(outstanding_utt_m_); + outstanding_utt_.push(std::move(task)); + // We dont want to have too many files ready in + // outstanding_utt_ (using device memory) using + // max_batch_size_ as an arbitrary (large enough) value + iterate = (outstanding_utt_.size() < max_batch_size_); + } + } while (iterate); + cudaStreamSynchronize(cudaStreamPerThread); // to keep CuVector in scope +} + +void BatchedThreadedNnet3CudaPipeline2::AcquireTasks() { + // Trying to get new tasks + std::unique_lock lk(outstanding_utt_m_); + while (current_tasks_.size() < max_batch_size_) { + // If use_online_features_ is false, we have to fill + // outstanding_utt_ by computing features + if (!use_online_features_ && outstanding_utt_.size() == 0) { + lk.unlock(); + ComputeOfflineFeatures(); + lk.lock(); + } + // If still empty, break + if (outstanding_utt_.size() == 0) break; + UtteranceTask &task = outstanding_utt_.front(); + bool was_created = cuda_online_pipeline_.TryInitCorrID(task.corr_id); + // No channel was available. Breaking for now + if (!was_created) break; + + auto &callback = task.callback; + auto &key = task.key; + std::atomic *group_cnt = task.group_cnt; + cuda_online_pipeline_.SetLatticeCallback( + task.corr_id, [this, callback, key, group_cnt](CompactLattice &clat) { + if (callback) callback(clat); + n_tasks_not_done_.fetch_sub(1, std::memory_order_release); + if (group_cnt) group_cnt->fetch_sub(1, std::memory_order_release); + }); + current_tasks_.push_back(std::move(task)); + outstanding_utt_.pop(); + } +} + +void BatchedThreadedNnet3CudaPipeline2::ComputeTasks() { + while (threads_running_) { + if (current_tasks_.size() < max_batch_size_) AcquireTasks(); + if (current_tasks_.empty()) { + // If we still have nothing to do, let's sleep a bit + usleep(KALDI_CUDA_DECODER_WAIT_FOR_NEW_TASKS_US); + continue; + } + BuildBatchFromCurrentTasks(); + + if (use_online_features_) + cuda_online_pipeline_.DecodeBatch(batch_corr_ids_, batch_wave_samples_, + batch_is_first_chunk_, + batch_is_last_chunk_); + else + cuda_online_pipeline_.DecodeBatch( + batch_corr_ids_, batch_features_, batch_features_frame_stride_, + batch_n_input_frames_valid_, batch_ivectors_, batch_is_first_chunk_, + batch_is_last_chunk_); + // Calling the destructors, freeing memory + tasks_last_chunk_.clear(); + } +} + +} // end namespace cuda_decoder +} // end namespace kaldi. + +#endif // HAVE_CUDA diff --git a/src/cudadecoder/batched-threaded-nnet3-cuda-pipeline2.h b/src/cudadecoder/batched-threaded-nnet3-cuda-pipeline2.h new file mode 100644 index 00000000000..e7f00910222 --- /dev/null +++ b/src/cudadecoder/batched-threaded-nnet3-cuda-pipeline2.h @@ -0,0 +1,268 @@ +// cudadecoder/batched-threaded-nnet3-cuda-pipeline2.h +// +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Hugo Braun +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if HAVE_CUDA == 1 + +#ifndef KALDI_CUDA_DECODER_BATCHED_THREADED_NNET3_CUDA_PIPELINE2_H_ +#define KALDI_CUDA_DECODER_BATCHED_THREADED_NNET3_CUDA_PIPELINE2_H_ + +#define KALDI_CUDA_DECODER_AUDIO_HOST_DEVICE_BUFFER_SIZE 16000 * 50 + +#include +#include + +#include "cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h" +#include "cudadecoder/cuda-decoder.h" +#include "cudafeat/online-cuda-feature-pipeline.h" +#include "feat/wave-reader.h" + +// +// Offline wrapper for the online pipeline. +// Supports non-greedy features (such as non-greedy ivectors) +// + +namespace kaldi { +namespace cuda_decoder { +struct BatchedThreadedNnet3CudaPipeline2Config { + BatchedThreadedNnet3CudaPipeline2Config() : use_online_features(false) {} + BatchedThreadedNnet3CudaOnlinePipelineConfig cuda_online_pipeline_opts; + bool use_online_features; + void Register(OptionsItf *po) { + po->Register("use-online-features", &use_online_features, + "Run feature extraction in an online manner (greedy)"); + + cuda_online_pipeline_opts.Register(po); + } +}; + +class BatchedThreadedNnet3CudaPipeline2 { + const BatchedThreadedNnet3CudaPipeline2Config &config_; + BatchedThreadedNnet3CudaOnlinePipeline cuda_online_pipeline_; + using CorrelationID = BatchedThreadedNnet3CudaOnlinePipeline::CorrelationID; + + struct UtteranceTask { + UtteranceTask &operator=(const UtteranceTask &) = delete; + UtteranceTask(const UtteranceTask &) = delete; + UtteranceTask(UtteranceTask &&) = default; + UtteranceTask &operator=(UtteranceTask &&) = default; + UtteranceTask() = default; + + std::shared_ptr wave_data; + std::unique_ptr> + h_wave; // (task.wave_data->Data(), 0) + std::string key; + int32 samp_offset; + CorrelationID corr_id; + std::atomic *group_cnt; + std::function callback; + bool auto_close_after_callback; + + std::unique_ptr> + d_features; // Used only when use_online_features == false + std::unique_ptr> + d_ivectors; // Used only when use_online_features == false + }; + + bool use_online_features_; + int n_input_per_chunk_; + std::atomic corr_id_cnt_; + + // Tasks added to the queue, but not yet used + std::queue preprocessing_utt_queue_; + std::mutex preprocessing_utt_queue_m_; + std::queue outstanding_utt_; + std::mutex outstanding_utt_m_; + + // Tasks currently being decoded by the cuda pipeline + std::vector current_tasks_; + + // Contains the ID of the tasks that are being completed + // (we are decoding their last chunk) + std::vector tasks_last_chunk_; + + // Batch sent to online pipeline + std::vector batch_corr_ids_; + std::vector batch_is_first_chunk_; + std::vector batch_is_last_chunk_; + // Used when use_online_features_ + std::vector> batch_wave_samples_; + // Used when !use_online_features_ + std::vector batch_features_; + int batch_features_frame_stride_; + std::vector batch_ivectors_; + std::vector batch_n_input_frames_valid_; + + int32 max_batch_size_; + // Thread responsible of feeding the online pipeline + bool threads_running_; + std::thread online_pipeline_control_thread_; + + // Number of tasks currently running + std::atomic n_tasks_not_done_; + + // Number of tasks currently running (per group) + std::unordered_map>> + n_group_tasks_not_done_; + std::mutex n_group_tasks_not_done_m_; + + // If auto_close_after_callback is false, we will store the completed + // lattices + // there + // They will be explicitely deleted by CloseDecodeHandle + struct Output { + Output() : is_clat_set(false) {} + std::atomic is_clat_set; // using a separate atomic because + // std::atomic only exists + // with C++20 + std::shared_ptr clat; + }; + std::unique_ptr cuda_features_; + + struct HostDeviceVector { + cudaEvent_t evt; + BaseFloat *h_data; + BaseFloat *d_data; + size_t size; + + HostDeviceVector() + : h_data(NULL), + d_data(NULL), + size(KALDI_CUDA_DECODER_AUDIO_HOST_DEVICE_BUFFER_SIZE) { + cudaEventCreate(&evt); + Reallocate(size); + } + + virtual ~HostDeviceVector() { + Deallocate(); + cudaEventDestroy(evt); + } + + void Reallocate(size_t new_size) { + KALDI_ASSERT(new_size > 0); + Deallocate(); + cudaMalloc(&d_data, new_size * sizeof(*d_data)); + cudaMallocHost(&h_data, new_size * sizeof(*d_data)); + new_size = size; + } + void Deallocate() { + if (d_data) cudaFree(d_data); + if (h_data) cudaFreeHost(h_data); + } + }; + + std::unique_ptr wave_buffer_, next_wave_buffer_; + + public: + BatchedThreadedNnet3CudaPipeline2( + const BatchedThreadedNnet3CudaPipeline2Config &config, + const fst::Fst &decode_fst, + const nnet3::AmNnetSimple &am_nnet, const TransitionModel &trans_model) + : config_(config), + cuda_online_pipeline_(config.cuda_online_pipeline_opts, decode_fst, + am_nnet, trans_model), + use_online_features_(config_.use_online_features), + corr_id_cnt_(0), + max_batch_size_(config_.cuda_online_pipeline_opts.max_batch_size), + threads_running_(true), + online_pipeline_control_thread_( + &BatchedThreadedNnet3CudaPipeline2::ComputeTasks, this), + n_tasks_not_done_(0) { + KALDI_ASSERT( + "CPU feature extraction is only available when " + "use-online-features is set" && + (config_.cuda_online_pipeline_opts.use_gpu_feature_extraction || + config_.use_online_features)); + batch_corr_ids_.reserve(max_batch_size_); + batch_wave_samples_.reserve(max_batch_size_); + batch_is_last_chunk_.reserve(max_batch_size_); + batch_is_first_chunk_.reserve(max_batch_size_); + tasks_last_chunk_.reserve(max_batch_size_); + if (use_online_features_) { + n_input_per_chunk_ = cuda_online_pipeline_.GetNSampsPerChunk(); + } else { + n_input_per_chunk_ = cuda_online_pipeline_.GetNInputFramesPerChunk(); + cuda_features_.reset(new OnlineCudaFeaturePipeline( + config_.cuda_online_pipeline_opts.feature_opts)); + wave_buffer_.reset(new HostDeviceVector()); + next_wave_buffer_.reset(new HostDeviceVector()); + } + } + + ~BatchedThreadedNnet3CudaPipeline2() { + threads_running_ = false; + online_pipeline_control_thread_.join(); + } + + // Will decode wave_data. Then when done, will call the callback with + // the final lattice. It does not create a handle, so you don't need to + // call CloseDecodeHandle, and GetLattice cannot be used with + // DecodeWithCallback (the lattice is provided through the callback) + // Should be preferred to OpenDecodeHandle/GetLattice/CloseDecodeHandle + // when possible The callback function is called in a multithreaded + // environment. It must be threadsafe To wait for those tasks to + // complete you can use WaitForGroup or WaitForAllTasks + void DecodeWithCallback(const std::shared_ptr &wave_data, + const std::function &callback, + const std::string &group = std::string()) { + DecodeWithCallback(std::string(), wave_data, + std::unique_ptr>(), callback, + group); + } + + void DecodeWithCallback(const VectorBase &wave_data, + float sample_rate, + const std::function &callback, + const std::string &group = std::string()) { + KALDI_ASSERT(sample_rate == cuda_online_pipeline_.GetModelFrequency()); + std::unique_ptr> h_wave( + new SubVector(wave_data, 0, wave_data.Dim())); + DecodeWithCallback(std::string(), std::shared_ptr(), + std::move(h_wave), callback, group); + } + + // Create a Task Group. Tasks can be associated with a group. + // It is then possible to sync only on those tasks using WaitForGroup + // (instead of WaitForAllTasks) + void CreateTaskGroup(const std::string &group); + void DestroyTaskGroup(const std::string &group); + // Wait for all tasks in that group to complete + void WaitForGroup(const std::string &group); + + void WaitForAllTasks(); + + // Used for debug + void SetSymbolTable(fst::SymbolTable *word_syms) { + cuda_online_pipeline_.SetSymbolTable(word_syms); + } + + private: + void DecodeWithCallback(const std::string &key, + const std::shared_ptr &wave_data, + std::unique_ptr> &&h_wave, + const std::function &callback, + const std::string &group = std::string()); + void BuildBatchFromCurrentTasks(); + void AcquireTasks(); + void ComputeTasks(); + void ComputeOfflineFeatures(); +}; + +} // end namespace cuda_decoder +} // end namespace kaldi. + +#endif // KALDI_CUDA_DECODER_BATCHED_THREADED_CUDA_DECODER_H_ +#endif // HAVE_CUDA diff --git a/src/cudadecoder/cuda-decodable-itf.h b/src/cudadecoder/cuda-decodable-itf.h index 98d0619b6eb..939983dc258 100644 --- a/src/cudadecoder/cuda-decodable-itf.h +++ b/src/cudadecoder/cuda-decodable-itf.h @@ -15,6 +15,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +// +// Important: This file is deprecated and will be removed in a future release +// + #ifndef KALDI_CUDA_DECODER_DECODABLE_ITF_H #define KALDI_CUDA_DECODER_DECODABLE_ITF_H @@ -24,7 +28,7 @@ namespace kaldi { namespace cuda_decoder { class CudaDecodableInterface : public DecodableInterface { -public: + public: virtual BaseFloat *GetLogLikelihoodsCudaPointer(int32 subsampled_frame) = 0; }; diff --git a/src/cudadecoder/cuda-decoder.cc b/src/cudadecoder/cuda-decoder.cc index 89b1ef5f099..39bb93a087e 100644 --- a/src/cudadecoder/cuda-decoder.cc +++ b/src/cudadecoder/cuda-decoder.cc @@ -55,11 +55,11 @@ CudaDecoder::CudaDecoder(const CudaFst &fst, const CudaDecoderConfig &config, cudaStreamCreate(©_st_); // For all the allocating/initializing process // We create a special channel - // containing the exact state a channel should have when starting a new decode - // It contains fst.Start(), the non-emitting tokens created by fst.Start(), - // and all the data used by the decoder. - // When calling InitDecoding() on a new channel, we simply clone this special - // channel into that new channel + // containing the exact state a channel should have when starting a new + // decode It contains fst.Start(), the non-emitting tokens created by + // fst.Start(), and all the data used by the decoder. When calling + // InitDecoding() on a new channel, we simply clone this special channel + // into that new channel ++nchannels_; // adding the special initial channel init_channel_id_ = nchannels_ - 1; // Using last one as init_channel_params AllocateHostData(); @@ -273,14 +273,13 @@ void CudaDecoder::InitDeviceParams() { h_device_params_->init_cost = StdWeight::One().Value(); h_device_params_->hashmap_capacity = hashmap_capacity_; h_device_params_->max_active = max_active_; - // For the first static_beam_q_length elements of the queue, we will keep the - // beam static + // For the first static_beam_q_length elements of the queue, we will + // keep the beam static adaptive_beam_static_segment_ = aux_q_capacity_ / KALDI_CUDA_DECODER_ADAPTIVE_BEAM_STATIC_SEGMENT; - // For the last adaptive_beam_q_length elements of the queue, we will decrease - // the beam, segment by segment - // For more information, please refer to the definition of GetAdaptiveBeam in - // cuda-decoder-kernels.cu + // For the last adaptive_beam_q_length elements of the queue, we will + // decrease the beam, segment by segment For more information, please + // refer to the definition of GetAdaptiveBeam in cuda-decoder-kernels.cu int32 adaptive_beam_q_length = (aux_q_capacity_ - adaptive_beam_static_segment_); int32 adaptive_beam_bin_width = @@ -377,7 +376,8 @@ void CudaDecoder::ComputeInitialChannel() { } void CudaDecoder::InitDecoding(const std::vector &channels) { - // Cloning the init_channel_id_ channel into all channels in the channels vec + // Cloning the init_channel_id_ channel into all channels in the + // channels vec const int nlanes_used = channels.size(); // Getting *h_kernel_params ready to use LoadChannelsStateToLanes(channels); @@ -419,12 +419,13 @@ void CudaDecoder::InitDecoding(const std::vector &channels) { h_all_argmin_cost_[ichannel] = {-1, 0.0f}; frame_offsets_[ichannel].clear(); frame_offsets_[ichannel].push_back(n_initial_tokens); - if (thread_pool_) - thread_pool_->enqueue(THREAD_POOL_HIGH_PRIORITY, - &CudaDecoder::InitDecodingH2HCopies, this, - ichannel); - else - InitDecodingH2HCopies(ichannel); + // TODO put it back + // if (thread_pool_) { + // thread_pool_->post([ichannel, this] { + // InitDecodingH2HCopies(ichannel); + // }); + //} else + InitDecodingH2HCopies(ichannel); } } @@ -545,10 +546,9 @@ void CudaDecoder::MoveConcatenatedCopyToVector( void CudaDecoder::ApplyMaxActiveAndReduceBeam(enum QUEUE_ID queue_id) { // Checking if we should activate max active for the current frame - // once it is active, it is active for the whole frame (for all non emitting - // iterations) - // If at least one lane queue is bigger than max_active, - // we'll apply a topk on that queue (k=max_active_) + // once it is active, it is active for the whole frame (for all non + // emitting iterations) If at least one lane queue is bigger than + // max_active, we'll apply a topk on that queue (k=max_active_) bool use_aux_q = (queue_id == AUX_Q); ComputeCostsHistogramKernel(KaldiCudaDecoderNumBlocks(nlanes_used_), KALDI_CUDA_DECODER_1D_BLOCK, compute_st_, @@ -559,32 +559,6 @@ void CudaDecoder::ApplyMaxActiveAndReduceBeam(enum QUEUE_ID queue_id) { compute_st_, *h_device_params_, *h_kernel_params_, use_aux_q); } -int32 CudaDecoder::NumFramesToDecode( - const std::vector &channels, - std::vector &decodables, int32 max_num_frames) { - int32 nframes_to_decode = INT_MAX; - // std::vector debug_ntokens; - // std::vector debug_narcs; - for (int32 ilane = 0; ilane < nlanes_used_; ++ilane) { - const ChannelId ichannel = channels[ilane]; - const int32 num_frames_decoded = num_frames_decoded_[ichannel]; - KALDI_ASSERT(num_frames_decoded >= 0 && - "You must call InitDecoding() before AdvanceDecoding()"); - int32 num_frames_ready = decodables[ilane]->NumFramesReady(); - // num_frames_ready must be >= num_frames_decoded, or else - // the number of frames ready must have decreased (which doesn't - // make sense) or the decodable object changed between calls - // (which isn't allowed). - KALDI_ASSERT(num_frames_ready >= num_frames_decoded); - int32 channel_nframes_to_decode = num_frames_ready - num_frames_decoded; - nframes_to_decode = std::min(nframes_to_decode, channel_nframes_to_decode); - } - if (max_num_frames >= 0) - nframes_to_decode = std::min(nframes_to_decode, max_num_frames); - - return nframes_to_decode; -} - void CudaDecoder::ExpandArcsEmitting() { ExpandArcsKernel(KaldiCudaDecoderNumBlocks(nlanes_used_), KALDI_CUDA_DECODER_1D_BLOCK, compute_st_, @@ -659,8 +633,8 @@ void CudaDecoder::CopyMainQueueDataToHost() { cudaEventRecord(concatenated_data_ready_evt_, compute_st_); cudaStreamWaitEvent(copy_st_, concatenated_data_ready_evt_, 0); // the copies on copy_st will wait on compute_st_ - cudaEventSynchronize( - lane_offsets_ready_evt_); // we need the total size of each segments + cudaEventSynchronize(lane_offsets_ready_evt_); // we need the total + // size of each segments LaunchD2HCopies(); // Making sure the previous H2H copies are done @@ -754,108 +728,134 @@ void CudaDecoder::ConcatenateData() { void CudaDecoder::AdvanceDecoding( const std::vector &channels, std::vector &decodables, int32 max_num_frames) { - if (channels.size() == 0) return; // nothing to do + int nframes_to_decode = INT_MAX; + for (int32 ilane = 0; ilane < channels.size(); ++ilane) { + const ChannelId ichannel = channels[ilane]; + const int32 num_frames_decoded = num_frames_decoded_[ichannel]; + KALDI_ASSERT(num_frames_decoded >= 0 && + "You must call InitDecoding() before AdvanceDecoding()"); + int32 num_frames_ready = decodables[ilane]->NumFramesReady(); + // num_frames_ready must be >= num_frames_decoded, or else + // the number of frames ready must have decreased (which doesn't + // make sense) or the decodable object changed between calls + // (which isn't allowed). + KALDI_ASSERT(num_frames_ready >= num_frames_decoded); + int32 channel_nframes_to_decode = num_frames_ready - num_frames_decoded; + nframes_to_decode = std::min(nframes_to_decode, channel_nframes_to_decode); + } + if (max_num_frames >= 0) + nframes_to_decode = std::min(nframes_to_decode, max_num_frames); + + std::vector> lanes_assignements; + for (int f = 0; f < nframes_to_decode; ++f) { + lanes_assignements.clear(); + for (int32 ilane = 0; ilane < channels.size(); ++ilane) { + const ChannelId ichannel = channels[ilane]; + int32 iframe = num_frames_decoded_[ichannel]; + BaseFloat *ptr = decodables[ilane]->GetLogLikelihoodsCudaPointer(iframe); + lanes_assignements.push_back({ichannel, ptr}); + } + AdvanceDecoding(lanes_assignements); + } +} + +void CudaDecoder::AdvanceDecoding( + const std::vector> &lanes_assignements) { + if (lanes_assignements.size() == 0) return; // nothing to do // Context switch : Loading the channels state in lanes + + // Looping over the frames that we will compute + // Loglikelihoods from the acoustic model + // Setting the loglikelihoods pointers for that frame + std::vector channels; // TODO + channels.reserve(lanes_assignements.size()); + for (LaneId ilane = 0; ilane < lanes_assignements.size(); ++ilane) { + ChannelId ichannel = lanes_assignements[ilane].first; + channels.push_back(ichannel); + channel_to_compute_[ilane] = ichannel; + h_lanes_counters_.lane(ilane)->loglikelihoods = + lanes_assignements[ilane].second; + } LoadChannelsStateToLanes(channels); KALDI_ASSERT(nlanes_used_ > 0); + cudaMemcpyAsync(d_lanes_counters_.MutableData(), h_lanes_counters_.lane(0), + nlanes_used_ * sizeof(*h_lanes_counters_.lane(0)), + cudaMemcpyHostToDevice, compute_st_); + // compute_st_ will wait for nnet3 to complete + cudaEventRecord(nnet3_done_evt_, cudaStreamPerThread); + cudaStreamWaitEvent(compute_st_, nnet3_done_evt_, 0); - // We'll decode nframes_to_decode, such as all channels have at least that - // number - // of frames available - int32 nframes_to_decode = - NumFramesToDecode(channels, decodables, max_num_frames); - - // Looping over the frames that we will compute - for (int32 iframe = 0; iframe < nframes_to_decode; ++iframe) { - // Loglikelihoods from the acoustic model - // Setting the loglikelihoods pointers for that frame - for (LaneId ilane = 0; ilane < nlanes_used_; ++ilane) { - ChannelId ichannel = channel_to_compute_[ilane]; - int32 frame = num_frames_decoded_[ichannel]; - h_lanes_counters_.lane(ilane)->loglikelihoods = - decodables[ilane]->GetLogLikelihoodsCudaPointer(frame); - } - cudaMemcpyAsync(d_lanes_counters_.MutableData(), h_lanes_counters_.lane(0), - nlanes_used_ * sizeof(*h_lanes_counters_.lane(0)), - cudaMemcpyHostToDevice, compute_st_); - // compute_st_ will wait for nnet3 to complete - cudaEventRecord(nnet3_done_evt_, cudaStreamPerThread); - cudaStreamWaitEvent(compute_st_, nnet3_done_evt_, 0); - - // Estimating cutoff using argmin from last frame - ResetForFrameAndEstimateCutoffKernel( - KaldiCudaDecoderNumBlocks(1, nlanes_used_), KALDI_CUDA_DECODER_1D_BLOCK, - compute_st_, *h_device_params_, *h_kernel_params_); - // Reset max active status. If necessary, ApplyMaxActiveAndReduceBeam will - // switch it back on - compute_max_active_ = false; - - // Processing emitting arcs. We've done the preprocess stage at the end of - // the previous frame - ExpandArcsEmitting(); - // We'll loop until we have a small enough number of non-emitting arcs - // in the token queue. We'll then break the loop - for (int i = 0; i < KALDI_CUDA_DECODER_N_NON_EMITTING_MAIN_ITERATIONS; - ++i) { - // If one of the aux_q contains more than max_active_ tokens, - // we'll reduce the beam to only keep max_active_ tokens - ApplyMaxActiveAndReduceBeam(AUX_Q); - // Prune the aux_q. Apply the latest beam (using the one from - // ApplyMaxActiveAndReduceBeam if triggered) - // move the survival tokens to the main queue - // and do the preprocessing necessary for the next ExpandArcs - PruneAndPreprocess(); - - // "heavy duty" kernel for non-emitting. The long tail of small - // non-emitting iterations will be done in - // FinalizeProcessNonEmittingKernel - ExpandArcsNonEmitting(); - } + // Estimating cutoff using argmin from last frame + ResetForFrameAndEstimateCutoffKernel( + KaldiCudaDecoderNumBlocks(1, nlanes_used_), KALDI_CUDA_DECODER_1D_BLOCK, + compute_st_, *h_device_params_, *h_kernel_params_); + // Reset max active status. If necessary, ApplyMaxActiveAndReduceBeam + // will switch it back on + compute_max_active_ = false; + + // Processing emitting arcs. We've done the preprocess stage at the end + // of the previous frame + ExpandArcsEmitting(); + // We'll loop until we have a small enough number of non-emitting arcs + // in the token queue. We'll then break the loop + for (int i = 0; i < KALDI_CUDA_DECODER_N_NON_EMITTING_MAIN_ITERATIONS; ++i) { + // If one of the aux_q contains more than max_active_ tokens, + // we'll reduce the beam to only keep max_active_ tokens ApplyMaxActiveAndReduceBeam(AUX_Q); + // Prune the aux_q. Apply the latest beam (using the one from + // ApplyMaxActiveAndReduceBeam if triggered) + // move the survival tokens to the main queue + // and do the preprocessing necessary for the next ExpandArcs PruneAndPreprocess(); - // Finalizing process non emitting. Takes care of the long tail, - // the final iterations with a small numbers of arcs. Do the work inside a - // single CTA (per lane), - FinalizeProcessNonEmittingKernel(KaldiCudaDecoderNumBlocks(1, nlanes_used_), - KALDI_CUDA_DECODER_LARGEST_1D_BLOCK, - compute_st_, *h_device_params_, - *h_kernel_params_); - - // We now have our final token main queues for that frame - - // Post processing the tokens for that frame - // - do the preprocess necessary for the next emitting expand (will happen - // with next frame) - // - if a state S has more than one token associated to it, generate the - // list of those tokens - // It allows to backtrack efficiently in GetRawLattice - // - compute the extra costs - PostProcessingMainQueue(); - - // Waiting on previous d2h before writing on same device memory - cudaStreamWaitEvent(compute_st_, d2h_copy_extra_prev_tokens_evt_, 0); - // Concatenating the data that will be moved to host into large arrays - ConcatenateData(); - // Copying the final lane counters for that frame - CopyLaneCountersToHostSync(); - CheckOverflow(); - - // Moving the data necessary for GetRawLattice/GetBestPath back to host for - // storage - CopyMainQueueDataToHost(); - - for (LaneId ilane = 0; ilane < nlanes_used_; ++ilane) { - const ChannelId ichannel = channel_to_compute_[ilane]; - // We're done processing that frame - ++num_frames_decoded_[ichannel]; - const int32 main_q_end = - h_lanes_counters_.lane(ilane)->main_q_narcs_and_end.y; - // Saving frame offsets for GetRawLattice - frame_offsets_[ichannel].push_back(frame_offsets_[ichannel].back() + - main_q_end); - } + + // "heavy duty" kernel for non-emitting. The long tail of small + // non-emitting iterations will be done in + // FinalizeProcessNonEmittingKernel + ExpandArcsNonEmitting(); } + ApplyMaxActiveAndReduceBeam(AUX_Q); + PruneAndPreprocess(); + // Finalizing process non emitting. Takes care of the long tail, + // the final iterations with a small numbers of arcs. Do the work inside + // a single CTA (per lane), + FinalizeProcessNonEmittingKernel(KaldiCudaDecoderNumBlocks(1, nlanes_used_), + KALDI_CUDA_DECODER_LARGEST_1D_BLOCK, + compute_st_, *h_device_params_, + *h_kernel_params_); + + // We now have our final token main queues for that frame + + // Post processing the tokens for that frame + // - do the preprocess necessary for the next emitting expand (will + // happen with next frame) + // - if a state S has more than one token associated to it, generate the + // list of those tokens + // It allows to backtrack efficiently in GetRawLattice + // - compute the extra costs + PostProcessingMainQueue(); + + // Waiting on previous d2h before writing on same device memory + cudaStreamWaitEvent(compute_st_, d2h_copy_extra_prev_tokens_evt_, 0); + // Concatenating the data that will be moved to host into large arrays + ConcatenateData(); + // Copying the final lane counters for that frame + CopyLaneCountersToHostSync(); + CheckOverflow(); + + // Moving the data necessary for GetRawLattice/GetBestPath back to host + // for storage + CopyMainQueueDataToHost(); + for (LaneId ilane = 0; ilane < nlanes_used_; ++ilane) { + const ChannelId ichannel = channel_to_compute_[ilane]; + // We're done processing that frame + ++num_frames_decoded_[ichannel]; + const int32 main_q_end = + h_lanes_counters_.lane(ilane)->main_q_narcs_and_end.y; + // Saving frame offsets for GetRawLattice + frame_offsets_[ichannel].push_back(frame_offsets_[ichannel].back() + + main_q_end); + } SaveChannelsStateFromLanes(); } @@ -865,26 +865,32 @@ void CudaDecoder::CheckOverflow() { bool q_overflow = lane_counters->q_overflow; if (q_overflow != OVERFLOW_NONE) { // An overflow was prevented in a kernel - // The algorithm can still go on but quality of the result can be reduced - // (less tokens were generated) + // The algorithm can still go on but quality of the + // result can be reduced (less tokens were generated) if ((q_overflow & OVERFLOW_MAIN_Q) == OVERFLOW_MAIN_Q) { // overflowed main_q - KALDI_WARN - << "Preventing overflow of main_q. Continuing " - << "execution but the quality of the output may be decreased. " - << "To prevent this from happening, please increase the parameter " - "--main-q-capacity" - << " and/or decrease --max-active"; + KALDI_WARN << "Preventing overflow of main_q. " + "Continuing " + << "execution but the quality of " + "the output may be decreased. " + << "To prevent this from happening, " + "please increase the " + "parameter " + "--main-q-capacity" + << " and/or decrease --max-active"; } if ((q_overflow & OVERFLOW_AUX_Q) == OVERFLOW_AUX_Q) { // overflowed aux_q - KALDI_WARN - << "Preventing overflow of aux_q. Continuing " - << "execution but the quality of the output may be decreased. " - << "To prevent this from happening, please increase the parameter " - "--aux-q-capacity" - << " and/or decrease --beam"; + KALDI_WARN << "Preventing overflow of aux_q. " + "Continuing " + << "execution but the quality of " + "the output may be decreased. " + << "To prevent this from happening, " + "please increase the " + "parameter " + "--aux-q-capacity" + << " and/or decrease --beam"; } KALDI_ASSERT(lane_counters->main_q_narcs_and_end.y < main_q_capacity_); @@ -922,27 +928,26 @@ void CudaDecoder::GetBestCost(const std::vector &channels, }; int32 max_main_q_end = GetMaxForAllLanes(func_main_q_end); - // Step1 : Finding the best cost in the last token queue, with and without - // final costs. - // Also saving the indexes of those min. + // Step1 : Finding the best cost in the last token queue, with and + // without final costs. Also saving the indexes of those min. GetBestCostStep1Kernel( KaldiCudaDecoderNumBlocks(max_main_q_end, nlanes_used_), KALDI_CUDA_DECODER_1D_BLOCK, compute_st_, *h_device_params_, *h_kernel_params_, use_final_costs, StdWeight::Zero().Value()); - // Step2: Now that we now what the minimum cost is, we list all tokens within + // Step2: Now that we now what the minimum cost is, we list all tokens + // within // [min_cost; min_cost+lattice_beam] - // min_cost takes into account the final costs if use_final_costs is true, - // AND if a final state is is present in the last token queue + // min_cost takes into account the final costs if use_final_costs is + // true, AND if a final state is is present in the last token queue GetBestCostStep2Kernel( KaldiCudaDecoderNumBlocks(max_main_q_end, nlanes_used_), KALDI_CUDA_DECODER_1D_BLOCK, compute_st_, *h_device_params_, *h_kernel_params_, use_final_costs, StdWeight::Zero().Value()); - // Step3 : Moves some data to host. We are moving the data that couldn't be - // moved - // directly in step 2, e.g. results of atomics (we don't know which one is - // last) + // Step3 : Moves some data to host. We are moving the data that couldn't + // be moved directly in step 2, e.g. results of atomics (we don't know + // which one is last) GetBestCostStep3Kernel( KaldiCudaDecoderNumBlocks(max_main_q_end, nlanes_used_), KALDI_CUDA_DECODER_1D_BLOCK, compute_st_, *h_device_params_, @@ -966,8 +971,8 @@ void CudaDecoder::GetBestCost(const std::vector &channels, int32 arg = minarg.y; // Saving both in output argmins->push_back({arg, min_cost}); - // Whether or not the last token queue contains at least one token - // associated with a final FST state + // Whether or not the last token queue contains at least one + // token associated with a final FST state has_reached_final->push_back( h_lanes_counters_.lane(ilane)->has_reached_final); // Number of tokens within [min_cost; min_cost+lattice_beam] @@ -999,8 +1004,8 @@ void CudaDecoder::GetBestPath(const std::vector &channels, const ChannelId ichannel = channels[ilane]; const int32 token_with_best_cost = argmins_[ilane].first; std::unique_lock channel_lk(channel_lock_[ichannel]); - // If that token in that frame f is available, then all tokens in that frame - // f are available + // If that token in that frame f is available, then all tokens + // in that frame f are available WaitForH2HCopies(); const bool isfinal = has_reached_final_[ilane]; TokenId token_idx = token_with_best_cost; @@ -1019,7 +1024,8 @@ void CudaDecoder::GetBestPath(const std::vector &channels, int32 arc_idx; TokenId prev_token_idx; if (token.IsUniqueTokenForStateAndFrame()) { - // If we have only one, it is an arc with extra_cost == 0 + // If we have only one, it is an arc with + // extra_cost == 0 arc_idx = token.arc_idx; prev_token_idx = token.prev_token; } else { @@ -1031,8 +1037,10 @@ void CudaDecoder::GetBestPath(const std::vector &channels, CostType arc_extra_cost = h_all_tokens_extra_prev_tokens_extra_and_acoustic_cost_[ichannel] [offset + - i].x; - // Picking one arc on the best path (extra_cost == 0) + i] + .x; + // Picking one arc on the best path + // (extra_cost == 0) if (arc_extra_cost == 0.0f) { InfoToken list_token = h_all_tokens_extra_prev_tokens_[ichannel][offset + i]; @@ -1146,10 +1154,11 @@ void CudaDecoder::AddFinalTokensToLattice( // Total number of tokens for that utterance. Used in // GetLatticeStateInternalId const int32 total_ntokens = h_all_tokens_info_[ichannel].size(); - // Reading the overall best_cost for that utterance's last frame. Was set by - // GetBestCost + // Reading the overall best_cost for that utterance's last frame. Was + // set by GetBestCost const CostType best_cost = h_all_argmin_cost_[ichannel].second; - // Iterating through tokens associated with a final state in the last frame + // Iterating through tokens associated with a final state in the last + // frame for (auto &p : h_all_final_tokens_list_[ichannel]) { // This final token has a final cost of final_token_cost CostType final_token_cost = p.second; @@ -1170,47 +1179,48 @@ void CudaDecoder::AddFinalTokensToLattice( decltype(curr_f_raw_lattice_state->end()) map_it; bool inserted; - // We need to create the fst_lattice_state linked to our internal id in the - // lattice if it doesn't already exists + // We need to create the fst_lattice_state linked to our + // internal id in the lattice if it doesn't already exists // Inserts only if the key doesn't exist in the map std::tie(map_it, inserted) = curr_f_raw_lattice_state->insert( {state_internal_id, {FLT_MAX, -1, false}}); - // If we've inserted the element, it means that that state didn't exist in - // the map - // Because this is a final state, we need to do a bit of extra work to add - // the final_cost to it + // If we've inserted the element, it means that that state + // didn't exist in the map Because this is a final state, we + // need to do a bit of extra work to add the final_cost to it if (inserted) { - // We want to figure out which FST state this token is associated to - // We don't have that info anymore, it wasn't transfered from the GPU - // We still need it for final tokens, because we need to know which - // final cost to add in the lattice. - // To find that original FST state, we need the id of an arc going to - // that state, - // then we'll look in the graph and figure out next_state[arc_idx] - // we just need a valid arc_idx + // We want to figure out which FST state this token is + // associated to We don't have that info anymore, it + // wasn't transfered from the GPU We still need it for + // final tokens, because we need to know which final + // cost to add in the lattice. To find that original FST + // state, we need the id of an arc going to that state, + // then we'll look in the graph and figure out + // next_state[arc_idx] we just need a valid arc_idx int32 arc_idx; if (final_token.IsUniqueTokenForStateAndFrame()) { // If unique, we can directly use this arc_idx arc_idx = final_token.arc_idx; } else { - // If we have multiple tokens associated to that fst state, just pick - // the first one - // from the list + // If we have multiple tokens associated to that + // fst state, just pick the first one from the + // list int32 offset, size; std::tie(offset, size) = final_token.GetSameFSTStateTokensList(); InfoToken prev_token = h_all_tokens_extra_prev_tokens_[ichannel][offset]; arc_idx = prev_token.arc_idx; } - // Creating the state associated with our internal id in the lattice + // Creating the state associated with our internal id in + // the lattice OutputLatticeState fst_lattice_final_state = fst_out->AddState(); map_it->second.fst_lattice_state = fst_lattice_final_state; q_curr_frame_todo->push_back({final_token_idx, final_token}); if (h_all_has_reached_final_[ichannel]) { - // If we have reached final states, adding the final cost - // We now have a valid arc_idx. We can read the FST state + // If we have reached final states, adding the + // final cost We now have a valid arc_idx. We + // can read the FST state StateId fst_next_state = fst_.h_arc_nextstate_[arc_idx]; fst_out->SetFinal(fst_lattice_final_state, @@ -1243,14 +1253,14 @@ void CudaDecoder::AddArcToLattice( // We will now add this arc to the output lattice // We know the destination state of the arc (to_fst_lattice_state) // We need to figure out its source - // And propagate the extra cost from the destination to the source of that arc - // (we go backward) + // And propagate the extra cost from the destination to the source of + // that arc (we go backward) OutputLatticeState from_fst_lattice_state; // Having the predecessor in the previous frame // <=> that token is associated to an emiting arc bool emitting = (list_prev_token_idx < curr_frame_offset); - // Checking if the source of that arc is the start state (original state at - // the beginning of the decode) + // Checking if the source of that arc is the start state (original state + // at the beginning of the decode) if (list_prev_token_idx != 0) { // Selecting the right map // - emitting arc -> previous frame map @@ -1281,13 +1291,12 @@ void CudaDecoder::AddArcToLattice( // We found a new min CostType diff = (prev_token_extra_cost - this_arc_prev_token_extra_cost); // If the change is large enough, - // and if the state that we're writing to was already closed, - // then we need to replay that frame. - // if the source state is already closed it means we've - // read its extra_cost value. Now we're writing again to it. - // We have to do the first read again, to get the updated - // value - // that's why we're replaying that frame + // and if the state that we're writing to was already + // closed, then we need to replay that frame. if the + // source state is already closed it means we've read + // its extra_cost value. Now we're writing again to it. + // We have to do the first read again, to get the + // updated value that's why we're replaying that frame // (between frames everything is in topological order) if (diff > extra_cost_min_delta_ && from_map_it->second.is_state_closed) { *must_replay_frame = true; @@ -1296,7 +1305,8 @@ void CudaDecoder::AddArcToLattice( from_map_it->second.token_extra_cost = prev_token_extra_cost; } - // Reading the OutputLatticeState of the source state in the output lattice + // Reading the OutputLatticeState of the source state in the + // output lattice from_fst_lattice_state = from_map_it->second.fst_lattice_state; } else { from_fst_lattice_state = @@ -1438,7 +1448,8 @@ void CudaDecoder::SwapPrevAndCurrLatticeMap( if (iframe > 0) { KALDI_ASSERT(!q_curr_frame_todo->empty()); if (!dbg_found_best_path) { - KALDI_WARN << "Warning didn't find exact best path in GetRawLattice"; + KALDI_WARN << "Warning didn't find exact best path in " + "GetRawLattice"; } } } @@ -1474,19 +1485,18 @@ void CudaDecoder::ConcurrentGetRawLatticeSingleChannel(const ChannelId ichannel, // Allocating the datastructures that we need // [prev|curr]_f_raw_lattice_state - // Used to get information about a lattice state (i.e. a (iframe, fst_state) - // pair) - // using its LatticeStateInternalId (its ID inside of the decoder) - // It gives us the OutputLatticeState (its ID in the output lattice) - // alongside with the extra_cost of that state in the lattice + // Used to get information about a lattice state (i.e. a (iframe, + // fst_state) pair) using its LatticeStateInternalId (its ID inside of + // the decoder) It gives us the OutputLatticeState (its ID in the output + // lattice) alongside with the extra_cost of that state in the lattice // Those maps are used to build the external lattice using what we know // internally - // Using one map per frame. We always know to which frame a token belongs. - // Using one big map slows everything down + // Using one map per frame. We always know to which frame a token + // belongs. Using one big map slows everything down std::unordered_map prev_f_raw_lattice_state, curr_f_raw_lattice_state; - // We want the unicity of each arc_idx for one frame. Important because we - // can replay a frame (and possibly add multiple time the same arc) + // We want the unicity of each arc_idx for one frame. Important because + // we can replay a frame (and possibly add multiple time the same arc) std::unordered_set f_arc_idx_added; // When backtracking, we read tokens in the current frame (in // q_curr_frame_todo_), @@ -1504,7 +1514,7 @@ void CudaDecoder::ConcurrentGetRawLatticeSingleChannel(const ChannelId ichannel, h_all_tokens_acoustic_cost_[ichannel].shrink_to_fit(); h_all_tokens_extra_prev_tokens_[ichannel].shrink_to_fit(); h_all_tokens_extra_prev_tokens_extra_and_acoustic_cost_[ichannel] - .shrink_to_fit(); + .shrink_to_fit(); best_cost_idx = h_all_argmin_cost_[ichannel].first; } KALDI_ASSERT( @@ -1514,10 +1524,11 @@ void CudaDecoder::ConcurrentGetRawLatticeSingleChannel(const ChannelId ichannel, const int32 nframes = NumFramesDecoded(ichannel); // Making sure that this token is available for this channel. // We're going to read storage data from this channel. Locking it - // If that token in that frame f is available, then all tokens in that frame - // f are available + // If that token in that frame f is available, then all tokens in that + // frame f are available WaitForH2HCopies(); std::unique_lock channel_lk(channel_lock_[ichannel]); + // Total number of tokens generated by the utterance on channel ichannel const int32 total_ntokens = h_all_tokens_info_[ichannel].size(); @@ -1537,38 +1548,40 @@ void CudaDecoder::ConcurrentGetRawLatticeSingleChannel(const ChannelId ichannel, // For each frame we're going to process tokens that need to be inserted // into the output lattice // and add their predecessors to the todo list - // iframe == -1 contains the start state and the first non emitting tokens. - // It is not linked to a real frame + // iframe == -1 contains the start state and the first non emitting + // tokens. It is not linked to a real frame for (int32 iframe = nframes - 1; iframe >= -1; --iframe) { - // Tokens for the current frame were inserted after this offset in the - // token list + // Tokens for the current frame were inserted after this offset + // in the token list const int32 curr_frame_offset = (iframe >= 0) ? frame_offsets_[ichannel][iframe] : 0; // bool must_replay_frame - // In some cases we can update an extra_cost that has already been used - // For instance we process arcs in that order : - // 1) a -> b, which updates extra_cost[b] using extra_cost[a] - // 2) c -> a, which updates extra-cost[a] (using extra_cost[c]) - // because the arcs were not considered in topological order, we need to + // In some cases we can update an extra_cost that has already + // been used For instance we process arcs in that order : 1) a + // -> b, which updates extra_cost[b] using extra_cost[a] 2) c -> + // a, which updates extra-cost[a] (using extra_cost[c]) because + // the arcs were not considered in topological order, we need to // run // again the step 1, - // to get the correct extra_cost[b] (using the latest extra_cost[a]) - // However, we only re-run the step 1 if the value extra_cost[a] has - // changed more than extra_cost_min_delta_ + // to get the correct extra_cost[b] (using the latest + // extra_cost[a]) However, we only re-run the step 1 if the + // value extra_cost[a] has changed more than + // extra_cost_min_delta_ bool must_replay_frame; - // dbg_found_best_path is used in an useful assert, making sure the best - // path is still there for each frame - // if something went wrong in the kernels, it's not likely we respect that + // dbg_found_best_path is used in an useful assert, making sure + // the best path is still there for each frame if something went + // wrong in the kernels, it's not likely we respect that // property out of luck bool dbg_found_best_path = false; do { must_replay_frame = false; // Reading something to do. We are pushing stuff back in // q_curr_frame_todo while reading it, - // so it's important to always read q_curr_frame_todo_.size() directly - // not using a queue, because we may need to recompute the frame (if + // so it's important to always read + // q_curr_frame_todo_.size() directly not using a queue, + // because we may need to recompute the frame (if // must_replay_frame is true) for (int32 u = 0; u < q_curr_frame_todo.size(); ++u) { TokenId token_idx; @@ -1587,24 +1600,24 @@ void CudaDecoder::ConcurrentGetRawLatticeSingleChannel(const ChannelId ichannel, InfoToken *tok_beg; float2 *extra_extra_and_acoustic_cost_beg; int32 nsamestate; - // Getting the list of the tokens linked to the same FST state, in the - // same frame - // In the GPU decoder a token is linked to a single arc, but we can - // generate - // multiple token for a same fst_nextstate in the same frame. - // In the CPU decoder we would use the forward_links list to store - // everything in the same metatoken - // GetSameFSTStateTokenList returns the list of tokens linked to the - // same FST state than token - // (in the current frame) + // Getting the list of the tokens linked to the + // same FST state, in the same frame In the GPU + // decoder a token is linked to a single arc, + // but we can generate multiple token for a same + // fst_nextstate in the same frame. In the CPU + // decoder we would use the forward_links list + // to store everything in the same metatoken + // GetSameFSTStateTokenList returns the list of + // tokens linked to the same FST state than + // token (in the current frame) GetSameFSTStateTokenList(ichannel, token, &tok_beg, &extra_extra_and_acoustic_cost_beg, &nsamestate); - // dbg_found_zero used for debugging. For each FST state, we have a - // token with the - // best cost for that FST state - // that token has an extra_cost of 0.0f. This is a sanity check + // dbg_found_zero used for debugging. For each + // FST state, we have a token with the best cost + // for that FST state that token has an + // extra_cost of 0.0f. This is a sanity check bool dbg_found_zero = false; for (int32 iprev = 0; iprev < nsamestate; ++iprev) { InfoToken list_prev_token; @@ -1636,9 +1649,10 @@ void CudaDecoder::ConcurrentGetRawLatticeSingleChannel(const ChannelId ichannel, } if (must_replay_frame) { - // We need to replay the frame. Because all states will be read again, - // we can reopen them (and they will be closed again when being read - // from again) + // We need to replay the frame. Because all + // states will be read again, we can reopen them + // (and they will be closed again when being + // read from again) for (auto it = curr_f_raw_lattice_state.begin(); it != curr_f_raw_lattice_state.end(); ++it) { it->second.is_state_closed = false; @@ -1646,8 +1660,8 @@ void CudaDecoder::ConcurrentGetRawLatticeSingleChannel(const ChannelId ichannel, } } while (must_replay_frame); - // Done processing this frame. Swap the datastructures, move on to - // previous frame (we go --iframe) + // Done processing this frame. Swap the datastructures, move on + // to previous frame (we go --iframe) SwapPrevAndCurrLatticeMap(iframe, dbg_found_best_path, &q_curr_frame_todo, &q_prev_frame_todo, &curr_f_raw_lattice_state, &prev_f_raw_lattice_state, &f_arc_idx_added); @@ -1692,9 +1706,8 @@ int32 CudaDecoder::NumFramesDecoded(ChannelId ichannel) const { void CudaDecoder::CheckStaticAsserts() { // Checking if all constants look ok - // We need that because we need to be able to do the scan in one pass in the - // kernel - // update_beam_using_histogram_kernel + // We need that because we need to be able to do the scan in one pass in + // the kernel update_beam_using_histogram_kernel KALDI_ASSERT(KALDI_CUDA_DECODER_HISTO_NBINS < KALDI_CUDA_DECODER_1D_BLOCK); KALDI_ASSERT(KALDI_CUDA_DECODER_NONEM_LT_MAX_NARCS > 0); } @@ -1730,7 +1743,8 @@ void CudaDecoder::ComputeH2HCopiesCPUWorker() { } void CudaDecoder::ComputeH2HCopies() { - // Waiting for either something to do or the instruction to stop the threads + // Waiting for either something to do or the instruction to stop the + // threads { std::unique_lock n_h2h_lk(n_h2h_main_task_todo_mutex_); n_h2h_main_task_todo_cv_.wait(n_h2h_lk, [this] { @@ -1738,9 +1752,8 @@ void CudaDecoder::ComputeH2HCopies() { }); --n_h2h_main_task_todo_; } - // If we are done, stop the wait and return now. ComputeH2HCopiesCPUWorker - // will also return, - // stopping the thread + // If we are done, stop the wait and return now. + // ComputeH2HCopiesCPUWorker will also return, stopping the thread if (!h2h_threads_running_) return; // Waiting for the D2H copies. This is threadsafe // Step 1: acoustic costs @@ -1788,8 +1801,8 @@ void CudaDecoder::ComputeH2HCopies() { &h_all_tokens_extra_prev_tokens_extra_and_acoustic_cost_); } - // If we're the last cpu thread to complete the current tasks, notify the main - // thread + // If we're the last cpu thread to complete the current tasks, notify + // the main thread bool all_done; { std::lock_guard lk_not_done(n_h2h_task_not_done_mutex_); @@ -1800,7 +1813,7 @@ void CudaDecoder::ComputeH2HCopies() { } } -void CudaDecoder::SetThreadPoolAndStartCPUWorkers(ThreadPool *thread_pool, +void CudaDecoder::SetThreadPoolAndStartCPUWorkers(ThreadPoolLight *thread_pool, int32 nworkers) { KALDI_ASSERT(nworkers > 0); n_threads_used_ = nworkers; @@ -1810,7 +1823,7 @@ void CudaDecoder::SetThreadPoolAndStartCPUWorkers(ThreadPool *thread_pool, this); } -} // end namespace cuda_decoder +} // namespace cuda_decoder } // end namespace kaldi #endif // HAVE_CUDA == 1 diff --git a/src/cudadecoder/cuda-decoder.h b/src/cudadecoder/cuda-decoder.h index 83ef1f49d8d..95bc7cac130 100644 --- a/src/cudadecoder/cuda-decoder.h +++ b/src/cudadecoder/cuda-decoder.h @@ -21,8 +21,8 @@ #include "cudadecoder/cuda-decodable-itf.h" #include "cudadecoder/cuda-decoder-common.h" #include "cudadecoder/cuda-fst.h" +#include "cudadecoder/thread-pool-light.h" #include "nnet3/decodable-online-looped.h" -#include "thread-pool.h" #include #include @@ -41,7 +41,7 @@ struct CudaDecoderConfig { CudaDecoderConfig() : default_beam(15.0), lattice_beam(10.0), - ntokens_pre_allocated(2000000), + ntokens_pre_allocated(1000000), main_q_capacity(-1), aux_q_capacity(-1), max_active(10000) {} @@ -57,23 +57,32 @@ struct CudaDecoderConfig { opts->Register("max-active", &max_active, "At the end of each frame computation, we keep only its " "best max-active tokens. One token is the instantiation of " - "a single arc. Typical values are within the 5k-10k range."); + "a single arc. Typical values are within the 5k-10k " + "range."); opts->Register("ntokens-pre-allocated", &ntokens_pre_allocated, - "Advanced - Number of tokens pre-allocated in host buffers. " + "Advanced - Number of tokens pre-allocated in host " + "buffers. " "If this size is exceeded the buffer will reallocate, " "reducing performance."); std::ostringstream main_q_capacity_desc; main_q_capacity_desc - << "Advanced - Capacity of the main queue : Maximum number of " - "tokens that can be stored *after* pruning for each frame. " + << "Advanced - Capacity of the main queue : Maximum number " + "of " + "tokens that can be stored *after* pruning for each " + "frame. " "Lower -> less memory usage, Higher -> More accurate. " "Tokens stored in the main queue were already selected " - "through a max-active pre-selection. It means that for each " + "through a max-active pre-selection. It means that for " + "each " "emitting/non-emitting iteration, we can add at most " - "~max-active tokens to the main queue. Typically only the " - "emitting iteration creates a large number of tokens. Using " - "main-q-capacity=k*max-active with k=4..10 should be safe. " - "If main-q-capacity is too small, we will print a warning " + "~max-active tokens to the main queue. Typically only " + "the " + "emitting iteration creates a large number of tokens. " + "Using " + "main-q-capacity=k*max-active with k=4..10 should be " + "safe. " + "If main-q-capacity is too small, we will print a " + "warning " "but prevent the overflow. The computation can safely " "continue, but the quality of the output may decrease " "(-1 = set to " @@ -84,16 +93,26 @@ struct CudaDecoderConfig { std::ostringstream aux_q_capacity_desc; aux_q_capacity_desc << "Advanced - Capacity of the auxiliary queue : Maximum " - "number of raw tokens that can be stored *before* pruning " - "for each frame. Lower -> less memory usage, Higher -> More " - "accurate. During the tokens generation, if we detect that " - "we are getting close to saturating that capacity, we will " - "reduce the beam dynamically (adaptive beam) to keep only " - "the best tokens in the remaining space. If the aux queue " - "is still too small, we will print an overflow warning, but " - "prevent the overflow. The computation can safely continue, " - "but the quality of the output may decrease. We strongly " - "recommend keeping aux-q-capacity large (>400k), to avoid " + "number of raw tokens that can be stored *before* " + "pruning " + "for each frame. Lower -> less memory usage, Higher -> " + "More " + "accurate. During the tokens generation, if we detect " + "that " + "we are getting close to saturating that capacity, we " + "will " + "reduce the beam dynamically (adaptive beam) to keep " + "only " + "the best tokens in the remaining space. If the aux " + "queue " + "is still too small, we will print an overflow warning, " + "but " + "prevent the overflow. The computation can safely " + "continue, " + "but the quality of the output may decrease. We " + "strongly " + "recommend keeping aux-q-capacity large (>400k), to " + "avoid " "triggering the adaptive beam and/or the overflow " "(-1 = set to " << KALDI_CUDA_DECODER_AUX_Q_MAIN_Q_CAPACITIES_FACTOR @@ -133,122 +152,119 @@ class CudaDecoder { // we pick an available channel, call InitDecoding on that channel // (with that ChannelId in the channels vector in the arguments) // then call AdvanceDecoding whenever frames are ready for the decoder - // for that utterance (also passing the same ChannelId to AdvanceDecoding) + // for that utterance (also passing the same ChannelId to + // AdvanceDecoding) // // A decoder lane is where the computation actually happens // a decoder lane is channel, and perform the actual decoding // of that channel. // If we have 200 lanes, we can compute 200 utterances (channels) - // at the same time. We need many lanes in parallel to saturate the big GPUs + // at the same time. We need many lanes in parallel to saturate the big + // GPUs // // An analogy would be lane -> a CPU core, channel -> a software thread - // A channel saves the current state of the decoding for a given utterance. - // It can be kept idle until more frames are ready to be processed - // - // We will use as many lanes as necessary to saturate the GPU, but not more. - // A lane has an higher memory usage than a channel. If you just want to be - // able to - // keep more audio channels open at the same time (when I/O is the bottleneck - // for instance, - // typically in the context of online decoding), you should instead use more - // channels. - // - // A channel is typically way smaller in term of memory usage, and can be used - // to oversubsribe lanes in the context of online decoding - // For instance, we could choose nlanes=200 because it gives us good + // A channel saves the current state of the decoding for a given + // utterance. It can be kept idle until more frames are ready to be + // processed + // + // We will use as many lanes as necessary to saturate the GPU, but not + // more. A lane has an higher memory usage than a channel. If you just + // want to be able to keep more audio channels open at the same time + // (when I/O is the bottleneck for instance, typically in the context of + // online decoding), you should instead use more channels. + // + // A channel is typically way smaller in term of memory usage, and can + // be used to oversubsribe lanes in the context of online decoding For + // instance, we could choose nlanes=200 because it gives us good // performance - // on a given GPU. It gives us an end-to-end performance of 3000 XRTF. We are - // doing online, - // so we only get audio at realtime speed for a given utterance/channel. - // We then decide to receive audio from 2500 audio channels at the same time - // (each at realtime speed), - // and as soon as we have frames ready for nlanes=200 channels, we call + // on a given GPU. It gives us an end-to-end performance of 3000 XRTF. + // We are doing online, so we only get audio at realtime speed for a + // given utterance/channel. We then decide to receive audio from 2500 + // audio channels at the same time (each at realtime speed), and as soon + // as we have frames ready for nlanes=200 channels, we call // AdvanceDecoding on those channels // In that configuration, we have nlanes=200 (for performance), and // nchannels=2500 (to have enough audio // available at a given time). - // Using nlanes=2500 in that configuration would first not be possible (out of - // memory), but also not necessary. - // Increasing the number of lanes is only useful if it increases performance. - // If the GPU is saturated at nlanes=200, - // you should not increase that number + // Using nlanes=2500 in that configuration would first not be possible + // (out of memory), but also not necessary. Increasing the number of + // lanes is only useful if it increases performance. If the GPU is + // saturated at nlanes=200, you should not increase that number CudaDecoder(const CudaFst &fst, const CudaDecoderConfig &config, int32 nlanes, int32 nchannels); // Reads the config from config void ReadConfig(const CudaDecoderConfig &config); - // Special constructor for nlanes = nchannels. Here for the non-advanced user - // Here we can consider nchannels = batch size. If we want to decode 10 - // utterances at a time, - // we can use nchannels = 10 + // Special constructor for nlanes = nchannels. Here for the non-advanced + // user Here we can consider nchannels = batch size. If we want to + // decode 10 utterances at a time, we can use nchannels = 10 CudaDecoder(const CudaFst &fst, const CudaDecoderConfig &config, int32 nchannels) : CudaDecoder(fst, config, nchannels, nchannels) {} - ~CudaDecoder(); + virtual ~CudaDecoder(); // InitDecoding initializes the decoding, and should only be used if you // intend to call AdvanceDecoding() on the channels listed in channels void InitDecoding(const std::vector &channels); - // Computes the heavy H2H copies of InitDecoding. Usually launched on the - // threadpool + // Computes the heavy H2H copies of InitDecoding. Usually launched on + // the threadpool void InitDecodingH2HCopies(ChannelId ichannel); // AdvanceDecoding on a given batch // a batch is defined by the channels vector // We can compute N channels at the same time (in the same batch) // where N = number of lanes, as defined in the constructor - // AdvanceDecoding will compute as many frames as possible while running the - // full batch - // when at least one channel has no more frames ready to be computed, - // AdvanceDecoding returns - // The user then decides what to do, i.e.: + // AdvanceDecoding will compute as many frames as possible while running + // the full batch when at least one channel has no more frames ready to + // be computed, AdvanceDecoding returns The user then decides what to + // do, i.e.: // // 1) either remove the empty channel from the channels list // and call again AdvanceDecoding // 2) or swap the empty channel with another one that has frames ready // and call again AdvanceDecoding // - // Solution 2) should be preferred because we need to run full, big batches to - // saturate the GPU + // Solution 2) should be preferred because we need to run full, big + // batches to saturate the GPU // // If max_num_frames is >= 0 it will decode no more than // that many frames. + void AdvanceDecoding( + const std::vector> &lanes_assignements); + + // Version with deprecated API - will be removed at some point void AdvanceDecoding(const std::vector &channels, std::vector &decodables, int32 max_num_frames = -1); // Returns the number of frames already decoded in a given channel int32 NumFramesDecoded(ChannelId ichannel) const; - // GetBestPath gets the one-best decoding traceback. If "use_final_probs" is - // true - // AND we reached a final state, it limits itself to final states; - // otherwise it gets the most likely token not taking into account - // final-probs. + // GetBestPath gets the one-best decoding traceback. If + // "use_final_probs" is true AND we reached a final state, it limits + // itself to final states; otherwise it gets the most likely token not + // taking into account final-probs. void GetBestPath(const std::vector &channels, std::vector &fst_out_vec, bool use_final_probs = true); // It is possible to use a threadsafe version of GetRawLattice, which is // ConcurrentGetRawLatticeSingleChannel() // Which will do the heavy CPU work associated with GetRawLattice - // It is necessary to first call PrepareForGetRawLattice *on the main thread* - // on the channels. - // The main thread is the one we use to call all other functions, like - // InitDecoding or AdvanceDecoding - // We usually call it "cuda control thread", but it is a CPU thread - // For example: - // on main cpu thread : Call PrepareForGetRawLattice on channel 8,6,3 - // then: - // on some cpu thread : Call ConcurrentGetRawLatticeSingleChannel on channel 3 - // on some cpu thread : Call ConcurrentGetRawLatticeSingleChannel on channel 8 - // on some cpu thread : Call ConcurrentGetRawLatticeSingleChannel on channel 6 + // It is necessary to first call PrepareForGetRawLattice *on the main + // thread* on the channels. The main thread is the one we use to call + // all other functions, like InitDecoding or AdvanceDecoding We usually + // call it "cuda control thread", but it is a CPU thread For example: on + // main cpu thread : Call PrepareForGetRawLattice on channel 8,6,3 then: + // on some cpu thread : Call ConcurrentGetRawLatticeSingleChannel on + // channel 3 on some cpu thread : Call + // ConcurrentGetRawLatticeSingleChannel on channel 8 on some cpu thread + // : Call ConcurrentGetRawLatticeSingleChannel on channel 6 void PrepareForGetRawLattice(const std::vector &channels, bool use_final_probs); void ConcurrentGetRawLatticeSingleChannel(ChannelId ichannel, Lattice *fst_out); - // GetRawLattice gets the lattice decoding traceback (using the lattice-beam - // in the CudaConfig parameters). - // If "use_final_probs" is true - // AND we reached a final state, it limits itself to final states; + // GetRawLattice gets the lattice decoding traceback (using the + // lattice-beam in the CudaConfig parameters). If "use_final_probs" is + // true AND we reached a final state, it limits itself to final states; // otherwise it gets the most likely token not taking into account // final-probs. void GetRawLattice(const std::vector &channels, @@ -259,20 +275,21 @@ class CudaDecoder { // finding the minimum cost // We list all tokens that have a cost within [best; best+lattice_beam] // in list_lattice_tokens. - // We alsos set has_reached_final[ichannel] to true if token associated to a - // final state - // exists in the last token queue of that channel + // We alsos set has_reached_final[ichannel] to true if token associated + // to a final state exists in the last token queue of that channel void GetBestCost( const std::vector &channels, bool isfinal, std::vector> *argmins, std::vector>> *list_lattice_tokens, std::vector *has_reached_final); + // (optional) Giving the decoder access to the cpu thread pool - // We will use it to compute specific CPU work, such as InitDecodingH2HCopies - // For recurrent CPU work, such as ComputeH2HCopies, we will use dedicated CPU - // threads - // We will launch nworkers of those threads - void SetThreadPoolAndStartCPUWorkers(ThreadPool *thread_pool, int32 nworkers); + // We will use it to compute specific CPU work, such as + // InitDecodingH2HCopies For recurrent CPU work, such as + // ComputeH2HCopies, we will use dedicated CPU threads We will launch + // nworkers of those threads + void SetThreadPoolAndStartCPUWorkers(ThreadPoolLight *thread_pool, + int32 nworkers); private: // Data allocation. Called in constructor @@ -291,50 +308,41 @@ class CudaDecoder { void SetChannelsInKernelParams(const std::vector &channels); void ResetChannelsInKernelParams(); // Context-switch functions - // Used to perform the context-switch of load/saving the state of a channels - // into a lane. When a channel will be executed on a lane, we load that - // channel into that lane (same idea than when we load a software threads into - // the registers of a CPU) + // Used to perform the context-switch of load/saving the state of a + // channels into a lane. When a channel will be executed on a lane, we + // load that channel into that lane (same idea than when we load a + // software threads into the registers of a CPU) void LoadChannelsStateToLanes(const std::vector &channels); void SaveChannelsStateFromLanes(); - // We compute the decodes by batch. Each decodable in the batch has a - // different number of frames ready - // We compute the min number of frames ready (so that the full batch is - // executing). If max_num_frames - // is > 0, we apply that ceiling to the NumFramesToDecode. - int32 NumFramesToDecode(const std::vector &channels, - std::vector &decodables, - int32 max_num_frames); // Expand the arcs, emitting stage. Must be called after // a preprocess_in_place, which happens in PostProcessingMainQueue. // ExpandArcsEmitting is called first when decoding a frame, - // using the preprocessing that happened at the end of the previous frame, - // in PostProcessingMainQueue + // using the preprocessing that happened at the end of the previous + // frame, in PostProcessingMainQueue void ExpandArcsEmitting(); - // ExpandArcs, non-emitting stage. Must be called after PruneAndPreprocess. + // ExpandArcs, non-emitting stage. Must be called after + // PruneAndPreprocess. void ExpandArcsNonEmitting(); // If we have more than max_active_ tokens in the queue (either after an // expand, or at the end of the frame) - // we will compute a new beam that will only keep a number of tokens as close - // as possible to max_active_ tokens - // (that number is >= max_active_) (soft topk) - // All ApplyMaxActiveAndReduceBeam is find the right beam for that topk and - // set it. - // We need to then call PruneAndPreprocess (explicitly pruning tokens with - // cost > beam) - // Or PostProcessingMainQueue (ignoring tokens with cost > beam in the next + // we will compute a new beam that will only keep a number of tokens as + // close as possible to max_active_ tokens (that number is >= + // max_active_) (soft topk) All ApplyMaxActiveAndReduceBeam is find the + // right beam for that topk and set it. We need to then call + // PruneAndPreprocess (explicitly pruning tokens with cost > beam) Or + // PostProcessingMainQueue (ignoring tokens with cost > beam in the next // frame) void ApplyMaxActiveAndReduceBeam(enum QUEUE_ID queue_id); - // Called after an ExpandArcs. Prune the aux_q (output of the ExpandArcs), - // move the survival tokens to the main_q, do the preprocessing at the same - // time - // We don't need it after the last ExpandArcsNonEmitting. + // Called after an ExpandArcs. Prune the aux_q (output of the + // ExpandArcs), move the survival tokens to the main_q, do the + // preprocessing at the same time We don't need it after the last + // ExpandArcsNonEmitting. void PruneAndPreprocess(); // Once the non-emitting is done, the main_q is final for that frame. - // We now generate all the data associated with that main_q, such as listing - // the different tokens sharing the same token.next_state - // we also preprocess for the ExpandArcsEmitting of the next frame - // Once PostProcessingMainQueue, all working data is back to its original + // We now generate all the data associated with that main_q, such as + // listing the different tokens sharing the same token.next_state we + // also preprocess for the ExpandArcsEmitting of the next frame Once + // PostProcessingMainQueue, all working data is back to its original // state, to make sure we're ready for the next context switch void PostProcessingMainQueue(); // Moving the relevant data to host, ie the data that will be needed in @@ -344,62 +352,56 @@ class CudaDecoder { // CheckOverflow // If a kernel sets the flag h_q_overflow, we send a warning to stderr // Overflows are detected and prevented on the device. It only means - // that we've discarded the tokens that were created after the queue was full - // That's why we only send a warning. It is not a fatal error + // that we've discarded the tokens that were created after the queue was + // full That's why we only send a warning. It is not a fatal error void CheckOverflow(); - // Evaluates the function func for each lane, returning the max of all return - // values - // (func returns int32) - // Used for instance to ge the max number of arcs for all lanes - // func is called with h_lanes_counters_[ilane] for each lane. - // h_lanes_counters_ - // must be ready to be used when calling GetMaxForAllLanes (you might want to + // Evaluates the function func for each lane, returning the max of all + // return values (func returns int32) Used for instance to ge the max + // number of arcs for all lanes func is called with + // h_lanes_counters_[ilane] for each lane. h_lanes_counters_ must be + // ready to be used when calling GetMaxForAllLanes (you might want to // call - // CopyLaneCountersToHost[A|]sync to make sure everything is ready first) + // CopyLaneCountersToHost[A|]sync to make sure everything is ready + // first) int32 GetMaxForAllLanes(std::function func); // Copy the lane counters back to host, async or sync - // The lanes counters contain all the information such as main_q_end (number - // of tokens in the main_q) - // main_q_narcs (number of arcs) during the computation. That's why we - // frequently copy it back to host - // to know what to do next + // The lanes counters contain all the information such as main_q_end + // (number of tokens in the main_q) main_q_narcs (number of arcs) during + // the computation. That's why we frequently copy it back to host to + // know what to do next void CopyLaneCountersToHostAsync(); void CopyLaneCountersToHostSync(); - // The selected tokens for each frame will be copied back to host. We will - // store them on host memory, and we wil use them to create the final lattice - // once we've reached the last frame - // We will also copy information on those tokens that we've generated on the - // device, such as which tokens are associated to the same FST state in the - // same frame, or their extra cost. - // We cannot call individuals Device2Host copies for each channel, because it - // would lead to a lot of small copies, reducing performance. Instead we - // concatenate all channels data into a single - // continuous array, copy that array to host, then unpack it to the individual - // channel vectors - // The first step (pack then copy to host, async) is done in - // ConcatenateData - // The second step is done in LaunchD2H and sLaunchH2HCopies - // A sync on cudaStream st has to happen between the two functions to make - // sure that the copy is done + // The selected tokens for each frame will be copied back to host. We + // will store them on host memory, and we wil use them to create the + // final lattice once we've reached the last frame We will also copy + // information on those tokens that we've generated on the device, such + // as which tokens are associated to the same FST state in the same + // frame, or their extra cost. We cannot call individuals Device2Host + // copies for each channel, because it would lead to a lot of small + // copies, reducing performance. Instead we concatenate all channels + // data into a single continuous array, copy that array to host, then + // unpack it to the individual channel vectors The first step (pack then + // copy to host, async) is done in ConcatenateData The second step is + // done in LaunchD2H and sLaunchH2HCopies A sync on cudaStream st has to + // happen between the two functions to make sure that the copy is done // // Each lane contains X elements to be copied, where X = func(ilane) - // That data is contained in the array (pointer, X), with pointer = src[ilane] - // It will be concatenated in d_concat on device, then copied async into - // h_concat - // That copy is launched on stream st - // The offset of the data of each lane in the concatenate array is saved in + // That data is contained in the array (pointer, X), with pointer = + // src[ilane] It will be concatenated in d_concat on device, then copied + // async into h_concat That copy is launched on stream st The offset of + // the data of each lane in the concatenate array is saved in // *lanes_offsets_ptr // it will be used for unpacking in MoveConcatenatedCopyToVector // // func is called with h_lanes_counters_[ilane] for each lane. // h_lanes_counters_ - // must be ready to be used when calling GetMaxForAllLanes (you might want to - // call - // CopyLaneCountersToHost[A|]sync to make sure everything is ready first) - // Concatenate data on device before calling the D2H copies + // must be ready to be used when calling GetMaxForAllLanes (you might + // want to call CopyLaneCountersToHost[A|]sync to make sure everything + // is ready first) Concatenate data on device before calling the D2H + // copies void ConcatenateData(); - // Start the D2H copies used to send data back to host at the end of each - // frames + // Start the D2H copies used to send data back to host at the end of + // each frames void LaunchD2HCopies(); // ComputeH2HCopies // At the end of each frame, we copy data back to host @@ -408,8 +410,8 @@ class CudaDecoder { // This is done by ComputeH2HCopies void ComputeH2HCopies(); // Takes care of preparing the data for ComputeH2HCopies - // and check whether we can use the threadpool or we have to do the work on - // the current thread + // and check whether we can use the threadpool or we have to do the work + // on the current thread void LaunchH2HCopies(); // Function called by the CPU worker threads // Calls ComputeH2HCopies when triggered @@ -426,8 +428,8 @@ class CudaDecoder { // Computes a set of static asserts on the static values // In theory we should do them at compile time void CheckStaticAsserts(); - // Can be called in GetRawLattice to do a bunch of deep asserts on the data - // Slow, so disabled by default + // Can be called in GetRawLattice to do a bunch of deep asserts on the + // data Slow, so disabled by default void DebugValidateLattice(); // @@ -439,19 +441,19 @@ class CudaDecoder { const CudaFst fst_; // Counters used by a decoder lane // Contains all the single values generated during computation, - // such as the current size of the main_q, the number of arcs currently in - // that queue - // We load data from the channel state during context-switch (for instance the - // size of the last token queue for that channel) + // such as the current size of the main_q, the number of arcs currently + // in that queue We load data from the channel state during + // context-switch (for instance the size of the last token queue for + // that channel) HostLaneMatrix h_lanes_counters_; // Counters of channels - // Contains all the single values saved to remember the state of a channel - // not used during computation. Those values are loaded/saved into/from a lane - // during context switching + // Contains all the single values saved to remember the state of a + // channel not used during computation. Those values are loaded/saved + // into/from a lane during context switching ChannelCounters *h_channels_counters_; - // Contain the various counters used by lanes/channels, such as main_q_end, - // main_q_narcs. On device memory (equivalent of h_channels_counters on - // device) + // Contain the various counters used by lanes/channels, such as + // main_q_end, main_q_narcs. On device memory (equivalent of + // h_channels_counters on device) DeviceChannelMatrix d_channels_counters_; DeviceLaneMatrix d_lanes_counters_; // Number of lanes and channels, as defined in the constructor arguments @@ -463,20 +465,19 @@ class CudaDecoder { // - the auxiliary queue // // The auxiliary queue is used to store the raw output of ExpandArcs. - // We then prune that aux queue (and apply max-active) and move the survival - // tokens in the main queue. - // Tokens stored in the main q can then be used to generate new tokens (using - // ExpandArcs) - // We also generate more information about what's in the main_q at the end of - // a frame (in PostProcessingMainQueue) + // We then prune that aux queue (and apply max-active) and move the + // survival tokens in the main queue. Tokens stored in the main q can + // then be used to generate new tokens (using ExpandArcs) We also + // generate more information about what's in the main_q at the end of a + // frame (in PostProcessingMainQueue) // // As a reminder, here's the data structure of a token : // // struct Token { state, cost, prev_token, arc_idx } // // Please keep in mind that this structure is also used in the context - // of lattice decoding. We are not storing a list of forward links like in the - // CPU decoder. A token stays an instanciation of an single arc. + // of lattice decoding. We are not storing a list of forward links like + // in the CPU decoder. A token stays an instanciation of an single arc. // // For performance reasons, we split the tokens in three parts : // { state } , { cost }, { prev_token, arc_idx } @@ -484,55 +485,52 @@ class CudaDecoder { // For instance, d_main_q_state[i], d_main_q_cost[i], d_main_q_info[i] // all refer to the same token (at index i) // The data structure InfoToken contains { prev_token, arc_idx } - // We also store the acoustic costs independently in d_main_q_acoustic_cost_ + // We also store the acoustic costs independently in + // d_main_q_acoustic_cost_ // // The data is eiher linked to a channel, or to a lane. // // Channel data (DeviceChannelMatrix): // - // The data linked with a channel contains the data of frame i we need to - // remember - // to compute frame i+1. It is the list of tokens from frame i, with some - // additional info - // (ie the prefix sum of the emitting arcs degrees from those tokens). - // We are only storing d_main_q_state_and_cost_ as channel data because that's - // all we need in a token to compute - // frame i+1. We don't need token.arc_idx or token.prev_token. - // The reason why we also store that prefix sum is because we do the emitting - // preprocessing - // at the end of frame i. The reason for that is that we need infos from the - // hashmap to do that preprocessing. - // The hashmap is always cleared at the end of a frame. So we need to do the - // preprocessing at the end of frame i, - // and then save d_main_q_degrees_prefix_sum_. d_main_q_arc_offsets is - // generated also during preprocessing. + // The data linked with a channel contains the data of frame i we need + // to remember to compute frame i+1. It is the list of tokens from frame + // i, with some additional info (ie the prefix sum of the emitting arcs + // degrees from those tokens). We are only storing + // d_main_q_state_and_cost_ as channel data because that's all we need + // in a token to compute frame i+1. We don't need token.arc_idx or + // token.prev_token. The reason why we also store that prefix sum is + // because we do the emitting preprocessing at the end of frame i. The + // reason for that is that we need infos from the hashmap to do that + // preprocessing. The hashmap is always cleared at the end of a frame. + // So we need to do the preprocessing at the end of frame i, and then + // save d_main_q_degrees_prefix_sum_. d_main_q_arc_offsets is generated + // also during preprocessing. // // Lane data (DeviceLaneMatrix): // - // The lane data is everything we use during computation, but which we reset - // at the end of each frame. - // For instance we use a hashmap at some point during the computation, but at - // the end of each frame we reset it. That - // way that hashmap is able to compute whichever channel the next time - // AdvanceDecoding is called. The reasons why we do that is : - // - // - We use context switching. Before and after every frames, we can do a - // context switching. Which means that a lane cannot save a channel's state - // in any way once AdvanceDecoding returns. e.g., during a call of - // AdvanceDecoding, ilane=2 may compute 5 frames from channel=57 (as defined - // in the std::vector channels). - // In the next call, the same ilane=2 may compute 10 frames from channel=231. - // A lane data has to be reset to its original state at the end of each + // The lane data is everything we use during computation, but which we + // reset at the end of each frame. For instance we use a hashmap at some + // point during the computation, but at the end of each frame we reset + // it. That way that hashmap is able to compute whichever channel the + // next time AdvanceDecoding is called. The reasons why we do that is : + // + // - We use context switching. Before and after every frames, we can do + // a context switching. Which means that a lane cannot save a channel's + // state in any way once AdvanceDecoding returns. e.g., during a call of + // AdvanceDecoding, ilane=2 may compute 5 frames from channel=57 (as + // defined in the std::vector channels). In the next call, + // the same ilane=2 may compute 10 frames from channel=231. A lane data + // has to be reset to its original state at the end of each // AdvanceDecoding call. - // If somehow some data has to be saved, it needs to be declared as channel - // data. - // - // - The reason why we make the distinction between lane and channel data (in - // theory everything could be consider channel data), is because - // a lane uses more memory than a channel. In the context of online decoding, - // we need to create a lot channels, and we need them to be as small as - // possible in memory. - // Everything that can be reused between channels is stored as lane data. + // If somehow some data has to be saved, it needs to be declared as + // channel data. + // + // - The reason why we make the distinction between lane and channel + // data (in theory everything could be consider channel data), is + // because a lane uses more memory than a channel. In the context of + // online decoding, we need to create a lot channels, and we need them + // to be as small as possible in memory. Everything that can be reused + // between channels is stored as lane data. // // Channel data members: @@ -544,10 +542,9 @@ class CudaDecoder { // preprocess_in_place in PostProcessingMainQueue) DeviceChannelMatrix d_main_q_degrees_prefix_sum_; // d_main_q_arc_offsets[i] = fst_.arc_offsets[d_main_q_state[i]] - // we pay the price for the random memory accesses of fst_.arc_offsets in the - // preprocess kernel - // we cache the results in d_main_q_arc_offsets which will be read in a - // coalesced fashion in expand + // we pay the price for the random memory accesses of fst_.arc_offsets + // in the preprocess kernel we cache the results in d_main_q_arc_offsets + // which will be read in a coalesced fashion in expand DeviceChannelMatrix d_main_q_arc_offsets_; // @@ -566,43 +563,43 @@ class CudaDecoder { DeviceLaneMatrix d_main_q_acoustic_cost_; // At the end of a frame, we use a hashmap to detect the tokens that are // associated with the same FST state S - // We do it that the very end, to only use the hashmap on post-prune, post-max - // active tokens + // We do it that the very end, to only use the hashmap on post-prune, + // post-max active tokens DeviceLaneMatrix d_hashmap_values_; // Reminder: in the GPU lattice decoder, a token is always associated // to a single arc. Which means that multiple tokens in the same frame // can be associated with the same FST state. // - // We are NOT listing those duplicates as ForwardLinks in an unique meta-token - // like in the CPU lattice decoder + // We are NOT listing those duplicates as ForwardLinks in an unique + // meta-token like in the CPU lattice decoder // // When more than one token is associated to a single FST state, - // we will list those tokens into another list : d_main_q_extra_prev_tokens - // we will also save data useful in such a case, such as the extra_cost of a - // token compared to the best for that state + // we will list those tokens into another list : + // d_main_q_extra_prev_tokens we will also save data useful in such a + // case, such as the extra_cost of a token compared to the best for that + // state DeviceLaneMatrix d_main_q_extra_prev_tokens_; DeviceLaneMatrix d_main_q_extra_and_acoustic_cost_; // Histogram. Used to perform the histogram of the token costs // in the main_q. Used to perform a soft topk of the main_q (max-active) DeviceLaneMatrix d_histograms_; - // When filling the hashmap in PostProcessingMainQueue, we create a hashmap - // value for each FST state - // presents in the main_q (if at least one token is associated with that - // state) - // d_main_q_state_hash_idx_[token_idx] is the index of the state token.state - // in the hashmap - // Stored into a FSTStateHashIndex, which is actually a int32. - // FSTStateHashIndex should only - // be accessed through [Get|Set]FSTStateHashIndex, because it uses the bit - // sign to also remember if that token is the representative of that state. - // If only one token is associated with S, its representative will be itself + // When filling the hashmap in PostProcessingMainQueue, we create a + // hashmap value for each FST state presents in the main_q (if at least + // one token is associated with that state) + // d_main_q_state_hash_idx_[token_idx] is the index of the state + // token.state in the hashmap Stored into a FSTStateHashIndex, which is + // actually a int32. FSTStateHashIndex should only be accessed through + // [Get|Set]FSTStateHashIndex, because it uses the bit sign to also + // remember if that token is the representative of that state. If only + // one token is associated with S, its representative will be itself DeviceLaneMatrix d_main_q_state_hash_idx_; // local_idx of the extra cost list for a state - // For a given state S, first token associated with S will have local_idx=0 - // the second one local_idx=1, etc. The order of the local_idxs is random + // For a given state S, first token associated with S will have + // local_idx=0 the second one local_idx=1, etc. The order of the + // local_idxs is random DeviceLaneMatrix d_main_q_n_extra_prev_tokens_local_idx_; - // Where to write the extra_prev_tokens in the d_main_q_extra_prev_tokens_ - // queue + // Where to write the extra_prev_tokens in the + // d_main_q_extra_prev_tokens_ queue DeviceLaneMatrix d_main_q_extra_prev_tokens_prefix_sum_; // Used when computing the prefix_sums in preprocess_in_place. Stores // the local_sums per CTA @@ -616,13 +613,11 @@ class CudaDecoder { DeviceLaneMatrix d_extra_prev_tokens_concat_matrix_; DeviceLaneMatrix d_acoustic_cost_concat_matrix_; DeviceLaneMatrix d_infotoken_concat_matrix_; - // We will list in d_list_final_tokens_in_main_q all tokens within [min_cost; - // min_cost+lattice_beam] - // It is used when calling GetBestCost - // We only use an interface here because we will actually reuse data from - // d_aux_q_state_and_cost - // We are done using the aux_q when GetBestCost is called, so we can reuse - // that memory + // We will list in d_list_final_tokens_in_main_q all tokens within + // [min_cost; min_cost+lattice_beam] It is used when calling GetBestCost + // We only use an interface here because we will actually reuse data + // from d_aux_q_state_and_cost We are done using the aux_q when + // GetBestCost is called, so we can reuse that memory HostLaneMatrix h_list_final_tokens_in_main_q_; // Parameters used by the kernels // DeviceParams contains all the parameters that won't change @@ -652,8 +647,8 @@ class CudaDecoder { // Static segment of the adaptive beam. Cf InitDeviceParams int32 adaptive_beam_static_segment_; // The first index of all the following vectors (or vector) - // is the ChannelId. e.g., to get the number of frames decoded in channel 2, - // look into num_frames_decoded_[2]. + // is the ChannelId. e.g., to get the number of frames decoded in + // channel 2, look into num_frames_decoded_[2]. // Keep track of the number of frames decoded in the current file. std::vector num_frames_decoded_; @@ -673,8 +668,9 @@ class CudaDecoder { // channel) bool worker_threads_running_; // For each channel, set by PrepareForGetRawLattice - // argmin cost, list of the tokens within [best_cost;best_cost+lattice_beam] - // and if we've reached a final token. Set by PrepareForGetRawLattice. + // argmin cost, list of the tokens within + // [best_cost;best_cost+lattice_beam] and if we've reached a final + // token. Set by PrepareForGetRawLattice. std::vector> h_all_argmin_cost_; std::vector>> h_all_final_tokens_list_; std::vector h_all_has_reached_final_; @@ -714,12 +710,11 @@ class CudaDecoder { // // A lattice state is defined by the pair (iframe, fst_state) // A token is associated to a lattice state (iframe, token.next_state) - // Multiple token in the same frame can be associated to the same lattice - // state - // (they all go to the same token.next_state) - // We need to quickly identify what is the lattice state of a token. - // We are able to do that through GetLatticeStateInternalId(token), - // which returns the internal unique ID for each lattice state for a token + // Multiple token in the same frame can be associated to the same + // lattice state (they all go to the same token.next_state) We need to + // quickly identify what is the lattice state of a token. We are able to + // do that through GetLatticeStateInternalId(token), which returns the + // internal unique ID for each lattice state for a token // // When we build the output lattice, we a get new lattice state // output_lattice_state = fst_out->AddState() @@ -733,32 +728,28 @@ class CudaDecoder { TokenId token_idx, InfoToken token); // Keeping track of a variety of info about states in the lattice - // - token_extra_cost. A path going from the current lattice_state to the - // end has an extra cost - // compared to the best path (which has an extra cost of 0). - // token_extra_cost is the minimum of the extra_cost of all paths going from - // the current lattice_state - // to the final frame. - // - fst_lattice_state is the StateId of the lattice_state in fst_out (in - // the output lattice). lattice_state is an internal state used in + // - token_extra_cost. A path going from the current lattice_state to + // the end has an extra cost compared to the best path (which has an + // extra cost of 0). token_extra_cost is the minimum of the extra_cost + // of all paths going from the current lattice_state to the final frame. + // - fst_lattice_state is the StateId of the lattice_state in fst_out + // (in the output lattice). lattice_state is an internal state used in // GetRawLattice. // - is_state_closed is true if the token_extra_cost has been read by // another token. It means that the - // token_extra_cost value has been used, and if we modify token_extra_cost - // again, we may need to recompute the current frame (so that everyone uses - // the latest - // token_extra_cost value) + // token_extra_cost value has been used, and if we modify + // token_extra_cost again, we may need to recompute the current frame + // (so that everyone uses the latest token_extra_cost value) struct RawLatticeState { CostType token_extra_cost; OutputLatticeState fst_lattice_state; bool is_state_closed; }; - // extra_cost_min_delta_ used in the must_replay_frame situation. Please read - // comments - // associated with must_replay_frame in GetRawLattice to understand what it - // does + // extra_cost_min_delta_ used in the must_replay_frame situation. Please + // read comments associated with must_replay_frame in GetRawLattice to + // understand what it does CostType extra_cost_min_delta_; - ThreadPool *thread_pool_; + ThreadPoolLight *thread_pool_; std::vector cpu_dedicated_threads_; int32 n_threads_used_; std::vector lanes2channels_todo_; @@ -777,11 +768,10 @@ class CudaDecoder { std::condition_variable init_decoding_h2h_done_; std::atomic active_wait_; bool h2h_threads_running_; - // Using the output from GetBestPath, we add the best tokens (as selected in - // GetBestCost) - // from the final frame to the output lattice. We also fill the data - // structures - // (such as q_curr_frame_todo_, or curr_f_raw_lattice_state_) accordingly + // Using the output from GetBestPath, we add the best tokens (as + // selected in GetBestCost) from the final frame to the output lattice. + // We also fill the data structures (such as q_curr_frame_todo_, or + // curr_f_raw_lattice_state_) accordingly void AddFinalTokensToLattice( ChannelId ichannel, std::vector> *q_curr_frame_todo, @@ -798,8 +788,8 @@ class CudaDecoder { InfoToken *list_prev_token, CostType *this_arc_prev_token_extra_cost, CostType *acoustic_cost, OutputLatticeState *lattice_src_state, bool *keep_arc, bool *dbg_found_zero); - // Add the arc to the lattice. Also updates what needs to be updated in the - // GetRawLattice datastructures. + // Add the arc to the lattice. Also updates what needs to be updated in + // the GetRawLattice datastructures. void AddArcToLattice( int32 list_arc_idx, TokenId list_prev_token_idx, InfoToken list_prev_token, int32 curr_frame_offset, @@ -822,10 +812,9 @@ class CudaDecoder { *curr_f_raw_lattice_state, CostType *token_extra_cost, OutputLatticeState *to_fst_lattice_state); - // A token is an instance of an arc. It goes to a FST state (token.next_state) - // Multiple token in the same frame can go to the same FST state. - // GetSameFSTStateTokenList - // returns that list + // A token is an instance of an arc. It goes to a FST state + // (token.next_state) Multiple token in the same frame can go to the + // same FST state. GetSameFSTStateTokenList returns that list void GetSameFSTStateTokenList(ChannelId ichannel, InfoToken &token, InfoToken **tok_beg, float2 **arc_extra_cost_beg, int32 *nprevs); diff --git a/src/cudadecoder/decodable-cumatrix.cc b/src/cudadecoder/decodable-cumatrix.cc index d7c1d0359a5..a4362c83b9d 100644 --- a/src/cudadecoder/decodable-cumatrix.cc +++ b/src/cudadecoder/decodable-cumatrix.cc @@ -16,9 +16,13 @@ * limitations under the License. */ +// +// Important: This file is deprecated and will be removed in a future release +// + #if HAVE_CUDA == 1 -#include "decodable-cumatrix.h" +#include "cudadecoder/decodable-cumatrix.h" namespace kaldi { namespace cuda_decoder { @@ -48,8 +52,8 @@ int32 DecodableCuMatrixMapped::NumIndices() const { } // returns cuda pointer to nnet3 output -BaseFloat * -DecodableCuMatrixMapped::GetLogLikelihoodsCudaPointer(int32 subsampled_frame) { +BaseFloat *DecodableCuMatrixMapped::GetLogLikelihoodsCudaPointer( + int32 subsampled_frame) { BaseFloat *frame_nnet3_out = (BaseFloat *)likes_->Data() + (subsampled_frame - frame_offset_) * likes_->Stride(); diff --git a/src/cudadecoder/decodable-cumatrix.h b/src/cudadecoder/decodable-cumatrix.h index d34079cc9c7..7f42151ed0f 100644 --- a/src/cudadecoder/decodable-cumatrix.h +++ b/src/cudadecoder/decodable-cumatrix.h @@ -16,6 +16,10 @@ * limitations under the License. */ +// +// Important: This file is deprecated and will be removed in a future release +// + #ifndef KALDI_CUDA_DECODER_DECODABLE_CUMATRIX_H_ #define KALDI_CUDA_DECODER_DECODABLE_CUMATRIX_H_ @@ -31,10 +35,11 @@ namespace cuda_decoder { an interface similar to the Decodable Interface */ class DecodableCuMatrixMapped : public CudaDecodableInterface { -public: - // This constructor creates an object that will not delete "likes" when done. - // the frame_offset is the frame the row 0 of 'likes' corresponds to, would be - // greater than one if this is not the first chunk of likelihoods. + public: + // This constructor creates an object that will not delete "likes" when + // done. the frame_offset is the frame the row 0 of 'likes' corresponds + // to, would be greater than one if this is not the first chunk of + // likelihoods. DecodableCuMatrixMapped(const TransitionModel &tm, const CuMatrixBase &likes, int32 frame_offset = 0); @@ -56,8 +61,8 @@ class DecodableCuMatrixMapped : public CudaDecodableInterface { // returns cuda pointer to nnet3 output virtual BaseFloat *GetLogLikelihoodsCudaPointer(int32 subsampled_frame); -private: - const TransitionModel &trans_model_; // for tid to pdf mapping + private: + const TransitionModel &trans_model_; // for tid to pdf mapping const CuMatrixBase *likes_; int32 frame_offset_; diff --git a/src/cudadecoder/thread-pool-light.h b/src/cudadecoder/thread-pool-light.h new file mode 100644 index 00000000000..7a1c2adb8e2 --- /dev/null +++ b/src/cudadecoder/thread-pool-light.h @@ -0,0 +1,169 @@ +// cudadecoder/cuda-decoder.h +// +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Hugo Braun, Justin Luitjens, Ryan Leary +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_CUDA_DECODER_THREAD_POOL_LIGHT_H_ +#define KALDI_CUDA_DECODER_THREAD_POOL_LIGHT_H_ + +#define KALDI_CUDA_DECODER_THREAD_POOL_QUEUE_FULL_WAIT_FOR_US 1000 + +#include +#include +#include +#include "util/stl-utils.h" + +namespace kaldi { +namespace cuda_decoder { + +struct ThreadPoolLightTask { + void (*func_ptr)(void *, uint64_t, uint64_t); + void *obj_ptr; + uint64_t arg1; + uint64_t arg2; +}; + +template +// Single producer, multiple consumer +class ThreadPoolLightSPMCQueue { + static const unsigned int QUEUE_MASK = QUEUE_SIZE - 1; + std::vector tasks_; + std::atomic back_; + std::atomic front_; + int inc(int curr) { return ((curr + 1) & QUEUE_MASK); } + + public: + ThreadPoolLightSPMCQueue() { + KALDI_ASSERT(QUEUE_SIZE > 1); + bool is_power_of_2 = ((QUEUE_SIZE & (QUEUE_SIZE - 1)) == 0); + KALDI_ASSERT(is_power_of_2); // validity of QUEUE_MASK + tasks_.resize(QUEUE_SIZE); + front_.store(0); + back_.store(0); + } + + bool TryPush(const ThreadPoolLightTask &task) { + int back = back_.load(std::memory_order_relaxed); + int next = inc(back); + if (next == front_.load(std::memory_order_acquire)) { + return false; // queue is full + } + tasks_[back] = task; + back_.store(next, std::memory_order_release); + + return true; + } + + bool TryPop(ThreadPoolLightTask *front_task) { + while (true) { + int front = front_.load(std::memory_order_relaxed); + if (front == back_.load(std::memory_order_acquire)) + return false; // queue is empty + *front_task = tasks_[front]; + if (front_.compare_exchange_weak(front, inc(front), + std::memory_order_release)) + return true; + } + } +}; + +class ThreadPoolLightWorker { + // Multi consumer queue, because worker can steal work + ThreadPoolLightSPMCQueue<512> queue_; + // If this thread has no more work to do, it will try to steal work from + // other + std::unique_ptr thread_; + bool run_thread_; + ThreadPoolLightTask curr_task_; + std::shared_ptr other_; + + void Work() { + while (run_thread_) { + if (queue_.TryPop(&curr_task_) || other_->TrySteal(&curr_task_)) { + // Not calling func_ptr as a member function, + // because we need to specialize the arguments + // anyway (we may want to ignore arg2, for + // instance) Using a wrapper func + (curr_task_.func_ptr)(curr_task_.obj_ptr, curr_task_.arg1, + curr_task_.arg2); + } else { + usleep(1000); // TODO + } + } + } + + protected: + // Another worker can steal a task from this queue + // This is done so that a very long task computed by one thread does not + // hold the entire threadpool to complete a time-sensitive task + bool TrySteal(ThreadPoolLightTask *task) { return queue_.TryPop(task); } + + public: + ThreadPoolLightWorker() : run_thread_(true), other_(NULL) {} + virtual ~ThreadPoolLightWorker() { Stop(); } + bool TryPush(const ThreadPoolLightTask &task) { return queue_.TryPush(task); } + void SetOtherWorkerToStealFrom( + const std::shared_ptr other) { + other_ = other; + } + void Start() { + KALDI_ASSERT("Please call SetOtherWorkerToStealFrom() first" && other_); + thread_.reset(new std::thread(&ThreadPoolLightWorker::Work, this)); + } + void Stop() { + run_thread_ = false; + thread_->join(); + } +}; + +class ThreadPoolLight { + std::vector> workers_; + int curr_iworker_; // next call on tryPush will post work on this + // worker + int nworkers_; + + public: + ThreadPoolLight(int32 nworkers = std::thread::hardware_concurrency()) + : curr_iworker_(0), nworkers_(nworkers) { + KALDI_ASSERT(nworkers > 1); + workers_.resize(nworkers); + for (int i = 0; i < workers_.size(); ++i) + workers_[i] = std::make_shared(); + + for (int i = 0; i < workers_.size(); ++i) { + int iother = (i + nworkers / 2) % nworkers; + workers_[i]->SetOtherWorkerToStealFrom(workers_[iother]); + workers_[i]->Start(); + } + } + + bool TryPush(const ThreadPoolLightTask &task) { + if (!workers_[curr_iworker_]->TryPush(task)) return false; + ++curr_iworker_; + if (curr_iworker_ == nworkers_) curr_iworker_ = 0; + return true; + } + + void Push(const ThreadPoolLightTask &task) { + // Could try another curr_iworker_ + while (!TryPush(task)) + usleep(KALDI_CUDA_DECODER_THREAD_POOL_QUEUE_FULL_WAIT_FOR_US); + } +}; + +} // end namespace cuda_decoder +} // end namespace kaldi + +#endif // KALDI_CUDA_DECODER_THREAD_POOL_H_ diff --git a/src/cudadecoder/thread-pool.h b/src/cudadecoder/thread-pool.h index 920ea6d3300..8c864ddb6b9 100644 --- a/src/cudadecoder/thread-pool.h +++ b/src/cudadecoder/thread-pool.h @@ -1,6 +1,6 @@ // cudadecoder/thread-pool.h // Source: https://github.com/progschj/ThreadPool -// Modified to add a priority queue +// Modified to add a priority queue // Ubtained under this license: /* Copyright (c) 2012 Jakob Progsch, Václav Zeman @@ -25,8 +25,12 @@ freely, subject to the following restrictions: distribution. */ -#ifndef KALDI_CUDA_DECODER_THREAD_POOL_H_ -#define KALDI_CUDA_DECODER_THREAD_POOL_H_ +// +// Important: This file is deprecated and will be removed in a future release +// + +#ifndef KALDI_CUDA_DECODER_DEPRECATED_THREAD_POOL_H_ +#define KALDI_CUDA_DECODER_DEPRECATED_THREAD_POOL_H_ #include #include @@ -43,10 +47,14 @@ namespace kaldi { namespace cuda_decoder { // C++ indexes enum 0,1,2... -enum ThreadPoolPriority { THREAD_POOL_LOW_PRIORITY, THREAD_POOL_NORMAL_PRIORITY, THREAD_POOL_HIGH_PRIORITY }; +enum ThreadPoolPriority { + THREAD_POOL_LOW_PRIORITY, + THREAD_POOL_NORMAL_PRIORITY, + THREAD_POOL_HIGH_PRIORITY +}; class ThreadPool { -public: + public: ThreadPool(size_t); template auto enqueue(ThreadPoolPriority priority, F &&f, Args &&... args) @@ -61,10 +69,11 @@ class ThreadPool { std::vector workers; // the task queue struct Task { - std::function func; - // Ordered first by priority, then FIFO order - // tasks created first will have a higher priority_with_fifo.second - std::pair priority_with_fifo; + std::function func; + // Ordered first by priority, then FIFO order + // tasks created first will have a higher + // priority_with_fifo.second + std::pair priority_with_fifo; }; friend bool operator<(const ThreadPool::Task &lhs, const ThreadPool::Task &rhs); @@ -92,7 +101,7 @@ inline ThreadPool::ThreadPool(size_t threads) for (;;) { Task task; - { + { std::unique_lock lock(this->queue_mutex); this->condition.wait( lock, [this] { return this->stop || !this->tasks.empty(); }); @@ -100,8 +109,8 @@ inline ThreadPool::ThreadPool(size_t threads) if (!tasks.empty()) { task = std::move(this->tasks.top()); this->tasks.pop(); + } } - } task.func(); } }); @@ -111,7 +120,8 @@ inline ThreadPool::ThreadPool(size_t threads) template auto ThreadPool::enqueue(F &&f, Args &&... args) -> std::future::type> { - return enqueue(THREAD_POOL_NORMAL_PRIORITY, std::forward(f), std::forward(args)...); + return enqueue(THREAD_POOL_NORMAL_PRIORITY, std::forward(f), + std::forward(args)...); } // add new work item to the pool @@ -128,8 +138,7 @@ auto ThreadPool::enqueue(ThreadPoolPriority priority, F &&f, Args &&... args) std::unique_lock lock(queue_mutex); // don't allow enqueueing after stopping the pool - if (stop) - throw std::runtime_error("enqueue on stopped ThreadPool"); + if (stop) throw std::runtime_error("enqueue on stopped ThreadPool"); Task task; task.func = [func]() { (*func)(); }; long long task_fifo_id = task_counter--; @@ -151,12 +160,10 @@ inline ThreadPool::~ThreadPool() { stop = true; } condition.notify_all(); - for (std::thread &worker : workers) - worker.join(); + for (std::thread &worker : workers) worker.join(); } } // end namespace cuda_decoder } // end namespace kaldi - #endif // KALDI_CUDA_DECODER_THREAD_POOL_H_ diff --git a/src/cudadecoderbin/Makefile b/src/cudadecoderbin/Makefile index 6d086c14fc7..9c7ec7837a7 100644 --- a/src/cudadecoderbin/Makefile +++ b/src/cudadecoderbin/Makefile @@ -13,7 +13,7 @@ endif LDFLAGS += $(CUDA_LDFLAGS) LDLIBS += $(CUDA_LDLIBS) -BINFILES = batched-wav-nnet3-cuda +BINFILES = batched-wav-nnet3-cuda batched-wav-nnet3-cuda2 batched-wav-nnet3-cuda-online OBJFILES = diff --git a/src/cudadecoderbin/batched-wav-nnet3-cuda-online.cc b/src/cudadecoderbin/batched-wav-nnet3-cuda-online.cc new file mode 100644 index 00000000000..f27eb54be6e --- /dev/null +++ b/src/cudadecoderbin/batched-wav-nnet3-cuda-online.cc @@ -0,0 +1,423 @@ +// cudadecoderbin/batched-wav-nnet3-cuda-online.cc +// +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Hugo Braun +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if HAVE_CUDA == 1 + +#include +#include +#include +#include +#include +#include "cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h" +#include "cudamatrix/cu-allocator.h" +#include "fstext/fstext-lib.h" +#include "lat/lattice-functions.h" +#include "nnet3/am-nnet-simple.h" +#include "nnet3/nnet-utils.h" +#include "util/kaldi-thread.h" + +using namespace kaldi; +using namespace cuda_decoder; + +// +// Binary for the online pipeline BatchedThreadedNnet3CudaOnlinePipeline +// Can serve both as a benchmarking tool and an example on how to call +// BatchedThreadedNnet3CudaOnlinePipeline +// + +// Prints some statistics based on latencies stored in latencies +void PrintLatencyStats(std::vector &latencies) { + if (latencies.empty()) return; + double total = std::accumulate(latencies.begin(), latencies.end(), 0.); + double avg = total / latencies.size(); + std::sort(latencies.begin(), latencies.end()); + + double nresultsf = static_cast(latencies.size()); + size_t per90i = static_cast(std::floor(90. * nresultsf / 100.)); + size_t per95i = static_cast(std::floor(95. * nresultsf / 100.)); + size_t per99i = static_cast(std::floor(99. * nresultsf / 100.)); + + double lat_90 = latencies[per90i]; + double lat_95 = latencies[per95i]; + double lat_99 = latencies[per99i]; + + KALDI_LOG << "Latencies (s):\tAvg\t\t90%\t\t95%\t\t99%"; + KALDI_LOG << std::fixed << std::setprecision(3) << "\t\t\t" << avg << "\t\t" + << lat_90 << "\t\t" << lat_95 << "\t\t" << lat_99; +} + +// time with arbitrary reference +double inline gettime_monotonic() { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + double time = ts.tv_sec; + time += (double)(ts.tv_nsec) / 1e9; + return time; +} + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace fst; + + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + + const char *usage = + "Reads in wav file(s) and simulates online " + "decoding with " + "neural nets\n" + "(nnet3 setup). Note: some configuration values " + "and inputs " + "are\n" + "set via config files whose filenames are passed " + "as " + "options\n" + "\n" + "Usage: batched-wav-nnet3-cuda [options] " + " " + " " + " \n"; + + std::string word_syms_rxfilename; + + bool write_lattice = true; + int num_todo = -1; + int niterations = 3; + int num_streaming_channels = 2000; + ParseOptions po(usage); + po.Register("write-lattice", &write_lattice, + "Output lattice to a file. Setting to " + "false is useful when " + "benchmarking"); + po.Register("word-symbol-table", &word_syms_rxfilename, + "Symbol table for words [for debug output]"); + po.Register("file-limit", &num_todo, + "Limits the number of files that are processed by " + "this driver. " + "After N files are processed the remaining files " + "are ignored. " + "Useful for profiling"); + po.Register("iterations", &niterations, + "Number of times to decode the corpus. Output will " + "be written " + "only once."); + po.Register("num-parallel-streaming-channels", &num_streaming_channels, + "Number of channels streaming in parallel"); + + // Multi-threaded CPU and batched GPU decoder + BatchedThreadedNnet3CudaOnlinePipelineConfig batched_decoder_config; + CuDevice::RegisterDeviceOptions(&po); + RegisterCuAllocatorOptions(&po); + batched_decoder_config.Register(&po); + + po.Read(argc, argv); + batched_decoder_config.num_channels = std::max( + batched_decoder_config.num_channels, 2 * num_streaming_channels); + + if (po.NumArgs() != 4) { + po.PrintUsage(); + return 1; + } + + g_cuda_allocator.SetOptions(g_allocator_options); + CuDevice::Instantiate().SelectGpuId("yes"); + CuDevice::Instantiate().AllowMultithreading(); + + std::string nnet3_rxfilename = po.GetArg(1), fst_rxfilename = po.GetArg(2), + wav_rspecifier = po.GetArg(3), clat_wspecifier = po.GetArg(4); + TransitionModel trans_model; + nnet3::AmNnetSimple am_nnet; + + // read transition model and nnet + bool binary; + Input ki(nnet3_rxfilename, &binary); + trans_model.Read(ki.Stream(), binary); + am_nnet.Read(ki.Stream(), binary); + SetBatchnormTestMode(true, &(am_nnet.GetNnet())); + SetDropoutTestMode(true, &(am_nnet.GetNnet())); + nnet3::CollapseModel(nnet3::CollapseModelConfig(), &(am_nnet.GetNnet())); + + CompactLatticeWriter clat_writer(clat_wspecifier); + std::mutex clat_writer_m; + + fst::Fst *decode_fst = + fst::ReadFstKaldiGeneric(fst_rxfilename); + + BatchedThreadedNnet3CudaOnlinePipeline cuda_pipeline( + batched_decoder_config, *decode_fst, am_nnet, trans_model); + + delete decode_fst; + + fst::SymbolTable *word_syms = NULL; + if (word_syms_rxfilename != "") { + if (!(word_syms = fst::SymbolTable::ReadText(word_syms_rxfilename))) + KALDI_ERR << "Could not read symbol " + "table from file " + << word_syms_rxfilename; + else { + // cuda_pipeline.SetSymbolTable(word_syms); + } + } + + int32 num_task_submitted = 0, num_err = 0; + double tot_like = 0.0; + int64 num_frames = 0; + double total_audio_not_starved = 0; + double total_compute_time_not_starved = 0; + + int chunk_length = cuda_pipeline.GetNSampsPerChunk(); + double chunk_seconds = cuda_pipeline.GetSecondsPerChunk(); + double seconds_per_sample = chunk_seconds / chunk_length; + + // pre-loading data + // we don't want to measure I/O + double total_audio = 0; + SequentialTableReader wav_reader(wav_rspecifier); + std::vector> all_wav; + std::vector all_wav_keys; + { + std::cout << "Loading eval dataset..." << std::flush; + for (; !wav_reader.Done(); wav_reader.Next()) { + std::string utt = wav_reader.Key(); + std::shared_ptr wave_data = std::make_shared(); + wave_data->Swap(&wav_reader.Value()); + all_wav.push_back(wave_data); + all_wav_keys.push_back(utt); + total_audio += wave_data->Duration(); + } + std::cout << "done" << std::endl; + } + total_audio *= niterations; + + struct Stream { + std::shared_ptr wav; + BatchedThreadedNnet3CudaOnlinePipeline::CorrelationID corr_id; + int offset; + double send_next_chunk_at; + double *latency_ptr; + + Stream(const std::shared_ptr &_wav, + BatchedThreadedNnet3CudaOnlinePipeline::CorrelationID _corr_id, + double *_latency_ptr) + : wav(_wav), corr_id(_corr_id), offset(0), latency_ptr(_latency_ptr) { + send_next_chunk_at = gettime_monotonic(); + } + + bool operator<(const Stream &other) { + return (send_next_chunk_at < other.send_next_chunk_at); + } + }; + nvtxRangePush("Global Timer"); + // starting timer here so we + // can measure throughput + // without allocation + // overheads + // using kaldi timer, which starts counting in the + // constructor + Timer timer; + double this_iteration_timer = timer.Elapsed(); + std::vector iteration_timer; + std::vector> curr_tasks, next_tasks; + curr_tasks.reserve(num_streaming_channels); + next_tasks.reserve(num_streaming_channels); + size_t all_wav_i = 0; + size_t all_wav_max = all_wav.size() * niterations; + std::vector latencies(all_wav_max); + BatchedThreadedNnet3CudaOnlinePipeline::CorrelationID correlation_id_cnt = + 0; + // Batch sent to online pipeline + std::vector + batch_corr_ids; + std::vector batch_is_first_chunk; + std::vector batch_is_last_chunk; + // Used when use_online_ivectors_ + std::vector> batch_wave_samples; + + double batch_valid_at = gettime_monotonic(); + bool pipeline_starved_warning_printed = false; + while (true) { + int this_iteration_total_samples = 0; + batch_valid_at = 0.; + while (curr_tasks.size() < num_streaming_channels && + all_wav_i < all_wav_max) { + // Creating new tasks + uint64_t corr_id = correlation_id_cnt++; + size_t all_wav_i_modulo = all_wav_i % (all_wav.size()); + double *latency_ptr = &latencies[all_wav_i]; + std::unique_ptr ptr( + new Stream(all_wav[all_wav_i_modulo], corr_id, latency_ptr)); + curr_tasks.emplace_back(std::move(ptr)); + + // If no channels are available, we will wait up + // to INT_MAX microseconds for a channel to + // become available. The reason why we can in + // theory have no channel available is because a + // channel is still in used when the last chunk + // has been processed but the lattice is still + // being generated This is why we set + // batched_decoder_config.num_channels strictly + // higher than num_streaming_channels + // If we want to ensure that we are never using + // more channels than num_streaming_channels, we + // can call WaitForLatticeCallbacks after each + // DecodeBatch. That way, we know TryInitCorrID + // will always have a channel available right + // away if batched_decoder_config.num_channels + // >= num_streaming_channels + KALDI_ASSERT(cuda_pipeline.TryInitCorrID(corr_id, INT_MAX)); + const std::string &utt = all_wav_keys[all_wav_i_modulo]; + size_t iteration = all_wav_i / all_wav.size(); + std::string key = + (iteration == 0) ? utt : (std::to_string(iteration) + "-" + utt); + cuda_pipeline.SetLatticeCallback( + corr_id, [&clat_writer, &clat_writer_m, key, write_lattice, + latency_ptr](CompactLattice &clat) { + if (write_lattice) { + std::lock_guard lk(clat_writer_m); + clat_writer.Write(key, clat); + } + double now = gettime_monotonic(); + *latency_ptr = now - *latency_ptr; + }); + ++all_wav_i; + ++num_task_submitted; + } + // If still empty, done + if (curr_tasks.empty()) break; + + std::sort(curr_tasks.begin(), curr_tasks.end()); + + for (size_t itask = 0; itask < curr_tasks.size(); ++itask) { + Stream &task = *(curr_tasks[itask]); + + SubVector data(task.wav->Data(), 0); + int32 samp_offset = task.offset; + int32 nsamp = data.Dim(); + int32 samp_remaining = nsamp - samp_offset; + int32 num_samp = + chunk_length < samp_remaining ? chunk_length : samp_remaining; + bool is_last_chunk = (chunk_length >= samp_remaining); + SubVector wave_part(data, samp_offset, num_samp); + bool is_first_chunk = (samp_offset == 0); + + task.offset += num_samp; + batch_valid_at = std::max(task.send_next_chunk_at, batch_valid_at); + this_iteration_total_samples += num_samp; + + batch_corr_ids.push_back(task.corr_id); + batch_is_first_chunk.push_back(is_first_chunk); + batch_is_last_chunk.push_back(is_last_chunk); + batch_wave_samples.push_back(wave_part); + + if (!is_last_chunk) { + next_tasks.push_back(std::move(curr_tasks[itask])); + } else { + *task.latency_ptr = task.send_next_chunk_at; + } + + task.send_next_chunk_at += chunk_seconds; + if (batch_corr_ids.size() == batched_decoder_config.max_batch_size || + (itask == (curr_tasks.size() - 1))) { + // Wait for batch to be valid + double now = gettime_monotonic(); + double wait_for = batch_valid_at - now; + if (wait_for > 0) usleep(wait_for * 1e6); + + cuda_pipeline.DecodeBatch(batch_corr_ids, batch_wave_samples, + batch_is_first_chunk, batch_is_last_chunk); + batch_corr_ids.clear(); + batch_is_first_chunk.clear(); + batch_is_last_chunk.clear(); + batch_wave_samples.clear(); + } + } + bool pipeline_starved = (curr_tasks.size() < num_streaming_channels); + if (pipeline_starved && !pipeline_starved_warning_printed) { + std::cout << "\nNote: Streaming the end of the " + "last " + "utterances. " + "Not enough unprocessed " + "utterances available to stream " + << num_streaming_channels + << " channels in parallel. The " + "pipeline is starved. Will now " + "stream partial batches while " + "still limiting I/O at realtime " + "speed. RTFX will drop. \n" + << std::endl; + pipeline_starved_warning_printed = true; + } + double curr_timer = timer.Elapsed(); + double diff = curr_timer - this_iteration_timer; + this_iteration_timer = curr_timer; + double this_iteration_total_seconds = + this_iteration_total_samples * seconds_per_sample; + if (!pipeline_starved) { + total_audio_not_starved += this_iteration_total_seconds; + total_compute_time_not_starved += diff; + } + double this_iteration_rtfx = this_iteration_total_seconds / diff; + if (pipeline_starved) std::cout << "STARVED: "; + std::cout << "Number of active streaming channels: " << std::setw(5) + << curr_tasks.size() << "\tInstant RTFX: " << std::setw(6) + << std::fixed << std::setprecision(1) << this_iteration_rtfx + << std::endl; + + curr_tasks.swap(next_tasks); + next_tasks.clear(); + } + cuda_pipeline.WaitForLatticeCallbacks(); + nvtxRangePop(); + + KALDI_LOG << "Decoded " << num_task_submitted << " utterances, " << num_err + << " with errors."; + KALDI_LOG << "Overall likelihood per frame was " << (tot_like / num_frames) + << " per frame over " << num_frames << " frames."; + + KALDI_LOG << "NON-STARVED:"; + KALDI_LOG << "\tThis section only concerns the part of the " + "computation " + "where we had enough active utterances to simulate " + << num_streaming_channels << " parallel clients. "; + KALDI_LOG << "\tIt corresponds to the throughput an online instance " + "can handle with all channels in use."; + KALDI_LOG << "\tTotal Compute Time: " << total_compute_time_not_starved; + KALDI_LOG << "\tTotal Audio Decoded: " << total_audio_not_starved; + KALDI_LOG << "\tRealTimeX: " + << total_audio_not_starved / total_compute_time_not_starved; + + KALDI_LOG << "OVERALL:"; + KALDI_LOG << "\tTotal Utterances Decoded: " << num_task_submitted; + KALDI_LOG << "\tTotal Audio Decoded: " << total_audio << " seconds"; + KALDI_LOG << "\tLatency stats:"; + PrintLatencyStats(latencies); + + delete word_syms; // will delete if non-NULL. + + clat_writer.Close(); + + cudaDeviceSynchronize(); + + return 0; + } catch (const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} // main() + +#endif // if HAVE_CUDA == 1 diff --git a/src/cudadecoderbin/batched-wav-nnet3-cuda.cc b/src/cudadecoderbin/batched-wav-nnet3-cuda.cc index bfe8d8a2ce6..46138116bd8 100644 --- a/src/cudadecoderbin/batched-wav-nnet3-cuda.cc +++ b/src/cudadecoderbin/batched-wav-nnet3-cuda.cc @@ -36,6 +36,9 @@ using namespace cuda_decoder; // Not using a semaphore because it is usually not necessary to wait #define KALDI_CUDA_DECODER_BIN_PIPELINE_FULL_SLEEP ((double)1 / 1e5) +// This pipeline is deprecated and will be removed. Please switch to +// batched-wav-nnet3-cuda2 + void GetDiagnosticsAndPrintOutput(const std::string &utt, const fst::SymbolTable *word_syms, const CompactLattice &clat, @@ -111,13 +114,18 @@ int main(int argc, char *argv[]) { typedef kaldi::int64 int64; const char *usage = - "Reads in wav file(s) and simulates online decoding with neural nets\n" - "(nnet3 setup), with optional iVector-based speaker adaptation and\n" - "optional endpointing. Note: some configuration values and inputs " + "Reads in wav file(s) and simulates online decoding with " + "neural nets\n" + "(nnet3 setup), with optional iVector-based speaker " + "adaptation and\n" + "optional endpointing. Note: some configuration values " + "and inputs " "are\n" - "set via config files whose filenames are passed as options\n" + "set via config files whose filenames are passed as " + "options\n" "\n" - "Usage: batched-wav-nnet3-cuda [options] " + "Usage: batched-wav-nnet3-cuda [options] " + " " " \n"; std::string word_syms_rxfilename; @@ -137,8 +145,10 @@ int main(int argc, char *argv[]) { po.Register("word-symbol-table", &word_syms_rxfilename, "Symbol table for words [for debug output]"); po.Register("file-limit", &num_todo, - "Limits the number of files that are processed by this driver. " - "After N files are processed the remaining files are ignored. " + "Limits the number of files that are processed by " + "this driver. " + "After N files are processed the remaining files " + "are ignored. " "Useful for profiling"); po.Register("iterations", &iterations, "Number of times to decode the corpus."); @@ -226,7 +236,7 @@ int main(int argc, char *argv[]) { std::string utt = wav_reader.Key(); std::string key = utt; - if (iterations > 0) { + if (iter > 0) { // make key unique for each iteration key = std::to_string(iter) + "-" + key; } @@ -234,42 +244,51 @@ int main(int argc, char *argv[]) { const WaveData &wave_data = wav_reader.Value(); if (iter == 0) { - // calculating number of utterances per iteration - // calculating total audio time per iteration + // calculating number of utterances per + // iteration calculating total audio + // time per iteration total_audio += wave_data.Duration(); } - // Creating a function alias for the callback function of that utterance + // Creating a function alias for the callback + // function of that utterance auto finish_one_decode_lamba = [ - // Capturing the arguments that will change by copy - utt, key, - // Capturing the const/global args by reference + // Capturing the arguments that will + // change by copy + utt, key, + // Capturing the const/global args by + // reference &word_syms, &cuda_pipeline, &stdout_mutex, &num_frames, - &clat_write_mutex, &clat_writer, &write_lattice, - &tot_like] - // The callback function receive the compact lattice as argument - // if determinize_lattice is true, it is a determinized lattice - // otherwise, it is a raw lattice converted to compact format + &clat_write_mutex, &clat_writer, &write_lattice, &tot_like] + // The callback function receive the compact + // lattice as argument if + // determinize_lattice is true, it is a + // determinized lattice otherwise, it is a + // raw lattice converted to compact format // through ConvertLattice (CompactLattice & clat_in) { - // Content of our callback function. Calling the general - // FinishOneDecode function with the proper arguments + // Content of our callback function. + // Calling the general + // FinishOneDecode function with the + // proper arguments FinishOneDecode( - // Captured arguments used to specialize FinishOneDecode for - // this task + // Captured arguments used to + // specialize FinishOneDecode + // for this task utt, key, word_syms, &cuda_pipeline, &num_frames, &tot_like, - &clat_writer, &clat_write_mutex, &stdout_mutex, - write_lattice, - // Generated lattice that will be passed once the task is + &clat_writer, &clat_write_mutex, &stdout_mutex, write_lattice, + // Generated lattice that will + // be passed once the task is // complete clat_in); }; - // Adding a new task. Once the output lattice is ready, it will call - // finish_one_decode_lamba - // Important : finish_one_decode_lamba is called in the threadpool. We - // need it to be threadsafe - // (use locks around relevant parts, like writing to I/O) + // Adding a new task. Once the output lattice is + // ready, it will call finish_one_decode_lamba + // Important : finish_one_decode_lamba is called + // in the threadpool. We need it to be + // threadsafe (use locks around relevant parts, + // like writing to I/O) cuda_pipeline.OpenDecodeHandle(key, wave_data, task_group, finish_one_decode_lamba); num_task_submitted++; @@ -277,7 +296,7 @@ int main(int argc, char *argv[]) { nvtxRangePop(); if (num_todo != -1 && num_task_submitted >= num_todo) break; } // end utterance loop - + std::string group_done; // Non-blocking way to check if a group is done // returns false if zero groups are ready @@ -294,9 +313,11 @@ int main(int argc, char *argv[]) { } // end iterations loop // We've submitted all tasks. Now waiting for them to complete - // We could also have called WaitForAllTasks and CloseAllDecodeHandles + // We could also have called WaitForAllTasks and + // CloseAllDecodeHandles while (num_groups_done < iterations) { - // WaitForAnyGroup is blocking. It will hold until one group is ready + // WaitForAnyGroup is blocking. It will hold until one + // group is ready std::string group_done = cuda_pipeline.WaitForAnyGroup(); cuda_pipeline.CloseAllDecodeHandlesForGroup(group_done); double total_time = timer.Elapsed(); @@ -324,7 +345,7 @@ int main(int argc, char *argv[]) { cuda_pipeline.Finalize(); cudaDeviceSynchronize(); - + delete word_syms; // will delete if non-NULL. return 0; diff --git a/src/cudadecoderbin/batched-wav-nnet3-cuda2.cc b/src/cudadecoderbin/batched-wav-nnet3-cuda2.cc new file mode 100644 index 00000000000..83f5b6a0650 --- /dev/null +++ b/src/cudadecoderbin/batched-wav-nnet3-cuda2.cc @@ -0,0 +1,203 @@ +// cudadecoderbin/batched-wav-nnet3-cuda2.cc +// +// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Hugo Braun +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#if HAVE_CUDA == 1 + +#include +#include +#include +#include +#include "cudadecoder/batched-threaded-nnet3-cuda-pipeline2.h" +#include "cudamatrix/cu-allocator.h" +#include "fstext/fstext-lib.h" +#include "lat/lattice-functions.h" +#include "nnet3/am-nnet-simple.h" +#include "nnet3/nnet-utils.h" +#include "util/kaldi-thread.h" + +using namespace kaldi; +using namespace cuda_decoder; + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace fst; + + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + + const char *usage = + "Reads in wav file(s) and decodes them with " + "neural nets\n" + "(nnet3 setup). Note: some configuration values " + "and inputs " + "are\n" + "set via config files whose filenames are passed as " + "options\n" + "\n" + "Usage: batched-wav-nnet3-cuda [options] " + " " + " \n"; + + std::string word_syms_rxfilename; + + bool write_lattice = true; + int num_todo = -1; + int iterations = 1; + ParseOptions po(usage); + po.Register("write-lattice", &write_lattice, + "Output lattice to a file. Setting to false is useful when " + "benchmarking"); + po.Register("word-symbol-table", &word_syms_rxfilename, + "Symbol table for words [for debug output]"); + po.Register("file-limit", &num_todo, + "Limits the number of files that are processed by " + "this driver. " + "After N files are processed the remaining files " + "are ignored. " + "Useful for profiling"); + po.Register("iterations", &iterations, + "Number of times to decode the corpus. Output will " + "be written " + "only once."); + + // Multi-threaded CPU and batched GPU decoder + BatchedThreadedNnet3CudaPipeline2Config batched_decoder_config; + CuDevice::RegisterDeviceOptions(&po); + RegisterCuAllocatorOptions(&po); + batched_decoder_config.Register(&po); + + po.Read(argc, argv); + + if (po.NumArgs() != 4) { + po.PrintUsage(); + return 1; + } + + g_cuda_allocator.SetOptions(g_allocator_options); + CuDevice::Instantiate().SelectGpuId("yes"); + CuDevice::Instantiate().AllowMultithreading(); + + std::string nnet3_rxfilename = po.GetArg(1), fst_rxfilename = po.GetArg(2), + wav_rspecifier = po.GetArg(3), clat_wspecifier = po.GetArg(4); + TransitionModel trans_model; + nnet3::AmNnetSimple am_nnet; + + // read transition model and nnet + bool binary; + Input ki(nnet3_rxfilename, &binary); + trans_model.Read(ki.Stream(), binary); + am_nnet.Read(ki.Stream(), binary); + SetBatchnormTestMode(true, &(am_nnet.GetNnet())); + SetDropoutTestMode(true, &(am_nnet.GetNnet())); + nnet3::CollapseModel(nnet3::CollapseModelConfig(), &(am_nnet.GetNnet())); + + CompactLatticeWriter clat_writer(clat_wspecifier); + std::mutex clat_writer_m; + + fst::Fst *decode_fst = + fst::ReadFstKaldiGeneric(fst_rxfilename); + + BatchedThreadedNnet3CudaPipeline2 cuda_pipeline( + batched_decoder_config, *decode_fst, am_nnet, trans_model); + + delete decode_fst; + + fst::SymbolTable *word_syms = NULL; + if (word_syms_rxfilename != "") { + if (!(word_syms = fst::SymbolTable::ReadText(word_syms_rxfilename))) + KALDI_ERR << "Could not read symbol table from file " + << word_syms_rxfilename; + else { + // cuda_pipeline.SetSymbolTable(word_syms); + } + } + + int32 num_task_submitted = 0, num_err = 0; + double tot_like = 0.0; + int64 num_frames = 0; + double total_audio = 0; + + nvtxRangePush("Global Timer"); + // starting timer here so we + // can measure throughput + // without allocation + // overheads + // using kaldi timer, which starts counting in the constructor + Timer timer; + std::vector iteration_timer; + for (int iter = 0; iter < iterations; iter++) { + num_task_submitted = 0; + SequentialTableReader wav_reader(wav_rspecifier); + for (; !wav_reader.Done(); wav_reader.Next()) { + std::string utt = wav_reader.Key(); + std::string key = utt; + if (iter > 0) key = std::to_string(iter) + "-" + key; + std::shared_ptr wave_data = std::make_shared(); + wave_data->Swap(&wav_reader.Value()); + if (iter == 0) { + // calculating number of utterances per + // iteration calculating total audio + // time per iteration + total_audio += wave_data->Duration(); + } + + cuda_pipeline.DecodeWithCallback( + wave_data, [&clat_writer, &clat_writer_m, key, + write_lattice](CompactLattice &clat) { + if (write_lattice) { + std::lock_guard lk(clat_writer_m); + clat_writer.Write(key, clat); + } + }); + + num_task_submitted++; + if (num_todo != -1 && num_task_submitted >= num_todo) break; + } // end utterance loop + } // end iterations loop + + cuda_pipeline.WaitForAllTasks(); + + // number of seconds elapsed since the creation of timer + double total_time = timer.Elapsed(); + nvtxRangePop(); + + KALDI_LOG << "Decoded " << num_task_submitted << " utterances, " << num_err + << " with errors."; + KALDI_LOG << "Overall likelihood per frame was " << (tot_like / num_frames) + << " per frame over " << num_frames << " frames."; + + KALDI_LOG << "Overall: " + << " Aggregate Total Time: " << total_time + << " Total Audio: " << total_audio * iterations + << " RealTimeX: " << total_audio * iterations / total_time; + + delete word_syms; // will delete if non-NULL. + + clat_writer.Close(); + + cudaDeviceSynchronize(); + + return 0; + } catch (const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} // main() + +#endif // if HAVE_CUDA == 1