diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index c318df9a2863b..39cdf88b1f600 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -643,7 +643,8 @@ def determine_root_gpu_device(gpus): # set cuda device to root gpu # related to https://github.com/PyTorchLightning/pytorch-lightning/issues/958 # Refer solution: https://github.com/pytorch/pytorch/issues/9871#issuecomment-408304190 - root_device = torch.device("cuda", root_gpu) + # root_device = torch.device("cuda", root_gpu) + root_device = (torch.device("cuda", root_gpu) if root_gpu >= 0 else torch.device("cpu")) torch.cuda.set_device(root_device) return root_gpu diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index da0c16d0e4f52..91037753b7637 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -428,7 +428,8 @@ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: # set cuda device to root gpu # related to https://github.com/PyTorchLightning/pytorch-lightning/issues/958 # Refer: https://github.com/pytorch/pytorch/issues/9871#issuecomment-408304190 - root_device = torch.device("cuda", root_gpu) + root_device = (torch.device("cuda", root_gpu) + if root_gpu >= 0 else torch.device("cpu")) torch.cuda.set_device(root_device) else: raise RuntimeError(