Skip to content

Commit

Permalink
GPU gemms true fp16 (apache#17466)
Browse files Browse the repository at this point in the history
* Temporal solution for fp16 accumulation in Bert gemms

* Resolve alpha/beta type issue

* add documentation for env variable MXNET_FC_TRUE_FP16

* Improve description of env variable

* Add unitest checking environment variable

* keep pseudo-fp16 if architecture does not support Float16Compute

* Fix cpplint
  • Loading branch information
MoisesHer authored and Vladimir Cherepanov committed Apr 7, 2020
1 parent a5a100d commit 44486fc
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 13 deletions.
4 changes: 4 additions & 0 deletions docs/static_site/src/pages/api/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,10 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`.
- Values: 0(false) or 1(true) ```(default=1)```
- This variable controls whether to use the MKL-DNN backend in fused RNN operator for CPU context. There are two fusion implementations of RNN operator in MXNet. The MKL-DNN implementation has a better performance than the naive one, but the latter is more stable in the backward operation currently.

* MXNET_FC_TRUE_FP16
- Values: 0(false) or 1(true) ```(default=0)```
- If this variable is set to true, MXNet will perform fp16 accumulation when using cuBLAS and input datatype is set to float16. This could increase the speed of the computation, but might result in loss of accuracy. This makes this setting useful mainly for inference usecases.

Settings for Minimum Memory Usage
---------------------------------
- Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1```
Expand Down
30 changes: 26 additions & 4 deletions src/operator/contrib/transformer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,28 @@ void CublasStridedBatchedGemm(mshadow::Stream<gpu>* s, bool transA, bool transB,
<< "Must init CuBLAS handle in stream";

cublasHandle_t blas_handle = mshadow::Stream<gpu>::GetBlasHandle(s);
auto err = CUBLAS_STATUS_SUCCESS;
using TrueFP16Type = DType;
using PseudoFP16Type = typename CublasType<DType>::ScaleType;
// Set up alpha and beta values in the possible formats needed (only different when dtype == half)
TrueFP16Type trueFP16_alpha = static_cast<TrueFP16Type>(alpha);
TrueFP16Type trueFP16_beta = static_cast<TrueFP16Type>(beta);
PseudoFP16Type pseudoFP16_alpha = static_cast<PseudoFP16Type>(alpha);
PseudoFP16Type pseudoFP16_beta = static_cast<PseudoFP16Type>(beta);
const void *alpha_ptr;
const void *beta_ptr;
cudaDataType_t computeType;
bool use_true_fp16 = dmlc::GetEnv("MXNET_FC_TRUE_FP16", false);
if (use_true_fp16) {
alpha_ptr = &trueFP16_alpha;
beta_ptr = &trueFP16_beta;
computeType = CublasType<TrueFP16Type>::kCudaFlag;
} else {
alpha_ptr = &pseudoFP16_alpha;
beta_ptr = &pseudoFP16_beta;
computeType = CublasType<PseudoFP16Type>::kCudaFlag;
}

// cublasGemmStridedBatchedEx is only supported for GPU with architecture
// capabilities equal or greater than 5.0. Fall back to
// cublasSgemmStridedBatched, which doesn't support implicit conversion
Expand All @@ -59,12 +81,12 @@ void CublasStridedBatchedGemm(mshadow::Stream<gpu>* s, bool transA, bool transB,
CUBLAS_CALL(cublasGemmStridedBatchedEx(
blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB),
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k),
reinterpret_cast<void*>(&alpha),
alpha_ptr,
a, CublasType<DType>::kCudaFlag, static_cast<int>(lda), strideA,
b, CublasType<DType>::kCudaFlag, static_cast<int>(ldb), strideB,
reinterpret_cast<void*>(&beta),
beta_ptr,
c, CublasType<DType>::kCudaFlag, static_cast<int>(ldc), strideC,
static_cast<int>(batchCount), CUDA_R_32F, algo));
static_cast<int>(batchCount), computeType, algo));
} else {
if (std::is_same<DType, float>::value) {
CUBLAS_CALL(cublasSgemmStridedBatched(
Expand Down Expand Up @@ -124,7 +146,7 @@ void gemm_switch_fp32accum(mshadow::Stream<gpu>* s, bool transA, bool transB,
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {
CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, strideA, b, ldb,
strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
} else {
CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, strideA, b, ldb,
strideB, beta, c, ldc, strideC, batchCount);
Expand Down
53 changes: 44 additions & 9 deletions src/operator/linalg_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half:
mshadow::half::half_t beta,
bool tA, bool tB, Stream<gpu> *s) {
using namespace mxnet;
using namespace mxnet::common::cuda;
using mshadow::gpu;
CHECK_NOTNULL(s);
check_gemm(A, B, C, alpha, beta, tA, tB);
Expand All @@ -268,25 +269,59 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half:
auto previous_math_mode = SetCublasMathMode(blas_handle, cublas_math_mode);
#endif

// pseudo-fp16 (fp32 math with fp16 I/O)
float alpha_f = float(alpha); // NOLINT(*)
float beta_f = float(beta); // NOLINT(*)

// As of cuda8, cublas adopted the cuda datatype, rather than maintaining its own datatype.
// As of cuda8, cublas adopted the cuda datatype, rather than maintaining its own datatype.
#if CUDA_VERSION >= 8000
cudaDataType_t half_datatype = CUDA_R_16F;
#else
cublasDataType_t half_datatype = CUBLAS_DATA_HALF;
#endif
CUBLAS_CALL(cublasSgemmEx(blas_handle,
auto algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
using TrueFP16Type = mshadow::half::half_t;
using PseudoFP16Type = typename CublasType<mshadow::half::half_t>::ScaleType;
TrueFP16Type trueFP16_alpha = static_cast<TrueFP16Type>(alpha);
TrueFP16Type trueFP16_beta = static_cast<TrueFP16Type>(beta);
PseudoFP16Type pseudoFP16_alpha = static_cast<PseudoFP16Type>(alpha);
PseudoFP16Type pseudoFP16_beta = static_cast<PseudoFP16Type>(beta);
const void *alpha_ptr;
const void *beta_ptr;
cudaDataType_t computeType;
bool use_true_fp16 = dmlc::GetEnv("MXNET_FC_TRUE_FP16", false);
if (use_true_fp16) {
alpha_ptr = &trueFP16_alpha;
beta_ptr = &trueFP16_beta;
computeType = CublasType<TrueFP16Type>::kCudaFlag;
} else {
alpha_ptr = &pseudoFP16_alpha;
beta_ptr = &pseudoFP16_beta;
computeType = CublasType<PseudoFP16Type>::kCudaFlag;
}
if (SupportsFloat16Compute(s->dev_id)) {
CUBLAS_CALL(cublasGemmEx(blas_handle,
(tB ? CUBLAS_OP_T : CUBLAS_OP_N),
(tA ? CUBLAS_OP_T : CUBLAS_OP_N),
C.size(1), C.size(0), (tB ? B.size(1) : B.size(0)),
&alpha_f,
alpha_ptr,
B.dptr_, half_datatype, B.stride_,
A.dptr_, half_datatype, A.stride_,
&beta_f,
C.dptr_, half_datatype, C.stride_));
beta_ptr,
C.dptr_, half_datatype, C.stride_,
computeType, algo));
} else {
// pseudo-fp16 (fp32 math with fp16 I/O)
if (use_true_fp16)
common::LogOnce("MXNET_FC_TRUE_FP16 was set but this architecture does not support it.");
float alpha_f = static_cast<float>(alpha);
float beta_f = static_cast<float>(beta);
CUBLAS_CALL(cublasSgemmEx(blas_handle,
(tB ? CUBLAS_OP_T : CUBLAS_OP_N),
(tA ? CUBLAS_OP_T : CUBLAS_OP_N),
C.size(1), C.size(0), (tB ? B.size(1) : B.size(0)),
&alpha_f,
B.dptr_, half_datatype, B.stride_,
A.dptr_, half_datatype, A.stride_,
&beta_f,
C.dptr_, half_datatype, C.stride_));
}
#if CUDA_VERSION >= 9000
SetCublasMathMode(blas_handle, previous_math_mode);
#endif
Expand Down
21 changes: 21 additions & 0 deletions tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,27 @@ def test_symbol_block_symbolic_bn_fp16_cast():
y1 = net(x)
assert np.dtype(y1.dtype).name == 'float16'

@with_seed()
def test_gemms_true_fp16():
ctx = mx.gpu(0)
input = mx.nd.random.uniform(shape=(1, 512), dtype='float16', ctx=ctx)
weights = mx.nd.random.uniform(shape=(128, 512), ctx=ctx)

net = nn.Dense(128, in_units=512, use_bias=False)
net.cast('float16')
net.initialize(ctx=ctx)
net.weight.set_data(weights)
ref_results = net(input)

os.environ["MXNET_FC_TRUE_FP16"] = "1"
results_trueFP16 = net(input)
atol = 1e-2
rtol = 1e-2
assert_almost_equal(ref_results.asnumpy(), results_trueFP16.asnumpy(),
atol=atol, rtol=rtol)
os.environ["MXNET_FC_TRUE_FP16"] = "0"


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 44486fc

Please sign in to comment.