Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Bert gemms true fp16 #17466

Merged
merged 10 commits into from
Apr 7, 2020
4 changes: 4 additions & 0 deletions docs/static_site/src/pages/api/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,10 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`.
- Values: 0(false) or 1(true) ```(default=1)```
- This variable controls whether to use the MKL-DNN backend in fused RNN operator for CPU context. There are two fusion implementations of RNN operator in MXNet. The MKL-DNN implementation has a better performance than the naive one, but the latter is more stable in the backward operation currently.

* MXNET_FC_TRUE_FP16
- Values: 0(false) or 1(true) ```(default=0)```
- If this variable is set to true, MXNet will performs true FP16 computation in CUBLAS gemms when input datatype is float16.
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved

Settings for Minimum Memory Usage
---------------------------------
- Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1```
Expand Down
30 changes: 25 additions & 5 deletions src/operator/contrib/transformer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,36 @@ void CublasStridedBatchedGemm(mshadow::Stream<gpu>* s, bool transA, bool transB,

cublasHandle_t blas_handle = mshadow::Stream<gpu>::GetBlasHandle(s);
auto err = CUBLAS_STATUS_SUCCESS;
// TODO(cfujitsang): handle computation_precision
using TrueFP16Type = DType;
using PseudoFP16Type = typename CublasType<DType>::ScaleType;
// Set up alpha and beta values in the possible formats needed (only different when dtype == half)
TrueFP16Type trueFP16_alpha = static_cast<TrueFP16Type>(alpha);
TrueFP16Type trueFP16_beta = static_cast<TrueFP16Type>(beta);
PseudoFP16Type pseudoFP16_alpha = static_cast<PseudoFP16Type>(alpha);
PseudoFP16Type pseudoFP16_beta = static_cast<PseudoFP16Type>(beta);
const void *alpha_ptr;
const void *beta_ptr;
cudaDataType_t computeType;
bool use_true_fp16 = dmlc::GetEnv("MXNET_FC_TRUE_FP16", false);
if (use_true_fp16) {
alpha_ptr = &trueFP16_alpha;
beta_ptr = &trueFP16_beta;
computeType = CublasType<TrueFP16Type>::kCudaFlag;
} else {
alpha_ptr = &pseudoFP16_alpha;
beta_ptr = &pseudoFP16_beta;
computeType = CublasType<PseudoFP16Type>::kCudaFlag;
}

err = cublasGemmStridedBatchedEx(
blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB),
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k),
reinterpret_cast<void*>(&alpha),
alpha_ptr,
a, CublasType<DType>::kCudaFlag, static_cast<int>(lda), strideA,
b, CublasType<DType>::kCudaFlag, static_cast<int>(ldb), strideB,
reinterpret_cast<void*>(&beta),
beta_ptr,
c, CublasType<DType>::kCudaFlag, static_cast<int>(ldc), strideC,
static_cast<int>(batchCount), CUDA_R_32F, algo);
static_cast<int>(batchCount), computeType, algo);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas gemmEx fail.";
#else
LOG(FATAL) << "Not implemented with CUDA < 9.1";
Expand All @@ -77,7 +97,7 @@ void gemm_switch_fp32accum(mshadow::Stream<gpu>* s, bool transA, bool transB,
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {
CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, strideA, b, ldb,
strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
} else {
CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, strideA, b, ldb,
strideB, beta, c, ldc, strideC, batchCount);
Expand Down
47 changes: 33 additions & 14 deletions src/operator/linalg_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half:
mshadow::half::half_t beta,
bool tA, bool tB, Stream<gpu> *s) {
using namespace mxnet;
using namespace mxnet::common::cuda;
using mshadow::gpu;
CHECK_NOTNULL(s);
check_gemm(A, B, C, alpha, beta, tA, tB);
Expand All @@ -261,25 +262,43 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half:
auto previous_math_mode = SetCublasMathMode(blas_handle, cublas_math_mode);
#endif

// pseudo-fp16 (fp32 math with fp16 I/O)
float alpha_f = float(alpha); // NOLINT(*)
float beta_f = float(beta); // NOLINT(*)

// As of cuda8, cublas adopted the cuda datatype, rather than maintaining its own datatype.
// As of cuda8, cublas adopted the cuda datatype, rather than maintaining its own datatype.
#if CUDA_VERSION >= 8000
cudaDataType_t half_datatype = CUDA_R_16F;
#else
cublasDataType_t half_datatype = CUBLAS_DATA_HALF;
#endif
CUBLAS_CALL(cublasSgemmEx(blas_handle,
(tB ? CUBLAS_OP_T : CUBLAS_OP_N),
(tA ? CUBLAS_OP_T : CUBLAS_OP_N),
C.size(1), C.size(0), (tB ? B.size(1) : B.size(0)),
&alpha_f,
B.dptr_, half_datatype, B.stride_,
A.dptr_, half_datatype, A.stride_,
&beta_f,
C.dptr_, half_datatype, C.stride_));
auto algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
using TrueFP16Type = mshadow::half::half_t;
using PseudoFP16Type = typename CublasType<mshadow::half::half_t>::ScaleType;
TrueFP16Type trueFP16_alpha = static_cast<TrueFP16Type>(alpha);
TrueFP16Type trueFP16_beta = static_cast<TrueFP16Type>(beta);
PseudoFP16Type pseudoFP16_alpha = static_cast<PseudoFP16Type>(alpha);
PseudoFP16Type pseudoFP16_beta = static_cast<PseudoFP16Type>(beta);
const void *alpha_ptr;
const void *beta_ptr;
cudaDataType_t computeType;
bool use_true_fp16 = dmlc::GetEnv("MXNET_FC_TRUE_FP16", false);
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
if (use_true_fp16) {
alpha_ptr = &trueFP16_alpha;
beta_ptr = &trueFP16_beta;
computeType = CublasType<TrueFP16Type>::kCudaFlag;
} else {
alpha_ptr = &pseudoFP16_alpha;
beta_ptr = &pseudoFP16_beta;
computeType = CublasType<PseudoFP16Type>::kCudaFlag;
}

CUBLAS_CALL(cublasGemmEx(blas_handle,
(tB ? CUBLAS_OP_T : CUBLAS_OP_N),
(tA ? CUBLAS_OP_T : CUBLAS_OP_N),
C.size(1), C.size(0), (tB ? B.size(1) : B.size(0)),
alpha_ptr,
B.dptr_, half_datatype, B.stride_,
A.dptr_, half_datatype, A.stride_,
beta_ptr,
C.dptr_, half_datatype, C.stride_,
computeType, algo));
#if CUDA_VERSION >= 9000
SetCublasMathMode(blas_handle, previous_math_mode);
#endif
Expand Down