Skip to content

Commit

Permalink
[CONTRIB] Allow customized initializer in PopenPool (#8789)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuanjing Shi authored Aug 20, 2021
1 parent e691c7f commit d722c10
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 5 deletions.
34 changes: 31 additions & 3 deletions python/tvm/contrib/popen_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,22 @@ class PopenWorker:
PopenWorker provides a low-level
API to interact with a separate process via Popen.
Parameters
----------
initializer: callable or None
A callable initializer, or None
initargs: Tuple[object]
A tuple of args for the initializer
"""

def __init__(self):
def __init__(self, initializer=None, initargs=()):
self._proc = None
self._initializer = initializer
self._initargs = initargs
if self._initializer is not None and not callable(self._initializer):
raise TypeError("initializer must be callable for PopenWorker")

def __del__(self):
try:
Expand Down Expand Up @@ -203,6 +215,10 @@ def send(self, fn, args=(), kwargs=None, timeout=None):

if self._proc is None:
self._start()
# init
if self._initializer is not None:
self.send(self._initializer, self._initargs)
self.recv()
kwargs = {} if not kwargs else kwargs
data = cloudpickle.dumps((fn, args, kwargs, timeout), protocol=pickle.HIGHEST_PROTOCOL)
try:
Expand Down Expand Up @@ -269,21 +285,33 @@ class PopenPoolExecutor:
timeout : float
Timeout value for each function submit.
initializer: callable or None
A callable initializer, or None
initargs: Tuple[object]
A tuple of args for the initializer
Note
----
If max_workers is NONE then the number returned by
os.cpu_count() is used. This method aligns with the
behavior of multiprocessing.pool().
"""

def __init__(self, max_workers=None, timeout=None):
def __init__(self, max_workers=None, timeout=None, initializer=None, initargs=()):
if max_workers is None:
max_workers = os.cpu_count()
# Use an internal thread pool to send to popen workers
self._threadpool = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
self._timeout = timeout
self._worker_map = {}
self._lock = threading.Lock()
self._initializer = initializer
self._initargs = initargs

if self._initializer is not None and not callable(self._initializer):
raise TypeError("initializer must be callable for PopenPoolExecutor")

def __del__(self):
self._lock.acquire()
Expand All @@ -300,7 +328,7 @@ def _worker_run(self, fn, args, kwargs):
self._lock.acquire()
tid = threading.get_ident()
if tid not in self._worker_map:
proc = PopenWorker()
proc = PopenWorker(self._initializer, self._initargs)
self._worker_map[tid] = proc
else:
proc = self._worker_map[tid]
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@

from ._ffi_api import nop, echo, device_test, run_check_signal, object_use_count
from ._ffi_api import test_wrap_callback, test_raise_error_callback, test_check_eq_callback
from ._ffi_api import ErrorTest, FrontendTestModule
from ._ffi_api import ErrorTest, FrontendTestModule, identity_cpp

from .popen_pool import initializer, after_initializer, register_ffi, call_cpp_ffi
from .popen_pool import call_py_ffi, call_cpp_py_ffi

from . import auto_scheduler
59 changes: 59 additions & 0 deletions python/tvm/testing/popen_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, missing-function-docstring
"""Common functions for popen_pool test cases"""
import tvm

TEST_GLOBAL_STATE_1 = 0
TEST_GLOBAL_STATE_2 = 0
TEST_GLOBAL_STATE_3 = 0


def initializer(test_global_state_1, test_global_state_2, test_global_state_3):
global TEST_GLOBAL_STATE_1, TEST_GLOBAL_STATE_2, TEST_GLOBAL_STATE_3
TEST_GLOBAL_STATE_1 = test_global_state_1
TEST_GLOBAL_STATE_2 = test_global_state_2
TEST_GLOBAL_STATE_3 = test_global_state_3


def after_initializer():
global TEST_GLOBAL_STATE_1, TEST_GLOBAL_STATE_2, TEST_GLOBAL_STATE_3
return TEST_GLOBAL_STATE_1, TEST_GLOBAL_STATE_2, TEST_GLOBAL_STATE_3


@tvm._ffi.register_func("testing.identity_py")
def identity_py(arg):
return arg


def register_ffi():
@tvm._ffi.register_func("testing.nested_identity_py")
def _identity_py(arg): # pylint: disable=unused-variable
return arg


def call_py_ffi(arg):
_identity_py = tvm._ffi.get_global_func("testing.nested_identity_py")
return _identity_py(arg)


def call_cpp_ffi(arg):
return tvm.testing.echo(arg)


def call_cpp_py_ffi(arg):
return tvm.testing.identity_cpp(arg)
8 changes: 8 additions & 0 deletions src/support/ffi_testing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ TVM_REGISTER_GLOBAL("testing.run_check_signal").set_body_typed([](int nsec) {
LOG(INFO) << "Function finished without catching signal";
});

TVM_REGISTER_GLOBAL("testing.identity_cpp").set_body([](TVMArgs args, TVMRetValue* ret) {
const auto* identity_func = tvm::runtime::Registry::Get("testing.identity_py");
ICHECK(identity_func != nullptr)
<< "AttributeError: \"testing.identity_py\" is not registered. Please check "
"if the python module is properly loaded";
*ret = (*identity_func)(args[0]);
});

// in src/api_test.cc
void ErrorTest(int x, int y) {
// raise ValueError
Expand Down
42 changes: 41 additions & 1 deletion tests/python/contrib/test_popen_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,16 @@
import pytest
import time
from tvm.contrib.popen_pool import PopenWorker, PopenPoolExecutor
from tvm.testing import identity_after, terminate_self
from tvm.testing import (
identity_after,
terminate_self,
initializer,
after_initializer,
register_ffi,
call_py_ffi,
call_cpp_ffi,
call_cpp_py_ffi,
)


def test_popen_worker():
Expand Down Expand Up @@ -66,6 +75,37 @@ def test_popen_pool_executor():
assert val.value == idx


def test_popen_initializer():
initargs = [1, 2, 3]
proc = PopenWorker(initializer=initializer, initargs=initargs)
proc.send(after_initializer)
test_global_state_1, test_global_state_2, test_global_state_3 = proc.recv()
assert test_global_state_1 == initargs[0]
assert test_global_state_2 == initargs[1]
assert test_global_state_3 == initargs[2]


def test_popen_ffi():
proc = PopenWorker(register_ffi)

# call python function via ffi
initargs = [0]
proc.send(call_py_ffi, initargs)
assert proc.recv() == initargs[0]

# call cpp function via ffi
initargs = [1]
proc.send(call_cpp_ffi, initargs)
assert proc.recv() == initargs[0]

# call python function from cpp function via ffi
initargs = [2]
proc.send(call_cpp_py_ffi, initargs)
assert proc.recv() == initargs[0]


if __name__ == "__main__":
test_popen_worker()
test_popen_pool_executor()
test_popen_initializer()
test_popen_ffi()

0 comments on commit d722c10

Please sign in to comment.