Skip to content

Commit

Permalink
Add torch.serialization.skip_data context manager (pytorch#134504)
Browse files Browse the repository at this point in the history
## Semantic

The semantic is
(1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint).

```python
import torch
import torch.nn as nn

sd = nn.Linear(3, 5).state_dict()
with torch.serialization.skip_data():
    torch.save(sd, 'foo.pt')
print(torch.load('foo.pt', weights_only=True))
```

(2)  With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor)

```python
import torch
import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode

with FakeTensorMode():
    m = nn.Linear(3, 5, dtype=torch.float16, device='cuda')

sd = m.state_dict()
with torch.serialization.skip_data(materialize_fake_tensors=True):
    torch.save(sd, 'bla.pt')
print(torch.load('bla.pt', weights_only=True))
# OrderedDict([('weight', tensor([[0., 0., 0.],
#        [0., 0., 0.],
#        [0., 0., 0.],
#        [0., 0., 0.],
#        [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))])

```

## Follow Ups

- [ ] `torch.load` semantic for skip_data context manager
- [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass)

Differential Revision: [D62238610](https://our.internmc.facebook.com/intern/diff/D62238610)
Pull Request resolved: pytorch#134504
Approved by: https://github.com/albanD
  • Loading branch information
mikaylagawarecki authored and tolleybot committed Sep 14, 2024
1 parent 2c4dcea commit 472e8e1
Show file tree
Hide file tree
Showing 7 changed files with 270 additions and 26 deletions.
1 change: 1 addition & 0 deletions docs/source/notes/serialization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -398,3 +398,4 @@ The following utility functions are related to serialization:
.. autofunction:: clear_safe_globals
.. autofunction:: get_safe_globals
.. autoclass:: safe_globals
.. autoclass:: skip_data
12 changes: 11 additions & 1 deletion test/test_cpp_extensions_open_device_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def test_open_device_tensorlist_type_fallback(self):
# call _fused_adamw_ with undefined tensor.
self.module.fallback_with_undefined_tensor()

def test_open_device_numpy_serialization_map_location(self):
def test_open_device_numpy_serialization(self):
torch.utils.rename_privateuse1_backend("foo")
device = self.module.custom_device()
default_protocol = torch.serialization.DEFAULT_PROTOCOL
Expand All @@ -553,6 +553,7 @@ def test_open_device_numpy_serialization_map_location(self):
self.assertTrue(
rebuild_func is torch._utils._rebuild_device_tensor_from_numpy
)
# Test map_location
with TemporaryFileName() as f:
torch.save(sd, f)
with safe_globals(
Expand All @@ -569,6 +570,15 @@ def test_open_device_numpy_serialization_map_location(self):
sd_loaded = torch.load(f, map_location="cpu")
self.assertTrue(sd_loaded["x"].is_cpu)

# Test metadata_only
with TemporaryFileName() as f:
with self.assertRaisesRegex(
RuntimeError,
"Cannot serialize tensors on backends with no storage under skip_data context manager",
):
with torch.serialization.skip_data():
torch.save(sd, f)


if __name__ == "__main__":
common.run_tests()
88 changes: 88 additions & 0 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Owner(s): ["module: serialization"]

import contextlib
import copy
import gc
import gzip
Expand All @@ -19,6 +20,7 @@
from pathlib import Path

import torch
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensorConverter
from torch._utils import _rebuild_tensor
from torch._utils_internal import get_file_path_2
from torch.serialization import (
Expand All @@ -27,6 +29,7 @@
LoadEndianness,
safe_globals,
set_default_load_endianness,
skip_data,
SourceChangeWarning,
)
from torch.testing._internal.common_device_type import instantiate_device_type_tests
Expand Down Expand Up @@ -4212,6 +4215,91 @@ def test_filewriter_metadata_writing(self, filename):
sd_loaded_ref = torch.load(f)
self.assertEqual(sd_loaded, sd_loaded_ref)

@parametrize("materialize_fake", (True, False))
def test_skip_data_serialization(self, materialize_fake):
# Create one tensor that uses each of the paths in __reduce_ex__ that should work
t_device = "cuda" if torch.cuda.is_available() else "cpu"
t_v2 = torch.randn(2, 3, device=t_device)
t_v3 = torch.randn(2, 3, dtype=torch.complex32, device=t_device)
i = torch.tensor([[0, 1, 1],
[2, 0, 2]])
v = torch.tensor([3, 4, 5], dtype=torch.float32)
if not materialize_fake:
# FakeTensorConverter messes up sizes of i and v for the sparse tensor
st = torch.sparse_coo_tensor(i, v, (2, 4))
tt = TwoTensor(torch.randn(2, device=t_device), torch.randn(2, device=t_device))

mode, converter = FakeTensorMode(), FakeTensorConverter()

def fn(t):
return converter.from_real_tensor(mode, t) if materialize_fake else t

sd = {'t_v2': fn(t_v2), 't_v3': fn(t_v3), 'tt': fn(tt)}
sd_expected = {
't_v2': torch.zeros(2, 3, device=t_device),
't_v3': torch.zeros(2, 3, dtype=torch.complex32, device=t_device),
'tt': TwoTensor(torch.zeros(2, device=t_device), torch.zeros(2, device=t_device)),
}

if not materialize_fake:
sd['st'] = st
sd_expected['st'] = torch.sparse_coo_tensor(torch.zeros(2, 3), torch.zeros(3), (2, 4))

with BytesIOContext() as f:
with skip_data(materialize_fake_tensors=materialize_fake):
torch.save(sd, f)
f.seek(0)
with safe_globals([TwoTensor]):
sd_loaded = torch.load(f, weights_only=True)
self.assertEqual(sd_loaded, sd_expected, exact_device=True)
self.assertFalse(getattr(torch.serialization._serialization_tls, "materialize_fake_tensors", False))
self.assertFalse(getattr(torch.serialization._serialization_tls, "skip_data", False))

# Test that without materialize_fake_tensor, behavior for fake_tensors is not altered by ctx
if not materialize_fake:
ft = converter.from_real_tensor(mode, torch.randn(2, device=t_device))
with self.assertRaisesRegex(AttributeError, "Can't pickle local object 'WeakValueDictionary.__init__.<locals>.remove'"):
with skip_data(), BytesIOContext() as f:
torch.save(ft, f)

@parametrize("materialize_fake", (True, False))
def test_skip_data_serialization_preserves_views(self, materialize_fake):
ctx = FakeTensorMode if materialize_fake else contextlib.nullcontext
with ctx():
t = torch.randn(2, 3)
t_view = t.view(-1)
t_slice = t[1]
sd = {'t': t, 't_view': t_view, 't_slice': t_slice}
with BytesIOContext() as f:
with skip_data(materialize_fake_tensors=materialize_fake):
torch.save(sd, f)
f.seek(0)
sd_loaded = torch.load(f, weights_only=True)
self.assertTrue(id(sd_loaded['t_view'].untyped_storage()) == id(sd_loaded['t'].untyped_storage()))
self.assertTrue(id(sd_loaded['t_slice'].untyped_storage()) == id(sd_loaded['t'].untyped_storage()))

def test_skip_data_serialization_error_cases(self):
def _save_load(t):
with BytesIOContext() as f:
with skip_data():
torch.save(t, f)
f.seek(0)
torch.load(f, weights_only=True)

nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)])
t = torch.randn(2, 3, device="meta")
with self.assertRaisesRegex(RuntimeError, "Cannot serialize nested tensor under skip_data context manager"):
_save_load(nt)

