Skip to content

Commit

Permalink
[Tuning] Allow multiprocessing spawn to work (on macOS llvm at least) (
Browse files Browse the repository at this point in the history
…#8363)

* go to callable class

* add some documentation and naming

* extend comment

* manually do logic to avoid bug with pointer comparison

* revert changes to light change, correct comment'

* more principled change, but also kind of hacky

* test other tuning methods

* remove check;

* jostle CI
  • Loading branch information
AndrewZhaoLuo authored Jul 1, 2021
1 parent c989e4a commit 578f617
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 93 deletions.
63 changes: 36 additions & 27 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,29 @@

import contextlib
import logging
import shutil
import os
import shutil
import tempfile
import threading
import time
import typing
from random import getrandbits
from collections import namedtuple
import tempfile
from random import getrandbits

import tvm._ffi
import tvm.ir.transform
from tvm import nd, rpc as _rpc
from tvm.error import TVMError
from tvm import nd
from tvm import rpc as _rpc
from tvm.contrib import ndk, nvcc, stackvm, tar
from tvm.driver import build
from tvm.contrib import nvcc, ndk, tar, stackvm
from tvm.error import TVMError
from tvm.target import Target

from ..utils import get_const_tuple
from ..env import AutotvmGlobalScope
from ..task.space import InstantiationError

from .measure import MeasureResult, MeasureErrorNo, Builder, Runner
from ..utils import get_const_tuple
from .local_executor import LocalExecutor
from .measure import Builder, MeasureErrorNo, MeasureResult, Runner

logger = logging.getLogger("autotvm")

Expand Down Expand Up @@ -393,8 +393,8 @@ def __init__(

def set_task(self, task):
# pylint: disable=import-outside-toplevel
from ...rpc.tracker import Tracker
from ...rpc.server import Server
from ...rpc.tracker import Tracker

self.task = task
tracker = Tracker(port=9000, port_end=10000, silent=True)
Expand Down Expand Up @@ -605,26 +605,17 @@ def run_through_rpc(
return MeasureResult(costs, errno, tstamp - tic + build_result.time_cost, tstamp)


def default_module_loader(pre_load_function=None):
"""Returns a default function that can be passed as module_loader to run_through_rpc.
Parameters
----------
pre_load_function : Optional[Function[tvm.rpc.Session, tvm.runtime.Module]]
Invoked after a session is established and before the default code-loading RPC calls are
issued. Allows performing pre-upload actions, e.g. resetting the remote runtime environment.
class DefaultModuleLoader:
"""See default_module_loader(). A pickleable emulation of the original function closure."""

Returns
-------
ModuleLoader :
A function that can be passed as module_loader to run_through_rpc.
"""
def __init__(self, pre_load_function=None) -> None:
self.pre_load_function = pre_load_function

@contextlib.contextmanager
def default_module_loader_mgr(remote_kwargs, build_result):
def __call__(self, remote_kwargs, build_result):
remote = request_remote(**remote_kwargs)
if pre_load_function is not None:
pre_load_function(remote, build_result)
if self.pre_load_function is not None:
self.pre_load_function(remote, build_result)

remote.upload(build_result.filename)
try:
Expand All @@ -636,7 +627,25 @@ def default_module_loader_mgr(remote_kwargs, build_result):
remote.remove(os.path.splitext(build_result.filename)[0] + ".so")
remote.remove("")

return default_module_loader_mgr

def default_module_loader(pre_load_function=None):
"""Returns a default function that can be passed as module_loader to run_through_rpc.
Parameters
----------
pre_load_function : Optional[Function[tvm.rpc.Session, tvm.runtime.Module]]
Invoked after a session is established and before the default code-loading RPC calls are
issued. Allows performing pre-upload actions, e.g. resetting the remote runtime environment.
Returns
-------
DefaultModuleLoader :
A callable that can be passed as module_loader to run_through_rpc.
"""

# This was a function with a closure before but that couldn't be pickled!
# We need pickle to work for using python's multiprocessing on some platforms.
return DefaultModuleLoader(pre_load_function)


def request_remote(device_key, host=None, port=None, priority=1, timeout=60):
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
# specific language governing permissions and limitations
# under the License.
"""Target data structure."""
import json
import os
import re
import json
import warnings
import tvm._ffi

from tvm.runtime import Object
import tvm._ffi
from tvm._ffi import register_func as _register_func
from tvm.runtime import Object

from . import _ffi_api


Expand Down
3 changes: 0 additions & 3 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,6 @@ Target::Target(const Map<String, ObjectRef>& config) {

Target::Target(Target target, Target host) {
ObjectPtr<TargetNode> n = make_object<TargetNode>(*target.get());
CHECK(!n->host.defined() || n->host.same_as(host))
<< "ValueError: Adding a host to a target whose host field has been defined";
// add target host into host field
n->host = std::move(host);
data_ = std::move(n);
}
Expand Down
107 changes: 63 additions & 44 deletions tests/python/integration/test_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,19 @@
Test the tuner
"""
import logging
import multiprocessing as mp
import sys
import textwrap
import time

import pytest

import tvm
import tvm.relay
from tvm import te

from tvm import autotvm
import tvm.testing
from tvm import autotvm, te
from tvm.autotvm.tuner import RandomTuner
from tvm.target import Target

import tvm.testing


def setup_module():
@autotvm.template("testing/conv2d_no_batching")
Expand Down Expand Up @@ -140,62 +137,84 @@ def get_sample_task(target=tvm.target.cuda(), target_host=None):
return task, target


def run_test_with_all_multiprocessing(func, *args, **kwargs):
"""Check all multiprocessing methods work for the tuning test.
In the past fork() had the most support at detriment to spawn() and forkserver().
As fork() is unavailable or unsafe on some platforms it is good to check all
available methods.
"""
for multiprocessing_method in mp.get_all_start_methods():
old_start_method = mp.get_start_method()
try:
mp.set_start_method(multiprocessing_method, force=True)
func(*args, **kwargs)
finally:
mp.set_start_method(old_start_method, force=True)


@tvm.testing.parametrize_targets("cuda", "opencl")
def test_tuning_gpu(target, dev):
# init task
task, target = get_sample_task(target, None)
logging.info("task config space: %s", task.config_space)
def runner(target, dev):
# init task
task, target = get_sample_task(target, None)
logging.info("task config space: %s", task.config_space)

measure_option = autotvm.measure_option(autotvm.LocalBuilder(), autotvm.LocalRunner())
measure_option = autotvm.measure_option(autotvm.LocalBuilder(), autotvm.LocalRunner())

results = []
results = []

tuner = RandomTuner(task)
tuner.tune(
n_trial=20,
measure_option=measure_option,
callbacks=(lambda _tuner, _inputs, rs: results.extend(rs),),
)
tuner = RandomTuner(task)
tuner.tune(
n_trial=20,
measure_option=measure_option,
callbacks=(lambda _tuner, _inputs, rs: results.extend(rs),),
)

assert len(results) == 20

assert len(results) == 20
successful_results = [r for r in results if r.error_no == autotvm.MeasureErrorNo.NO_ERROR]
assert len(successful_results) > 0, f"No successful tuning runs: {results!r}"

successful_results = [r for r in results if r.error_no == autotvm.MeasureErrorNo.NO_ERROR]
assert len(successful_results) > 0, f"No successful tuning runs: {results!r}"
run_test_with_all_multiprocessing(runner, target, dev)


def test_tuning_cpu():
ir_mod = tvm.parser.fromtext(
textwrap.dedent(
def runner():
ir_mod = tvm.parser.fromtext(
textwrap.dedent(
"""
#[version = "0.0.5"]
def @main(%a : Tensor[(1, 3, 32, 32), float32], %b : Tensor[(3, 3, 5, 5), float32]) {
nn.conv2d(%a, %b, data_layout="NCHW", kernel_layout="OIHW")
}
"""
#[version = "0.0.5"]
def @main(%a : Tensor[(1, 3, 32, 32), float32], %b : Tensor[(3, 3, 5, 5), float32]) {
nn.conv2d(%a, %b, data_layout="NCHW", kernel_layout="OIHW")
}
"""
)
)
)
tasks = autotvm.task.relay_integration.extract_from_program(
ir_mod, {}, tvm.target.create("llvm")
)
assert len(tasks) == 1, f"Extracted != 1 task from program: {tasks!r}"
tasks = autotvm.task.relay_integration.extract_from_program(
ir_mod, {}, tvm.target.create("llvm")
)
assert len(tasks) == 1, f"Extracted != 1 task from program: {tasks!r}"

task = tasks[0]
task = tasks[0]

measure_option = autotvm.measure_option(autotvm.LocalBuilder(), autotvm.LocalRunner())
measure_option = autotvm.measure_option(autotvm.LocalBuilder(), autotvm.LocalRunner())

results = []
results = []

tuner = RandomTuner(task)
tuner.tune(
n_trial=20,
measure_option=measure_option,
callbacks=(lambda _tuner, _inputs, rs: results.extend(rs),),
)
tuner = RandomTuner(task)
tuner.tune(
n_trial=20,
measure_option=measure_option,
callbacks=(lambda _tuner, _inputs, rs: results.extend(rs),),
)

assert len(results) == 20

assert len(results) == 20
successful_results = [r for r in results if r.error_no == autotvm.MeasureErrorNo.NO_ERROR]
assert len(successful_results) > 0, f"No successful tuning runs: {results!r}"

successful_results = [r for r in results if r.error_no == autotvm.MeasureErrorNo.NO_ERROR]
assert len(successful_results) > 0, f"No successful tuning runs: {results!r}"
run_test_with_all_multiprocessing(runner)


if __name__ == "__main__":
Expand Down
22 changes: 6 additions & 16 deletions tests/python/unittest/test_target_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
# under the License.
import json
import sys

import pytest
import tvm
from tvm.target import cuda, rocm, mali, intel_graphics, arm_cpu, vta, bifrost, Target
from tvm.target import Target, arm_cpu, bifrost, cuda, intel_graphics, mali, rocm, vta


@tvm.target.generic_func
Expand Down Expand Up @@ -210,17 +211,6 @@ def test_target_host_single_string_with_tag():
assert tgt.host.attrs["registers_per_block"] == 32768


def test_target_host_warning():
"""
Confirm that constructing a target with invalid
attributes fails as expected.
"""
with pytest.raises(
ValueError, match="Adding a host to a target whose host field has been defined"
):
tvm.target.Target("cuda --host nvidia/jetson-nano", "llvm")


def test_target_host_merge_0():
tgt = tvm.target.Target(tvm.target.Target("cuda --host nvidia/jetson-nano"), None)
assert tgt.kind.name == "cuda"
Expand All @@ -240,10 +230,10 @@ def test_target_host_merge_1():


def test_target_host_merge_2():
with pytest.raises(
ValueError, match="Adding a host to a target whose host field has been defined"
):
tvm.target.Target(tvm.target.Target("cuda --host llvm"), tvm.target.Target("llvm"))
"""Test picking the same host is ok."""
tgt = tvm.target.Target(tvm.target.Target("cuda --host llvm"), tvm.target.Target("llvm"))
assert tgt.kind.name == "cuda"
assert tgt.host.kind.name == "llvm"


@pytest.mark.skip(reason="Causing infinite loop because of pytest and handle issue")
Expand Down

0 comments on commit 578f617

Please sign in to comment.