diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 0b797dff0e42f..ffbe508816403 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -16,6 +16,8 @@ import torch.distributed as dist import torch.multiprocessing as mp from torchmetrics import Metric +import numpy +import os import tests.helpers.utils as tutils from pytorch_lightning.core.step_result import Result @@ -96,6 +98,7 @@ def test_result_reduce_ddp(): tutils.set_random_master_port() worldsize = 2 + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize)