Skip to content

Commit

Permalink
[MetaSchedule][Runtime] Enhance Runner RandomFill
Browse files Browse the repository at this point in the history
...
  • Loading branch information
junrushao committed Jun 17, 2022
1 parent 1b8f3b5 commit 7b73c1f
Show file tree
Hide file tree
Showing 13 changed files with 311 additions and 28 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ tvm_option(USE_CUDNN "Build with cuDNN" OFF)
tvm_option(USE_CUBLAS "Build with cuBLAS" OFF)
tvm_option(USE_CUTLASS "Build with CUTLASS" OFF)
tvm_option(USE_THRUST "Build with Thrust" OFF)
tvm_option(USE_CURAND "Build with cuRAND" OFF)
tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF)
tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF)
tvm_option(USE_SORT "Build with sort support" ON)
Expand Down
3 changes: 3 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,9 @@ set(USE_VTA_FPGA OFF)
# Whether use Thrust
set(USE_THRUST OFF)

# Whether use cuRAND
set(USE_CURAND OFF)

# Whether to build the TensorFlow TVMDSOOp module
set(USE_TF_TVMDSOOP OFF)

Expand Down
12 changes: 12 additions & 0 deletions cmake/modules/CUDA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ if(USE_CUDA)
list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC})
endif(USE_THRUST)

if(USE_CURAND)
message(STATUS "Build with cuRAND support")
message(STATUS "${CUDA_CURAND_LIBRARY}")
cmake_minimum_required(VERSION 3.13) # to compile CUDA code
enable_language(CUDA)
tvm_file_glob(GLOB CONTRIB_CURAND_SRC_CC src/runtime/contrib/curand/*.cc)
tvm_file_glob(GLOB CONTRIB_CURAND_SRC_CU src/runtime/contrib/curand/*.cu)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CURAND_LIBRARY})
list(APPEND RUNTIME_SRCS ${CONTRIB_CURAND_SRC_CC})
list(APPEND RUNTIME_SRCS ${CONTRIB_CURAND_SRC_CU})
endif(USE_CURAND)

if(USE_GRAPH_EXECUTOR_CUDA_GRAPH)
if(NOT USE_GRAPH_EXECUTOR)
message(FATAL_ERROR "CUDA Graph is only supported by graph executor, please set USE_GRAPH_EXECUTOR=ON")
Expand Down
1 change: 1 addition & 0 deletions cmake/modules/LibInfo.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ function(add_lib_info src_file)
TVM_INFO_USE_TFLITE="${USE_TFLITE}"
TVM_INFO_USE_THREADS="${USE_THREADS}"
TVM_INFO_USE_THRUST="${USE_THRUST}"
TVM_INFO_USE_CURAND="${USE_CURAND}"
TVM_INFO_USE_VITIS_AI="${USE_VITIS_AI}"
TVM_INFO_USE_VULKAN="${USE_VULKAN}"
TVM_INFO_USE_CLML="${USE_CLML}"
Expand Down
5 changes: 5 additions & 0 deletions cmake/utils/FindCUDA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ macro(find_cuda use_cuda use_cudnn)
PATHS ${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES lib lib64 targets/x86_64-linux/lib targets/x86_64-linux/lib/stubs lib64/stubs lib/x86_64-linux-gnu
NO_DEFAULT_PATH)
find_library(CUDA_CURAND_LIBRARY curand
${CUDA_TOOLKIT_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib
NO_DEFAULT_PATH)
find_library(CUDA_CUBLAS_LIBRARY cublas
${CUDA_TOOLKIT_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib
Expand Down Expand Up @@ -134,6 +138,7 @@ macro(find_cuda use_cuda use_cudnn)
message(STATUS "Found CUDA_CUDNN_INCLUDE_DIRS=" ${CUDA_CUDNN_INCLUDE_DIRS})
message(STATUS "Found CUDA_CUDNN_LIBRARY=" ${CUDA_CUDNN_LIBRARY})
message(STATUS "Found CUDA_CUBLAS_LIBRARY=" ${CUDA_CUBLAS_LIBRARY})
message(STATUS "Found CUDA_CURAND_LIBRARY=" ${CUDA_CURAND_LIBRARY})
message(STATUS "Found CUDA_CUBLASLT_LIBRARY=" ${CUDA_CUBLASLT_LIBRARY})
endif(CUDA_FOUND)
endmacro(find_cuda)
10 changes: 3 additions & 7 deletions python/tvm/meta_schedule/runner/local_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,14 @@
from ...runtime import Device, Module
from ..utils import derived_object, get_global_func_with_default_on_worker
from .config import EvaluatorConfig
from .runner import PyRunner, RunnerFuture, RunnerInput, RunnerResult, PyRunnerFuture
from .runner import PyRunner, PyRunnerFuture, RunnerFuture, RunnerInput, RunnerResult
from .utils import (
T_ARGUMENT_LIST,
T_ARG_INFO_JSON_OBJ_LIST,
T_ARGUMENT_LIST,
alloc_argument_common,
run_evaluator_common,
)


logger = logging.getLogger(__name__) # pylint: disable=invalid-name


Expand Down Expand Up @@ -313,9 +312,6 @@ def _check(
get_global_func_with_default_on_worker(name=f_alloc_argument, default=None)
get_global_func_with_default_on_worker(name=f_run_evaluator, default=None)
get_global_func_with_default_on_worker(name=f_cleanup, default=None)
get_global_func_with_default_on_worker(
name="tvm.contrib.random.random_fill", default=None
)

value = self.pool.submit(
_check,
Expand Down Expand Up @@ -348,7 +344,7 @@ def default_alloc_argument(
The allocation args
"""
f_random_fill = get_global_func_with_default_on_worker(
name="tvm.contrib.random.random_fill", default=None
name="tvm.contrib.random.random_fill_for_measure", default=None
)
return alloc_argument_common(f_random_fill, device, args_info, alloc_repeat)

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/runner/rpc_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def default_alloc_argument(
"""
f_random_fill = get_global_func_on_rpc_session(
session,
"tvm.contrib.random.random_fill",
"tvm.contrib.random.random_fill_for_measure",
"Please make sure 'USE_RANDOM' is turned ON in the config.cmake on the RPC server.",
)

Expand Down
104 changes: 104 additions & 0 deletions src/runtime/contrib/curand/curand.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <curand.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>

#include "../../cuda/cuda_common.h"
#include "./helper_cuda_kernels.h"

namespace tvm {
namespace runtime {
namespace curand {

#define TVM_CURAND_CALL(func) \
{ \
curandStatus_t e = (func); \
ICHECK(e == CURAND_STATUS_SUCCESS) << "cuRAND error: " << e; \
}

class CURandGenerator {
public:
CURandGenerator() { TVM_CURAND_CALL(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT)); }
~CURandGenerator() { TVM_CURAND_CALL(curandDestroyGenerator(gen)); }

void Generate32bit(void* ptr, int64_t n) {
TVM_CURAND_CALL(curandGenerateNormal(gen, static_cast<float*>(ptr), n, 0.0f, 5.0f));
cudaDeviceSynchronize();
}

void Generate64bit(void* ptr, int64_t n) {
TVM_CURAND_CALL(curandGenerateNormalDouble(gen, static_cast<double*>(ptr), n, 0.0f, 5.0f));
}

curandGenerator_t gen;
};

DeviceAPI* GetCUDADeviceAPI() {
const PackedFunc* get_cuda_api = runtime::Registry::Get("device_api.cuda");
ICHECK(get_cuda_api) << "ValueError: TVM is not built with USE_CUDA=ON";
void* ret = (*get_cuda_api)();
runtime::DeviceAPI* cuda_api = static_cast<runtime::DeviceAPI*>(ret);
return cuda_api;
}

int64_t GetTensorSize(DLTensor* tensor) {
int64_t tensor_size = 1;
for (int i = 0; i < tensor->ndim; ++i) {
tensor_size *= tensor->shape[i];
}
return tensor_size;
}

struct DeferredFunc {
public:
DeferredFunc(std::function<void()> func) : func_(func) {}
~DeferredFunc() { func_(); }

private:
std::function<void()> func_;
};

void RandomFill(DLTensor* tensor) {
static DeviceAPI* cuda_api = GetCUDADeviceAPI();
CHECK(tensor->device.device_type == DLDeviceType::kDLCUDA)
<< "ValueError: cuRAND only works on CUDA devices";
if (tensor->dtype.code == DLDataTypeCode::kDLFloat && tensor->dtype.bits == 16) {
int64_t tensor_size = GetTensorSize(tensor);
void* data = cuda_api->AllocWorkspace(tensor->device, tensor_size * sizeof(float));
{
DeferredFunc defer([data, tensor]() { cuda_api->FreeWorkspace(tensor->device, data); });
CURandGenerator().Generate32bit(data, GetTensorSize(tensor));
ConvertFp32toFp16(/*src=*/data, /*dst=*/tensor->data, /*num=*/tensor_size);
}
} else if (tensor->dtype.code == DLDataTypeCode::kDLFloat && tensor->dtype.bits == 32) {
CURandGenerator().Generate32bit(tensor->data, GetTensorSize(tensor));
} else if (tensor->dtype.code == DLDataTypeCode::kDLFloat && tensor->dtype.bits == 64) {
CURandGenerator().Generate64bit(tensor->data, GetTensorSize(tensor));
} else {
LOG(FATAL) << "ValueError: Unsupported dtype: " << tensor->dtype;
}
TVMSynchronize(tensor->device.device_type, tensor->device.device_type, nullptr);
}

