diff --git a/CMakeLists.txt b/CMakeLists.txt index 6931b40c667d9..31b0a90ef29f3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -89,6 +89,7 @@ tvm_option(USE_CUDNN "Build with cuDNN" OFF) tvm_option(USE_CUBLAS "Build with cuBLAS" OFF) tvm_option(USE_CUTLASS "Build with CUTLASS" OFF) tvm_option(USE_THRUST "Build with Thrust" OFF) +tvm_option(USE_CURAND "Build with cuRAND" OFF) tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF) tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF) tvm_option(USE_SORT "Build with sort support" ON) diff --git a/cmake/config.cmake b/cmake/config.cmake index 212b565f25fbe..b9a3aaef7d7e8 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -296,6 +296,9 @@ set(USE_VTA_FPGA OFF) # Whether use Thrust set(USE_THRUST OFF) +# Whether use cuRAND +set(USE_CURAND OFF) + # Whether to build the TensorFlow TVMDSOOp module set(USE_TF_TVMDSOOP OFF) diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake index 056ed18d442e5..bbbf6b89ba2e3 100644 --- a/cmake/modules/CUDA.cmake +++ b/cmake/modules/CUDA.cmake @@ -69,6 +69,18 @@ if(USE_CUDA) list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC}) endif(USE_THRUST) + if(USE_CURAND) + message(STATUS "Build with cuRAND support") + message(STATUS "${CUDA_CURAND_LIBRARY}") + cmake_minimum_required(VERSION 3.13) # to compile CUDA code + enable_language(CUDA) + tvm_file_glob(GLOB CONTRIB_CURAND_SRC_CC src/runtime/contrib/curand/*.cc) + tvm_file_glob(GLOB CONTRIB_CURAND_SRC_CU src/runtime/contrib/curand/*.cu) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CURAND_LIBRARY}) + list(APPEND RUNTIME_SRCS ${CONTRIB_CURAND_SRC_CC}) + list(APPEND RUNTIME_SRCS ${CONTRIB_CURAND_SRC_CU}) + endif(USE_CURAND) + if(USE_GRAPH_EXECUTOR_CUDA_GRAPH) if(NOT USE_GRAPH_EXECUTOR) message(FATAL_ERROR "CUDA Graph is only supported by graph executor, please set USE_GRAPH_EXECUTOR=ON") diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index 06c42494a3314..3b3d8a4bcc9aa 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -111,6 +111,7 @@ function(add_lib_info src_file) TVM_INFO_USE_TFLITE="${USE_TFLITE}" TVM_INFO_USE_THREADS="${USE_THREADS}" TVM_INFO_USE_THRUST="${USE_THRUST}" + TVM_INFO_USE_CURAND="${USE_CURAND}" TVM_INFO_USE_VITIS_AI="${USE_VITIS_AI}" TVM_INFO_USE_VULKAN="${USE_VULKAN}" TVM_INFO_USE_CLML="${USE_CLML}" diff --git a/cmake/utils/FindCUDA.cmake b/cmake/utils/FindCUDA.cmake index 8f3f638309cd6..607f1761ae497 100644 --- a/cmake/utils/FindCUDA.cmake +++ b/cmake/utils/FindCUDA.cmake @@ -85,6 +85,10 @@ macro(find_cuda use_cuda use_cudnn) PATHS ${CUDA_TOOLKIT_ROOT_DIR} PATH_SUFFIXES lib lib64 targets/x86_64-linux/lib targets/x86_64-linux/lib/stubs lib64/stubs lib/x86_64-linux-gnu NO_DEFAULT_PATH) + find_library(CUDA_CURAND_LIBRARY curand + ${CUDA_TOOLKIT_ROOT_DIR}/lib64 + ${CUDA_TOOLKIT_ROOT_DIR}/lib + NO_DEFAULT_PATH) find_library(CUDA_CUBLAS_LIBRARY cublas ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib @@ -134,6 +138,7 @@ macro(find_cuda use_cuda use_cudnn) message(STATUS "Found CUDA_CUDNN_INCLUDE_DIRS=" ${CUDA_CUDNN_INCLUDE_DIRS}) message(STATUS "Found CUDA_CUDNN_LIBRARY=" ${CUDA_CUDNN_LIBRARY}) message(STATUS "Found CUDA_CUBLAS_LIBRARY=" ${CUDA_CUBLAS_LIBRARY}) + message(STATUS "Found CUDA_CURAND_LIBRARY=" ${CUDA_CURAND_LIBRARY}) message(STATUS "Found CUDA_CUBLASLT_LIBRARY=" ${CUDA_CUBLASLT_LIBRARY}) endif(CUDA_FOUND) endmacro(find_cuda) diff --git a/python/tvm/meta_schedule/runner/local_runner.py b/python/tvm/meta_schedule/runner/local_runner.py index d76fe0b840a4a..c74ee99002afe 100644 --- a/python/tvm/meta_schedule/runner/local_runner.py +++ b/python/tvm/meta_schedule/runner/local_runner.py @@ -25,15 +25,14 @@ from ...runtime import Device, Module from ..utils import derived_object, get_global_func_with_default_on_worker from .config import EvaluatorConfig -from .runner import PyRunner, RunnerFuture, RunnerInput, RunnerResult, PyRunnerFuture +from .runner import PyRunner, PyRunnerFuture, RunnerFuture, RunnerInput, RunnerResult from .utils import ( - T_ARGUMENT_LIST, T_ARG_INFO_JSON_OBJ_LIST, + T_ARGUMENT_LIST, alloc_argument_common, run_evaluator_common, ) - logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -313,9 +312,6 @@ def _check( get_global_func_with_default_on_worker(name=f_alloc_argument, default=None) get_global_func_with_default_on_worker(name=f_run_evaluator, default=None) get_global_func_with_default_on_worker(name=f_cleanup, default=None) - get_global_func_with_default_on_worker( - name="tvm.contrib.random.random_fill", default=None - ) value = self.pool.submit( _check, @@ -348,7 +344,7 @@ def default_alloc_argument( The allocation args """ f_random_fill = get_global_func_with_default_on_worker( - name="tvm.contrib.random.random_fill", default=None + name="tvm.contrib.random.random_fill_for_measure", default=None ) return alloc_argument_common(f_random_fill, device, args_info, alloc_repeat) diff --git a/python/tvm/meta_schedule/runner/rpc_runner.py b/python/tvm/meta_schedule/runner/rpc_runner.py index 9ff2489f8eb1c..8d9d797f7c172 100644 --- a/python/tvm/meta_schedule/runner/rpc_runner.py +++ b/python/tvm/meta_schedule/runner/rpc_runner.py @@ -474,7 +474,7 @@ def default_alloc_argument( """ f_random_fill = get_global_func_on_rpc_session( session, - "tvm.contrib.random.random_fill", + "tvm.contrib.random.random_fill_for_measure", "Please make sure 'USE_RANDOM' is turned ON in the config.cmake on the RPC server.", ) diff --git a/src/runtime/contrib/curand/curand.cc b/src/runtime/contrib/curand/curand.cc new file mode 100644 index 0000000000000..b515076da38c4 --- /dev/null +++ b/src/runtime/contrib/curand/curand.cc @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +#include "../../cuda/cuda_common.h" +#include "./helper_cuda_kernels.h" + +namespace tvm { +namespace runtime { +namespace curand { + +#define TVM_CURAND_CALL(func) \ + { \ + curandStatus_t e = (func); \ + ICHECK(e == CURAND_STATUS_SUCCESS) << "cuRAND error: " << e; \ + } + +class CURandGenerator { + public: + CURandGenerator() { TVM_CURAND_CALL(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT)); } + ~CURandGenerator() { TVM_CURAND_CALL(curandDestroyGenerator(gen)); } + + void Generate32bit(void* ptr, int64_t n) { + TVM_CURAND_CALL(curandGenerateNormal(gen, static_cast(ptr), n, 0.0f, 5.0f)); + cudaDeviceSynchronize(); + } + + void Generate64bit(void* ptr, int64_t n) { + TVM_CURAND_CALL(curandGenerateNormalDouble(gen, static_cast(ptr), n, 0.0f, 5.0f)); + } + + curandGenerator_t gen; +}; + +DeviceAPI* GetCUDADeviceAPI() { + const PackedFunc* get_cuda_api = runtime::Registry::Get("device_api.cuda"); + ICHECK(get_cuda_api) << "ValueError: TVM is not built with USE_CUDA=ON"; + void* ret = (*get_cuda_api)(); + runtime::DeviceAPI* cuda_api = static_cast(ret); + return cuda_api; +} + +int64_t GetTensorSize(DLTensor* tensor) { + int64_t tensor_size = 1; + for (int i = 0; i < tensor->ndim; ++i) { + tensor_size *= tensor->shape[i]; + } + return tensor_size; +} + +struct DeferredFunc { + public: + DeferredFunc(std::function func) : func_(func) {} + ~DeferredFunc() { func_(); } + + private: + std::function func_; +}; + +void RandomFill(DLTensor* tensor) { + static DeviceAPI* cuda_api = GetCUDADeviceAPI(); + CHECK(tensor->device.device_type == DLDeviceType::kDLCUDA) + << "ValueError: cuRAND only works on CUDA devices"; + if (tensor->dtype.code == DLDataTypeCode::kDLFloat && tensor->dtype.bits == 16) { + int64_t tensor_size = GetTensorSize(tensor); + void* data = cuda_api->AllocWorkspace(tensor->device, tensor_size * sizeof(float)); + { + DeferredFunc defer([data, tensor]() { cuda_api->FreeWorkspace(tensor->device, data); }); + CURandGenerator().Generate32bit(data, GetTensorSize(tensor)); + ConvertFp32toFp16(/*src=*/data, /*dst=*/tensor->data, /*num=*/tensor_size); + } + } else if (tensor->dtype.code == DLDataTypeCode::kDLFloat && tensor->dtype.bits == 32) { + CURandGenerator().Generate32bit(tensor->data, GetTensorSize(tensor)); + } else if (tensor->dtype.code == DLDataTypeCode::kDLFloat && tensor->dtype.bits == 64) { + CURandGenerator().Generate64bit(tensor->data, GetTensorSize(tensor)); + } else { + LOG(FATAL) << "ValueError: Unsupported dtype: " << tensor->dtype; + } + TVMSynchronize(tensor->device.device_type, tensor->device.device_type, nullptr); +} + +TVM_REGISTER_GLOBAL("runtime.contrib.curand.RandomFill").set_body_typed(RandomFill); + +} // namespace curand +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/curand/helper_cuda_kernels.cu b/src/runtime/contrib/curand/helper_cuda_kernels.cu new file mode 100644 index 0000000000000..a08fc09441b40 --- /dev/null +++ b/src/runtime/contrib/curand/helper_cuda_kernels.cu @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "./helper_cuda_kernels.h" + +namespace tvm { +namespace runtime { +namespace curand { + +__global__ void KernelFp32ToFp16(const float* src, half* dst, int num) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx < num) { + dst[idx] = src[idx]; + } +} + +void ConvertFp32toFp16(const void* _src, void* _dst, int64_t num) { + const float* src = static_cast(_src); + half* dst = static_cast(_dst); + KernelFp32ToFp16<<<(num + 255) / 256, 256>>>(src, dst, num); +} + +} // namespace curand +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/curand/helper_cuda_kernels.h b/src/runtime/contrib/curand/helper_cuda_kernels.h new file mode 100644 index 0000000000000..094c755590aa5 --- /dev/null +++ b/src/runtime/contrib/curand/helper_cuda_kernels.h @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +namespace tvm { +namespace runtime { +namespace curand { + +/*! + * \brief An auxiliary function to convert an FP32 array to FP16. + * \param src The source FP32 array. + * \param dst The destination FP16 array. + * \param num The number of elements in the array. + */ +void ConvertFp32toFp16(const void* src, void* dst, int64_t num); + +} // namespace curand +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/random/mt_random_engine.cc b/src/runtime/contrib/random/mt_random_engine.cc index 161ae62220123..ac52594360059 100644 --- a/src/runtime/contrib/random/mt_random_engine.cc +++ b/src/runtime/contrib/random/mt_random_engine.cc @@ -21,13 +21,16 @@ * \file random/mt_random_engine.cc * \brief mt19937 random engine */ +#include #include #include #include +#include #include #include #include +#include #include "../3rdparty/compiler-rt/builtin_fp16.h" @@ -116,52 +119,112 @@ class RandomEngine { } void RandomFill(DLTensor* data) { - int64_t size = 1; - for (int i = 0; i < data->ndim; ++i) { - size *= data->shape[i]; + if (data->device.device_type == kDLCPU) { + FillData(data); + } else { + runtime::NDArray local = runtime::NDArray::Empty( + std::vector{data->shape, data->shape + data->ndim}, data->dtype, {kDLCPU, 0}); + DLTensor* tensor = const_cast(local.operator->()); + FillData(tensor); + runtime::NDArray::CopyFromTo(tensor, data); } + } + void RandomFillForMeasure(DLTensor* data) { if (data->device.device_type == kDLCPU) { - FillData(data, size); + FillDataForMeasure(data); } else { runtime::NDArray local = runtime::NDArray::Empty( std::vector{data->shape, data->shape + data->ndim}, data->dtype, {kDLCPU, 0}); DLTensor* tensor = const_cast(local.operator->()); - FillData(tensor, size); + FillDataForMeasure(tensor); runtime::NDArray::CopyFromTo(tensor, data); } } private: - void FillData(DLTensor* tensor, int64_t size) { + void FillDataImpl(void* data, int64_t st, int64_t ed, DLDataType dtype) { // Make the value be 1.0 - 10.0, not (0.0 - 1.0) so that we could satisfy // quantized dtype (uint8 / int8) data non-empty requirement std::uniform_real_distribution<> dist(1.0, 10.0); // Use float representation could make us work well on float / int type too. - if (tensor->dtype.bits == 1) { - std::generate_n(static_cast(tensor->data), size, [&]() { return dist(rnd_engine_); }); - } else if (tensor->dtype.bits == 4) { + if (dtype.bits == 1) { + std::generate_n(static_cast(data) + st, ed - st, [&]() { return dist(rnd_engine_); }); + } else if (dtype.bits == 4) { // For uint4/int4 we pack two values into a single byte. // Thus, to ensure both values are non-zero, we use a distribution of 17 - 30. std::uniform_real_distribution<> packed_dist(17.0, 30.0); - std::generate_n(reinterpret_cast(tensor->data), size, + std::generate_n(reinterpret_cast(data) + st, ed - st, [&]() { return packed_dist(rnd_engine_); }); - } else if (tensor->dtype.bits == 8) { - std::generate_n(static_cast(tensor->data), size, + } else if (dtype.bits == 8) { + std::generate_n(static_cast(data) + st, ed - st, [&]() { return dist(rnd_engine_); }); - } else if (tensor->dtype.bits == 16) { - std::generate_n(static_cast(tensor->data), size, [&]() { + } else if (dtype.bits == 16) { + std::generate_n(static_cast(data) + st, ed - st, [&]() { return __truncXfYf2__( static_cast(dist(rnd_engine_))); }); - } else if (tensor->dtype.bits == 32) { - std::generate_n(static_cast(tensor->data), size, [&]() { return dist(rnd_engine_); }); - } else if (tensor->dtype.bits == 64) { - std::generate_n(static_cast(tensor->data), size, + } else if (dtype.bits == 32) { + std::generate_n(static_cast(data) + st, ed - st, [&]() { return dist(rnd_engine_); }); + } else if (dtype.bits == 64) { + std::generate_n(static_cast(data) + st, ed - st, [&]() { return dist(rnd_engine_); }); } else { - LOG(FATAL) << "Doesn't support dtype code " << tensor->dtype.code << " dtype bits " - << tensor->dtype.bits; + LOG(FATAL) << "Doesn't support dtype code " << dtype.code << " dtype bits " << dtype.bits; + } + } + + void FillData(DLTensor* tensor) { + int64_t size = 1; + for (int i = 0; i < tensor->ndim; ++i) { + size *= tensor->shape[i]; + } + DLDataType dtype = tensor->dtype; + if (dtype.bits == 1 || dtype.bits == 4 || dtype.bits == 8 || dtype.bits == 16 || + dtype.bits == 32 || dtype.bits == 64) { + FillDataImpl(tensor->data, 0, size, dtype); + } else { + LOG(FATAL) << "Doesn't support dtype code " << dtype.code << " dtype bits " << dtype.bits; + } + } + + void FillDataForMeasure(DLTensor* tensor) { + struct ParallelTask { + static int RunTask(int task_id, TVMParallelGroupEnv* penv, void* cdata) { + ParallelTask* task = static_cast(cdata); + task->Run(task_id); + return 0; + } + + void Run(int i) { + int64_t chunk_size = size / num_threads; + int64_t st = i * chunk_size; + int64_t ed = std::min(st + chunk_size, size); + self->FillDataImpl(data, st, ed, dtype); + } + + RandomEngine* self; + void* data; + int num_threads; + int64_t size; + DLDataType dtype; + }; + + ParallelTask task; + task.self = this; + task.data = tensor->data; + DLDataType dtype = task.dtype = tensor->dtype; + int64_t& size = task.size = 1; + for (int i = 0; i < tensor->ndim; ++i) { + size *= tensor->shape[i]; + } + if (dtype.bits == 1 || dtype.bits == 4 || dtype.bits == 8 || dtype.bits == 16 || + dtype.bits == 32 || dtype.bits == 64) { + int num_threads = task.num_threads = runtime::threading::MaxConcurrency(); + int res = TVMBackendParallelLaunch(ParallelTask::RunTask, &task, num_threads); + ICHECK_EQ(res, 0) << "RandomFillForMeasure: TVMBackendParallelLaunch failed"; + } else { + LOG(FATAL) << "Doesn't support dtype code " << dtype.code << " dtype bits " << dtype.bits; } } diff --git a/src/runtime/contrib/random/random.cc b/src/runtime/contrib/random/random.cc index 2cb56b87fdf57..38c2de6555e90 100644 --- a/src/runtime/contrib/random/random.cc +++ b/src/runtime/contrib/random/random.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include @@ -123,5 +124,19 @@ TVM_REGISTER_GLOBAL("tvm.contrib.random.random_fill").set_body([](TVMArgs args, entry->random_engine.RandomFill(out); }); +TVM_REGISTER_GLOBAL("tvm.contrib.random.random_fill_for_measure") + .set_body([](TVMArgs args, TVMRetValue* ret) -> void { + static const PackedFunc* curand = Registry::Get("runtime.contrib.curand.RandomFill"); + DLTensor* out = args[0]; + if (curand && out->device.device_type == DLDeviceType::kDLCUDA) { + if (out->dtype.code == DLDataTypeCode::kDLFloat) { + (*curand)(out); + return; + } + } + RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); + entry->random_engine.RandomFillForMeasure(out); + }); + } // namespace contrib } // namespace tvm diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index be0cd9eb8f52c..6f0a6114f3d9d 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -163,6 +163,10 @@ #define TVM_INFO_USE_THRUST "NOT-FOUND" #endif +#ifndef TVM_INFO_USE_CURAND +#define TVM_INFO_USE_CURAND "NOT-FOUND" +#endif + #ifndef TVM_INFO_USE_MIOPEN #define TVM_INFO_USE_MIOPEN "NOT-FOUND" #endif @@ -308,6 +312,7 @@ TVM_DLL Map GetLibInfo() { {"USE_TFLITE", TVM_INFO_USE_TFLITE}, {"USE_THREADS", TVM_INFO_USE_THREADS}, {"USE_THRUST", TVM_INFO_USE_THRUST}, + {"USE_CURAND", TVM_INFO_USE_CURAND}, {"USE_VITIS_AI", TVM_INFO_USE_VITIS_AI}, {"USE_VULKAN", TVM_INFO_USE_VULKAN}, {"USE_CLML", TVM_INFO_USE_CLML},