Skip to content

Commit

Permalink
fix small bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
kaylode committed Mar 9, 2022
1 parent 8f08c1d commit a031247
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions theseus/utilities/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
import torch
from typing import Any
from loggers.observer import LoggerObserver
from theseus.utilities.loggers.observer import LoggerObserver

LOGGER = LoggerObserver.getLogger('main')

Expand Down Expand Up @@ -37,7 +37,7 @@ def move_to(obj: Any, device: torch.device):
Returns:
type(obj) -- same object but moved to specified device
"""
if torch.is_tensor(obj):
if torch.is_tensor(obj) or isinstance(obj, torch.nn.Module):
return obj.to(device)
elif isinstance(obj, dict):
res = {k: move_to(v, device) for k, v in obj.items()}
Expand Down

0 comments on commit a031247

Please sign in to comment.