with self.assertWarnsRegex(UserWarning, "meta device under skip_data context manager is a no-op"):
_save_load(t)

with self.assertRaisesRegex(RuntimeError, "Please call torch.load outside the skip_data context manager"):
with skip_data(), BytesIOContext() as f:
torch.save(torch.randn(2, 3), f)
f.seek(0)
torch.load(f, weights_only=True)

def run(self, *args, **kwargs):
with serialization_method(use_zip=True):
return super().run(*args, **kwargs)
Expand Down
75 changes: 67 additions & 8 deletions torch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,19 @@ def __deepcopy__(self, memo):
return new_tensor

def __reduce_ex__(self, proto):
materialize_fake_tensors = (
torch.serialization._serialization_tls.materialize_fake_tensors
)
state = torch._utils._get_obj_state(self)
if type(self) is Tensor and not state:
# Ignore all state when using FakeTensor with skip_data(materialize_fake_tensors) because FakeTensor has
# some state that cannot be pickled
if (
# TODO: remove hasattr, it's a hack to support versions of torch that
# don't have _subclasses
hasattr(torch, "_subclasses")
and type(self) is torch._subclasses.fake_tensor.FakeTensor
and materialize_fake_tensors
) or (type(self) is Tensor and not state):
# Fast path for regular tensor without Python state.
return self._reduce_ex_internal(proto)
if has_torch_function_unary(self):
Expand Down Expand Up @@ -251,6 +262,12 @@ def _reduce_ex_internal(self, proto):
# See Note [Don't serialize hooks]
warn_if_has_hooks(self)
backward_hooks: Dict[Any, Any] = OrderedDict()

