From f8b320f523b24fd8ddb8cf7026e61bbb4f4ea348 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 18 Jun 2022 02:11:55 -0700 Subject: [PATCH] [MetaSchedule][Runtime] Enhance Runner RandomFill (#11758) --- CMakeLists.txt | 1 + cmake/config.cmake | 3 + cmake/modules/CUDA.cmake | 12 ++ cmake/modules/LibInfo.cmake | 1 + cmake/utils/FindCUDA.cmake | 5 + docs/contribute/pull_request.rst | 1 + .../tvm/auto_scheduler/testing/tune_onnx.py | 10 +- .../tvm/auto_scheduler/testing/tune_relay.py | 10 +- python/tvm/auto_scheduler/testing/tune_te.py | 10 +- .../tvm/meta_schedule/runner/local_runner.py | 48 ++++---- python/tvm/meta_schedule/runner/rpc_runner.py | 50 +++++---- python/tvm/meta_schedule/testing/tune_onnx.py | 8 +- .../tvm/meta_schedule/testing/tune_relay.py | 8 +- python/tvm/meta_schedule/testing/tune_te.py | 8 +- src/runtime/contrib/curand/curand.cc | 104 ++++++++++++++++++ .../contrib/curand/helper_cuda_kernels.cu | 42 +++++++ .../contrib/curand/helper_cuda_kernels.h | 41 +++++++ .../contrib/random/mt_random_engine.cc | 103 +++++++++++++---- src/runtime/contrib/random/random.cc | 15 +++ src/support/libinfo.cc | 5 + 20 files changed, 377 insertions(+), 108 deletions(-) create mode 100644 src/runtime/contrib/curand/curand.cc create mode 100644 src/runtime/contrib/curand/helper_cuda_kernels.cu create mode 100644 src/runtime/contrib/curand/helper_cuda_kernels.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 6931b40c667d..31b0a90ef29f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -89,6 +89,7 @@ tvm_option(USE_CUDNN "Build with cuDNN" OFF) tvm_option(USE_CUBLAS "Build with cuBLAS" OFF) tvm_option(USE_CUTLASS "Build with CUTLASS" OFF) tvm_option(USE_THRUST "Build with Thrust" OFF) +tvm_option(USE_CURAND "Build with cuRAND" OFF) tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF) tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF) tvm_option(USE_SORT "Build with sort support" ON) diff --git a/cmake/config.cmake b/cmake/config.cmake index 212b565f25fb..b9a3aaef7d7e 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -296,6 +296,9 @@ set(USE_VTA_FPGA OFF) # Whether use Thrust set(USE_THRUST OFF) +# Whether use cuRAND +set(USE_CURAND OFF) + # Whether to build the TensorFlow TVMDSOOp module set(USE_TF_TVMDSOOP OFF) diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake index 056ed18d442e..bbbf6b89ba2e 100644 --- a/cmake/modules/CUDA.cmake +++ b/cmake/modules/CUDA.cmake @@ -69,6 +69,18 @@ if(USE_CUDA) list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC}) endif(USE_THRUST) + if(USE_CURAND) + message(STATUS "Build with cuRAND support") + message(STATUS "${CUDA_CURAND_LIBRARY}") + cmake_minimum_required(VERSION 3.13) # to compile CUDA code + enable_language(CUDA) + tvm_file_glob(GLOB CONTRIB_CURAND_SRC_CC src/runtime/contrib/curand/*.cc) + tvm_file_glob(GLOB CONTRIB_CURAND_SRC_CU src/runtime/contrib/curand/*.cu) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CURAND_LIBRARY}) + list(APPEND RUNTIME_SRCS ${CONTRIB_CURAND_SRC_CC}) + list(APPEND RUNTIME_SRCS ${CONTRIB_CURAND_SRC_CU}) + endif(USE_CURAND) + if(USE_GRAPH_EXECUTOR_CUDA_GRAPH) if(NOT USE_GRAPH_EXECUTOR) message(FATAL_ERROR "CUDA Graph is only supported by graph executor, please set USE_GRAPH_EXECUTOR=ON") diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index 06c42494a331..3b3d8a4bcc9a 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -111,6 +111,7 @@ function(add_lib_info src_file) TVM_INFO_USE_TFLITE="${USE_TFLITE}" TVM_INFO_USE_THREADS="${USE_THREADS}" TVM_INFO_USE_THRUST="${USE_THRUST}" + TVM_INFO_USE_CURAND="${USE_CURAND}" TVM_INFO_USE_VITIS_AI="${USE_VITIS_AI}" TVM_INFO_USE_VULKAN="${USE_VULKAN}" TVM_INFO_USE_CLML="${USE_CLML}" diff --git a/cmake/utils/FindCUDA.cmake b/cmake/utils/FindCUDA.cmake index 8f3f638309cd..607f1761ae49 100644 --- a/cmake/utils/FindCUDA.cmake +++ b/cmake/utils/FindCUDA.cmake @@ -85,6 +85,10 @@ macro(find_cuda use_cuda use_cudnn) PATHS ${CUDA_TOOLKIT_ROOT_DIR} PATH_SUFFIXES lib lib64 targets/x86_64-linux/lib targets/x86_64-linux/lib/stubs lib64/stubs lib/x86_64-linux-gnu NO_DEFAULT_PATH) + find_library(CUDA_CURAND_LIBRARY curand + ${CUDA_TOOLKIT_ROOT_DIR}/lib64 + ${CUDA_TOOLKIT_ROOT_DIR}/lib + NO_DEFAULT_PATH) find_library(CUDA_CUBLAS_LIBRARY cublas ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib @@ -134,6 +138,7 @@ macro(find_cuda use_cuda use_cudnn) message(STATUS "Found CUDA_CUDNN_INCLUDE_DIRS=" ${CUDA_CUDNN_INCLUDE_DIRS}) message(STATUS "Found CUDA_CUDNN_LIBRARY=" ${CUDA_CUDNN_LIBRARY}) message(STATUS "Found CUDA_CUBLAS_LIBRARY=" ${CUDA_CUBLAS_LIBRARY}) + message(STATUS "Found CUDA_CURAND_LIBRARY=" ${CUDA_CURAND_LIBRARY}) message(STATUS "Found CUDA_CUBLASLT_LIBRARY=" ${CUDA_CUBLASLT_LIBRARY}) endif(CUDA_FOUND) endmacro(find_cuda) diff --git a/docs/contribute/pull_request.rst b/docs/contribute/pull_request.rst index 26989fb8e6a3..81852a212610 100644 --- a/docs/contribute/pull_request.rst +++ b/docs/contribute/pull_request.rst @@ -118,6 +118,7 @@ space. You can remove stale images that aren't used in the presently checked-out other worktrees using the following command: .. code:: bash + docker/clear-stale-images.sh Consult the ``--help`` for more options. diff --git a/python/tvm/auto_scheduler/testing/tune_onnx.py b/python/tvm/auto_scheduler/testing/tune_onnx.py index 84ab1b48f8d2..5fbc875d1eda 100644 --- a/python/tvm/auto_scheduler/testing/tune_onnx.py +++ b/python/tvm/auto_scheduler/testing/tune_onnx.py @@ -26,6 +26,7 @@ from tvm import meta_schedule as ms from tvm import relay from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc +from tvm.meta_schedule.utils import cpu_count from tvm.relay.frontend import from_onnx from tvm.support import describe @@ -73,11 +74,6 @@ def _parse_args(): type=str, required=True, ) - args.add_argument( - "--rpc-workers", - type=int, - required=True, - ) args.add_argument( "--work-dir", type=str, @@ -100,7 +96,7 @@ def _parse_args(): ) args.add_argument( "--cpu-flush", - type=bool, + type=int, required=True, ) parsed = args.parse_args() @@ -125,7 +121,7 @@ def main(): key=ARGS.rpc_key, host=ARGS.rpc_host, port=ARGS.rpc_port, - n_parallel=ARGS.rpc_workers, + n_parallel=cpu_count(logical=True), number=ARGS.number, repeat=ARGS.repeat, min_repeat_ms=ARGS.min_repeat_ms, diff --git a/python/tvm/auto_scheduler/testing/tune_relay.py b/python/tvm/auto_scheduler/testing/tune_relay.py index 2bd78139993b..58ea327ec50b 100644 --- a/python/tvm/auto_scheduler/testing/tune_relay.py +++ b/python/tvm/auto_scheduler/testing/tune_relay.py @@ -26,6 +26,7 @@ from tvm import relay from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc from tvm.meta_schedule.testing.relay_workload import get_network +from tvm.meta_schedule.utils import cpu_count from tvm.support import describe @@ -66,11 +67,6 @@ def _parse_args(): type=str, required=True, ) - args.add_argument( - "--rpc-workers", - type=int, - required=True, - ) args.add_argument( "--work-dir", type=str, @@ -98,7 +94,7 @@ def _parse_args(): ) args.add_argument( "--cpu-flush", - type=bool, + type=int, required=True, ) parsed = args.parse_args() @@ -123,7 +119,7 @@ def main(): key=ARGS.rpc_key, host=ARGS.rpc_host, port=ARGS.rpc_port, - n_parallel=ARGS.rpc_workers, + n_parallel=cpu_count(logical=True), number=ARGS.number, repeat=ARGS.repeat, min_repeat_ms=ARGS.min_repeat_ms, diff --git a/python/tvm/auto_scheduler/testing/tune_te.py b/python/tvm/auto_scheduler/testing/tune_te.py index 2eaddbbc081e..4a6874a53d34 100644 --- a/python/tvm/auto_scheduler/testing/tune_te.py +++ b/python/tvm/auto_scheduler/testing/tune_te.py @@ -21,6 +21,7 @@ import tvm from tvm import auto_scheduler from tvm.meta_schedule.testing.te_workload import CONFIGS +from tvm.meta_schedule.utils import cpu_count from tvm.support import describe @@ -56,11 +57,6 @@ def _parse_args(): type=str, required=True, ) - args.add_argument( - "--rpc-workers", - type=int, - required=True, - ) args.add_argument( "--work-dir", type=str, @@ -83,7 +79,7 @@ def _parse_args(): ) args.add_argument( "--cpu-flush", - type=bool, + type=int, required=True, ) parsed = args.parse_args() @@ -132,7 +128,7 @@ def main(): key=ARGS.rpc_key, host=ARGS.rpc_host, port=ARGS.rpc_port, - n_parallel=ARGS.rpc_workers, + n_parallel=cpu_count(logical=True), number=ARGS.number, repeat=ARGS.repeat, min_repeat_ms=ARGS.min_repeat_ms, diff --git a/python/tvm/meta_schedule/runner/local_runner.py b/python/tvm/meta_schedule/runner/local_runner.py index d76fe0b840a4..2d3214f53b6b 100644 --- a/python/tvm/meta_schedule/runner/local_runner.py +++ b/python/tvm/meta_schedule/runner/local_runner.py @@ -23,17 +23,17 @@ from ...contrib.popen_pool import PopenPoolExecutor from ...runtime import Device, Module +from ..profiler import Profiler from ..utils import derived_object, get_global_func_with_default_on_worker from .config import EvaluatorConfig -from .runner import PyRunner, RunnerFuture, RunnerInput, RunnerResult, PyRunnerFuture +from .runner import PyRunner, PyRunnerFuture, RunnerFuture, RunnerInput, RunnerResult from .utils import ( - T_ARGUMENT_LIST, T_ARG_INFO_JSON_OBJ_LIST, + T_ARGUMENT_LIST, alloc_argument_common, run_evaluator_common, ) - logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -137,26 +137,29 @@ def resource_handler(): yield finally: # Final step. Always clean up - f_cleanup() + with Profiler.timeit("LocalRunner/cleanup"): + f_cleanup() with resource_handler(): # Step 1: create the local runtime module - rt_mod = tvm.runtime.load_module(artifact_path) - # Step 2: create the local device - device = tvm.runtime.device(dev_type=device_type, dev_id=0) - # Step 3: Allocate input arguments - repeated_args: List[T_ARGUMENT_LIST] = f_alloc_argument( - device, - args_info, - alloc_repeat, - ) - # Step 4: Run time_evaluator - costs: List[float] = f_run_evaluator( - rt_mod, - device, - evaluator_config, - repeated_args, - ) + with Profiler.timeit("LocalRunner/load_module"): + rt_mod = tvm.runtime.load_module(artifact_path) + # Step 2: Allocate input arguments + with Profiler.timeit("LocalRunner/alloc_argument"): + device = tvm.runtime.device(dev_type=device_type, dev_id=0) + repeated_args: List[T_ARGUMENT_LIST] = f_alloc_argument( + device, + args_info, + alloc_repeat, + ) + # Step 3: Run time_evaluator + with Profiler.timeit("LocalRunner/run_evaluator"): + costs: List[float] = f_run_evaluator( + rt_mod, + device, + evaluator_config, + repeated_args, + ) return costs @@ -313,9 +316,6 @@ def _check( get_global_func_with_default_on_worker(name=f_alloc_argument, default=None) get_global_func_with_default_on_worker(name=f_run_evaluator, default=None) get_global_func_with_default_on_worker(name=f_cleanup, default=None) - get_global_func_with_default_on_worker( - name="tvm.contrib.random.random_fill", default=None - ) value = self.pool.submit( _check, @@ -348,7 +348,7 @@ def default_alloc_argument( The allocation args """ f_random_fill = get_global_func_with_default_on_worker( - name="tvm.contrib.random.random_fill", default=None + name="tvm.contrib.random.random_fill_for_measure", default=None ) return alloc_argument_common(f_random_fill, device, args_info, alloc_repeat) diff --git a/python/tvm/meta_schedule/runner/rpc_runner.py b/python/tvm/meta_schedule/runner/rpc_runner.py index 9ff2489f8eb1..aa6f3daaac60 100644 --- a/python/tvm/meta_schedule/runner/rpc_runner.py +++ b/python/tvm/meta_schedule/runner/rpc_runner.py @@ -25,6 +25,7 @@ from tvm.rpc import RPCSession from tvm.runtime import Device, Module +from ..profiler import Profiler from ..utils import ( cpu_count, derived_object, @@ -243,7 +244,7 @@ def __init__( f_alloc_argument: Union[T_ALLOC_ARGUMENT, str, None] = None, f_run_evaluator: Union[T_RUN_EVALUATOR, str, None] = None, f_cleanup: Union[T_CLEANUP, str, None] = None, - max_workers: Optional[int] = 1, + max_workers: Optional[int] = None, initializer: Optional[Callable[[], None]] = None, ) -> None: """Constructor @@ -284,7 +285,7 @@ def __init__( self.f_run_evaluator = f_run_evaluator self.f_cleanup = f_cleanup if max_workers is None: - max_workers = cpu_count() + max_workers = cpu_count(logical=True) logger.info("RPCRunner: max_workers = %d", max_workers) self.pool = PopenPoolExecutor( max_workers=max_workers, @@ -378,31 +379,36 @@ def resource_handler(): yield finally: # Final step. Always clean up - f_cleanup(session, remote_path) + with Profiler.timeit("RPCRunner/cleanup"): + f_cleanup(session, remote_path) with resource_handler(): # Step 1. Create session - session = f_create_session(rpc_config) - device = session.device(dev_type=device_type, dev_id=0) + with Profiler.timeit("RPCRunner/create_session"): + session = f_create_session(rpc_config) + device = session.device(dev_type=device_type, dev_id=0) # Step 2. Upload the module - _, remote_path = osp.split(artifact_path) - local_path: str = artifact_path - rt_mod: Module = f_upload_module(session, local_path, remote_path) + with Profiler.timeit("RPCRunner/upload_module"): + _, remote_path = osp.split(artifact_path) + local_path: str = artifact_path + rt_mod: Module = f_upload_module(session, local_path, remote_path) # Step 3: Allocate input arguments - repeated_args: List[T_ARGUMENT_LIST] = f_alloc_argument( - session, - device, - args_info, - alloc_repeat, - ) + with Profiler.timeit("RPCRunner/alloc_argument"): + repeated_args: List[T_ARGUMENT_LIST] = f_alloc_argument( + session, + device, + args_info, + alloc_repeat, + ) # Step 4: Run time_evaluator - costs: List[float] = f_run_evaluator( - session, - rt_mod, - device, - evaluator_config, - repeated_args, - ) + with Profiler.timeit("LocalRunner/run_evaluator"): + costs: List[float] = f_run_evaluator( + session, + rt_mod, + device, + evaluator_config, + repeated_args, + ) return costs @@ -474,7 +480,7 @@ def default_alloc_argument( """ f_random_fill = get_global_func_on_rpc_session( session, - "tvm.contrib.random.random_fill", + "tvm.contrib.random.random_fill_for_measure", "Please make sure 'USE_RANDOM' is turned ON in the config.cmake on the RPC server.", ) diff --git a/python/tvm/meta_schedule/testing/tune_onnx.py b/python/tvm/meta_schedule/testing/tune_onnx.py index 1a51622b5cde..88cb360c0171 100644 --- a/python/tvm/meta_schedule/testing/tune_onnx.py +++ b/python/tvm/meta_schedule/testing/tune_onnx.py @@ -71,11 +71,6 @@ def _parse_args(): type=str, required=True, ) - args.add_argument( - "--rpc-workers", - type=int, - required=True, - ) args.add_argument( "--work-dir", type=str, @@ -98,7 +93,7 @@ def _parse_args(): ) args.add_argument( "--cpu-flush", - type=bool, + type=int, required=True, ) parsed = args.parse_args() @@ -140,7 +135,6 @@ def main(): enable_cpu_cache_flush=ARGS.cpu_flush, ), alloc_repeat=1, - max_workers=ARGS.rpc_workers, ) with ms.Profiler() as profiler: lib = ms.tune_relay( diff --git a/python/tvm/meta_schedule/testing/tune_relay.py b/python/tvm/meta_schedule/testing/tune_relay.py index 6188e124fde8..ce15c60c15e6 100644 --- a/python/tvm/meta_schedule/testing/tune_relay.py +++ b/python/tvm/meta_schedule/testing/tune_relay.py @@ -64,11 +64,6 @@ def _parse_args(): type=str, required=True, ) - args.add_argument( - "--rpc-workers", - type=int, - required=True, - ) args.add_argument( "--work-dir", type=str, @@ -96,7 +91,7 @@ def _parse_args(): ) args.add_argument( "--cpu-flush", - type=bool, + type=int, required=True, ) parsed = args.parse_args() @@ -141,7 +136,6 @@ def main(): enable_cpu_cache_flush=ARGS.cpu_flush, ), alloc_repeat=1, - max_workers=ARGS.rpc_workers, ) with ms.Profiler() as profiler: lib = ms.tune_relay( diff --git a/python/tvm/meta_schedule/testing/tune_te.py b/python/tvm/meta_schedule/testing/tune_te.py index cbc310f999ad..8740d7442478 100644 --- a/python/tvm/meta_schedule/testing/tune_te.py +++ b/python/tvm/meta_schedule/testing/tune_te.py @@ -59,11 +59,6 @@ def _parse_args(): type=str, required=True, ) - args.add_argument( - "--rpc-workers", - type=int, - required=True, - ) args.add_argument( "--work-dir", type=str, @@ -86,7 +81,7 @@ def _parse_args(): ) args.add_argument( "--cpu-flush", - type=bool, + type=int, required=True, ) parsed = args.parse_args() @@ -119,7 +114,6 @@ def main(): enable_cpu_cache_flush=ARGS.cpu_flush, ), alloc_repeat=1, - max_workers=ARGS.rpc_workers, ) with ms.Profiler() as profiler: sch: Optional[tir.Schedule] = ms.tune_tir( diff --git a/src/runtime/contrib/curand/curand.cc b/src/runtime/contrib/curand/curand.cc new file mode 100644 index 000000000000..23282304f716 --- /dev/null +++ b/src/runtime/contrib/curand/curand.cc @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +#include "../../cuda/cuda_common.h" +#include "./helper_cuda_kernels.h" + +namespace tvm { +namespace runtime { +namespace curand { + +#define TVM_CURAND_CALL(func) \ + { \ + curandStatus_t e = (func); \ + ICHECK(e == CURAND_STATUS_SUCCESS) << "cuRAND error: " << e; \ + } + +class CURandGenerator { + public: + CURandGenerator() { TVM_CURAND_CALL(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT)); } + ~CURandGenerator() { TVM_CURAND_CALL(curandDestroyGenerator(gen)); } + + void Generate32bit(void* ptr, int64_t n) { + TVM_CURAND_CALL(curandGenerateNormal(gen, static_cast(ptr), n, 0.0f, 5.0f)); + cudaDeviceSynchronize(); + } + + void Generate64bit(void* ptr, int64_t n) { + TVM_CURAND_CALL(curandGenerateNormalDouble(gen, static_cast(ptr), n, 0.0f, 5.0f)); + } + + curandGenerator_t gen; +}; + +DeviceAPI* GetCUDADeviceAPI() { + const PackedFunc* get_cuda_api = runtime::Registry::Get("device_api.cuda"); + ICHECK(get_cuda_api) << "ValueError: TVM is not built with USE_CUDA=ON"; + void* ret = (*get_cuda_api)(); + runtime::DeviceAPI* cuda_api = static_cast(ret); + return cuda_api; +} + +int64_t GetTensorSize(DLTensor* tensor) { + int64_t tensor_size = 1; + for (int i = 0; i < tensor->ndim; ++i) { + tensor_size *= tensor->shape[i]; + } + return tensor_size; +} + +struct DeferredFunc { + public: + explicit DeferredFunc(std::function func) : func_(func) {} + ~DeferredFunc() { func_(); } + + private: + std::function func_; +}; + +void RandomFill(DLTensor* tensor) { + static DeviceAPI* cuda_api = GetCUDADeviceAPI(); + CHECK(tensor->device.device_type == DLDeviceType::kDLCUDA) + << "ValueError: cuRAND only works on CUDA devices"; + if (tensor->dtype.code == DLDataTypeCode::kDLFloat && tensor->dtype.bits == 16) { + int64_t tensor_size = GetTensorSize(tensor); + void* data = cuda_api->AllocWorkspace(tensor->device, tensor_size * sizeof(float)); + { + DeferredFunc defer([data, tensor]() { cuda_api->FreeWorkspace(tensor->device, data); }); + CURandGenerator().Generate32bit(data, GetTensorSize(tensor)); + ConvertFp32toFp16(/*src=*/data, /*dst=*/tensor->data, /*num=*/tensor_size); + } + } else if (tensor->dtype.code == DLDataTypeCode::kDLFloat && tensor->dtype.bits == 32) { + CURandGenerator().Generate32bit(tensor->data, GetTensorSize(tensor)); + } else if (tensor->dtype.code == DLDataTypeCode::kDLFloat && tensor->dtype.bits == 64) { + CURandGenerator().Generate64bit(tensor->data, GetTensorSize(tensor)); + } else { + LOG(FATAL) << "ValueError: Unsupported dtype: " << tensor->dtype; + } + TVMSynchronize(tensor->device.device_type, tensor->device.device_type, nullptr); +} + +TVM_REGISTER_GLOBAL("runtime.contrib.curand.RandomFill").set_body_typed(RandomFill); + +} // namespace curand +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/curand/helper_cuda_kernels.cu b/src/runtime/contrib/curand/helper_cuda_kernels.cu new file mode 100644 index 000000000000..a08fc09441b4 --- /dev/null +++ b/src/runtime/contrib/curand/helper_cuda_kernels.cu @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "./helper_cuda_kernels.h" + +namespace tvm { +namespace runtime { +namespace curand { + +__global__ void KernelFp32ToFp16(const float* src, half* dst, int num) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx < num) { + dst[idx] = src[idx]; + } +} + +void ConvertFp32toFp16(const void* _src, void* _dst, int64_t num) { + const float* src = static_cast(_src); + half* dst = static_cast(_dst); + KernelFp32ToFp16<<<(num + 255) / 256, 256>>>(src, dst, num); +} + +} // namespace curand +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/curand/helper_cuda_kernels.h b/src/runtime/contrib/curand/helper_cuda_kernels.h new file mode 100644 index 000000000000..582162579a3a --- /dev/null +++ b/src/runtime/contrib/curand/helper_cuda_kernels.h @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RUNTIME_CONTRIB_CURAND_HELPER_CUDA_KERNELS_H_ +#define TVM_RUNTIME_CONTRIB_CURAND_HELPER_CUDA_KERNELS_H_ + +#include +#include + +namespace tvm { +namespace runtime { +namespace curand { + +/*! + * \brief An auxiliary function to convert an FP32 array to FP16. + * \param src The source FP32 array. + * \param dst The destination FP16 array. + * \param num The number of elements in the array. + */ +void ConvertFp32toFp16(const void* src, void* dst, int64_t num); + +} // namespace curand +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_CURAND_HELPER_CUDA_KERNELS_H_ diff --git a/src/runtime/contrib/random/mt_random_engine.cc b/src/runtime/contrib/random/mt_random_engine.cc index 161ae6222012..ac5259436005 100644 --- a/src/runtime/contrib/random/mt_random_engine.cc +++ b/src/runtime/contrib/random/mt_random_engine.cc @@ -21,13 +21,16 @@ * \file random/mt_random_engine.cc * \brief mt19937 random engine */ +#include #include #include #include +#include #include #include #include +#include #include "../3rdparty/compiler-rt/builtin_fp16.h" @@ -116,52 +119,112 @@ class RandomEngine { } void RandomFill(DLTensor* data) { - int64_t size = 1; - for (int i = 0; i < data->ndim; ++i) { - size *= data->shape[i]; + if (data->device.device_type == kDLCPU) { + FillData(data); + } else { + runtime::NDArray local = runtime::NDArray::Empty( + std::vector{data->shape, data->shape + data->ndim}, data->dtype, {kDLCPU, 0}); + DLTensor* tensor = const_cast(local.operator->()); + FillData(tensor); + runtime::NDArray::CopyFromTo(tensor, data); } + } + void RandomFillForMeasure(DLTensor* data) { if (data->device.device_type == kDLCPU) { - FillData(data, size); + FillDataForMeasure(data); } else { runtime::NDArray local = runtime::NDArray::Empty( std::vector{data->shape, data->shape + data->ndim}, data->dtype, {kDLCPU, 0}); DLTensor* tensor = const_cast(local.operator->()); - FillData(tensor, size); + FillDataForMeasure(tensor); runtime::NDArray::CopyFromTo(tensor, data); } } private: - void FillData(DLTensor* tensor, int64_t size) { + void FillDataImpl(void* data, int64_t st, int64_t ed, DLDataType dtype) { // Make the value be 1.0 - 10.0, not (0.0 - 1.0) so that we could satisfy // quantized dtype (uint8 / int8) data non-empty requirement std::uniform_real_distribution<> dist(1.0, 10.0); // Use float representation could make us work well on float / int type too. - if (tensor->dtype.bits == 1) { - std::generate_n(static_cast(tensor->data), size, [&]() { return dist(rnd_engine_); }); - } else if (tensor->dtype.bits == 4) { + if (dtype.bits == 1) { + std::generate_n(static_cast(data) + st, ed - st, [&]() { return dist(rnd_engine_); }); + } else if (dtype.bits == 4) { // For uint4/int4 we pack two values into a single byte. // Thus, to ensure both values are non-zero, we use a distribution of 17 - 30. std::uniform_real_distribution<> packed_dist(17.0, 30.0); - std::generate_n(reinterpret_cast(tensor->data), size, + std::generate_n(reinterpret_cast(data) + st, ed - st, [&]() { return packed_dist(rnd_engine_); }); - } else if (tensor->dtype.bits == 8) { - std::generate_n(static_cast(tensor->data), size, + } else if (dtype.bits == 8) { + std::generate_n(static_cast(data) + st, ed - st, [&]() { return dist(rnd_engine_); }); - } else if (tensor->dtype.bits == 16) { - std::generate_n(static_cast(tensor->data), size, [&]() { + } else if (dtype.bits == 16) { + std::generate_n(static_cast(data) + st, ed - st, [&]() { return __truncXfYf2__( static_cast(dist(rnd_engine_))); }); - } else if (tensor->dtype.bits == 32) { - std::generate_n(static_cast(tensor->data), size, [&]() { return dist(rnd_engine_); }); - } else if (tensor->dtype.bits == 64) { - std::generate_n(static_cast(tensor->data), size, + } else if (dtype.bits == 32) { + std::generate_n(static_cast(data) + st, ed - st, [&]() { return dist(rnd_engine_); }); + } else if (dtype.bits == 64) { + std::generate_n(static_cast(data) + st, ed - st, [&]() { return dist(rnd_engine_); }); } else { - LOG(FATAL) << "Doesn't support dtype code " << tensor->dtype.code << " dtype bits " - << tensor->dtype.bits; + LOG(FATAL) << "Doesn't support dtype code " << dtype.code << " dtype bits " << dtype.bits; + } + } + + void FillData(DLTensor* tensor) { + int64_t size = 1; + for (int i = 0; i < tensor->ndim; ++i) { + size *= tensor->shape[i]; + } + DLDataType dtype = tensor->dtype; + if (dtype.bits == 1 || dtype.bits == 4 || dtype.bits == 8 || dtype.bits == 16 || + dtype.bits == 32 || dtype.bits == 64) { + FillDataImpl(tensor->data, 0, size, dtype); + } else { + LOG(FATAL) << "Doesn't support dtype code " << dtype.code << " dtype bits " << dtype.bits; + } + } + + void FillDataForMeasure(DLTensor* tensor) { + struct ParallelTask { + static int RunTask(int task_id, TVMParallelGroupEnv* penv, void* cdata) { + ParallelTask* task = static_cast(cdata); + task->Run(task_id); + return 0; + } + + void Run(int i) { + int64_t chunk_size = size / num_threads; + int64_t st = i * chunk_size; + int64_t ed = std::min(st + chunk_size, size); + self->FillDataImpl(data, st, ed, dtype); + } + + RandomEngine* self; + void* data; + int num_threads; + int64_t size; + DLDataType dtype; + }; + + ParallelTask task; + task.self = this; + task.data = tensor->data; + DLDataType dtype = task.dtype = tensor->dtype; + int64_t& size = task.size = 1; + for (int i = 0; i < tensor->ndim; ++i) { + size *= tensor->shape[i]; + } + if (dtype.bits == 1 || dtype.bits == 4 || dtype.bits == 8 || dtype.bits == 16 || + dtype.bits == 32 || dtype.bits == 64) { + int num_threads = task.num_threads = runtime::threading::MaxConcurrency(); + int res = TVMBackendParallelLaunch(ParallelTask::RunTask, &task, num_threads); + ICHECK_EQ(res, 0) << "RandomFillForMeasure: TVMBackendParallelLaunch failed"; + } else { + LOG(FATAL) << "Doesn't support dtype code " << dtype.code << " dtype bits " << dtype.bits; } } diff --git a/src/runtime/contrib/random/random.cc b/src/runtime/contrib/random/random.cc index 2cb56b87fdf5..38c2de6555e9 100644 --- a/src/runtime/contrib/random/random.cc +++ b/src/runtime/contrib/random/random.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include @@ -123,5 +124,19 @@ TVM_REGISTER_GLOBAL("tvm.contrib.random.random_fill").set_body([](TVMArgs args, entry->random_engine.RandomFill(out); }); +TVM_REGISTER_GLOBAL("tvm.contrib.random.random_fill_for_measure") + .set_body([](TVMArgs args, TVMRetValue* ret) -> void { + static const PackedFunc* curand = Registry::Get("runtime.contrib.curand.RandomFill"); + DLTensor* out = args[0]; + if (curand && out->device.device_type == DLDeviceType::kDLCUDA) { + if (out->dtype.code == DLDataTypeCode::kDLFloat) { + (*curand)(out); + return; + } + } + RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); + entry->random_engine.RandomFillForMeasure(out); + }); + } // namespace contrib } // namespace tvm diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index be0cd9eb8f52..6f0a6114f3d9 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -163,6 +163,10 @@ #define TVM_INFO_USE_THRUST "NOT-FOUND" #endif +#ifndef TVM_INFO_USE_CURAND +#define TVM_INFO_USE_CURAND "NOT-FOUND" +#endif + #ifndef TVM_INFO_USE_MIOPEN #define TVM_INFO_USE_MIOPEN "NOT-FOUND" #endif @@ -308,6 +312,7 @@ TVM_DLL Map GetLibInfo() { {"USE_TFLITE", TVM_INFO_USE_TFLITE}, {"USE_THREADS", TVM_INFO_USE_THREADS}, {"USE_THRUST", TVM_INFO_USE_THRUST}, + {"USE_CURAND", TVM_INFO_USE_CURAND}, {"USE_VITIS_AI", TVM_INFO_USE_VITIS_AI}, {"USE_VULKAN", TVM_INFO_USE_VULKAN}, {"USE_CLML", TVM_INFO_USE_CLML},