Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix of issue 600 #625

Merged
merged 1 commit into from
Dec 15, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,7 @@ def val_dataloader(self):
return None

@classmethod
def load_from_metrics(cls, weights_path, tags_csv):
def load_from_metrics(cls, weights_path, tags_csv, map_location=None):
"""Primary way of loading model from csv weights path.

:param str weights_path: Path to a PyTorch checkpoint
Expand Down Expand Up @@ -975,9 +975,10 @@ def load_from_metrics(cls, weights_path, tags_csv):
hparams = load_hparams_from_tags_csv(tags_csv)
hparams.__setattr__('on_gpu', False)

# load on CPU only to avoid OOM issues
# then its up to user to put back on GPUs
checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)
if map_location is not None:
checkpoint = torch.load(weights_path, map_location=map_location)
else:
checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)

# load the state_dict on the model automatically
model = cls(hparams)
Expand All @@ -989,17 +990,19 @@ def load_from_metrics(cls, weights_path, tags_csv):
return model

@classmethod
def load_from_checkpoint(cls, checkpoint_path):
def load_from_checkpoint(cls, checkpoint_path, map_location=None):
"""
Primary way of loading model from a checkpoint
:param checkpoint_path:
:param map_location: dic for mapping storage {'cuda:1':'cuda:0'}
:return:
"""

# load on CPU only to avoid OOM issues
# then its up to user to put back on GPUs
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
if map_location is not None:
checkpoint = torch.load(checkpoint_path, map_location=map_location)
else:
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)

try:
ckpt_hparams = checkpoint['hparams']
except KeyError:
Expand Down