Skip to content

Commit

Permalink
[MetaSchedule][Runtime] Enhance Runner RandomFill
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Jun 18, 2022
1 parent 4b15746 commit a8024f4
Show file tree
Hide file tree
Showing 19 changed files with 371 additions and 108 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ tvm_option(USE_CUDNN "Build with cuDNN" OFF)
tvm_option(USE_CUBLAS "Build with cuBLAS" OFF)
tvm_option(USE_CUTLASS "Build with CUTLASS" OFF)
tvm_option(USE_THRUST "Build with Thrust" OFF)
tvm_option(USE_CURAND "Build with cuRAND" OFF)
tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF)
tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF)
tvm_option(USE_SORT "Build with sort support" ON)
Expand Down
3 changes: 3 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,9 @@ set(USE_VTA_FPGA OFF)
# Whether use Thrust
set(USE_THRUST OFF)

# Whether use cuRAND
set(USE_CURAND OFF)

# Whether to build the TensorFlow TVMDSOOp module
set(USE_TF_TVMDSOOP OFF)

Expand Down
12 changes: 12 additions & 0 deletions cmake/modules/CUDA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ if(USE_CUDA)
list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC})
endif(USE_THRUST)

if(USE_CURAND)
message(STATUS "Build with cuRAND support")
message(STATUS "${CUDA_CURAND_LIBRARY}")
cmake_minimum_required(VERSION 3.13) # to compile CUDA code
enable_language(CUDA)
tvm_file_glob(GLOB CONTRIB_CURAND_SRC_CC src/runtime/contrib/curand/*.cc)
tvm_file_glob(GLOB CONTRIB_CURAND_SRC_CU src/runtime/contrib/curand/*.cu)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CURAND_LIBRARY})
list(APPEND RUNTIME_SRCS ${CONTRIB_CURAND_SRC_CC})
list(APPEND RUNTIME_SRCS ${CONTRIB_CURAND_SRC_CU})
endif(USE_CURAND)

if(USE_GRAPH_EXECUTOR_CUDA_GRAPH)
if(NOT USE_GRAPH_EXECUTOR)
message(FATAL_ERROR "CUDA Graph is only supported by graph executor, please set USE_GRAPH_EXECUTOR=ON")
Expand Down
1 change: 1 addition & 0 deletions cmake/modules/LibInfo.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ function(add_lib_info src_file)
TVM_INFO_USE_TFLITE="${USE_TFLITE}"
TVM_INFO_USE_THREADS="${USE_THREADS}"
TVM_INFO_USE_THRUST="${USE_THRUST}"
TVM_INFO_USE_CURAND="${USE_CURAND}"
TVM_INFO_USE_VITIS_AI="${USE_VITIS_AI}"
TVM_INFO_USE_VULKAN="${USE_VULKAN}"
TVM_INFO_USE_CLML="${USE_CLML}"
Expand Down
5 changes: 5 additions & 0 deletions cmake/utils/FindCUDA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ macro(find_cuda use_cuda use_cudnn)
PATHS ${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES lib lib64 targets/x86_64-linux/lib targets/x86_64-linux/lib/stubs lib64/stubs lib/x86_64-linux-gnu
NO_DEFAULT_PATH)
find_library(CUDA_CURAND_LIBRARY curand
${CUDA_TOOLKIT_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib
NO_DEFAULT_PATH)
find_library(CUDA_CUBLAS_LIBRARY cublas
${CUDA_TOOLKIT_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib
Expand Down Expand Up @@ -134,6 +138,7 @@ macro(find_cuda use_cuda use_cudnn)
message(STATUS "Found CUDA_CUDNN_INCLUDE_DIRS=" ${CUDA_CUDNN_INCLUDE_DIRS})
message(STATUS "Found CUDA_CUDNN_LIBRARY=" ${CUDA_CUDNN_LIBRARY})
message(STATUS "Found CUDA_CUBLAS_LIBRARY=" ${CUDA_CUBLAS_LIBRARY})
message(STATUS "Found CUDA_CURAND_LIBRARY=" ${CUDA_CURAND_LIBRARY})
message(STATUS "Found CUDA_CUBLASLT_LIBRARY=" ${CUDA_CUBLASLT_LIBRARY})
endif(CUDA_FOUND)
endmacro(find_cuda)
10 changes: 3 additions & 7 deletions python/tvm/auto_scheduler/testing/tune_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tvm import meta_schedule as ms
from tvm import relay
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.meta_schedule.utils import cpu_count
from tvm.relay.frontend import from_onnx
from tvm.support import describe

Expand Down Expand Up @@ -73,11 +74,6 @@ def _parse_args():
type=str,
required=True,
)
args.add_argument(
"--rpc-workers",
type=int,
required=True,
)
args.add_argument(
"--work-dir",
type=str,
Expand All @@ -100,7 +96,7 @@ def _parse_args():
)
args.add_argument(
"--cpu-flush",
type=bool,
type=int,
required=True,
)
parsed = args.parse_args()
Expand All @@ -125,7 +121,7 @@ def main():
key=ARGS.rpc_key,
host=ARGS.rpc_host,
port=ARGS.rpc_port,
n_parallel=ARGS.rpc_workers,
n_parallel=cpu_count(logical=True),
number=ARGS.number,
repeat=ARGS.repeat,
min_repeat_ms=ARGS.min_repeat_ms,
Expand Down
10 changes: 3 additions & 7 deletions python/tvm/auto_scheduler/testing/tune_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tvm import relay
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.meta_schedule.testing.relay_workload import get_network
from tvm.meta_schedule.utils import cpu_count
from tvm.support import describe


Expand Down Expand Up @@ -66,11 +67,6 @@ def _parse_args():
type=str,
required=True,
)
args.add_argument(
"--rpc-workers",
type=int,
required=True,
)
args.add_argument(
"--work-dir",
type=str,
Expand Down Expand Up @@ -98,7 +94,7 @@ def _parse_args():
)
args.add_argument(
"--cpu-flush",
type=bool,
type=int,
required=True,
)
parsed = args.parse_args()
Expand All @@ -123,7 +119,7 @@ def main():
key=ARGS.rpc_key,
host=ARGS.rpc_host,
port=ARGS.rpc_port,
n_parallel=ARGS.rpc_workers,
n_parallel=cpu_count(logical=True),
number=ARGS.number,
repeat=ARGS.repeat,
min_repeat_ms=ARGS.min_repeat_ms,
Expand Down
10 changes: 3 additions & 7 deletions python/tvm/auto_scheduler/testing/tune_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tvm
from tvm import auto_scheduler
from tvm.meta_schedule.testing.te_workload import CONFIGS
from tvm.meta_schedule.utils import cpu_count
from tvm.support import describe


Expand Down Expand Up @@ -56,11 +57,6 @@ def _parse_args():
type=str,
required=True,
)
args.add_argument(
"--rpc-workers",
type=int,
required=True,
)
args.add_argument(
"--work-dir",
type=str,
Expand All @@ -83,7 +79,7 @@ def _parse_args():
)
args.add_argument(
"--cpu-flush",
type=bool,
type=int,
required=True,
)
parsed = args.parse_args()
Expand Down Expand Up @@ -132,7 +128,7 @@ def main():
key=ARGS.rpc_key,
host=ARGS.rpc_host,
port=ARGS.rpc_port,
n_parallel=ARGS.rpc_workers,
n_parallel=cpu_count(logical=True),
number=ARGS.number,
repeat=ARGS.repeat,
min_repeat_ms=ARGS.min_repeat_ms,
Expand Down
48 changes: 24 additions & 24 deletions python/tvm/meta_schedule/runner/local_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@

from ...contrib.popen_pool import PopenPoolExecutor
from ...runtime import Device, Module
from ..profiler import Profiler
from ..utils import derived_object, get_global_func_with_default_on_worker
from .config import EvaluatorConfig
from .runner import PyRunner, RunnerFuture, RunnerInput, RunnerResult, PyRunnerFuture
from .runner import PyRunner, PyRunnerFuture, RunnerFuture, RunnerInput, RunnerResult
from .utils import (
T_ARGUMENT_LIST,
T_ARG_INFO_JSON_OBJ_LIST,
T_ARGUMENT_LIST,
alloc_argument_common,
run_evaluator_common,
)


logger = logging.getLogger(__name__) # pylint: disable=invalid-name


Expand Down Expand Up @@ -137,26 +137,29 @@ def resource_handler():
yield
finally:
# Final step. Always clean up
f_cleanup()
with Profiler.timeit("LocalRunner/cleanup"):
f_cleanup()

with resource_handler():
# Step 1: create the local runtime module
rt_mod = tvm.runtime.load_module(artifact_path)
# Step 2: create the local device
device = tvm.runtime.device(dev_type=device_type, dev_id=0)
# Step 3: Allocate input arguments
repeated_args: List[T_ARGUMENT_LIST] = f_alloc_argument(
device,
args_info,
alloc_repeat,
)
# Step 4: Run time_evaluator
costs: List[float] = f_run_evaluator(
rt_mod,
device,
evaluator_config,
repeated_args,
)
with Profiler.timeit("LocalRunner/load_module"):
rt_mod = tvm.runtime.load_module(artifact_path)
# Step 2: Allocate input arguments
with Profiler.timeit("LocalRunner/alloc_argument"):
device = tvm.runtime.device(dev_type=device_type, dev_id=0)
repeated_args: List[T_ARGUMENT_LIST] = f_alloc_argument(
device,
args_info,
alloc_repeat,
)
# Step 3: Run time_evaluator
with Profiler.timeit("LocalRunner/run_evaluator"):
costs: List[float] = f_run_evaluator(
rt_mod,
device,
evaluator_config,
repeated_args,
)
return costs


Expand Down Expand Up @@ -313,9 +316,6 @@ def _check(
get_global_func_with_default_on_worker(name=f_alloc_argument, default=None)
get_global_func_with_default_on_worker(name=f_run_evaluator, default=None)
get_global_func_with_default_on_worker(name=f_cleanup, default=None)
get_global_func_with_default_on_worker(
name="tvm.contrib.random.random_fill", default=None
)

value = self.pool.submit(
_check,
Expand Down Expand Up @@ -348,7 +348,7 @@ def default_alloc_argument(
The allocation args
"""
f_random_fill = get_global_func_with_default_on_worker(
name="tvm.contrib.random.random_fill", default=None
name="tvm.contrib.random.random_fill_for_measure", default=None
)
return alloc_argument_common(f_random_fill, device, args_info, alloc_repeat)

Expand Down
50 changes: 28 additions & 22 deletions python/tvm/meta_schedule/runner/rpc_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tvm.rpc import RPCSession
from tvm.runtime import Device, Module

from ..profiler import Profiler
from ..utils import (
cpu_count,
derived_object,
Expand Down Expand Up @@ -243,7 +244,7 @@ def __init__(
f_alloc_argument: Union[T_ALLOC_ARGUMENT, str, None] = None,
f_run_evaluator: Union[T_RUN_EVALUATOR, str, None] = None,
f_cleanup: Union[T_CLEANUP, str, None] = None,
max_workers: Optional[int] = 1,
max_workers: Optional[int] = None,
initializer: Optional[Callable[[], None]] = None,
) -> None:
"""Constructor
Expand Down Expand Up @@ -284,7 +285,7 @@ def __init__(
self.f_run_evaluator = f_run_evaluator
self.f_cleanup = f_cleanup
if max_workers is None:
max_workers = cpu_count()
max_workers = cpu_count(logical=True)
logger.info("RPCRunner: max_workers = %d", max_workers)
self.pool = PopenPoolExecutor(
max_workers=max_workers,
Expand Down Expand Up @@ -378,31 +379,36 @@ def resource_handler():
yield
finally:
# Final step. Always clean up
f_cleanup(session, remote_path)
with Profiler.timeit("RPCRunner/cleanup"):
f_cleanup(session, remote_path)

with resource_handler():
# Step 1. Create session
session = f_create_session(rpc_config)
device = session.device(dev_type=device_type, dev_id=0)
with Profiler.timeit("RPCRunner/create_session"):
session = f_create_session(rpc_config)
device = session.device(dev_type=device_type, dev_id=0)
# Step 2. Upload the module
_, remote_path = osp.split(artifact_path)
local_path: str = artifact_path
rt_mod: Module = f_upload_module(session, local_path, remote_path)
with Profiler.timeit("RPCRunner/upload_module"):
_, remote_path = osp.split(artifact_path)
local_path: str = artifact_path
rt_mod: Module = f_upload_module(session, local_path, remote_path)
# Step 3: Allocate input arguments
repeated_args: List[T_ARGUMENT_LIST] = f_alloc_argument(
session,
device,
args_info,
alloc_repeat,
)
with Profiler.timeit("RPCRunner/alloc_argument"):
repeated_args: List[T_ARGUMENT_LIST] = f_alloc_argument(
session,
device,
args_info,
alloc_repeat,
)
# Step 4: Run time_evaluator
costs: List[float] = f_run_evaluator(
session,
rt_mod,
device,
evaluator_config,
repeated_args,
)
with Profiler.timeit("LocalRunner/run_evaluator"):
costs: List[float] = f_run_evaluator(
session,
rt_mod,
device,
evaluator_config,
repeated_args,
)
return costs


Expand Down Expand Up @@ -474,7 +480,7 @@ def default_alloc_argument(
"""
f_random_fill = get_global_func_on_rpc_session(
session,
"tvm.contrib.random.random_fill",
"tvm.contrib.random.random_fill_for_measure",
"Please make sure 'USE_RANDOM' is turned ON in the config.cmake on the RPC server.",
)

Expand Down
8 changes: 1 addition & 7 deletions python/tvm/meta_schedule/testing/tune_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,6 @@ def _parse_args():
type=str,
required=True,
)
args.add_argument(
"--rpc-workers",
type=int,
required=True,
)
args.add_argument(
"--work-dir",
type=str,
Expand All @@ -98,7 +93,7 @@ def _parse_args():
)
args.add_argument(
"--cpu-flush",
type=bool,
type=int,
required=True,
)
parsed = args.parse_args()
Expand Down Expand Up @@ -140,7 +135,6 @@ def main():
enable_cpu_cache_flush=ARGS.cpu_flush,
),
alloc_repeat=1,
max_workers=ARGS.rpc_workers,
)
with ms.Profiler() as profiler:
lib = ms.tune_relay(
Expand Down
Loading

0 comments on commit a8024f4

Please sign in to comment.