diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 0b49b675d77d..52a7a898269a 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -365,9 +365,14 @@ def _popen_start_rpc_server( custom_addr=None, silent=False, no_fork=False, + server_init_callback=None, ): if no_fork: multiprocessing.set_start_method("spawn") + + if server_init_callback: + server_init_callback() + # This is a function that will be sent to the # Popen worker to run on a separate process. # Create and start the server in a different thread @@ -420,6 +425,25 @@ class Server(object): no_fork: bool, optional Whether forbid fork in multiprocessing. + + server_init_callback: Callable, optional + Additional initialization function when starting the server. + + Note + ---- + The RPC server only sees functions in the tvm namespace. + To bring additional custom functions to the server env, you can use server_init_callback. + + .. code:: python + + def server_init_callback(): + import tvm + # must import mypackage here + import mypackage + + tvm.register_func("function", mypackage.func) + + server = rpc.Server(host, server_init_callback=server_init_callback) """ def __init__( @@ -434,6 +458,7 @@ def __init__( custom_addr=None, silent=False, no_fork=False, + server_init_callback=None, ): try: if _ffi_api.ServerLoop is None: @@ -455,6 +480,7 @@ def __init__( custom_addr, silent, no_fork, + server_init_callback, ], ) # receive the port diff --git a/vta/python/vta/exec/rpc_server.py b/vta/python/vta/exec/rpc_server.py index b7a9c79392d2..dcf564dd0314 100644 --- a/vta/python/vta/exec/rpc_server.py +++ b/vta/python/vta/exec/rpc_server.py @@ -34,7 +34,6 @@ from ..libinfo import find_libvta -@tvm.register_func("tvm.rpc.server.start", override=True) def server_start(): """VTA RPC server extension.""" # pylint: disable=unused-variable @@ -148,8 +147,21 @@ def main(): else: tracker_addr = None + # register the initialization callback + def server_init_callback(): + # pylint: disable=redefined-outer-name, reimported, import-outside-toplevel, import-self + import tvm + import vta.exec.rpc_server + + tvm.register_func("tvm.rpc.server.start", vta.exec.rpc_server.server_start, override=True) + server = rpc.Server( - args.host, args.port, args.port_end, key=args.key, tracker_addr=tracker_addr + args.host, + args.port, + args.port_end, + key=args.key, + tracker_addr=tracker_addr, + server_init_callback=server_init_callback, ) server.proc.join()