From 5dd326d63090eb97dde0871cefed3de131d06a37 Mon Sep 17 00:00:00 2001 From: moisesh Date: Tue, 28 Jan 2020 13:22:12 -0800 Subject: [PATCH 1/7] Temporal solution for fp16 accumulation in Bert gemms --- src/operator/contrib/transformer.cu | 39 +++++++++++++++++++++++++---- src/operator/linalg_impl.h | 22 ++++++++++++++-- 2 files changed, 54 insertions(+), 7 deletions(-) diff --git a/src/operator/contrib/transformer.cu b/src/operator/contrib/transformer.cu index e152669478dd..d8e710ae874f 100644 --- a/src/operator/contrib/transformer.cu +++ b/src/operator/contrib/transformer.cu @@ -51,16 +51,45 @@ void CublasStridedBatchedGemm(mshadow::Stream* s, bool transA, bool transB, cublasHandle_t blas_handle = mshadow::Stream::GetBlasHandle(s); auto err = CUBLAS_STATUS_SUCCESS; - // TODO(cfujitsang): handle computation_precision + auto computeType = CUDA_R_32F; + const void *alpha_ptr; + const void *beta_ptr; + if (dmlc::GetEnv("MXNET_FC_TRUE_FP16", false)){ + computeType = CUDA_R_16F; + alpha_ptr = &CublasType::one; + beta_ptr = &CublasType::zero; + __half alpha_h = __float2half(alpha); + __half beta_h = __float2half(beta); + alpha_ptr = static_cast<__half*>(&alpha_h); + beta_ptr = static_cast<__half*>(&beta_h); + }else{ + alpha_ptr = &CublasType::one; + beta_ptr = &CublasType::zero; + alpha_ptr = α + beta_ptr = β + } + /*if (dmlc::GetEnv("MXNET_FC_TRUE_FP16", false)){ + computeType = CUDA_R_16F; + printf("R_16F compute\n"); + __half alpha_h = __float2half(alpha); + __half beta_h = __float2half(beta); + alpha_ptr = static_cast<__half*>(&alpha_h); + beta_ptr = static_cast<__half*>(&beta_h); + }else{ + alpha_ptr = static_cast(&alpha); + beta_ptr = static_cast(&beta); + }*/ + std::cout << typeid(alpha_ptr).name() << '\n'; + err = cublasGemmStridedBatchedEx( blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB), static_cast(m), static_cast(n), static_cast(k), - reinterpret_cast(&alpha), + alpha_ptr, a, CublasType::kCudaFlag, static_cast(lda), strideA, b, CublasType::kCudaFlag, static_cast(ldb), strideB, - reinterpret_cast(&beta), + beta_ptr, c, CublasType::kCudaFlag, static_cast(ldc), strideC, - static_cast(batchCount), CUDA_R_32F, algo); + static_cast(batchCount), computeType, algo); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas gemmEx fail."; #else LOG(FATAL) << "Not implemented with CUDA < 9.1"; @@ -77,7 +106,7 @@ void gemm_switch_fp32accum(mshadow::Stream* s, bool transA, bool transB, cudaStream_t stream = mshadow::Stream::GetStream(s); if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, strideA, b, ldb, - strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); + strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_DEFAULT_TENSOR_OP); } else { CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h index d83eb0d08815..e4cc595679f6 100644 --- a/src/operator/linalg_impl.h +++ b/src/operator/linalg_impl.h @@ -258,6 +258,7 @@ void linalg_gemm(const Tensor= 9000 auto cublas_math_mode = GetEnvAllowTensorCore() ? CUBLAS_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; + //printf("ALLOW TENSORS? %i\n", GetEnvAllowTensorCore()); auto previous_math_mode = SetCublasMathMode(blas_handle, cublas_math_mode); #endif @@ -271,7 +272,24 @@ void linalg_gemm(const Tensor(const Tensor= 9000 SetCublasMathMode(blas_handle, previous_math_mode); #endif From b5e4cd559b0e1ef989193235eaa83dfc36f25a79 Mon Sep 17 00:00:00 2001 From: moisesh Date: Tue, 28 Jan 2020 18:00:05 -0800 Subject: [PATCH 2/7] Resolve alpha/beta type issue --- src/operator/contrib/transformer.cu | 43 ++++++++++-------------- src/operator/linalg_impl.h | 51 +++++++++++++++-------------- 2 files changed, 43 insertions(+), 51 deletions(-) diff --git a/src/operator/contrib/transformer.cu b/src/operator/contrib/transformer.cu index d8e710ae874f..b240b15a4485 100644 --- a/src/operator/contrib/transformer.cu +++ b/src/operator/contrib/transformer.cu @@ -51,35 +51,26 @@ void CublasStridedBatchedGemm(mshadow::Stream* s, bool transA, bool transB, cublasHandle_t blas_handle = mshadow::Stream::GetBlasHandle(s); auto err = CUBLAS_STATUS_SUCCESS; - auto computeType = CUDA_R_32F; + using TrueFP16Type = DType; + using PseudoFP16Type = typename CublasType::ScaleType; + // Set up alpha and beta values in the possible formats needed (only different when dtype == half) + TrueFP16Type trueFP16_alpha = static_cast(alpha); + TrueFP16Type trueFP16_beta = static_cast(beta); + PseudoFP16Type pseudoFP16_alpha = static_cast(alpha); + PseudoFP16Type pseudoFP16_beta = static_cast(beta); const void *alpha_ptr; const void *beta_ptr; - if (dmlc::GetEnv("MXNET_FC_TRUE_FP16", false)){ - computeType = CUDA_R_16F; - alpha_ptr = &CublasType::one; - beta_ptr = &CublasType::zero; - __half alpha_h = __float2half(alpha); - __half beta_h = __float2half(beta); - alpha_ptr = static_cast<__half*>(&alpha_h); - beta_ptr = static_cast<__half*>(&beta_h); - }else{ - alpha_ptr = &CublasType::one; - beta_ptr = &CublasType::zero; - alpha_ptr = α - beta_ptr = β + cudaDataType_t computeType; + bool use_true_fp16 = dmlc::GetEnv("MXNET_FC_TRUE_FP16", false); + if (use_true_fp16) { + alpha_ptr = &trueFP16_alpha; + beta_ptr = &trueFP16_beta; + computeType = CublasType::kCudaFlag; + } else { + alpha_ptr = &pseudoFP16_alpha; + beta_ptr = &pseudoFP16_beta; + computeType = CublasType::kCudaFlag; } - /*if (dmlc::GetEnv("MXNET_FC_TRUE_FP16", false)){ - computeType = CUDA_R_16F; - printf("R_16F compute\n"); - __half alpha_h = __float2half(alpha); - __half beta_h = __float2half(beta); - alpha_ptr = static_cast<__half*>(&alpha_h); - beta_ptr = static_cast<__half*>(&beta_h); - }else{ - alpha_ptr = static_cast(&alpha); - beta_ptr = static_cast(&beta); - }*/ - std::cout << typeid(alpha_ptr).name() << '\n'; err = cublasGemmStridedBatchedEx( blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB), diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h index e4cc595679f6..6ed26cffc0c2 100644 --- a/src/operator/linalg_impl.h +++ b/src/operator/linalg_impl.h @@ -249,6 +249,7 @@ void linalg_gemm(const Tensor *s) { using namespace mxnet; + using namespace mxnet::common::cuda; using mshadow::gpu; CHECK_NOTNULL(s); check_gemm(A, B, C, alpha, beta, tA, tB); @@ -258,46 +259,46 @@ void linalg_gemm(const Tensor= 9000 auto cublas_math_mode = GetEnvAllowTensorCore() ? CUBLAS_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; - //printf("ALLOW TENSORS? %i\n", GetEnvAllowTensorCore()); auto previous_math_mode = SetCublasMathMode(blas_handle, cublas_math_mode); #endif - // pseudo-fp16 (fp32 math with fp16 I/O) - float alpha_f = float(alpha); // NOLINT(*) - float beta_f = float(beta); // NOLINT(*) - - // As of cuda8, cublas adopted the cuda datatype, rather than maintaining its own datatype. +// As of cuda8, cublas adopted the cuda datatype, rather than maintaining its own datatype. #if CUDA_VERSION >= 8000 cudaDataType_t half_datatype = CUDA_R_16F; #else cublasDataType_t half_datatype = CUBLAS_DATA_HALF; #endif - //if (dmlc::GetEnv("MXNET_FC_TRUE_FP16", false)){ - printf("CUBLAS HAHAH HALF\n"); - auto algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; - __half alpha_h = __float2half(alpha); - __half beta_h = __float2half(beta); - //} + auto algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + using TrueFP16Type = mshadow::half::half_t; + using PseudoFP16Type = typename CublasType::ScaleType; + TrueFP16Type trueFP16_alpha = static_cast(alpha); + TrueFP16Type trueFP16_beta = static_cast(beta); + PseudoFP16Type pseudoFP16_alpha = static_cast(alpha); + PseudoFP16Type pseudoFP16_beta = static_cast(beta); + const void *alpha_ptr; + const void *beta_ptr; + cudaDataType_t computeType; + bool use_true_fp16 = dmlc::GetEnv("MXNET_FC_TRUE_FP16", false); + if (use_true_fp16) { + alpha_ptr = &trueFP16_alpha; + beta_ptr = &trueFP16_beta; + computeType = CublasType::kCudaFlag; + } else { + alpha_ptr = &pseudoFP16_alpha; + beta_ptr = &pseudoFP16_beta; + computeType = CublasType::kCudaFlag; + } + CUBLAS_CALL(cublasGemmEx(blas_handle, (tB ? CUBLAS_OP_T : CUBLAS_OP_N), (tA ? CUBLAS_OP_T : CUBLAS_OP_N), C.size(1), C.size(0), (tB ? B.size(1) : B.size(0)), - &alpha_h, + alpha_ptr, B.dptr_, half_datatype, B.stride_, A.dptr_, half_datatype, A.stride_, - &beta_h, + beta_ptr, C.dptr_, half_datatype, C.stride_, - CUDA_R_16F, algo)); - //print("NORMAL\n"); - /*CUBLAS_CALL(cublasSgemmEx(blas_handle, - (tB ? CUBLAS_OP_T : CUBLAS_OP_N), - (tA ? CUBLAS_OP_T : CUBLAS_OP_N), - C.size(1), C.size(0), (tB ? B.size(1) : B.size(0)), - &alpha_f, - B.dptr_, half_datatype, B.stride_, - A.dptr_, half_datatype, A.stride_, - &beta_f, - C.dptr_, half_datatype, C.stride_));*/ + computeType, algo)); #if CUDA_VERSION >= 9000 SetCublasMathMode(blas_handle, previous_math_mode); #endif From 308033d6c59141e82a4f0e1d401f7125e5c2d3da Mon Sep 17 00:00:00 2001 From: moisesh Date: Wed, 29 Jan 2020 12:42:27 -0800 Subject: [PATCH 3/7] add documentation for env variable MXNET_FC_TRUE_FP16 --- docs/static_site/src/pages/api/faq/env_var.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/static_site/src/pages/api/faq/env_var.md b/docs/static_site/src/pages/api/faq/env_var.md index 57ab27630a8f..f1b93c2b3dd3 100644 --- a/docs/static_site/src/pages/api/faq/env_var.md +++ b/docs/static_site/src/pages/api/faq/env_var.md @@ -353,6 +353,10 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`. - Values: 0(false) or 1(true) ```(default=1)``` - This variable controls whether to use the MKL-DNN backend in fused RNN operator for CPU context. There are two fusion implementations of RNN operator in MXNet. The MKL-DNN implementation has a better performance than the naive one, but the latter is more stable in the backward operation currently. +* MXNET_FC_TRUE_FP16 + - Values: 0(false) or 1(true) ```(default=0)``` + - If this variable is set to true, MXNet will performs true FP16 computation in CUBLAS gemms when input datatype is float16. + Settings for Minimum Memory Usage --------------------------------- - Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1``` From ac350c6c84ead4872dcd8a46aa39435f8a157b32 Mon Sep 17 00:00:00 2001 From: moisesh Date: Wed, 29 Jan 2020 15:18:38 -0800 Subject: [PATCH 4/7] Improve description of env variable --- docs/static_site/src/pages/api/faq/env_var.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/static_site/src/pages/api/faq/env_var.md b/docs/static_site/src/pages/api/faq/env_var.md index f1b93c2b3dd3..5dceaa235ed0 100644 --- a/docs/static_site/src/pages/api/faq/env_var.md +++ b/docs/static_site/src/pages/api/faq/env_var.md @@ -355,7 +355,7 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`. * MXNET_FC_TRUE_FP16 - Values: 0(false) or 1(true) ```(default=0)``` - - If this variable is set to true, MXNet will performs true FP16 computation in CUBLAS gemms when input datatype is float16. + - If this variable is set to true, MXNet will perform fp16 accumulation when using cuBLAS and input datatype is set to float16. This could increase the speed of the computation, but might result in loss of accuracy. This makes this setting useful mainly for inference usecases. Settings for Minimum Memory Usage --------------------------------- From e6547226e36d0b9927ad13be361029d9c5a3474c Mon Sep 17 00:00:00 2001 From: moisesh Date: Fri, 28 Feb 2020 15:10:43 -0800 Subject: [PATCH 5/7] Add unitest checking environment variable --- tests/python/gpu/test_gluon_gpu.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index aa56eee33dc4..def0d945e9a2 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -615,6 +615,26 @@ def test_symbol_block_symbolic_bn_fp16_cast(): y1 = net(x) assert np.dtype(y1.dtype).name == 'float16' +@with_seed() +def test_gemms_true_fp16(): + ctx = mx.gpu(0) + input = mx.nd.random.uniform(shape=(1, 512), dtype='float16', ctx=ctx) + weights = mx.nd.random.uniform(shape=(128, 512), ctx=ctx) + + net = nn.Dense(128, in_units=512, use_bias=False) + net.cast('float16') + net.initialize(ctx=ctx) + net.weight.set_data(weights) + ref_results = net(input) + + os.environ["MXNET_FC_TRUE_FP16"] = "1" + results_trueFP16 = net(input) + atol = 1e-2 + rtol = 1e-2 + assert_almost_equal(ref_results.asnumpy(), results_trueFP16.asnumpy(), + atol=atol, rtol=rtol) + + if __name__ == '__main__': import nose nose.runmodule() From 91aca7bc358f8347e0f526f0f909bdbadf0563a6 Mon Sep 17 00:00:00 2001 From: moisesh Date: Tue, 10 Mar 2020 18:04:13 -0700 Subject: [PATCH 6/7] keep pseudo-fp16 if architecture does not support Float16Compute --- src/operator/linalg_impl.h | 38 +++++++++++++++++++++--------- tests/python/gpu/test_gluon_gpu.py | 1 + 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h index 6ed26cffc0c2..75d2182990cd 100644 --- a/src/operator/linalg_impl.h +++ b/src/operator/linalg_impl.h @@ -288,17 +288,33 @@ void linalg_gemm(const Tensor::kCudaFlag; } - - CUBLAS_CALL(cublasGemmEx(blas_handle, - (tB ? CUBLAS_OP_T : CUBLAS_OP_N), - (tA ? CUBLAS_OP_T : CUBLAS_OP_N), - C.size(1), C.size(0), (tB ? B.size(1) : B.size(0)), - alpha_ptr, - B.dptr_, half_datatype, B.stride_, - A.dptr_, half_datatype, A.stride_, - beta_ptr, - C.dptr_, half_datatype, C.stride_, - computeType, algo)); + if (SupportsFloat16Compute(s->dev_id)) { + CUBLAS_CALL(cublasGemmEx(blas_handle, + (tB ? CUBLAS_OP_T : CUBLAS_OP_N), + (tA ? CUBLAS_OP_T : CUBLAS_OP_N), + C.size(1), C.size(0), (tB ? B.size(1) : B.size(0)), + alpha_ptr, + B.dptr_, half_datatype, B.stride_, + A.dptr_, half_datatype, A.stride_, + beta_ptr, + C.dptr_, half_datatype, C.stride_, + computeType, algo)); + } else { + // pseudo-fp16 (fp32 math with fp16 I/O) + if (use_true_fp16) + common::LogOnce("MXNET_FC_TRUE_FP16 was set but this architecture does not support it."); + float alpha_f = float(alpha); + float beta_f = float(beta); + CUBLAS_CALL(cublasSgemmEx(blas_handle, + (tB ? CUBLAS_OP_T : CUBLAS_OP_N), + (tA ? CUBLAS_OP_T : CUBLAS_OP_N), + C.size(1), C.size(0), (tB ? B.size(1) : B.size(0)), + &alpha_f, + B.dptr_, half_datatype, B.stride_, + A.dptr_, half_datatype, A.stride_, + &beta_f, + C.dptr_, half_datatype, C.stride_)); + } #if CUDA_VERSION >= 9000 SetCublasMathMode(blas_handle, previous_math_mode); #endif diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index def0d945e9a2..42a2424c7d9b 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -633,6 +633,7 @@ def test_gemms_true_fp16(): rtol = 1e-2 assert_almost_equal(ref_results.asnumpy(), results_trueFP16.asnumpy(), atol=atol, rtol=rtol) + os.environ["MXNET_FC_TRUE_FP16"] = "0" if __name__ == '__main__': From 60fadda6e265b958e26ef585e1bf792c96bb89e3 Mon Sep 17 00:00:00 2001 From: moisesh Date: Wed, 11 Mar 2020 10:00:00 -0700 Subject: [PATCH 7/7] Fix cpplint --- src/operator/linalg_impl.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h index 75d2182990cd..fd6800d184e4 100644 --- a/src/operator/linalg_impl.h +++ b/src/operator/linalg_impl.h @@ -303,8 +303,8 @@ void linalg_gemm(const Tensor(alpha); + float beta_f = static_cast(beta); CUBLAS_CALL(cublasSgemmEx(blas_handle, (tB ? CUBLAS_OP_T : CUBLAS_OP_N), (tA ? CUBLAS_OP_T : CUBLAS_OP_N),