diff --git a/run_nerf.py b/run_nerf.py index bc270be86..51300f91c 100644 --- a/run_nerf.py +++ b/run_nerf.py @@ -873,6 +873,9 @@ def train(): if __name__=='__main__': - torch.set_default_tensor_type('torch.cuda.FloatTensor') + if torch.cuda.is_available(): + torch.set_default_tensor_type('torch.cuda.FloatTensor') + else: + torch.set_default_tensor_type('torch.FloatTensor') train()