TVM_REGISTER_GLOBAL("runtime.contrib.curand.RandomFill").set_body_typed(RandomFill);

} // namespace curand
} // namespace runtime
} // namespace tvm
42 changes: 42 additions & 0 deletions src/runtime/contrib/curand/helper_cuda_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <cuda_fp16.h>

#include "./helper_cuda_kernels.h"

namespace tvm {
namespace runtime {
namespace curand {

__global__ void KernelFp32ToFp16(const float* src, half* dst, int num) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
if (idx < num) {
dst[idx] = src[idx];
}
}

void ConvertFp32toFp16(const void* _src, void* _dst, int64_t num) {
const float* src = static_cast<const float*>(_src);
half* dst = static_cast<half*>(_dst);
KernelFp32ToFp16<<<(num + 255) / 256, 256>>>(src, dst, num);
}

} // namespace curand
} // namespace runtime
} // namespace tvm
36 changes: 36 additions & 0 deletions src/runtime/contrib/curand/helper_cuda_kernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <curand.h>
#include <tvm/runtime/registry.h>

namespace tvm {
namespace runtime {
namespace curand {

/*!
* \brief An auxiliary function to convert an FP32 array to FP16.
* \param src The source FP32 array.
* \param dst The destination FP16 array.
* \param num The number of elements in the array.
*/
void ConvertFp32toFp16(const void* src, void* dst, int64_t num);

} // namespace curand
} // namespace runtime
} // namespace tvm
Loading

0 comments on commit 7b73c1f

Please sign in to comment.