Skip to content

Commit

Permalink
[feat]: change matmul default setting: do not use tensor core, use ne…
Browse files Browse the repository at this point in the history
…w cublasGemmEx api (#10267)

Modify the default setting of matrix multiplication: do not utilize the
Tensor Core feature.
Add python api for matmul allow_tf32 like pytorch:
oneflow.backends.cuda.matmul.allow_tf32 = True
Add python api for matmul allow_fp16_reduced_precision_reduction like
pytorch:
oneflow.backends.cuda.matmul.allow_fp16_reduced_precision_reduction =
True
  • Loading branch information
lucky9-cyou authored Jun 7, 2023
1 parent 89b6916 commit 1cbe5f5
Show file tree
Hide file tree
Showing 8 changed files with 240 additions and 15 deletions.
39 changes: 39 additions & 0 deletions oneflow/api/python/ep/cuda_matmul_mode.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#include <memory>
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/ep/cuda/cuda_matmul_mode.h"

namespace py = pybind11;

namespace oneflow {

namespace ep {

ONEFLOW_API_PYBIND11_MODULE("ep", m) {
m.def("is_matmul_allow_tf32", &CudaMatmulMode::is_matmul_allow_tf32);
m.def("set_matmul_allow_tf32", &CudaMatmulMode::set_matmul_allow_tf32);
m.def("is_matmul_allow_fp16_reduced_precision_reduction",
&CudaMatmulMode::is_matmul_allow_fp16_reduced_precision_reduction);
m.def("set_matmul_allow_fp16_reduced_precision_reduction",
&CudaMatmulMode::set_matmul_allow_fp16_reduced_precision_reduction);
}

} // namespace ep

} // namespace oneflow
54 changes: 54 additions & 0 deletions oneflow/core/ep/cuda/cuda_matmul_mode.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#include "oneflow/core/ep/cuda/cuda_matmul_mode.h"

namespace oneflow {

namespace ep {

namespace {

bool* GetMatmulAllowTF32() {
static bool matmul_allow_tf32 = true;
return &matmul_allow_tf32;
}

bool* GetMatmulAllowFP16ReducedPrecisionReducton() {
static bool matmul_allow_fp16_reduced_precision_reduction = true;
return &matmul_allow_fp16_reduced_precision_reduction;
}

} // namespace

bool CudaMatmulMode::is_matmul_allow_tf32() { return *GetMatmulAllowTF32(); }

void CudaMatmulMode::set_matmul_allow_tf32(bool matmul_allow_tf32) {
*GetMatmulAllowTF32() = matmul_allow_tf32;
}

bool CudaMatmulMode::is_matmul_allow_fp16_reduced_precision_reduction() {
return *GetMatmulAllowFP16ReducedPrecisionReducton();
}

void CudaMatmulMode::set_matmul_allow_fp16_reduced_precision_reduction(
bool matmul_allow_fp16_reduced_precision_reduction) {
*GetMatmulAllowFP16ReducedPrecisionReducton() = matmul_allow_fp16_reduced_precision_reduction;
}

} // namespace ep

} // namespace oneflow
34 changes: 34 additions & 0 deletions oneflow/core/ep/cuda/cuda_matmul_mode.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#ifndef ONEFLOW_CORE_EP_CUDA_MATMUL_MODE_H_
#define ONEFLOW_CORE_EP_CUDA_MATMUL_MODE_H_

namespace oneflow {
namespace ep {

struct CudaMatmulMode {
static bool is_matmul_allow_tf32();
static void set_matmul_allow_tf32(bool matmul_allow_tf32);
static bool is_matmul_allow_fp16_reduced_precision_reduction();
static void set_matmul_allow_fp16_reduced_precision_reduction(
bool matmul_allow_fp16_reduced_precision_reduction);
};

} // namespace ep
} // namespace oneflow

#endif // ONEFLOW_CORE_EP_CUDA_MATMUL_MODE_H_
55 changes: 40 additions & 15 deletions oneflow/core/ep/cuda/primitive/broadcast_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ limitations under the License.
#include "oneflow/core/ep/include/primitive/broadcast_matmul.h"
#include "oneflow/core/ep/common/primitive/broadcast_matmul.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
#include "oneflow/core/ep/cuda/cuda_matmul_mode.h"
#include <cuda.h>

namespace oneflow {
Expand Down Expand Up @@ -60,37 +60,54 @@ union CublasScalarParameter {
half h;
};

CublasScalarParameter GetCublasScalarParameter(Scalar scalar, cudaDataType_t compute_type) {
CublasScalarParameter GetCublasScalarParameter(Scalar scalar, cublasComputeType_t compute_type) {
CublasScalarParameter sp{};
if (compute_type == CUDA_R_64F) {
if (compute_type == CUBLAS_COMPUTE_64F) {
sp.d = scalar.Value<double>();
} else if (compute_type == CUDA_R_32F) {
} else if (compute_type == CUBLAS_COMPUTE_32F_PEDANTIC
|| compute_type == CUBLAS_COMPUTE_32F_FAST_TF32
|| compute_type == CUBLAS_COMPUTE_32F) {
sp.s = scalar.Value<float>();
} else if (compute_type == CUDA_R_16F) {
} else if (compute_type == CUBLAS_COMPUTE_16F) {
sp.h = static_cast<half>(scalar.Value<float>());
} else {
UNIMPLEMENTED();
}
return sp;
}

cudaDataType_t GetComputeType(DataType data_type) {
cudaDataType_t GetCublasScalarType(DataType data_type) {
switch (data_type) {
case kFloat: return CUDA_R_32F;
case kDouble: return CUDA_R_64F;
default: return CUDA_R_32F;
}
}

cublasComputeType_t GetComputeType(DataType data_type, CudaStream* cuda_stream) {
switch (data_type) {
case kFloat: {
if (CudaMatmulMode::is_matmul_allow_tf32()) {
return CUBLAS_COMPUTE_32F_FAST_TF32;
} else {
// Starting with cuBLAS version 11.0.0, the library will automatically make use of Tensor
// Core capabilities wherever possible, unless they are explicitly disabled by selecting
// pedantic compute modes in cuBLAS
return CUBLAS_COMPUTE_32F_PEDANTIC;
}
}
case kDouble: return CUBLAS_COMPUTE_64F;
case kFloat16: {
const bool allow_half_accumulation =
ParseBooleanFromEnv("ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION", false);
if (allow_half_accumulation) {
return CUDA_R_16F;
if (cuda_stream->device_properties().major >= 5) {
return CUBLAS_COMPUTE_32F;
} else {
return CUDA_R_32F;
return CUBLAS_COMPUTE_16F;
}
}
#if CUDA_VERSION >= 11000
case kBFloat16: return CUDA_R_32F;
case kBFloat16: return CUBLAS_COMPUTE_32F;
#endif // CUDA_VERSION >= 11000
default: UNIMPLEMENTED(); return CUDA_R_32F;
default: UNIMPLEMENTED(); return CUBLAS_COMPUTE_32F;
}
}

Expand All @@ -102,7 +119,7 @@ void LaunchBroadcastMatmul(Stream* stream, DataType data_type, BlasTransposeType
Scalar beta, void* c) {
auto* cuda_stream = stream->As<CudaStream>();
const auto cuda_data_type = GetCudaDataType(data_type);
const auto compute_type = GetComputeType(data_type);
const auto compute_type = GetComputeType(data_type, cuda_stream);
const auto sp_alpha = GetCublasScalarParameter(alpha, compute_type);
const auto GetCublasOperation = [](BlasTransposeType transpose_type) {
if (transpose_type == BlasTransposeType::N) {
Expand Down Expand Up @@ -136,12 +153,19 @@ void LaunchBroadcastMatmul(Stream* stream, DataType data_type, BlasTransposeType
UNIMPLEMENTED();
}
const int cublas_ldc = n;

CublasMathModeGuard guard(cuda_stream->cublas_handle());
if (data_type == DataType::kFloat16) {
#if CUDA_VERSION < 11000
guard.SetMathMode(CUBLAS_TENSOR_OP_MATH);
#else
guard.SetMathMode(CUBLAS_DEFAULT_MATH);
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
if (cuda_stream->device_properties().major >= 5
&& CudaMatmulMode::is_matmul_allow_fp16_reduced_precision_reduction()) {
cublas_flags = static_cast<cublasMath_t>(cublas_flags
| CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
}
guard.SetMathMode(cublas_flags);
#endif // CUDA_VERSION < 11000
}
#if CUDA_VERSION >= 11000
Expand All @@ -150,6 +174,7 @@ void LaunchBroadcastMatmul(Stream* stream, DataType data_type, BlasTransposeType
cublasGemmAlgo_t algo =
(data_type == DataType::kFloat16) ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DEFAULT;
#endif

if (num_batch_dims == 1 && c_batch_dims[0] != 1) {
const void* cublas_a = b;
const void* cublas_b = a;
Expand Down
1 change: 1 addition & 0 deletions python/oneflow/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from . import cuda
from . import cudnn
from . import mps
40 changes: 40 additions & 0 deletions python/oneflow/backends/cuda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import oneflow._oneflow_internal


class cuMatmulMode:
def __getattr__(self, name):
if name == "allow_tf32":
return oneflow._oneflow_internal.ep.is_matmul_allow_tf32()
elif name == "allow_fp16_reduced_precision_reduction":
return (
oneflow._oneflow_internal.ep.is_matmul_allow_fp16_reduced_precision_reduction()
)
raise AssertionError("Unknown attribute " + name)

def __setattr__(self, name, value):
if name == "allow_tf32":
return oneflow._oneflow_internal.ep.set_matmul_allow_tf32(value)
elif name == "allow_fp16_reduced_precision_reduction":
return oneflow._oneflow_internal.ep.set_matmul_allow_fp16_reduced_precision_reduction(
value
)
raise AssertionError("Unknown attribute " + name)


matmul = cuMatmulMode()
3 changes: 3 additions & 0 deletions python/oneflow/test/expensive/test_conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
class TestConv3DModule(flow.unittest.TestCase):
@autotest(n=3)
def test_nn_functional_conv3d(test_case):
flow.backends.cuda.matmul.allow_tf32 = True
device = random_device()
img = torch.ones((1, 3, 16, 16, 16), requires_grad=True).to(device)
kernel = torch.ones((6, 3, 3, 3, 3), requires_grad=True).to(device)
Expand All @@ -32,6 +33,7 @@ def test_nn_functional_conv3d(test_case):

@autotest(n=10, rtol=1e-3, atol=1e-4)
def test_conv3d_with_random_data(test_case):
flow.backends.cuda.matmul.allow_tf32 = True
channels = random(1, 6)
m = torch.nn.Conv3d(
in_channels=channels,
Expand All @@ -53,6 +55,7 @@ def test_conv3d_with_random_data(test_case):
@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@autotest(n=5, check_allclose=False, rtol=1e-3)
def test_conv3d_group_with_random_data(test_case):
flow.backends.cuda.matmul.allow_tf32 = True
channels = 720 # lcm(1, 2, 3, 4, 5, 6)
m = torch.nn.Conv3d(
in_channels=channels,
Expand Down
29 changes: 29 additions & 0 deletions python/oneflow/test/modules/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import oneflow as flow
import oneflow.unittest
import torch as torch_original

from oneflow.test_utils.automated_test_util import *

Expand All @@ -33,6 +34,19 @@ def test_flow_matmul_with_random_data(test_case):
z = torch.matmul(x, y)
return z

@autotest(check_graph=True, rtol=1e-2, atol=1e-4)
def test_flow_tensor_matmul_with_random_data_allow_tf32(test_case):
flow.backends.cuda.matmul.allow_tf32 = True
torch_original.backends.cuda.matmul.allow_tf32 = True
device = random_device()
k = random(1, 6)
x = random_tensor(ndim=2, dim1=k).to(device)
y = random_tensor(ndim=2, dim0=k).to(device)
ret = x.matmul(y)
flow.backends.cuda.matmul.allow_tf32 = False
torch_original.backends.cuda.matmul.allow_tf32 = False
return ret

@autotest(check_graph=True, rtol=1e-2, atol=1e-4)
def test_flow_tensor_matmul_with_random_data(test_case):
device = random_device()
Expand All @@ -55,6 +69,21 @@ def test_flow_tensor_matmul_with_random_int_data(test_case):
np.allclose(flow_output_numpy, torch_output_numpy, 1e-05, 1e-05)
)

@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@autotest(n=5, check_graph=False)
def test_flow_tensor_matmul_with_random_fp16_data(test_case):
x = np.random.rand(3, 5)
y = np.random.rand(5, 4)
torch_x = torch.from_numpy(x).to(device=gpu_device(), dtype=torch.float16)
torch_y = torch.from_numpy(y).to(device=gpu_device(), dtype=torch.float16)
torch_output_numpy = torch_x.matmul(torch_y).cpu().numpy()
flow_x = flow.tensor(x).to(device="cuda", dtype=flow.float16)
flow_y = flow.tensor(y).to(device="cuda", dtype=flow.float16)
flow_output_numpy = flow_x.matmul(flow_y).cpu().numpy()
test_case.assertTrue(
np.allclose(flow_output_numpy, torch_output_numpy, 1e-05, 1e-05)
)

@autotest(n=5, check_graph=True, rtol=1e-2, atol=1e-3)
def test_flow_tensor_broadcast_matmul_with_random_data(test_case):
device = random_device()
Expand Down

0 comments on commit 1cbe5f5

Please sign in to comment.