Skip to content

Commit

Permalink
check if root gpu exists or available
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhamagarwal92 committed Mar 9, 2020
1 parent 54e9a5e commit 83c291d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,8 @@ def determine_root_gpu_device(gpus):
# set cuda device to root gpu
# related to https://github.com/PyTorchLightning/pytorch-lightning/issues/958
# Refer solution: https://github.com/pytorch/pytorch/issues/9871#issuecomment-408304190
root_device = torch.device("cuda", root_gpu)
# root_device = torch.device("cuda", root_gpu)
root_device = (torch.device("cuda", root_gpu) if root_gpu >= 0 else torch.device("cpu"))
torch.cuda.set_device(root_device)

return root_gpu
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,8 @@ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode:
# set cuda device to root gpu
# related to https://github.com/PyTorchLightning/pytorch-lightning/issues/958
# Refer: https://github.com/pytorch/pytorch/issues/9871#issuecomment-408304190
root_device = torch.device("cuda", root_gpu)
root_device = (torch.device("cuda", root_gpu)
if root_gpu >= 0 else torch.device("cpu"))
torch.cuda.set_device(root_device)
else:
raise RuntimeError(
Expand Down

0 comments on commit 83c291d

Please sign in to comment.