Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Bert gemms true fp16 #17466

Merged
merged 10 commits into from
Apr 7, 2020
4 changes: 4 additions & 0 deletions docs/static_site/src/pages/api/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,10 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`.
- Values: 0(false) or 1(true) ```(default=1)```
- This variable controls whether to use the MKL-DNN backend in fused RNN operator for CPU context. There are two fusion implementations of RNN operator in MXNet. The MKL-DNN implementation has a better performance than the naive one, but the latter is more stable in the backward operation currently.

* MXNET_FC_TRUE_FP16
- Values: 0(false) or 1(true) ```(default=0)```
- If this variable is set to true, MXNet will perform fp16 accumulation when using cuBLAS and input datatype is set to float16. This could increase the speed of the computation, but might result in loss of accuracy. This makes this setting useful mainly for inference usecases.

Settings for Minimum Memory Usage
---------------------------------
- Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1```
Expand Down
30 changes: 26 additions & 4 deletions src/operator/contrib/transformer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,28 @@ void CublasStridedBatchedGemm(mshadow::Stream<gpu>* s, bool transA, bool transB,
<< "Must init CuBLAS handle in stream";

cublasHandle_t blas_handle = mshadow::Stream<gpu>::GetBlasHandle(s);
auto err = CUBLAS_STATUS_SUCCESS;
using TrueFP16Type = DType;
using PseudoFP16Type = typename CublasType<DType>::ScaleType;
// Set up alpha and beta values in the possible formats needed (only different when dtype == half)
TrueFP16Type trueFP16_alpha = static_cast<TrueFP16Type>(alpha);
TrueFP16Type trueFP16_beta = static_cast<TrueFP16Type>(beta);
PseudoFP16Type pseudoFP16_alpha = static_cast<PseudoFP16Type>(alpha);
PseudoFP16Type pseudoFP16_beta = static_cast<PseudoFP16Type>(beta);
const void *alpha_ptr;
const void *beta_ptr;
cudaDataType_t computeType;
bool use_true_fp16 = dmlc::GetEnv("MXNET_FC_TRUE_FP16", false);
if (use_true_fp16) {
alpha_ptr = &trueFP16_alpha;
beta_ptr = &trueFP16_beta;
computeType = CublasType<TrueFP16Type>::kCudaFlag;
} else {
alpha_ptr = &pseudoFP16_alpha;
beta_ptr = &pseudoFP16_beta;
computeType = CublasType<PseudoFP16Type>::kCudaFlag;
}

// cublasGemmStridedBatchedEx is only supported for GPU with architecture
// capabilities equal or greater than 5.0. Fall back to
// cublasSgemmStridedBatched, which doesn't support implicit conversion
Expand All @@ -59,12 +81,12 @@ void CublasStridedBatchedGemm(mshadow::Stream<gpu>* s, bool transA, bool transB,
CUBLAS_CALL(cublasGemmStridedBatchedEx(
blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB),
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k),
reinterpret_cast<void*>(&alpha),
alpha_ptr,
a, CublasType<DType>::kCudaFlag, static_cast<int>(lda), strideA,
b, CublasType<DType>::kCudaFlag, static_cast<int>(ldb), strideB,
reinterpret_cast<void*>(&beta),
beta_ptr,
c, CublasType<DType>::kCudaFlag, static_cast<int>(ldc), strideC,
static_cast<int>(batchCount), CUDA_R_32F, algo));
static_cast<int>(batchCount), computeType, algo));
} else {
if (std::is_same<DType, float>::value) {
CUBLAS_CALL(cublasSgemmStridedBatched(
Expand Down Expand Up @@ -124,7 +146,7 @@ void gemm_switch_fp32accum(mshadow::Stream<gpu>* s, bool transA, bool transB,
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {
CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, strideA, b, ldb,
strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
} else {
CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, strideA, b, ldb,
strideB, beta, c, ldc, strideC, batchCount);
Expand Down
47 changes: 33 additions & 14 deletions src/operator/linalg_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half:
mshadow::half::half_t beta,
bool tA, bool tB, Stream<gpu> *s) {
using namespace mxnet;
using namespace mxnet::common::cuda;
using mshadow::gpu;
CHECK_NOTNULL(s);
check_gemm(A, B, C, alpha, beta, tA, tB);
Expand All @@ -261,25 +262,43 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half:
auto previous_math_mode = SetCublasMathMode(blas_handle, cublas_math_mode);
#endif

// pseudo-fp16 (fp32 math with fp16 I/O)
float alpha_f = float(alpha); // NOLINT(*)
float beta_f = float(beta); // NOLINT(*)

// As of cuda8, cublas adopted the cuda datatype, rather than maintaining its own datatype.
// As of cuda8, cublas adopted the cuda datatype, rather than maintaining its own datatype.
#if CUDA_VERSION >= 8000
cudaDataType_t half_datatype = CUDA_R_16F;
#else
cublasDataType_t half_datatype = CUBLAS_DATA_HALF;
#endif
CUBLAS_CALL(cublasSgemmEx(blas_handle,
(tB ? CUBLAS_OP_T : CUBLAS_OP_N),
(tA ? CUBLAS_OP_T : CUBLAS_OP_N),
C.size(1), C.size(0), (tB ? B.size(1) : B.size(0)),
&alpha_f,
B.dptr_, half_datatype, B.stride_,
A.dptr_, half_datatype, A.stride_,
&beta_f,
C.dptr_, half_datatype, C.stride_));
auto algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
using TrueFP16Type = mshadow::half::half_t;
using PseudoFP16Type = typename CublasType<mshadow::half::half_t>::ScaleType;
TrueFP16Type trueFP16_alpha = static_cast<TrueFP16Type>(alpha);
TrueFP16Type trueFP16_beta = static_cast<TrueFP16Type>(beta);
PseudoFP16Type pseudoFP16_alpha = static_cast<PseudoFP16Type>(alpha);
PseudoFP16Type pseudoFP16_beta = static_cast<PseudoFP16Type>(beta);
const void *alpha_ptr;
const void *beta_ptr;
cudaDataType_t computeType;
bool use_true_fp16 = dmlc::GetEnv("MXNET_FC_TRUE_FP16", false);
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
if (use_true_fp16) {
alpha_ptr = &trueFP16_alpha;
beta_ptr = &trueFP16_beta;
computeType = CublasType<TrueFP16Type>::kCudaFlag;
} else {
alpha_ptr = &pseudoFP16_alpha;
beta_ptr = &pseudoFP16_beta;
computeType = CublasType<PseudoFP16Type>::kCudaFlag;
}

CUBLAS_CALL(cublasGemmEx(blas_handle,
(tB ? CUBLAS_OP_T : CUBLAS_OP_N),
(tA ? CUBLAS_OP_T : CUBLAS_OP_N),
C.size(1), C.size(0), (tB ? B.size(1) : B.size(0)),
alpha_ptr,
B.dptr_, half_datatype, B.stride_,
A.dptr_, half_datatype, A.stride_,
beta_ptr,
C.dptr_, half_datatype, C.stride_,
computeType, algo));
#if CUDA_VERSION >= 9000
SetCublasMathMode(blas_handle, previous_math_mode);
#endif
Expand Down