Skip to content

Commit

Permalink
test_model_gpu: Use TF memory pool if available, feature-gate test (#…
Browse files Browse the repository at this point in the history
…688)

* `test_model_gpu`: Use TF memory pool if available, feature-gate test

* Fix typo

* `test_predict_extensive`: Disable test time monitoring

* Fix imports, use `has_cupy_gpu` for forward-compat

* `conftest`: Use `pytest_sessionstart` to enable TF GPU memory growth
  • Loading branch information
shadeMe authored Jun 8, 2022
1 parent 46334b5 commit 1c6e9f4
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
18 changes: 18 additions & 0 deletions thinc/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,24 @@
import pytest


def pytest_sessionstart(session):
# If Tensorflow is installed, attempt to enable memory growth
# to prevent it from allocating all of the GPU's free memory
# to its internal memory pool(s).
try:
import tensorflow as tf

physical_devices = tf.config.list_physical_devices("GPU")
for device in physical_devices:
try:
tf.config.experimental.set_memory_growth(device, True)
except:
# Invalid device or cannot modify virtual devices once initialized.
print(f"failed to enable Tensorflow memory growth on {device}")
except ImportError:
pass


def pytest_addoption(parser):
try:
parser.addoption("--slow", action="store_true", help="include slow tests")
Expand Down
1 change: 1 addition & 0 deletions thinc/tests/layers/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def test_predict_small(W_b_input):


@given(arrays_OI_O_BI(max_batch=20, max_out=30, max_in=30))
@settings(deadline=None)
def test_predict_extensive(W_b_input):
W, b, input_ = W_b_input
nr_out, nr_in = W.shape
Expand Down
7 changes: 2 additions & 5 deletions thinc/tests/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,15 +404,12 @@ def get_model_id(id_list, index):
assert len(list_of_ids) == len(list(set(list_of_ids)))


@pytest.mark.skipif(not has_cupy_gpu, reason="needs CuPy GPU")
def test_model_gpu():
pytest.importorskip("ml_datasets")
import ml_datasets

ops = "cpu"
if has_cupy_gpu:
ops = "cupy"

with use_ops(ops):
with use_ops("cupy"):
n_hidden = 32
dropout = 0.2
(train_X, train_Y), (dev_X, dev_Y) = ml_datasets.mnist()
Expand Down

0 comments on commit 1c6e9f4

Please sign in to comment.