From 83c291dbc8bbc29c8666dc072ca1e1108cfb752c Mon Sep 17 00:00:00 2001 From: Shubham Agarwal Date: Mon, 9 Mar 2020 09:30:57 +0000 Subject: [PATCH] check if root gpu exists or available --- pytorch_lightning/trainer/distrib_parts.py | 3 ++- pytorch_lightning/trainer/evaluation_loop.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index c318df9a2863b..39cdf88b1f600 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -643,7 +643,8 @@ def determine_root_gpu_device(gpus): # set cuda device to root gpu # related to https://github.com/PyTorchLightning/pytorch-lightning/issues/958 # Refer solution: https://github.com/pytorch/pytorch/issues/9871#issuecomment-408304190 - root_device = torch.device("cuda", root_gpu) + # root_device = torch.device("cuda", root_gpu) + root_device = (torch.device("cuda", root_gpu) if root_gpu >= 0 else torch.device("cpu")) torch.cuda.set_device(root_device) return root_gpu diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index da0c16d0e4f52..91037753b7637 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -428,7 +428,8 @@ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: # set cuda device to root gpu # related to https://github.com/PyTorchLightning/pytorch-lightning/issues/958 # Refer: https://github.com/pytorch/pytorch/issues/9871#issuecomment-408304190 - root_device = torch.device("cuda", root_gpu) + root_device = (torch.device("cuda", root_gpu) + if root_gpu >= 0 else torch.device("cpu")) torch.cuda.set_device(root_device) else: raise RuntimeError(