diff --git a/src/cudamatrix/cu-device.cc b/src/cudamatrix/cu-device.cc index 49c179b3673..140275d3b6e 100644 --- a/src/cudamatrix/cu-device.cc +++ b/src/cudamatrix/cu-device.cc @@ -110,6 +110,14 @@ void CuDevice::Initialize() { // Initialize CUBLAS. CUBLAS_SAFE_CALL(cublasCreate(&cublas_handle_)); CUBLAS_SAFE_CALL(cublasSetStream(cublas_handle_, cudaStreamPerThread)); + + if (device_options_.use_tensor_cores) { + // Enable tensor cores in CUBLAS + // Note if the device does not support tensor cores this will fall back to normal math mode + CUBLAS_SAFE_CALL(cublasSetMathMode(cublas_handle_, + CUBLAS_TENSOR_OP_MATH)); + } + // Initialize the cuSPARSE library CUSPARSE_SAFE_CALL(cusparseCreate(&cusparse_handle_)); CUSPARSE_SAFE_CALL(cusparseSetStream(cusparse_handle_, cudaStreamPerThread)); @@ -525,6 +533,8 @@ CuDevice::~CuDevice() { // Each thread has its own copy of the CuDevice object. // Note: this was declared "static". thread_local CuDevice CuDevice::this_thread_device_; + +CuDevice::CuDeviceOptions CuDevice::device_options_; // define and initialize the static members of the CuDevice object. int32 CuDevice::device_id_ = -1; diff --git a/src/cudamatrix/cu-device.h b/src/cudamatrix/cu-device.h index dc3df7e347d..8816f9d223b 100644 --- a/src/cudamatrix/cu-device.h +++ b/src/cudamatrix/cu-device.h @@ -184,8 +184,31 @@ class CuDevice { /// (i.e. from outside the class), call this only if Enabled() returns true. bool IsComputeExclusive(); + // Register command line options for CUDA device. + // This must be done before calling CuDevice::Initialize() + // Example: + // CuDevice::RegisterDeviceOptions(&po); + // po.Read(argc, argv); + // CuDevice::Initialize(); + static void RegisterDeviceOptions(OptionsItf *po) { + CuDevice::device_options_.Register(po); + } ~CuDevice(); private: + + struct CuDeviceOptions { + bool use_tensor_cores; // Enable tensor cores + CuDeviceOptions () : use_tensor_cores(false) {}; + void Register(OptionsItf *po) { + po->Register("cuda-use-tensor-cores", &use_tensor_cores, + "Enable FP16 tensor math. " + "This is higher performance but less accuracy. " + "This is only recommended for inference."); + } + }; + + static CuDeviceOptions device_options_; + // Default constructor used to initialize this_thread_device_ CuDevice(); CuDevice(CuDevice&); // Disallow. diff --git a/src/nnet3bin/nnet3-compute-batch.cc b/src/nnet3bin/nnet3-compute-batch.cc index b0001c96f57..5d4b9b1db48 100644 --- a/src/nnet3bin/nnet3-compute-batch.cc +++ b/src/nnet3bin/nnet3-compute-batch.cc @@ -80,6 +80,10 @@ int main(int argc, char *argv[]) { "priors stored with the model (in this case, " "a .mdl file is expected as input)."); +#if HAVE_CUDA==1 + CuDevice::RegisterDeviceOptions(&po); +#endif + po.Read(argc, argv); if (po.NumArgs() != 3) { diff --git a/src/nnet3bin/nnet3-compute.cc b/src/nnet3bin/nnet3-compute.cc index 45fde99a4f5..cf133025aae 100644 --- a/src/nnet3bin/nnet3-compute.cc +++ b/src/nnet3bin/nnet3-compute.cc @@ -78,6 +78,10 @@ int main(int argc, char *argv[]) { "priors stored with the model (in this case, " "a .mdl file is expected as input)."); +#if HAVE_CUDA==1 + CuDevice::RegisterDeviceOptions(&po); +#endif + po.Read(argc, argv); if (po.NumArgs() != 3) { diff --git a/src/nnet3bin/nnet3-latgen-faster-batch.cc b/src/nnet3bin/nnet3-latgen-faster-batch.cc index fad2d5ed356..ec52cff9776 100644 --- a/src/nnet3bin/nnet3-latgen-faster-batch.cc +++ b/src/nnet3bin/nnet3-latgen-faster-batch.cc @@ -108,6 +108,10 @@ int main(int argc, char *argv[]) { po.Register("use-gpu", &use_gpu, "yes|no|optional|wait, only has effect if compiled with CUDA"); +#if HAVE_CUDA==1 + CuDevice::RegisterDeviceOptions(&po); +#endif + po.Read(argc, argv); if (po.NumArgs() != 4) { diff --git a/src/nnet3bin/nnet3-xvector-compute.cc b/src/nnet3bin/nnet3-xvector-compute.cc index a4bc89a7def..e327681cf9b 100644 --- a/src/nnet3bin/nnet3-xvector-compute.cc +++ b/src/nnet3bin/nnet3-xvector-compute.cc @@ -113,6 +113,10 @@ int main(int argc, char *argv[]) { po.Register("pad-input", &pad_input, "If true, duplicate the first and " "last frames of the input features as required to equal min-chunk-size."); +#if HAVE_CUDA==1 + CuDevice::RegisterDeviceOptions(&po); +#endif + po.Read(argc, argv); if (po.NumArgs() != 3) {