From d0d4ed62b591241319c8e6f74e84cc261a00536f Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sat, 1 Jul 2023 01:03:31 +0100 Subject: [PATCH] :white_check_mark: Reduce run time for Mapde test (#627) - Reduce run time for Mapde test. --- tests/models/test_arch_mapde.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/models/test_arch_mapde.py b/tests/models/test_arch_mapde.py index 49a144343..e4b8d30a7 100644 --- a/tests/models/test_arch_mapde.py +++ b/tests/models/test_arch_mapde.py @@ -2,17 +2,20 @@ import numpy as np import torch -from tiatoolbox import utils from tiatoolbox.models import MapDe from tiatoolbox.models.architecture import fetch_pretrained_weights +from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import WSIReader +ON_GPU = toolbox_env.has_gpu() + def _load_mapde(tmp_path, name): """Loads MapDe model with specified weights.""" model = MapDe() fetch_pretrained_weights(name, f"{tmp_path}/weights.pth") - map_location = utils.misc.select_device(utils.env_detection.has_gpu()) + map_location = select_device(ON_GPU) pretrained = torch.load(f"{tmp_path}/weights.pth", map_location=map_location) model.load_state_dict(pretrained) @@ -34,14 +37,10 @@ def test_functionality(remote_sample, tmp_path): (0, 0, 252, 252), resolution=0.50, units="mpp", coord_space="resolution" ) - model = _load_mapde(tmp_path=tmp_path, name="mapde-crchisto") + model = _load_mapde(tmp_path=tmp_path, name="mapde-conic") patch = model.preproc(patch) batch = torch.from_numpy(patch)[None] - output = model.infer_batch(model, batch, on_gpu=False) - output = model.postproc(output[0]) - assert np.all(output[0:2] == [[99, 178], [64, 218]]) - - model = _load_mapde(tmp_path=tmp_path, name="mapde-conic") - output = model.infer_batch(model, batch, on_gpu=False) + model = model.to(select_device(ON_GPU)) + output = model.infer_batch(model, batch, on_gpu=ON_GPU) output = model.postproc(output[0]) assert np.all(output[0:2] == [[19, 171], [53, 89]])