Skip to content

Commit

Permalink
added 'weights_only' param in torch.load examples (pytorch#112860)
Browse files Browse the repository at this point in the history
Fixes pytorch#111876

`torch.load` without setting `weights_only=True` is unsafe. So updating examples of `torch.load` to use `weights_only=True` where possible and `weights_only=False` elsewhere with a warning of being unsafety.

Pull Request resolved: pytorch#112860
Approved by: https://github.com/kit1980
  • Loading branch information
Viditagarwal7479 authored and Skylion007 committed Nov 14, 2023
1 parent 0fc2970 commit 16b87aa
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions torch/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,21 +951,23 @@ def load(
Example:
>>> # xdoctest: +SKIP("undefined filepaths")
>>> torch.load('tensors.pt')
>>> torch.load('tensors.pt', weights_only=True)
# Load all tensors onto the CPU
>>> torch.load('tensors.pt', map_location=torch.device('cpu'))
>>> torch.load('tensors.pt', map_location=torch.device('cpu'), weights_only=True)
# Load all tensors onto the CPU, using a function
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage, weights_only=True)
# Load all tensors onto GPU 1
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1), weights_only=True)
# Map tensors from GPU 1 to GPU 0
>>> torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'})
>>> torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'}, weights_only=True)
# Load tensor from io.BytesIO object
# Loading from a buffer setting weights_only=False, warning this can be unsafe
>>> with open('tensor.pt', 'rb') as f:
... buffer = io.BytesIO(f.read())
>>> torch.load(buffer)
>>> torch.load(buffer, weights_only=False)
# Load a module with 'ascii' encoding for unpickling
>>> torch.load('module.pt', encoding='ascii')
# Loading from a module setting weights_only=False, warning this can be unsafe
>>> torch.load('module.pt', encoding='ascii', weights_only=False)
"""
torch._C._log_api_usage_once("torch.load")
UNSAFE_MESSAGE = (
Expand Down

0 comments on commit 16b87aa

Please sign in to comment.