skip_data = torch.serialization._serialization_tls.skip_data
materialize_fake_tensors = (
torch.serialization._serialization_tls.materialize_fake_tensors
)

# Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, MAIA Tensors.
# We considered a few options:
# 1. CPU tensor can't be used here.
Expand All @@ -268,6 +285,10 @@ def _reduce_ex_internal(self, proto):
# Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't
# support BFloat16. The rebuild tensor from numpy takes in the original self.dtype,
# this would reconstruct the BFloat16 tensor from numpy.
if skip_data:
raise RuntimeError(
"Cannot serialize tensors on backends with no storage under skip_data context manager"
)
numpy_tensor = (
self.cpu().numpy()
if self.dtype != torch.bfloat16
Expand All @@ -280,6 +301,10 @@ def _reduce_ex_internal(self, proto):
if self.device.type == "meta":
# NB: This implementation BREAKS storage sharing. Current
# hypothesis is that no one cares for meta tensors.
if skip_data:
warnings.warn(
"Serializing tensors on the meta device under skip_data context manager is a no-op"
)
arg_meta = (
self.dtype,
tuple(self.size()),
Expand All @@ -288,6 +313,10 @@ def _reduce_ex_internal(self, proto):
)
return (torch._utils._rebuild_meta_tensor_no_storage, arg_meta)
if self.is_quantized:
if skip_data:
raise RuntimeError(
"Cannot serialize qtensor under skip_data context manager, file an issue if you need this feature"
)
# quantizer_params can be different type based on torch attribute
quantizer_params: Union[
Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]
Expand Down Expand Up @@ -369,6 +398,10 @@ def _reduce_ex_internal(self, proto):
)
return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed)
elif self.is_nested:
if skip_data:
raise RuntimeError(
"Cannot serialize nested tensor under skip_data context manager, file an issue if you need this feature"
)
args_nested = (
# NB: values() currently returns the storage as a buffer in an unsafe way.
# Ideally, we'd use a private API for this instead. TODO: Switch to this if
Expand All @@ -383,14 +416,30 @@ def _reduce_ex_internal(self, proto):
type(self) is not torch.Tensor
and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__
and (
isinstance(
self,
(
torch._subclasses.fake_tensor.FakeTensor,
torch._subclasses.functional_tensor.FunctionalTensor,
),
isinstance(self, torch._subclasses.functional_tensor.FunctionalTensor)
or (
not isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
and self.data_ptr() == 0
)
or self.data_ptr() == 0
)
):
arg_wrapper_subclass = (
type(self),
self.dtype,
tuple(self.size()),
self.stride(),
self.storage_offset(),
self.layout,
self.device,
self.requires_grad,
)
return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass)
elif (
type(self) is not torch.Tensor
and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__
and (
isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
and not (skip_data and materialize_fake_tensors)
)
):
arg_wrapper_subclass = (
Expand Down Expand Up @@ -418,6 +467,16 @@ def _reduce_ex_internal(self, proto):
dtype=self.dtype,
_internal=True,
) # type: ignore[assignment]

# TODO: remove hasattr, it's a hack to support versions of torch that
# don't have _subclasses
if (
hasattr(torch, "_subclasses")
and isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
and skip_data
):
storage._fake_device = self.device

args = (
storage,
self.storage_offset(),
Expand Down
6 changes: 1 addition & 5 deletions torch/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import functools
import logging
import sys
import threading
import traceback
import warnings
from collections import defaultdict
Expand Down Expand Up @@ -109,16 +108,13 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
return kwargs["async"]


_thread_local_state = threading.local()


def _get_restore_location(device):
"""Return the map_location location.
Used for rebuild functions where the tensor device is distinct from the storage
"""

map_location = getattr(_thread_local_state, "map_location", None)
map_location = torch.serialization._serialization_tls.map_location
if map_location is None:
return device
else:
Expand Down
Loading

0 comments on commit 472e8e1

Please sign in to comment.