diff --git a/docs/source/notes/serialization.rst b/docs/source/notes/serialization.rst index 5541d28bdcafa..c05dc028a471c 100644 --- a/docs/source/notes/serialization.rst +++ b/docs/source/notes/serialization.rst @@ -398,3 +398,4 @@ The following utility functions are related to serialization: .. autofunction:: clear_safe_globals .. autofunction:: get_safe_globals .. autoclass:: safe_globals +.. autoclass:: skip_data diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py index 616a6e0f4b551..4e86ed458b078 100644 --- a/test/test_cpp_extensions_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -540,7 +540,7 @@ def test_open_device_tensorlist_type_fallback(self): # call _fused_adamw_ with undefined tensor. self.module.fallback_with_undefined_tensor() - def test_open_device_numpy_serialization_map_location(self): + def test_open_device_numpy_serialization(self): torch.utils.rename_privateuse1_backend("foo") device = self.module.custom_device() default_protocol = torch.serialization.DEFAULT_PROTOCOL @@ -553,6 +553,7 @@ def test_open_device_numpy_serialization_map_location(self): self.assertTrue( rebuild_func is torch._utils._rebuild_device_tensor_from_numpy ) + # Test map_location with TemporaryFileName() as f: torch.save(sd, f) with safe_globals( @@ -569,6 +570,15 @@ def test_open_device_numpy_serialization_map_location(self): sd_loaded = torch.load(f, map_location="cpu") self.assertTrue(sd_loaded["x"].is_cpu) + # Test metadata_only + with TemporaryFileName() as f: + with self.assertRaisesRegex( + RuntimeError, + "Cannot serialize tensors on backends with no storage under skip_data context manager", + ): + with torch.serialization.skip_data(): + torch.save(sd, f) + if __name__ == "__main__": common.run_tests() diff --git a/test/test_serialization.py b/test/test_serialization.py index a041473d195b1..3ba96b80541d8 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -1,5 +1,6 @@ # Owner(s): ["module: serialization"] +import contextlib import copy import gc import gzip @@ -19,6 +20,7 @@ from pathlib import Path import torch +from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensorConverter from torch._utils import _rebuild_tensor from torch._utils_internal import get_file_path_2 from torch.serialization import ( @@ -27,6 +29,7 @@ LoadEndianness, safe_globals, set_default_load_endianness, + skip_data, SourceChangeWarning, ) from torch.testing._internal.common_device_type import instantiate_device_type_tests @@ -4212,6 +4215,91 @@ def test_filewriter_metadata_writing(self, filename): sd_loaded_ref = torch.load(f) self.assertEqual(sd_loaded, sd_loaded_ref) + @parametrize("materialize_fake", (True, False)) + def test_skip_data_serialization(self, materialize_fake): + # Create one tensor that uses each of the paths in __reduce_ex__ that should work + t_device = "cuda" if torch.cuda.is_available() else "cpu" + t_v2 = torch.randn(2, 3, device=t_device) + t_v3 = torch.randn(2, 3, dtype=torch.complex32, device=t_device) + i = torch.tensor([[0, 1, 1], + [2, 0, 2]]) + v = torch.tensor([3, 4, 5], dtype=torch.float32) + if not materialize_fake: + # FakeTensorConverter messes up sizes of i and v for the sparse tensor + st = torch.sparse_coo_tensor(i, v, (2, 4)) + tt = TwoTensor(torch.randn(2, device=t_device), torch.randn(2, device=t_device)) + + mode, converter = FakeTensorMode(), FakeTensorConverter() + + def fn(t): + return converter.from_real_tensor(mode, t) if materialize_fake else t + + sd = {'t_v2': fn(t_v2), 't_v3': fn(t_v3), 'tt': fn(tt)} + sd_expected = { + 't_v2': torch.zeros(2, 3, device=t_device), + 't_v3': torch.zeros(2, 3, dtype=torch.complex32, device=t_device), + 'tt': TwoTensor(torch.zeros(2, device=t_device), torch.zeros(2, device=t_device)), + } + + if not materialize_fake: + sd['st'] = st + sd_expected['st'] = torch.sparse_coo_tensor(torch.zeros(2, 3), torch.zeros(3), (2, 4)) + + with BytesIOContext() as f: + with skip_data(materialize_fake_tensors=materialize_fake): + torch.save(sd, f) + f.seek(0) + with safe_globals([TwoTensor]): + sd_loaded = torch.load(f, weights_only=True) + self.assertEqual(sd_loaded, sd_expected, exact_device=True) + self.assertFalse(getattr(torch.serialization._serialization_tls, "materialize_fake_tensors", False)) + self.assertFalse(getattr(torch.serialization._serialization_tls, "skip_data", False)) + + # Test that without materialize_fake_tensor, behavior for fake_tensors is not altered by ctx + if not materialize_fake: + ft = converter.from_real_tensor(mode, torch.randn(2, device=t_device)) + with self.assertRaisesRegex(AttributeError, "Can't pickle local object 'WeakValueDictionary.__init__..remove'"): + with skip_data(), BytesIOContext() as f: + torch.save(ft, f) + + @parametrize("materialize_fake", (True, False)) + def test_skip_data_serialization_preserves_views(self, materialize_fake): + ctx = FakeTensorMode if materialize_fake else contextlib.nullcontext + with ctx(): + t = torch.randn(2, 3) + t_view = t.view(-1) + t_slice = t[1] + sd = {'t': t, 't_view': t_view, 't_slice': t_slice} + with BytesIOContext() as f: + with skip_data(materialize_fake_tensors=materialize_fake): + torch.save(sd, f) + f.seek(0) + sd_loaded = torch.load(f, weights_only=True) + self.assertTrue(id(sd_loaded['t_view'].untyped_storage()) == id(sd_loaded['t'].untyped_storage())) + self.assertTrue(id(sd_loaded['t_slice'].untyped_storage()) == id(sd_loaded['t'].untyped_storage())) + + def test_skip_data_serialization_error_cases(self): + def _save_load(t): + with BytesIOContext() as f: + with skip_data(): + torch.save(t, f) + f.seek(0) + torch.load(f, weights_only=True) + + nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)]) + t = torch.randn(2, 3, device="meta") + with self.assertRaisesRegex(RuntimeError, "Cannot serialize nested tensor under skip_data context manager"): + _save_load(nt) + + with self.assertWarnsRegex(UserWarning, "meta device under skip_data context manager is a no-op"): + _save_load(t) + + with self.assertRaisesRegex(RuntimeError, "Please call torch.load outside the skip_data context manager"): + with skip_data(), BytesIOContext() as f: + torch.save(torch.randn(2, 3), f) + f.seek(0) + torch.load(f, weights_only=True) + def run(self, *args, **kwargs): with serialization_method(use_zip=True): return super().run(*args, **kwargs) diff --git a/torch/_tensor.py b/torch/_tensor.py index 98563aebae9aa..7d8010081abc7 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -209,8 +209,19 @@ def __deepcopy__(self, memo): return new_tensor def __reduce_ex__(self, proto): + materialize_fake_tensors = ( + torch.serialization._serialization_tls.materialize_fake_tensors + ) state = torch._utils._get_obj_state(self) - if type(self) is Tensor and not state: + # Ignore all state when using FakeTensor with skip_data(materialize_fake_tensors) because FakeTensor has + # some state that cannot be pickled + if ( + # TODO: remove hasattr, it's a hack to support versions of torch that + # don't have _subclasses + hasattr(torch, "_subclasses") + and type(self) is torch._subclasses.fake_tensor.FakeTensor + and materialize_fake_tensors + ) or (type(self) is Tensor and not state): # Fast path for regular tensor without Python state. return self._reduce_ex_internal(proto) if has_torch_function_unary(self): @@ -251,6 +262,12 @@ def _reduce_ex_internal(self, proto): # See Note [Don't serialize hooks] warn_if_has_hooks(self) backward_hooks: Dict[Any, Any] = OrderedDict() + + skip_data = torch.serialization._serialization_tls.skip_data + materialize_fake_tensors = ( + torch.serialization._serialization_tls.materialize_fake_tensors + ) + # Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, MAIA Tensors. # We considered a few options: # 1. CPU tensor can't be used here. @@ -268,6 +285,10 @@ def _reduce_ex_internal(self, proto): # Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't # support BFloat16. The rebuild tensor from numpy takes in the original self.dtype, # this would reconstruct the BFloat16 tensor from numpy. + if skip_data: + raise RuntimeError( + "Cannot serialize tensors on backends with no storage under skip_data context manager" + ) numpy_tensor = ( self.cpu().numpy() if self.dtype != torch.bfloat16 @@ -280,6 +301,10 @@ def _reduce_ex_internal(self, proto): if self.device.type == "meta": # NB: This implementation BREAKS storage sharing. Current # hypothesis is that no one cares for meta tensors. + if skip_data: + warnings.warn( + "Serializing tensors on the meta device under skip_data context manager is a no-op" + ) arg_meta = ( self.dtype, tuple(self.size()), @@ -288,6 +313,10 @@ def _reduce_ex_internal(self, proto): ) return (torch._utils._rebuild_meta_tensor_no_storage, arg_meta) if self.is_quantized: + if skip_data: + raise RuntimeError( + "Cannot serialize qtensor under skip_data context manager, file an issue if you need this feature" + ) # quantizer_params can be different type based on torch attribute quantizer_params: Union[ Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int] @@ -369,6 +398,10 @@ def _reduce_ex_internal(self, proto): ) return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed) elif self.is_nested: + if skip_data: + raise RuntimeError( + "Cannot serialize nested tensor under skip_data context manager, file an issue if you need this feature" + ) args_nested = ( # NB: values() currently returns the storage as a buffer in an unsafe way. # Ideally, we'd use a private API for this instead. TODO: Switch to this if @@ -383,14 +416,30 @@ def _reduce_ex_internal(self, proto): type(self) is not torch.Tensor and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__ and ( - isinstance( - self, - ( - torch._subclasses.fake_tensor.FakeTensor, - torch._subclasses.functional_tensor.FunctionalTensor, - ), + isinstance(self, torch._subclasses.functional_tensor.FunctionalTensor) + or ( + not isinstance(self, torch._subclasses.fake_tensor.FakeTensor) + and self.data_ptr() == 0 ) - or self.data_ptr() == 0 + ) + ): + arg_wrapper_subclass = ( + type(self), + self.dtype, + tuple(self.size()), + self.stride(), + self.storage_offset(), + self.layout, + self.device, + self.requires_grad, + ) + return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass) + elif ( + type(self) is not torch.Tensor + and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__ + and ( + isinstance(self, torch._subclasses.fake_tensor.FakeTensor) + and not (skip_data and materialize_fake_tensors) ) ): arg_wrapper_subclass = ( @@ -418,6 +467,16 @@ def _reduce_ex_internal(self, proto): dtype=self.dtype, _internal=True, ) # type: ignore[assignment] + + # TODO: remove hasattr, it's a hack to support versions of torch that + # don't have _subclasses + if ( + hasattr(torch, "_subclasses") + and isinstance(self, torch._subclasses.fake_tensor.FakeTensor) + and skip_data + ): + storage._fake_device = self.device + args = ( storage, self.storage_offset(), diff --git a/torch/_utils.py b/torch/_utils.py index 938392fa97159..f0d38daa81149 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -3,7 +3,6 @@ import functools import logging import sys -import threading import traceback import warnings from collections import defaultdict @@ -109,16 +108,13 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs): return kwargs["async"] -_thread_local_state = threading.local() - - def _get_restore_location(device): """Return the map_location location. Used for rebuild functions where the tensor device is distinct from the storage """ - map_location = getattr(_thread_local_state, "map_location", None) + map_location = torch.serialization._serialization_tls.map_location if map_location is None: return device else: diff --git a/torch/serialization.py b/torch/serialization.py index 2ac36ec371fe5..d936d31d6f520 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -11,6 +11,7 @@ import sys import tarfile import tempfile +import threading import warnings from contextlib import closing, contextmanager from enum import Enum @@ -60,6 +61,7 @@ "get_safe_globals", "add_safe_globals", "safe_globals", + "skip_data", ] @@ -87,6 +89,22 @@ MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment] +# _serialization_tls is used to store thread local state specific to serialization +# that needs to be propagated to other files, in particular we use this for +# (1) map_location (needed for wrapper subclasses/third party devices to torch._utils) +# (2) skip_data (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) +# (3) materialize_fake_tensors (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) +class _SerializationLocal(threading.local): + def __init__(self): + super().__init__() + self.map_location: Optional[MAP_LOCATION] = None + self.skip_data: bool = False + self.materialize_fake_tensors: bool = False + + +_serialization_tls = _SerializationLocal() + + class SourceChangeWarning(Warning): pass @@ -268,6 +286,47 @@ class safe_globals(_weights_only_unpickler._safe_globals): """ +class skip_data: + """ + Context-manager that skips writing storage bytes for ``torch.save`` calls. + + Storages will still be saved, but the space that their bytes would usually be written to + will be empty space. The storage bytes can then be populated in a separate pass. + + .. warning:: + The ``skip_data`` context manager is an early prototype and is subject to change. + + Args: + materialize_fake_tensors: Whether to materialize FakeTensors. + + Example: + >>> # xdoctest: +SKIP("NamedTemporaryFile on Windows") + >>> import tempfile + >>> t = torch.randn(2, 3) + >>> with tempfile.NamedTemporaryFile() as f: + ... with torch.serialization.skip_data(): + ... torch.save(t, f.name) + ... torch.load(f.name, weights_only=True) + tensor([[0., 0., 0.], + [0., 0., 0.]]) + """ + + def __init__(self, materialize_fake_tensors: bool = False): + self.materialize_fake_tensors = materialize_fake_tensors + + def __enter__(self): + global _serialization_tls + self._old_skip_data = _serialization_tls.skip_data + self._old_materialize_fake_tensors = _serialization_tls.materialize_fake_tensors + _serialization_tls.skip_data = True + _serialization_tls.materialize_fake_tensors = self.materialize_fake_tensors + + def __exit__(self, type, value, tb): + global _serialization_tls + _serialization_tls.skip_data = self._old_skip_data + _serialization_tls.materialize_fake_tensors = self._old_materialize_fake_tensors + + def _is_zipfile(f) -> bool: # This is a stricter implementation than zipfile.is_zipfile(). # zipfile.is_zipfile() is True if the magic number appears anywhere in the @@ -797,6 +856,11 @@ def save( ) return else: + global _serialization_tls + if _serialization_tls.skip_data: + raise RuntimeError( + "Cannot use skip_data=True with _use_new_zipfile_serialization=False" + ) with _open_file_like(f, "wb") as opened_file: _legacy_save(obj, opened_file, pickle_module, pickle_protocol) @@ -955,7 +1019,13 @@ def persistent_id(obj: Any) -> Optional[Tuple]: ) -def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record): +def _save( + obj, + zip_file, + pickle_module, + pickle_protocol, + _disable_byteorder_record, +): serialized_storages = {} id_map: Dict[int, str] = {} @@ -990,7 +1060,7 @@ def persistent_id(obj): # If storage is allocated, ensure that any other saved storages # pointing to the same data all have the same dtype. If storage is # not allocated, don't perform this check - if storage.data_ptr() != 0: + if str(storage.device) != "meta" and storage.data_ptr() != 0: if storage.data_ptr() in storage_dtypes: if storage_dtype != storage_dtypes[storage.data_ptr()]: raise RuntimeError( @@ -1001,7 +1071,10 @@ def persistent_id(obj): storage_dtypes[storage.data_ptr()] = storage_dtype storage_key = id_map.setdefault(storage._cdata, str(len(id_map))) - location = location_tag(storage) + if hasattr(obj, "_fake_device") and obj._fake_device is not None: + location = str(obj._fake_device) + else: + location = location_tag(storage) serialized_storages[storage_key] = storage return ("storage", storage_type, storage_key, location, storage_numel) @@ -1027,14 +1100,18 @@ def persistent_id(obj): for key in sorted(serialized_storages.keys()): name = f"data/{key}" storage = serialized_storages[key] - # given that we copy things around anyway, we might use storage.cpu() - # this means to that to get tensors serialized, you need to implement - # .cpu() on the underlying Storage - if storage.device.type != "cpu": - storage = storage.cpu() - # Now that it is on the CPU we can directly copy it into the zip file num_bytes = storage.nbytes() - zip_file.write_record(name, storage, num_bytes) + global _serialization_tls + if _serialization_tls.skip_data: + zip_file.write_record_metadata(name, num_bytes) + else: + # given that we copy things around anyway, we might use storage.cpu() + # this means to that to get tensors serialized, you need to implement + # .cpu() on the underlying Storage + if storage.device.type != "cpu": + storage = storage.cpu() + # Now that it is on the CPU we can directly copy it into the zip file + zip_file.write_record(name, storage, num_bytes) def load( @@ -1184,6 +1261,14 @@ def _get_wo_message(message: str) -> str: updated_message += message return updated_message + DOCS_MESSAGE + global _serialization_tls + skip_data = _serialization_tls.skip_data + if skip_data: + raise RuntimeError( + "`torch.load` called within a torch.serialization.skip_data context manager " + "is not supported yet. Please call torch.load outside the skip_data context manager." + ) + if weights_only is None: weights_only, warn_weights_only = False, True else: @@ -1758,9 +1843,10 @@ def find_class(self, mod_name, name): unpickler.persistent_load = persistent_load # Needed for tensors where storage device and rebuild tensor device are # not connected (wrapper subclasses and tensors rebuilt using numpy) - torch._utils._thread_local_state.map_location = map_location + global _serialization_tls + _serialization_tls.map_location = map_location result = unpickler.load() - del torch._utils._thread_local_state.map_location + _serialization_tls.map_location = None torch._utils._validate_loaded_sparse_tensors() torch._C._log_api_usage_metadata( diff --git a/torch/storage.py b/torch/storage.py index b6ba608c16e5c..8848649905f93 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -39,6 +39,8 @@ class _StorageBase: is_sparse: _bool = False is_sparse_csr: _bool = False device: torch.device + # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) + _fake_device: _Optional[torch.device] = None def __init__(self, *args, **kwargs): pass @@ -649,6 +651,8 @@ def _get_device_from_module(module: str): class TypedStorage: is_sparse: _bool = False + # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) + _fake_device: _Optional[torch.device] = None dtype: torch.dtype