Skip to content

Commit

Permalink
SA: setting torch cuda device
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhamagarwal92 committed Mar 8, 2020
1 parent 2f17b2f commit 6d89505
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,12 @@ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode:

if isinstance(self.data_parallel_device_ids, list):
root_gpu = self.data_parallel_device_ids[0]

# set cuda device to root gpu
# related to https://github.com/PyTorchLightning/pytorch-lightning/issues/958
# Refer solution: https://github.com/pytorch/pytorch/issues/9871#issuecomment-408304190
root_device = torch.device("cuda", root_gpu)
torch.cuda.set_device(root_device)
else:
raise RuntimeError(
'Expected `data_parallel_device_ids` as a list, cannot determine root gpu.'
Expand Down

0 comments on commit 6d89505

Please sign in to comment.