diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index 84dff157aa503..e77721496386e 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -560,7 +560,6 @@ def __init__( port=self.tracker.port, port_end=10000, key=device_key, - use_popen=True, silent=True, tracker_addr=(self.tracker.host, self.tracker.port), ) diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index d212e5f26f20a..f328a06c079a2 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -404,7 +404,6 @@ def set_task(self, task): port=9000, port_end=10000, key=device_key, - use_popen=True, silent=True, tracker_addr=(tracker.host, tracker.port), ) diff --git a/python/tvm/contrib/popen_pool.py b/python/tvm/contrib/popen_pool.py index 5a25484e91061..ecda995c7162c 100644 --- a/python/tvm/contrib/popen_pool.py +++ b/python/tvm/contrib/popen_pool.py @@ -44,10 +44,11 @@ def kill_child_processes(pid): try: parent = psutil.Process(pid) + children = parent.children(recursive=True) except psutil.NoSuchProcess: return - for process in parent.children(recursive=True): + for process in children: try: process.kill() except psutil.NoSuchProcess: diff --git a/python/tvm/exec/rpc_server.py b/python/tvm/exec/rpc_server.py index 9692b98fe22b0..6b3e93edd2239 100644 --- a/python/tvm/exec/rpc_server.py +++ b/python/tvm/exec/rpc_server.py @@ -16,11 +16,7 @@ # under the License. # pylint: disable=redefined-outer-name, invalid-name """Start an RPC server""" -from __future__ import absolute_import - import argparse -import multiprocessing -import sys import logging from .. import rpc @@ -51,6 +47,7 @@ def main(args): load_library=args.load_library, custom_addr=args.custom_addr, silent=args.silent, + no_fork=not args.fork, ) server.proc.join() @@ -85,14 +82,9 @@ def main(args): parser.set_defaults(fork=True) args = parser.parse_args() logging.basicConfig(level=logging.INFO) - if args.fork is False: - if sys.version_info[0] < 3: - raise RuntimeError("Python3 is required for spawn mode.") - multiprocessing.set_start_method("spawn") - else: - if not args.silent: - logging.info( - "If you are running ROCM/Metal, fork will cause " - "compiler internal error. Try to launch with arg ```--no-fork```" - ) + if not args.fork is False and not args.silent: + logging.info( + "If you are running ROCM/Metal, fork will cause " + "compiler internal error. Try to launch with arg ```--no-fork```" + ) main(args) diff --git a/python/tvm/rpc/proxy.py b/python/tvm/rpc/proxy.py index 28117b09f280d..7e02bd77c491d 100644 --- a/python/tvm/rpc/proxy.py +++ b/python/tvm/rpc/proxy.py @@ -537,7 +537,7 @@ def __init__( self.thread.start() -def _popen_start_server( +def _popen_start_proxy_server( host, port=9091, port_end=9199, @@ -570,7 +570,7 @@ def _popen_start_server( class Proxy(object): """Start RPC proxy server on a seperate process. - Python implementation based on multi-processing. + Python implementation based on PopenWorker. Parameters ---------- @@ -618,7 +618,7 @@ def __init__( self.proc = PopenWorker() # send the function self.proc.send( - _popen_start_server, + _popen_start_proxy_server, [ host, port, diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 7861542531335..3fd6996034f72 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -31,20 +31,21 @@ import select import struct import logging +import threading import multiprocessing -import subprocess import time -import sys -import signal -import platform import tvm._ffi from tvm._ffi.base import py_str from tvm._ffi.libinfo import find_lib_path from tvm.runtime.module import load_module as _load_module from tvm.contrib import utils +from tvm.contrib.popen_pool import PopenWorker from . import _ffi_api from . import base + +# pylint: disable=unused-import +from . import testing from .base import TrackerCode logger = logging.getLogger("RPCServer") @@ -296,13 +297,85 @@ def _connect_proxy_loop(addr, key, load_library): time.sleep(retry_period) -def _popen(cmd): - proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=os.environ) - (out, _) = proc.communicate() - if proc.returncode != 0: - msg = "Server invoke error:\n" - msg += out - raise RuntimeError(msg) +class PopenRPCServerState(object): + """Internal PopenRPCServer State""" + + current = None + + def __init__( + self, + host, + port=9091, + port_end=9199, + is_proxy=False, + tracker_addr=None, + key="", + load_library=None, + custom_addr=None, + silent=False, + ): + + # start update + self.host = host + self.port = port + self.libs = [] + self.custom_addr = custom_addr + + if silent: + logger.setLevel(logging.ERROR) + + if not is_proxy: + sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM) + self.port = None + for my_port in range(port, port_end): + try: + sock.bind((host, my_port)) + self.port = my_port + break + except socket.error as sock_err: + if sock_err.errno in [98, 48]: + continue + raise sock_err + if not self.port: + raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) + logger.info("bind to %s:%d", host, self.port) + sock.listen(1) + self.sock = sock + self.thread = threading.Thread( + target=_listen_loop, + args=(self.sock, self.port, key, tracker_addr, load_library, self.custom_addr), + ) + self.thread.start() + else: + self.thread = threading.Thread( + target=_connect_proxy_loop, args=((host, port), key, load_library) + ) + self.thread.start() + + +def _popen_start_rpc_server( + host, + port=9091, + port_end=9199, + is_proxy=False, + tracker_addr=None, + key="", + load_library=None, + custom_addr=None, + silent=False, + no_fork=False, +): + if no_fork: + multiprocessing.set_start_method("spawn") + # This is a function that will be sent to the + # Popen worker to run on a separate process. + # Create and start the server in a different thread + state = PopenRPCServerState( + host, port, port_end, is_proxy, tracker_addr, key, load_library, custom_addr, silent + ) + PopenRPCServerState.current = state + # returns the port so that the main can get the port number. + return state.port class Server(object): @@ -328,11 +401,6 @@ class Server(object): If this is true, the host and port actually corresponds to the address of the proxy server. - use_popen : bool, optional - Whether to use Popen to start a fresh new process instead of fork. - This is recommended to switch on if we want to do local RPC demonstration - for GPU devices to avoid fork safety issues. - tracker_addr: Tuple (str, int) , optional The address of RPC Tracker in tuple(host, ip) format. If is not None, the server will register itself to the tracker. @@ -348,6 +416,9 @@ class Server(object): silent: bool, optional Whether run this server in silent mode. + + no_fork: bool, optional + Whether forbid fork in multiprocessing. """ def __init__( @@ -356,101 +427,44 @@ def __init__( port=9091, port_end=9199, is_proxy=False, - use_popen=False, tracker_addr=None, key="", load_library=None, custom_addr=None, silent=False, + no_fork=False, ): try: if _ffi_api.ServerLoop is None: raise RuntimeError("Please compile with USE_RPC=1") except NameError: raise RuntimeError("Please compile with USE_RPC=1") + self.proc = PopenWorker() + # send the function + self.proc.send( + _popen_start_rpc_server, + [ + host, + port, + port_end, + is_proxy, + tracker_addr, + key, + load_library, + custom_addr, + silent, + no_fork, + ], + ) + # receive the port + self.port = self.proc.recv() self.host = host - self.port = port - self.libs = [] - self.custom_addr = custom_addr - self.use_popen = use_popen - - if silent: - logger.setLevel(logging.ERROR) - - if use_popen: - cmd = [ - sys.executable, - "-m", - "tvm.exec.rpc_server", - "--host=%s" % host, - "--port=%s" % port, - "--port-end=%s" % port_end, - ] - if tracker_addr: - assert key - cmd += ["--tracker=%s:%d" % tracker_addr, "--key=%s" % key] - if load_library: - cmd += ["--load-library", load_library] - if custom_addr: - cmd += ["--custom-addr", custom_addr] - if silent: - cmd += ["--silent"] - - # prexec_fn is not thread safe and may result in deadlock. - # python 3.2 introduced the start_new_session parameter as - # an alternative to the common use case of - # prexec_fn=os.setsid. Once the minimum version of python - # supported by TVM reaches python 3.2 this code can be - # rewritten in favour of start_new_session. In the - # interim, stop the pylint diagnostic. - # - # pylint: disable=subprocess-popen-preexec-fn - if platform.system() == "Windows": - self.proc = subprocess.Popen(cmd, creationflags=subprocess.CREATE_NEW_PROCESS_GROUP) - else: - self.proc = subprocess.Popen(cmd, preexec_fn=os.setsid) - time.sleep(0.5) - elif not is_proxy: - sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM) - self.port = None - for my_port in range(port, port_end): - try: - sock.bind((host, my_port)) - self.port = my_port - break - except socket.error as sock_err: - if sock_err.errno in [98, 48]: - continue - raise sock_err - if not self.port: - raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) - logger.info("bind to %s:%d", host, self.port) - sock.listen(1) - self.sock = sock - self.proc = multiprocessing.Process( - target=_listen_loop, - args=(self.sock, self.port, key, tracker_addr, load_library, self.custom_addr), - ) - self.proc.start() - else: - self.proc = multiprocessing.Process( - target=_connect_proxy_loop, args=((host, port), key, load_library) - ) - self.proc.start() def terminate(self): """Terminate the server process""" - if self.use_popen: - if self.proc: - if platform.system() == "Windows": - os.kill(self.proc.pid, signal.CTRL_C_EVENT) - else: - os.killpg(self.proc.pid, signal.SIGTERM) - self.proc = None - else: - if self.proc: - self.proc.terminate() - self.proc = None + if self.proc: + self.proc.kill() + self.proc = None def __del__(self): self.terminate() diff --git a/python/tvm/rpc/testing.py b/python/tvm/rpc/testing.py new file mode 100644 index 0000000000000..b7acc74c413a3 --- /dev/null +++ b/python/tvm/rpc/testing.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name,unnecessary-comprehension +""" Testing functions for the RPC server.""" +import numpy as np +import tvm + + +# RPC test functions to be registered for unit-tests purposes +@tvm.register_func("rpc.test.addone") +def _addone(x): + return x + 1 + + +@tvm.register_func("rpc.test.strcat") +def _strcat(name, x): + return "%s:%d" % (name, x) + + +@tvm.register_func("rpc.test.except") +def _remotethrow(name): + raise ValueError("%s" % name) + + +@tvm.register_func("rpc.test.runtime_str_concat") +def _strcat(x, y): + return x + y + + +@tvm.register_func("rpc.test.remote_array_func") +def _remote_array_func(y): + x = np.ones((3, 4)) + np.testing.assert_equal(y.asnumpy(), x) + + +@tvm.register_func("rpc.test.add_to_lhs") +def _add_to_lhs(x): + return lambda y: x + y + + +@tvm.register_func("rpc.test.remote_return_nd") +def _my_module(name): + # Use closure to check the ref counter correctness + nd = tvm.nd.array(np.zeros(10).astype("float32")) + + if name == "get_arr": + return lambda: nd + if name == "ref_count": + return lambda: tvm.testing.object_use_count(nd) + if name == "get_elem": + return lambda idx: nd.asnumpy()[idx] + if name == "get_arr_elem": + return lambda arr, idx: arr.asnumpy()[idx] + raise RuntimeError("unknown name") diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 58985832fb359..7e790494125ef 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -827,8 +827,7 @@ def test_vm_rpc(): # Use local rpc server for testing. # Server must use popen so it doesn't inherit the current process state. It # will crash otherwise. - server = rpc.Server("localhost", port=9120, use_popen=True) - time.sleep(2) + server = rpc.Server("localhost", port=9120) remote = rpc.connect(server.host, server.port, session_timeout=10) # Upload the serialized Executable. diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 766338de35580..9658ce1b2c1e7 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -262,7 +262,7 @@ def verify_rpc_gpu_export(obj_format): from tvm import rpc - server = rpc.Server("localhost", use_popen=True, port=9094) + server = rpc.Server("localhost", port=9094) remote = rpc.connect(server.host, server.port) remote.upload(path_lib) loaded_lib = remote.load_module(path_lib) diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index 256fd33387bf7..a74f893065b8d 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -85,21 +85,6 @@ def verify_rpc(remote, target, shape, dtype): verify_rpc(remote, target, (10,), dtype) -@tvm.register_func("rpc.test.addone") -def addone(x): - return x + 1 - - -@tvm.register_func("rpc.test.strcat") -def strcat(name, x): - return "%s:%d" % (name, x) - - -@tvm.register_func("rpc.test.except") -def remotethrow(name): - raise ValueError("%s" % name) - - @tvm.testing.requires_rpc def test_rpc_simple(): server = rpc.Server("localhost", key="x1") @@ -115,11 +100,6 @@ def test_rpc_simple(): assert f2("abc", 11) == "abc:11" -@tvm.register_func("rpc.test.runtime_str_concat") -def strcat(x, y): - return x + y - - @tvm.testing.requires_rpc def test_rpc_runtime_string(): server = rpc.Server("localhost", key="x1") @@ -130,12 +110,6 @@ def test_rpc_runtime_string(): assert str(func(x, y)) == "abcdef" -@tvm.register_func("rpc.test.remote_array_func") -def remote_array_func(y): - x = np.ones((3, 4)) - np.testing.assert_equal(y.asnumpy(), x) - - @tvm.testing.requires_rpc def test_rpc_array(): x = np.ones((3, 4)) @@ -342,16 +316,11 @@ def check_remote_link_cl(remote): check_minrpc() -@tvm.register_func("rpc.test.remote_func") -def addone(x): - return lambda y: x + y - - @tvm.testing.requires_rpc def test_rpc_return_func(): server = rpc.Server("localhost", key="x1") client = rpc.connect(server.host, server.port, key="x1") - f1 = client.get_function("rpc.test.remote_func") + f1 = client.get_function("rpc.test.add_to_lhs") fadd = f1(10) assert fadd(12) == 22 @@ -393,21 +362,6 @@ def check_error_handling(): check_error_handling() -@tvm.register_func("rpc.test.remote_return_nd") -def my_module(name): - # Use closure to check the ref counter correctness - nd = tvm.nd.array(np.zeros(10).astype("float32")) - - if name == "get_arr": - return lambda: nd - elif name == "ref_count": - return lambda: tvm.testing.object_use_count(nd) - elif name == "get_elem": - return lambda idx: nd.asnumpy()[idx] - elif name == "get_arr_elem": - return lambda arr, idx: arr.asnumpy()[idx] - - @tvm.testing.requires_rpc def test_rpc_return_ndarray(): # start server @@ -428,15 +382,10 @@ def run_arr_test(): run_arr_test() -@tvm.register_func("rpc.test.remote_func2") -def addone(x): - return lambda y: x + y - - @tvm.testing.requires_rpc def test_local_func(): client = rpc.LocalSession() - f1 = client.get_function("rpc.test.remote_func2") + f1 = client.get_function("rpc.test.add_to_lhs") fadd = f1(10) assert fadd(12) == 22 @@ -458,8 +407,8 @@ def test_rpc_tracker_register(): key=device_key, tracker_addr=(tracker.host, tracker.port), ) - time.sleep(1) client = rpc.connect_tracker(tracker.host, tracker.port) + time.sleep(1) summary = client.summary() assert summary["queue_info"][device_key]["free"] == 1