From a5912013a6c76e6eda82c805c46354d85e308ace Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 31 Mar 2021 14:40:09 -0400 Subject: [PATCH] [PYTHON][RPC] Make rpc proxy jupyter friendly via PopenWorker. (#7757) * [PYTHON][RPC] Make rpc proxy jupyter friendly via PopenWorker. * Rework the contrib tests that was previous broken. --- python/tvm/contrib/popen_pool.py | 10 +- python/tvm/exec/popen_worker.py | 3 + python/tvm/exec/rpc_proxy.py | 24 +---- python/tvm/rpc/proxy.py | 125 +++++++++++++++++++------ tests/python/contrib/test_rpc_proxy.py | 17 +--- 5 files changed, 115 insertions(+), 64 deletions(-) diff --git a/python/tvm/contrib/popen_pool.py b/python/tvm/contrib/popen_pool.py index bca08622ac09..5a25484e9106 100644 --- a/python/tvm/contrib/popen_pool.py +++ b/python/tvm/contrib/popen_pool.py @@ -112,7 +112,10 @@ def kill(self): except IOError: pass # kill all child processes recurisvely - kill_child_processes(self._proc.pid) + try: + kill_child_processes(self._proc.pid) + except TypeError: + pass try: self._proc.kill() except OSError: @@ -149,6 +152,11 @@ def _start(self): self._reader = os.fdopen(main_read, "rb") self._writer = os.fdopen(main_write, "wb") + def join(self): + """Join the current process worker before it terminates""" + if self._proc: + self._proc.wait() + def send(self, fn, args=(), kwargs=None, timeout=None): """Send a new function task fn(*args, **kwargs) to the subprocess. diff --git a/python/tvm/exec/popen_worker.py b/python/tvm/exec/popen_worker.py index b62cca5cfce1..35571ac58de8 100644 --- a/python/tvm/exec/popen_worker.py +++ b/python/tvm/exec/popen_worker.py @@ -22,6 +22,7 @@ import threading import traceback import pickle +import logging import cloudpickle from tvm.contrib.popen_pool import StatusKind @@ -49,6 +50,8 @@ def main(): reader = os.fdopen(int(sys.argv[1]), "rb") writer = os.fdopen(int(sys.argv[2]), "wb") + logging.basicConfig(level=logging.INFO) + lock = threading.Lock() def _respond(ret_value): diff --git a/python/tvm/exec/rpc_proxy.py b/python/tvm/exec/rpc_proxy.py index 26625fb14c46..bf315fdb9087 100644 --- a/python/tvm/exec/rpc_proxy.py +++ b/python/tvm/exec/rpc_proxy.py @@ -16,14 +16,10 @@ # under the License. # pylint: disable=redefined-outer-name, invalid-name """RPC web proxy, allows redirect to websocket based RPC servers(browsers)""" -from __future__ import absolute_import - import logging import argparse -import multiprocessing -import sys import os -from ..rpc.proxy import Proxy +from tvm.rpc.proxy import Proxy def find_example_resource(): @@ -82,24 +78,6 @@ def main(args): "--example-rpc", type=bool, default=False, help="Whether to switch on example rpc mode" ) parser.add_argument("--tracker", type=str, default="", help="Report to RPC tracker") - parser.add_argument( - "--no-fork", - dest="fork", - action="store_false", - help="Use spawn mode to avoid fork. This option \ - is able to avoid potential fork problems with Metal, OpenCL \ - and ROCM compilers.", - ) - parser.set_defaults(fork=True) args = parser.parse_args() logging.basicConfig(level=logging.INFO) - if args.fork is False: - if sys.version_info[0] < 3: - raise RuntimeError("Python3 is required for spawn mode.") - multiprocessing.set_start_method("spawn") - else: - logging.info( - "If you are running ROCM/Metal, \ - fork with cause compiler internal error. Try to launch with arg ```--no-fork```" - ) main(args) diff --git a/python/tvm/rpc/proxy.py b/python/tvm/rpc/proxy.py index 2224d509e0d1..28117b09f280 100644 --- a/python/tvm/rpc/proxy.py +++ b/python/tvm/rpc/proxy.py @@ -22,12 +22,11 @@ the proxy server will forward the message between the client and server. """ # pylint: disable=unused-variable, unused-argument -from __future__ import absolute_import - import os +import asyncio import logging import socket -import multiprocessing +import threading import errno import struct import time @@ -43,6 +42,7 @@ "RPCProxy module requires tornado package %s. Try 'pip install tornado'." % error_msg ) +from tvm.contrib.popen_pool import PopenWorker from . import _ffi_api from . import base from .base import TrackerCode @@ -261,6 +261,7 @@ def __init__( logging.info(pair) self.app = tornado.web.Application(handlers) self.app.listen(web_port) + self.sock = sock self.sock.setblocking(0) self.loop = ioloop.IOLoop.current() @@ -471,6 +472,7 @@ def _proxy_server( index_page, resource_files, ): + asyncio.set_event_loop(asyncio.new_event_loop()) handler = ProxyServerHandler( listen_sock, listen_port, @@ -484,6 +486,87 @@ def _proxy_server( handler.run() +class PopenProxyServerState(object): + """Internal PopenProxy State for Popen""" + + current = None + + def __init__( + self, + host, + port=9091, + port_end=9199, + web_port=0, + timeout_client=600, + timeout_server=600, + tracker_addr=None, + index_page=None, + resource_files=None, + ): + + sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM) + self.port = None + for my_port in range(port, port_end): + try: + sock.bind((host, my_port)) + self.port = my_port + break + except socket.error as sock_err: + if sock_err.errno in [98, 48]: + continue + raise sock_err + if not self.port: + raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) + logging.info("RPCProxy: client port bind to %s:%d", host, self.port) + sock.listen(1) + self.thread = threading.Thread( + target=_proxy_server, + args=( + sock, + self.port, + web_port, + timeout_client, + timeout_server, + tracker_addr, + index_page, + resource_files, + ), + ) + # start the server in a different thread + # so we can return the port directly + self.thread.start() + + +def _popen_start_server( + host, + port=9091, + port_end=9199, + web_port=0, + timeout_client=600, + timeout_server=600, + tracker_addr=None, + index_page=None, + resource_files=None, +): + # This is a function that will be sent to the + # Popen worker to run on a separate process. + # Create and start the server in a different thread + state = PopenProxyServerState( + host, + port, + port_end, + web_port, + timeout_client, + timeout_server, + tracker_addr, + index_page, + resource_files, + ) + PopenProxyServerState.current = state + # returns the port so that the main can get the port number. + return state.port + + class Proxy(object): """Start RPC proxy server on a seperate process. @@ -532,43 +615,31 @@ def __init__( index_page=None, resource_files=None, ): - sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM) - self.port = None - for my_port in range(port, port_end): - try: - sock.bind((host, my_port)) - self.port = my_port - break - except socket.error as sock_err: - if sock_err.errno in [98, 48]: - continue - raise sock_err - if not self.port: - raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) - logging.info("RPCProxy: client port bind to %s:%d", host, self.port) - sock.listen(1) - self.proc = multiprocessing.Process( - target=_proxy_server, - args=( - sock, - self.port, + self.proc = PopenWorker() + # send the function + self.proc.send( + _popen_start_server, + [ + host, + port, + port_end, web_port, timeout_client, timeout_server, tracker_addr, index_page, resource_files, - ), + ], ) - self.proc.start() - sock.close() + # receive the port + self.port = self.proc.recv() self.host = host def terminate(self): """Terminate the server process""" if self.proc: logging.info("Terminating Proxy Server...") - self.proc.terminate() + self.proc.kill() self.proc = None def __del__(self): diff --git a/tests/python/contrib/test_rpc_proxy.py b/tests/python/contrib/test_rpc_proxy.py index 26d183185ae6..08da29b0af7b 100644 --- a/tests/python/contrib/test_rpc_proxy.py +++ b/tests/python/contrib/test_rpc_proxy.py @@ -38,20 +38,12 @@ def rpc_proxy_check(): from tvm.rpc import proxy web_port = 8888 - prox = proxy.Proxy("localhost", web_port=web_port) + prox = proxy.Proxy("127.0.0.1", web_port=web_port) def check(): if not tvm.runtime.enabled("rpc"): return - @tvm.register_func("rpc.test2.addone") - def addone(x): - return x + 1 - - @tvm.register_func("rpc.test2.strcat") - def addone(name, x): - return "%s:%d" % (name, x) - server = multiprocessing.Process( target=proxy.websocket_proxy_server, args=("ws://localhost:%d/ws" % web_port, "x1") ) @@ -60,10 +52,9 @@ def addone(name, x): server.deamon = True server.start() client = rpc.connect(prox.host, prox.port, key="x1") - f1 = client.get_function("rpc.test2.addone") - assert f1(10) == 11 - f2 = client.get_function("rpc.test2.strcat") - assert f2("abc", 11) == "abc:11" + f1 = client.get_function("testing.echo") + assert f1(10) == 10 + assert f1("xyz") == "xyz" check() except ImportError: