From a2cd158851f1abc90c56b237918125d94091934a Mon Sep 17 00:00:00 2001 From: Eta Date: Thu, 30 Nov 2023 00:06:36 -0600 Subject: [PATCH] feat(crypt): Add fast encryption --- .github/workflows/test.yml | 3 + CHANGELOG.md | 20 + examples/encryption.py | 92 +++ pyproject.toml | 2 +- requirements.txt | 3 +- tensorizer/_NumpyTensor.py | 85 ++- tensorizer/_crypt_info.py | 345 +++++++++ tensorizer/_internal_utils.py | 71 ++ tensorizer/_version.py | 2 +- tensorizer/serialization.py | 1304 +++++++++++++++++++++++++-------- tensorizer/stream_io.py | 100 +-- tests/__init__.py | 0 tests/test_serialization.py | 171 +++-- 13 files changed, 1705 insertions(+), 493 deletions(-) create mode 100644 examples/encryption.py create mode 100644 tensorizer/_crypt_info.py create mode 100644 tensorizer/_internal_utils.py create mode 100644 tests/__init__.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 59efb661..de9616ed 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,5 +22,8 @@ jobs: - name: Install Redis run: sudo apt-get install -y redis-server + - name: Install libsodium + run: sudo apt-get install -y libsodium23 + - name: Run tests run: python -m unittest discover tests/ --verbose diff --git a/CHANGELOG.md b/CHANGELOG.md index 130997b9..9444199b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,25 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added + +- Tensor encryption + - Encrypts all tensor weights in a file with minimal overhead + - Doesn't encrypt tensor metadata, such as: + - Tensor name + - Tensor `dtype` + - Tensor shape & size + - Requires an up-to-date version of `libsodium` + - Use `apt-get install libsodium23` on Ubuntu or Debian + - On other platforms, follow the + [installation instructions from the libsodium documentation](https://doc.libsodium.org/installation) + - Takes up less than 500 KiB once installed + - Uses a parallelized version of XSalsa20-Poly1305 as its encryption algorithm + - Splits each tensor's weights into ≤ 2 MiB chunks, encrypted separately + - Example usage: see [examples/encryption.py](examples/encryption.py) + ## [2.6.0] - 2023-10-30 ### Added @@ -220,6 +239,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `get_gpu_name` - `no_init_or_tensor` +[Unreleased]: https://github.com/coreweave/tensorizer/compare/v2.6.0...HEAD [2.6.0]: https://github.com/coreweave/tensorizer/compare/v2.5.1...v2.6.0 [2.5.1]: https://github.com/coreweave/tensorizer/compare/v2.5.0...v2.5.1 [2.5.0]: https://github.com/coreweave/tensorizer/compare/v2.4.0...v2.5.0 diff --git a/examples/encryption.py b/examples/encryption.py new file mode 100644 index 00000000..9f552893 --- /dev/null +++ b/examples/encryption.py @@ -0,0 +1,92 @@ +import os +import tempfile +import time + +import torch +from transformers import AutoConfig, AutoModelForCausalLM + +from tensorizer import ( + DecryptionParams, + EncryptionParams, + TensorDeserializer, + TensorSerializer, +) +from tensorizer.utils import no_init_or_tensor + +model_ref = "EleutherAI/gpt-neo-2.7B" + + +def original_model(ref) -> torch.nn.Module: + return AutoModelForCausalLM.from_pretrained(ref) + + +def empty_model(ref) -> torch.nn.Module: + config = AutoConfig.from_pretrained(ref) + with no_init_or_tensor(): + return AutoModelForCausalLM.from_config(config) + + +# Set a strong string or bytes passphrase here +passphrase: str = os.getenv("SUPER_SECRET_STRONG_PASSWORD", "") or input( + "Passphrase to use for encryption: " +) + +fd, path = tempfile.mkstemp(prefix="encrypted-tensors") + +try: + # Encrypt a model during serialization + encryption_params = EncryptionParams.from_passphrase_fast(passphrase) + + model = original_model(model_ref) + serialization_start = time.monotonic() + + serializer = TensorSerializer(path, encryption=encryption_params) + serializer.write_module(model) + serializer.close() + + serialization_end = time.monotonic() + del model + + # Then decrypt it again during deserialization + decryption_params = DecryptionParams.from_passphrase(passphrase) + + model = empty_model(model_ref) + deserialization_start = time.monotonic() + + deserializer = TensorDeserializer( + path, encryption=decryption_params, plaid_mode=True + ) + deserializer.load_into_module(model) + deserializer.close() + + deserialization_end = time.monotonic() + del model +finally: + os.close(fd) + os.unlink(path) + + +def print_speed(prefix, start, end, size): + mebibyte = 1 << 20 + gibibyte = 1 << 30 + duration = end - start + rate = size / duration + print( + f"{prefix} {size / gibibyte:.2f} GiB model in {duration:.2f} seconds," + f" {rate / mebibyte:.2f} MiB/s" + ) + + +print_speed( + "Serialized and encrypted", + serialization_start, + serialization_end, + serializer.total_tensor_bytes, +) + +print_speed( + "Deserialized encrypted", + deserialization_start, + deserialization_end, + deserializer.total_tensor_bytes, +) diff --git a/pyproject.toml b/pyproject.toml index 25085ce6..1668b3d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ "boto3>=1.26.0", "redis>=5.0.0", "hiredis>=2.2.0", - "pynacl>=1.5.0", + "libnacl>=2.1.0" ] classifiers = [ "Programming Language :: Python :: 3", diff --git a/requirements.txt b/requirements.txt index 2e32128b..9bd5aafe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,5 +3,6 @@ numpy>=1.19.5 protobuf>=3.19.5 psutil>=5.9.4 boto3>=1.26.0 +redis==5.0.0 hiredis -redis==5.0.0 \ No newline at end of file +libnacl>=2.1.0 diff --git a/tensorizer/_NumpyTensor.py b/tensorizer/_NumpyTensor.py index b4d6dd8a..37cb62c2 100644 --- a/tensorizer/_NumpyTensor.py +++ b/tensorizer/_NumpyTensor.py @@ -13,39 +13,74 @@ 8: torch.int64, } -# torch types with no numpy equivalents -# i.e. the only ones that need to be opaque -# Uses a comprehension to filter out any dtypes -# that don't exist in older torch versions -_ASYMMETRIC_TYPES = { - getattr(torch, t) +# Listing of types from a static copy of: +# tuple( +# dict.fromkeys( +# str(t) +# for t in vars(torch).values() +# if isinstance(t, torch.dtype) +# ) +# ) +_ALL_TYPES = { + f"torch.{t}": v for t in ( - "bfloat16", - "quint8", + "uint8", + "int8", + "int16", + "int32", + "int64", + "float16", + "float32", + "float64", + "complex32", + "complex64", + "complex128", + "bool", "qint8", + "quint8", "qint32", + "bfloat16", "quint4x2", "quint2x4", - "complex32", ) - if hasattr(torch, t) + if isinstance(v := getattr(torch, t, None), torch.dtype) +} + +# torch types with no numpy equivalents +# i.e. the only ones that need to be opaque +# Uses a comprehension to filter out any dtypes +# that don't exist in older torch versions +_ASYMMETRIC_TYPES = { + _ALL_TYPES[t] + for t in { + "torch.bfloat16", + "torch.quint8", + "torch.qint8", + "torch.qint32", + "torch.quint4x2", + "torch.quint2x4", + "torch.complex32", + } + & _ALL_TYPES.keys() } # These types aren't supported yet because they require supplemental # quantization parameters to deserialize correctly _UNSUPPORTED_TYPES = { - getattr(torch, t) - for t in ( - "quint8", - "qint8", - "qint32", - "quint4x2", - "quint2x4", - ) - if hasattr(torch, t) + _ALL_TYPES[t] + for t in { + "torch.quint8", + "torch.qint8", + "torch.qint32", + "torch.quint4x2", + "torch.quint2x4", + } + & _ALL_TYPES.keys() } -_DECODE_MAPPING = {str(t): t for t in _ASYMMETRIC_TYPES} +_DECODE_MAPPING = { + k: v for k, v in _ALL_TYPES.items() if v not in _UNSUPPORTED_TYPES +} class _NumpyTensor(NamedTuple): @@ -85,14 +120,12 @@ def from_buffer( buffer=buffer, offset=offset, ) - return cls(data=data, - numpy_dtype=numpy_dtype, - torch_dtype=torch_dtype) + return cls(data=data, numpy_dtype=numpy_dtype, torch_dtype=torch_dtype) @classmethod - def from_tensor(cls, - tensor: Union[torch.Tensor, - torch.nn.Module]) -> "_NumpyTensor": + def from_tensor( + cls, tensor: Union[torch.Tensor, torch.nn.Module] + ) -> "_NumpyTensor": """ Converts a torch tensor into a `_NumpyTensor`. May use an opaque dtype for the numpy array stored in diff --git a/tensorizer/_crypt_info.py b/tensorizer/_crypt_info.py new file mode 100644 index 00000000..40ba1f6f --- /dev/null +++ b/tensorizer/_crypt_info.py @@ -0,0 +1,345 @@ +import abc +import dataclasses +import io +import struct +import typing +import weakref +from functools import partial +from typing import ClassVar, List, Optional, Sequence, Union + +from tensorizer._internal_utils import _unpack_memoryview_from, _variable_read + + +class CryptInfoChunk(abc.ABC): + _chunk_types: ClassVar[ + typing.MutableMapping[int, typing.Type["CryptInfoChunk"]] + ] = weakref.WeakValueDictionary() + chunk_type: ClassVar[int] + _length_segment: ClassVar[struct.Struct] = struct.Struct(" "CryptInfoChunk": + chunk_type = CryptInfoChunk._chunk_type_segment.unpack_from( + buffer, offset + )[0] + return CryptInfoChunk._chunk_types[chunk_type].unpack_from( + buffer, offset + CryptInfoChunk._chunk_type_segment.size + ) + + @abc.abstractmethod + def pack_into(self, buffer, offset: int = 0) -> int: + CryptInfoChunk._chunk_type_segment.pack_into( + buffer, offset, self.chunk_type + ) + return offset + CryptInfoChunk._chunk_type_segment.size + + def pack(self) -> bytes: + buffer = io.BytesIO(bytes(self.size)) + self.pack_into(buffer.getbuffer(), 0) + return buffer.getvalue() + + def sized_pack(self) -> bytes: + buffer = io.BytesIO(bytes(self.sized_size)) + self.sized_pack_into(buffer.getbuffer(), 0) + return buffer.getvalue() + + def sized_pack_into(self, buffer, offset: int = 0) -> int: + start = offset + offset += CryptInfoChunk._length_segment.size + ret = self.pack_into(buffer, offset) + CryptInfoChunk._length_segment.pack_into(buffer, start, ret - start) + return ret + + @property + def sized_size(self) -> int: + return self.size + CryptInfoChunk._length_segment.size + + @property + @abc.abstractmethod + def size(self) -> int: + return CryptInfoChunk._chunk_type_segment.size + + # noinspection PyMethodOverriding + def __init_subclass__( + cls, /, *, chunk_type: Optional[int] = None, **kwargs + ): + super().__init_subclass__(**kwargs) + if chunk_type is not None: + cls.chunk_type = chunk_type + CryptInfoChunk._chunk_types[chunk_type] = cls + + +class KeyDerivationChunk(CryptInfoChunk, abc.ABC, chunk_type=1): + _derivation_methods: ClassVar[ + typing.MutableMapping[int, typing.Type["KeyDerivationChunk"]] + ] = weakref.WeakValueDictionary() + derivation_method: ClassVar[int] + _derivation_method_segment: ClassVar[struct.Struct] = struct.Struct(" "KeyDerivationChunk": + derivation_method = ( + KeyDerivationChunk._derivation_method_segment.unpack_from( + buffer, offset + )[0] + ) + return KeyDerivationChunk._derivation_methods[ + derivation_method + ].unpack_from( + buffer, offset + KeyDerivationChunk._derivation_method_segment.size + ) + + @abc.abstractmethod + def pack_into(self, buffer, offset: int = 0) -> int: + offset = super().pack_into(buffer, offset) + KeyDerivationChunk._derivation_method_segment.pack_into( + buffer, offset, self.derivation_method + ) + return offset + KeyDerivationChunk._derivation_method_segment.size + + @property + @abc.abstractmethod + def size(self) -> int: + return KeyDerivationChunk._derivation_method_segment.size + super().size + + +@dataclasses.dataclass +class FastKeyDerivationChunk(KeyDerivationChunk, derivation_method=1): + salt: Union[bytes, bytearray, memoryview] + + __slots__ = ("salt",) + + _contents_segment_template: ClassVar[str] = ( + " "FastKeyDerivationChunk": + return cls(salt=cls.read_salt(buffer, offset)[0]) + + def pack_into(self, buffer, offset: int = 0) -> int: + offset = super().pack_into(buffer, offset) + segment = self._contents_segment + segment.pack_into( + buffer, + offset, + len(self.salt), + self.salt, + ) + return offset + segment.size + + @property + def size(self) -> int: + return 2 + len(self.salt) + super().size + + +@dataclasses.dataclass +class XSalsa20ParallelChunk(CryptInfoChunk, chunk_type=2): + chunk_size: int + nonce: Union[bytes, bytearray, memoryview] + num_macs: int = dataclasses.field(init=False) + macs: Sequence[Union[bytes, bytearray, memoryview]] + + __slots__ = ("chunk_size", "nonce", "macs", "__dict__") + + NONCE_BYTES: ClassVar[int] = 24 + MAC_BYTES: ClassVar[int] = 16 + CHUNK_QUANTUM: ClassVar[int] = 64 + MINIMUM_CHUNK_SIZE: ClassVar[int] = 1024 + + _header_segment: ClassVar[struct.Struct] = struct.Struct( + "<" # Little-endian + "Q" # Chunk size + f"{NONCE_BYTES:d}s" # Initial nonce + "Q" # Number of MACs + ) + + _mac_segment: ClassVar[struct.Struct] = struct.Struct(f"<{MAC_BYTES:d}s") + + def __post_init__(self): + if len(self.nonce) != self.NONCE_BYTES: + raise ValueError("Invalid nonce size") + if not ( + isinstance(self.chunk_size, int) + and (self.chunk_size % self.CHUNK_QUANTUM == 0) + and self.chunk_size >= self.MINIMUM_CHUNK_SIZE + ): + raise ValueError("Invalid chunk size") + self.num_macs = len(self.macs) + for mac in self.macs: + if len(mac) != self.MAC_BYTES: + raise ValueError("Invalid MAC size") + + @classmethod + def unpack_from(cls, buffer, offset: int = 0) -> "XSalsa20ParallelChunk": + chunk_size, nonce, num_macs = ( + XSalsa20ParallelChunk._header_segment.unpack_from(buffer, offset) + ) + offset += XSalsa20ParallelChunk._header_segment.size + macs = [] + for i in range(num_macs): + macs.append( + _unpack_memoryview_from( + XSalsa20ParallelChunk._mac_segment.size, buffer, offset + ) + ) + offset += XSalsa20ParallelChunk._mac_segment.size + return cls(chunk_size, nonce, macs) + + def pack_into(self, buffer, offset: int = 0) -> int: + offset = super().pack_into(buffer, offset) + XSalsa20ParallelChunk._header_segment.pack_into( + buffer, offset, self.chunk_size, self.nonce, self.num_macs + ) + offset += XSalsa20ParallelChunk._header_segment.size + for mac in self.macs: + XSalsa20ParallelChunk._mac_segment.pack_into(buffer, offset, mac) + del mac + offset += XSalsa20ParallelChunk._mac_segment.size + return offset + + @property + def size(self) -> int: + return ( + XSalsa20ParallelChunk._header_segment.size + + XSalsa20ParallelChunk._mac_segment.size * self.num_macs + + super().size + ) + + +@dataclasses.dataclass +class XSalsa20SequentialChunk(CryptInfoChunk, chunk_type=3): + nonce: Union[bytes, bytearray, memoryview] + mac: Union[bytes, bytearray, memoryview] + + __slots__ = ("nonce", "mac") + + NONCE_BYTES: ClassVar[int] = 24 + MAC_BYTES: ClassVar[int] = 16 + + _contents_segment: ClassVar[struct.Struct] = struct.Struct( + "<" # Little-endian + f"{NONCE_BYTES:d}s" # Nonce + f"{MAC_BYTES:d}s" # MAC + ) + + def __post_init__(self): + if len(self.nonce) != self.NONCE_BYTES: + raise ValueError("Invalid nonce size") + if len(self.mac) != self.MAC_BYTES: + raise ValueError("Invalid MAC size") + + @classmethod + def unpack_from(cls, buffer, offset: int = 0) -> "XSalsa20SequentialChunk": + nonce, mac = XSalsa20SequentialChunk._contents_segment.unpack_from( + buffer, offset + ) + return cls(nonce, mac) + + def pack_into(self, buffer, offset: int = 0) -> int: + offset = super().pack_into(buffer, offset) + XSalsa20SequentialChunk._contents_segment.pack_into( + buffer, offset, self.nonce, self.mac + ) + return offset + XSalsa20SequentialChunk._contents_segment.size + + @property + def size(self) -> int: + return XSalsa20SequentialChunk._contents_segment.size + super().size + + +@dataclasses.dataclass +class CryptInfo: + num_chunks: int = dataclasses.field(init=False) + chunks: Sequence[CryptInfoChunk] = () + + _length_segment: ClassVar[struct.Struct] = struct.Struct( + " int: + return self._length_segment.size + self.size + + @property + def size(self) -> int: + return self._count_segment.size + sum(c.sized_size for c in self.chunks) + + def find_chunks( + self, + typ: Union[ + typing.Type[CryptInfoChunk], + typing.Tuple[typing.Type[CryptInfoChunk], ...], + ], + ) -> Sequence[CryptInfoChunk]: + return tuple(c for c in self.chunks if isinstance(c, typ)) + + def pack_into(self, buffer, offset: int = 0) -> int: + CryptInfo._count_segment.pack_into(buffer, offset, self.num_chunks) + offset += CryptInfo._count_segment.size + for chunk in self.chunks: + offset = chunk.sized_pack_into(buffer, offset) + return offset + + def sized_pack_into(self, buffer, offset: int = 0) -> int: + length_offset = offset + offset += CryptInfo._length_segment.size + ret = self.pack_into(buffer, offset) + CryptInfo._length_segment.pack_into(buffer, length_offset, ret - offset) + return ret + + @classmethod + def unpack_from(cls, buffer, offset: int = 0) -> "CryptInfo": + num_chunks: int = CryptInfo._count_segment.unpack_from(buffer, offset)[ + 0 + ] + offset += CryptInfo._count_segment.size + if num_chunks < 0: + raise ValueError( + "Invalid CryptInfo chunk count, cannot be negative" + ) + chunks: List[CryptInfoChunk] = [] + with memoryview(buffer) as mv: + for i in range(num_chunks): + chunk_size: int = CryptInfo._chunk_length_segment.unpack_from( + buffer, offset + )[0] + chunk_end: int = offset + chunk_size + offset += CryptInfo._chunk_length_segment.size + with mv[offset:chunk_end] as chunk_mv: + # Blocks out-of-bounds accesses + chunks.append(CryptInfoChunk.unpack_from(chunk_mv)) + offset = chunk_end + return cls(chunks) diff --git a/tensorizer/_internal_utils.py b/tensorizer/_internal_utils.py new file mode 100644 index 00000000..c4945572 --- /dev/null +++ b/tensorizer/_internal_utils.py @@ -0,0 +1,71 @@ +import dataclasses +import struct +import typing +from typing import Tuple, Union + +_Buffer = Union[bytes, bytearray, memoryview] # type: typing.TypeAlias + + +@dataclasses.dataclass(init=False) +class Chunked: + __slots__ = ("count", "total_size", "chunk_size", "remainder") + count: int + total_size: int + chunk_size: int + remainder: int + + def __init__(self, total_size: int, chunk_size: int): + self.total_size = total_size + self.chunk_size = chunk_size + self.remainder = total_size % chunk_size + self.count = total_size // chunk_size + (self.remainder != 0) + + +def _variable_read( + data: bytes, offset: int = 0, length_fmt: str = "B", data_fmt: str = "s" +) -> Tuple[Union[memoryview, Tuple], int]: + """ + Reads a variable-length field preceded by a length from a buffer. + + Returns: + A tuple of the data read, and the offset in the buffer + following the end of the field. + """ + assert length_fmt in ("B", "H", "I", "Q") + if length_fmt == "B": + length: int = data[offset] + offset += 1 + else: + length_struct = struct.Struct("<" + length_fmt) + length: int = length_struct.unpack_from(data, offset)[0] + offset += length_struct.size + if data_fmt == "s": + # When the data is read as bytes, just return a memoryview + end = offset + length + return _unpack_memoryview_from(length, data, offset), end + else: + data_struct = struct.Struct(f"<{length:d}{data_fmt}") + data = data_struct.unpack_from(data, offset) + offset += data_struct.size + return data, offset + + +def _unpack_memoryview_from( + length: int, buffer: _Buffer, offset: int +) -> memoryview: + # Grabbing a memoryview with bounds checking. + # Bounds checking is normally provided by the struct module, + # but it can't return memoryviews. + with memoryview(buffer) as mv: + end = offset + length + view = mv[offset:end] + if len(view) < length: + view.release() + mv.release() + # Simulate a struct.error message for consistency + raise struct.error( + "unpack_from requires a buffer of at least" + f" {length:d} bytes for unpacking {length:d} bytes at offset" + f" {offset:d} (actual buffer size is {len(buffer):d})" + ) + return view diff --git a/tensorizer/_version.py b/tensorizer/_version.py index e5e59e38..2614ce9d 100644 --- a/tensorizer/_version.py +++ b/tensorizer/_version.py @@ -1 +1 @@ -__version__ = "2.6.0" +__version__ = "2.7.0" diff --git a/tensorizer/serialization.py b/tensorizer/serialization.py index 237cd6c7..becbe4d0 100644 --- a/tensorizer/serialization.py +++ b/tensorizer/serialization.py @@ -8,11 +8,11 @@ import contextlib import ctypes import dataclasses +import functools import hashlib import io import itertools import logging -import math import mmap import os import queue @@ -40,16 +40,21 @@ Union, ) -import nacl.secret -import nacl.utils import numpy import redis import torch +import tensorizer._crypt as _crypt +import tensorizer._crypt_info as _crypt_info import tensorizer.stream_io as stream_io import tensorizer.utils as utils +from tensorizer._crypt._cgroup_cpu_count import ( + effective_cpu_count as _effective_cpu_count, +) +from tensorizer._internal_utils import Chunked as _Chunked +from tensorizer._internal_utils import _variable_read from tensorizer._NumpyTensor import _NumpyTensor -from tensorizer.stream_io import CURLStreamFile, DecryptedStream +from tensorizer.stream_io import CURLStreamFile if torch.cuda.is_available(): cudart = torch.cuda.cudart() @@ -58,12 +63,51 @@ lz4 = None -__all__ = ["TensorSerializer", "TensorDeserializer", "TensorType"] +__all__ = [ + "TensorSerializer", + "TensorDeserializer", + "TensorType", + "CryptographyError", + "EncryptionParams", + "DecryptionParams", +] # Setup logger logger = logging.getLogger(__name__) +# Get CPU count +cpu_count: int = _effective_cpu_count() + + +class CryptographyError(_crypt.CryptographyError): + pass + + +def _require_libsodium() -> None: + if not _crypt.available: + raise RuntimeError( + "libsodium shared object library not found or outdated." + " libsodium is a required dependency when using tensor encryption." + " Install an up-to-date version using the instructions at" + " https://doc.libsodium.org/installation or through" + ' a package manager (e.g. "apt-get install libsodium23")' + ) + + +def _requires_libsodium(func): + if _crypt.available: + return func + else: + + @functools.wraps(func) + def wrapper(*args, **kwargs): + _require_libsodium() + return func(*args, **kwargs) + + return wrapper + + # Whether the tensor is a parameter or a buffer on the model. class TensorType(Enum): PARAM = 0 @@ -74,18 +118,20 @@ class TensorType(Enum): # If tensors with "opaque" dtypes (those that are not supported by numpy) are # saved, then a tensorizer data version of 2 is required to (de)serialize the # file. Otherwise, the file is compatible with tensorizer data version 1 -TENSORIZER_VERSION = 2 +TENSORIZER_VERSION = 3 +OPAQUE_TENSORIZER_VERSION = 2 NON_OPAQUE_TENSORIZER_VERSION = 1 TENSORIZER_MAGIC = b"|TZR|" OPAQUE_DTYPE_SEP = "\0" +_TIMEOUT: typing.Final[int] = 3600 + class HashType(Enum): CRC32 = 0 SHA256 = 1 - XSALSA20 = 2 @dataclasses.dataclass(order=True) @@ -106,7 +152,7 @@ class TensorEntry: "data_offset", "data_length", "hashes", - "raw_headers", + "header_hashes", ) name: str type: TensorType @@ -116,7 +162,7 @@ class TensorEntry: data_offset: int data_length: int hashes: Optional[List[TensorHash]] - raw_headers: Optional[bytes] + header_hashes: Optional[Dict[HashType, Any]] @dataclasses.dataclass @@ -136,6 +182,13 @@ class _FileHeader: tensor_size: int tensor_count: int + class InvalidVersionError(ValueError): + version: int + + def __init__(self, *args, version: int): + super().__init__(*args) + self.version = version + def to_bytes(self) -> bytes: return self.version_number_format.pack( self.version_number @@ -149,12 +202,14 @@ def from_io( reader.read(cls.version_number_format.size) )[0] if version_number not in accepted_versions: - raise ValueError( + message = ( "Unsupported version: this data stream uses tensorizer" f" data version {version_number}, which is not supported" - " in this release of tensorizer." + " in this release of tensorizer, or" + " for the serialization/deserialization features selected." f"\nSupported data versions: {tuple(accepted_versions)}" ) + raise cls.InvalidVersionError(message, version=version_number) data = reader.read(cls.format.size) if len(data) < cls.format.size: raise ValueError( @@ -206,6 +261,7 @@ class _TensorHeaderSerializer: "I" # CRC32 hash value ) crc32_hash_offset: int + has_crc32: bool sha256_hash_segment: ClassVar[struct.Struct] = struct.Struct( "<" @@ -214,16 +270,10 @@ class _TensorHeaderSerializer: "32s" # SHA256 hash value ) sha256_hash_offset: int + has_sha256: bool - xsalsa20_hash_segment: ClassVar[struct.Struct] = struct.Struct( - "<" - "B" # XSalsa20 hash type - "B" # XSalsa20 length - "32s" # 32-byte salt - "24s" # 24-byte nonce - "B" # Crypto block size in # of shifts - ) - xsalsa20_hash_offset: int + crypt_info: Optional[_crypt_info.CryptInfo] + crypt_info_offset: int data_length_segment: ClassVar[struct.Struct] = struct.Struct( " int: + crc32 = 0 + for view in self._hashable_segment_views(): + with view: + crc32 = zlib.crc32(view, crc32) + return crc32 + + def compute_sha256(self): + sha256 = hashlib.sha256() + for view in self._hashable_segment_views(): + with view: + sha256.update(view) + return sha256 + def add_crc32(self, value: int): + if not self.has_crc32: + raise ValueError( + "Cannot add CRC32 to header defined without a CRC32 field" + ) self.crc32_hash_segment.pack_into( self.buffer, self.crc32_hash_offset, @@ -391,6 +466,10 @@ def add_crc32(self, value: int): ) def add_sha256(self, value: bytes): + if not self.has_sha256: + raise ValueError( + "Cannot add SHA256 to header defined without a SHA256 field" + ) self.sha256_hash_segment.pack_into( self.buffer, self.sha256_hash_offset, @@ -399,52 +478,9 @@ def add_sha256(self, value: bytes): value, # Hash value ) - def add_xsalsa20(self, salt: bytes, nonce: bytes, block_size: int): - shifts = int(math.log2(block_size)) - self.xsalsa20_hash_segment.pack_into( - self.buffer, - self.xsalsa20_hash_offset, - HashType.XSALSA20.value, # Hash type - 57, # XSalsa20 metadata length - salt, # Salt - nonce, # Nonce - shifts, # Crypto block size in # of shifts - ) - - def update_data_length(self, value: int): - self.data_length_segment.pack_into( - self.buffer, self.data_length_offset, value - ) - - -def _variable_read( - data: bytes, offset: int = 0, length_fmt: str = "B", data_fmt: str = "s" -) -> Tuple[Union[memoryview, Tuple], int]: - """ - Reads a variable-length field preceded by a length from a buffer. - - Returns: - A tuple of the data read, and the offset in the buffer - following the end of the field. - """ - assert length_fmt in ("B", "H", "I", "Q") - if length_fmt == "B": - length: int = data[offset] - offset += 1 - else: - length_struct = struct.Struct("<" + length_fmt) - length: int = length_struct.unpack_from(data, offset)[0] - offset += length_struct.size - if data_fmt == "s": - # When the data is read as bytes, just return a memoryview - end = offset + length - with memoryview(data) as mv: - return mv[offset:end], end - else: - data_struct = struct.Struct(f"<{length:d}{data_fmt}") - data = data_struct.unpack_from(data, offset) - offset += data_struct.size - return data, offset + def update_crypt_info(self): + if self.crypt_info is not None: + self.crypt_info.sized_pack_into(self.buffer, self.crypt_info_offset) @dataclasses.dataclass(init=False) @@ -456,8 +492,11 @@ class _TensorHeaderDeserializer: dtype: str shape: Tuple[int, ...] hashes: List[TensorHash] + crypt_info: Optional[_crypt_info.CryptInfo] data_length: int + _hashable_segments: Sequence[slice] + header_len_segment: ClassVar[struct.Struct] = struct.Struct(" Optional["_TensorHeaderDeserializer"]: # We read the entire header into memory rather than reading # it piecewise to avoid the overhead of many small reads, @@ -486,9 +531,16 @@ def from_io( buffer[:offset] = header_len_bytes with memoryview(buffer) as mv: reader.readinto(mv[offset:]) - return cls(buffer, zero_hashes=zero_hashes) + return cls( + buffer, zero_hashes=zero_hashes, check_crypt_info=check_crypt_info + ) - def __init__(self, buffer: bytearray, zero_hashes: bool = True): + def __init__( + self, + buffer: bytearray, + zero_hashes: bool = True, + check_crypt_info: bool = False, + ): self.buffer = buffer offset = self.header_len_segment.size self.module_idx, tensor_type = self.tensor_info_segment.unpack_from( @@ -517,12 +569,58 @@ def __init__(self, buffer: bytearray, zero_hashes: bool = True): if zero_hashes: self._zero_hashes(hashes_slice) + if check_crypt_info: + crypt_info_start = offset + crypt_info_slice, offset = self.read_crypt_info_block( + buffer, offset + ) + self._hashable_segments = ( + slice(None, crypt_info_start), + slice(offset, None), + ) + with crypt_info_slice: + self.crypt_info = _crypt_info.CryptInfo.unpack_from( + crypt_info_slice + ) + else: + self.crypt_info = None + self._hashable_segments = (slice(None, None),) + # Finally, get the tensor data length. offset = len(buffer) - self.data_length_segment.size self.data_length = self.data_length_segment.unpack_from(buffer, offset)[ 0 ] + def _hashable_segment_views(self): + for segment_slice in self._hashable_segments: + yield memoryview(self.buffer)[segment_slice] + + def compute_crc32(self) -> int: + crc32 = 0 + for view in self._hashable_segment_views(): + with view: + crc32 = zlib.crc32(view, crc32) + return crc32 + + def compute_sha256(self): + sha256 = hashlib.sha256() + for view in self._hashable_segment_views(): + with view: + sha256.update(view) + return sha256 + + def compute_hashes(self) -> Dict[HashType, Any]: + hashes = {} + for hash_type in self.hashes: + if hash_type.type in hashes: + continue + elif hash_type.type == HashType.CRC32: + hashes[hash_type.type] = self.compute_crc32() + elif hash_type.type == HashType.SHA256: + hashes[hash_type.type] = self.compute_sha256() + return hashes + @staticmethod def _decode_hashes(b: memoryview) -> List[TensorHash]: """ @@ -646,7 +744,7 @@ def _read_entry(cls, buffer: bytes, offset: int) -> Tuple[TensorEntry, int]: data_length=data_length, # The following fields are only available in the per-tensor headers hashes=None, - raw_headers=None, + header_hashes=None, ), offset, ) @@ -656,6 +754,307 @@ class HashMismatchError(Exception): pass +@dataclasses.dataclass(init=False) +class EncryptionParams: + """ + Defines encryption parameters for a TensorSerializer. + + There are three ways to use this class, mainly using its factory functions: + + #. Using `EncryptionParams.random()` + + This will generate a random encryption key. + This is the fastest and most secure option, but you must + save it somewhere to be able to use it for decryption later. + + #. Using `EncryptionParams.from_passphrase_fast()` + + This will generate a reproducible encryption key from a passphrase string, + using a fast algorithm (one round of SHA256 with salt). + This is a good choice if your passphrase is already strong, + such as when using a long randomly-generated string. + This does not provide the same protection against brute-force attempts + as an intentionally slow password hashing function. + + #. Using `EncryptionParams(key=...)` directly + + You can supply an exact key to use for encryption by directly invoking + the `EncryptionParams` constructor. This must be a `bytes` object of the + correct length to be used as an XSalsa20 cipher key. + This is more complicated and risky to use than the other options. + Do not use this with an insecure key. + + Examples: + + Using `EncryptionParams.from_passphrase_fast()` with + an environment variable:: + + passphrase: str = os.getenv("SUPER_SECRET_STRONG_PASSWORD") + encryption_params = EncryptionParams.from_passphrase_fast( + passphrase + ) + + # Use this to encrypt something: + serializer = TensorSerializer( + "model.tensors", encryption=encryption_params + ) + serializer.write_module(...) + serializer.close() + + # Then decrypt it again + decryption_params = DecryptionParams.from_passphrase(passphrase) + deserializer = TensorDeserializer( + "model.tensors", encryption=decryption_params + ) + deserializer.load_into_module(...) + deserializer.close() + + + Using `EncryptionParams.random()`:: + + encryption_params = EncryptionParams.random() + + # Use this to encrypt something: + serializer = TensorSerializer( + "model.tensors", encryption=encryption_params + ) + serializer.write_module(...) + serializer.close() + + # Then decrypt it again + key: bytes = encryption_params.key + decryption_params = DecryptionParams.from_key(key) + deserializer = TensorDeserializer( + "model.tensors", encryption=decryption_params + ) + deserializer.load_into_module(...) + deserializer.close() + """ + + # Not yet fully implemented: + # #. Using `EncryptionParams.from_passphrase_slow(passphrase)` + # + # This will generate a reproducible encryption key from a passphrase string, + # using a slow algorithm that is resistant against brute-force attacks. + # This should be used if the source passphrase is weak (or human-written). + # + # This algorithm is intentionally very slow and will slow down decryption + # and deserialization later. Use `EncryptionParams.from_passphrase_fast()` + # instead with a strong passphrase for better performance. + + key: bytes + salt: Optional[bytes] + _method: int + + _METHOD_FROM_PASSPHRASE_FAST: ClassVar[int] = 1 + + @_requires_libsodium + def __init__(self, key: bytes): + if not isinstance(key, (bytes, bytearray, memoryview)): + raise TypeError( + "Encryption key must be binary (bytes type)." + " To derive an encryption key from a string or passphrase," + " use EncryptionParams.from_passphrase_fast()" + ) + + if len(key) != _crypt.ChunkedEncryption.KEY_BYTES: + raise ValueError( + "Invalid encryption key length," + f" should be {_crypt.ChunkedEncryption.KEY_BYTES} bytes;" + f" got {len(key)} bytes instead." + " To generate a valid encryption key from any string" + " or bytes object," + " use EncryptionParams.from_passphrase_fast()" + ) + self.key = bytes(key) + self.salt = None + self._method = 0 + + @classmethod + @_requires_libsodium + def random(cls) -> "EncryptionParams": + return cls(_crypt.random_bytes(_crypt.ChunkedEncryption.KEY_BYTES)) + + @staticmethod + @_requires_libsodium + def _derive_salt( + salt: Union[str, bytes, bytearray, memoryview, None], + encoding, + fallback_size: int = 32, + ) -> bytes: + if salt is None: + return _crypt.random_bytes(fallback_size) + elif isinstance(salt, (bytes, bytearray, memoryview)): + return salt + elif isinstance(salt, str): + return salt.encode(encoding) + else: + raise TypeError("Invalid object type provided for salt") + + @_requires_libsodium + def _crypt_info_chunk(self) -> Optional[_crypt_info.CryptInfoChunk]: + if self._method == self._METHOD_FROM_PASSPHRASE_FAST: + return _crypt_info.FastKeyDerivationChunk(self.salt) + else: + return None + + @classmethod + @_requires_libsodium + def from_passphrase_fast( + cls, + passphrase: Union[str, bytes], + salt: Union[str, bytes, None] = None, + encoding="utf-8", + ) -> "EncryptionParams": + """ + Generates an encryption key from a password and salt. + + Args: + passphrase: The source passphrase from which to derive a key. + salt: A non-secret cryptographic salt to be stored in the model. + If None (the default), a secure random salt is used. + encoding: The encoding to use to convert `passphrase` to bytes + if provided as a ``str``. Defaults to UTF-8. + + Returns: + + """ + if not passphrase: + raise ValueError("Passphrase cannot be empty") + salt = cls._derive_salt(salt, encoding) + if isinstance(passphrase, str): + passphrase = passphrase.encode(encoding) + key_hash = hashlib.sha256(passphrase) + key_hash.update(salt) + params = cls(key=key_hash.digest()) + params.salt = salt + params._method = cls._METHOD_FROM_PASSPHRASE_FAST + return params + + # @classmethod + # def from_passphrase_slow( + # cls, + # passphrase: Union[str, bytes], + # encoding="utf-8", + # ) -> "EncryptionParams": + # if not passphrase: + # raise ValueError("Passphrase cannot be empty") + # salt = _crypt.random_bytes(_crypt.crypto_pwhash_SALTBYTES) + # if isinstance(passphrase, str): + # passphrase = passphrase.encode(encoding) + # key, params = _crypt.pwhash(passphrase, salt) + # return cls(key=key, salt=salt) + + +@dataclasses.dataclass(init=False) +class DecryptionParams: + """ + Defines decryption parameters for a TensorDeserializer. + + There are two ways to use this class, using its factory functions: + + #. Using `DecryptionParams.from_passphrase()` + + This will decrypt tensors using the specified passphrase string. + This may be used if `EncryptionParams.from_passphrase_fast()` + was used during encryption. + + #. Using `DecryptionParams.from_key()` + + This will decrypt tensors using an exact binary key. + This may always be used with the `key` from an `EncryptionParams` object, + regardless of whether the key was generated with + `EncryptionParams.from_passphrase_fast()`, `EncryptionParams.random()`, etc. + + Examples: + + Using `DecryptionParams.from_passphrase()` with + an environment variable:: + + passphrase: str = os.getenv("SUPER_SECRET_STRONG_PASSWORD") + encryption_params = EncryptionParams.from_passphrase_fast( + passphrase + ) + + # Use this to encrypt something: + serializer = TensorSerializer( + "model.tensors", encryption=encryption_params + ) + serializer.write_module(...) + serializer.close() + + # Then decrypt it again + decryption_params = DecryptionParams.from_passphrase(passphrase) + deserializer = TensorDeserializer( + "model.tensors", encryption=decryption_params + ) + deserializer.load_into_module(...) + deserializer.close() + + + Using `DecryptionParams.from_key()`:: + + encryption_params = EncryptionParams.random() + + # Use this to encrypt something: + serializer = TensorSerializer( + "model.tensors", encryption=encryption_params + ) + serializer.write_module(...) + serializer.close() + + # Then decrypt it again + key: bytes = encryption_params.key + decryption_params = DecryptionParams.from_key(key) + deserializer = TensorDeserializer( + "model.tensors", encryption=decryption_params + ) + deserializer.load_into_module(...) + deserializer.close() + """ + + key: Optional[bytes] + passphrase: Optional[bytes] + + def __init__(self): + self.key = None + self.passphrase = None + + @classmethod + @_requires_libsodium + def from_passphrase( + cls, passphrase: Union[str, bytes], encoding="utf-8" + ) -> "DecryptionParams": + if not passphrase: + raise ValueError("Passphrase cannot be empty") + if isinstance(passphrase, str): + passphrase = passphrase.encode(encoding) + elif not isinstance(passphrase, bytes): + raise TypeError("Invalid passphrase type: must be str or bytes") + params = cls() + params.passphrase = passphrase + return params + + @classmethod + @_requires_libsodium + def from_key(cls, key: bytes) -> "DecryptionParams": + if not key: + raise ValueError("Key cannot be empty") + elif len(key) != _crypt.ChunkedEncryption.KEY_BYTES: + raise ValueError( + "Invalid decryption key length," + f" should be {_crypt.ChunkedEncryption.KEY_BYTES} bytes;" + f" got {len(key)} bytes instead." + " DecryptionParams.from_key() should be used with a key" + " read from EncryptionParams.key." + " To decrypt with a passphrase instead of a binary key," + " use DecryptionParams.from_passphrase() instead." + ) + params = cls() + params.key = key + return params + + class TensorDeserializer( collections.abc.Mapping, contextlib.AbstractContextManager ): @@ -690,8 +1089,8 @@ class TensorDeserializer( verify_hash: If True, the hashes of each tensor will be verified against the hashes stored in the metadata. A `HashMismatchError` will be raised if any of the hashes do not match. - passphrase: The passphrase to use to decrypt the tensors. If None, - the tensors will not be decrypted. + encryption: A `DecryptionParams` object holding a password or key + to use for decryption. ``None`` (the default) means no decryption. Raises: HashMismatchError: If ``verify_hash=True`` and a deserialized tensor @@ -754,7 +1153,7 @@ def __init__( plaid_mode: bool = False, plaid_mode_buffers: Optional[int] = None, verify_hash: bool = False, - passphrase: Optional[str] = None, + encryption: Optional[DecryptionParams] = None, ): # Whether to verify the hashes of the tensors when they are loaded. # This value is used when no verify_hash argument is passed to the @@ -765,7 +1164,23 @@ def __init__( # pre-emptively and cancel it if __init__ is successful with self._cleanup: self._verify_hash = verify_hash - self._passphrase = passphrase + if encryption is not None and not isinstance( + encryption, DecryptionParams + ): + raise TypeError( + "encryption parameter: expected DecryptionParams instance" + f" or None, {encryption.__class__.__name__} found" + ) + self._encryption = encryption + self._encrypted = encryption is not None + if self._encrypted: + _require_libsodium() + self._decryption_pool = concurrent.futures.ThreadPoolExecutor( + max_workers=cpu_count, + thread_name_prefix="TensorizerDecryption", + ) + else: + self._decryption_pool = None if isinstance(file_obj, (str, bytes, os.PathLike, int)): self._file = stream_io.open_stream(file_obj, "rb") @@ -806,12 +1221,35 @@ def __init__( raise ValueError("Not a tensorizer file") # Read the file header - self._file_header = _FileHeader.from_io( - self._file, - accepted_versions=( + if self._encrypted: + accepted_versions = (TENSORIZER_VERSION,) + else: + accepted_versions = ( NON_OPAQUE_TENSORIZER_VERSION, + OPAQUE_TENSORIZER_VERSION, TENSORIZER_VERSION, - ), + ) + try: + self._file_header = _FileHeader.from_io( + self._file, accepted_versions=accepted_versions + ) + except _FileHeader.InvalidVersionError as e: + if self._encrypted and e.version in ( + NON_OPAQUE_TENSORIZER_VERSION, + OPAQUE_TENSORIZER_VERSION, + ): + raise CryptographyError( + "Tensor decryption was requested," + " but the file provided comes from a tensorizer version" + " predating encryption, so it must not be encrypted." + " Either set encryption=None on the TensorDeserializer," + " or ensure that the correct file was provided." + ) from e + else: + raise + + self._has_crypt_info: bool = ( + self._file_header.version_number >= TENSORIZER_VERSION ) # The total size of the file. @@ -1107,25 +1545,11 @@ def keys(self): # it as not implemented. return self._metadata.keys() - @staticmethod - def _get_salt_nonce(hashes: List[TensorHash]) -> Tuple[bytes, bytes, int]: - salt = None - nonce = None - block_sz = None - for hash_entry in hashes: - if hash_entry.type == HashType.XSALSA20: - salt = hash_entry.hash[:32] - nonce = hash_entry.hash[32:56] - shifts = struct.unpack(" None: """ @@ -1140,7 +1564,7 @@ def _verify_hashes( hash_type = tensor_hash.type hash_body = tensor_hash.hash if hash_type == HashType.CRC32: - crc = zlib.crc32(mv, zlib.crc32(headers)) + crc = zlib.crc32(mv, header_hashes[hash_type]) hash_crc = struct.unpack(" _crypt_info.CryptInfoChunk: + encryption_method = crypt_info.find_chunks( + ( + _crypt_info.XSalsa20ParallelChunk, + _crypt_info.XSalsa20SequentialChunk, + ) + ) + if not encryption_method: + raise CryptographyError("No known encryption method found in file") + elif len(encryption_method) > 1: + raise CryptographyError( + "Could not interpret encryption method of the file" + ) + return encryption_method[0] + + def _derive_encryption_key( + self, crypt_info: _crypt_info.CryptInfo + ) -> bytes: + if self._encryption.key is not None: + return self._encryption.key + # Requires key derivation from a passphrase + if self._encryption.passphrase is None: + raise ValueError("Invalid DecryptionParams") + # Check for a KeyDerivationChunk + kd = crypt_info.find_chunks(_crypt_info.KeyDerivationChunk) + if not kd: + raise CryptographyError( + "Passphrase was provided, but the tensor was" + " not originally encrypted using a passphrase" + " (i.e. EncryptionParams.from_passphrase_*)" + ) + elif len(kd) > 1: + raise CryptographyError( + "Could not interpret encryption key derivation" + " method of the file" + ) + else: + method = kd[0] + if isinstance( + method, + _crypt_info.FastKeyDerivationChunk, + ): + return EncryptionParams.from_passphrase_fast( + passphrase=self._encryption.passphrase, + salt=method.salt, + ).key + else: + # Only FastKeyDerivationChunk is currently + # recognized + raise CryptographyError("Unknown key derivation method") + + def _get_decryption_manager( + self, encryption_method: _crypt_info.CryptInfoChunk, key: bytes, buffer + ) -> Union["_crypt.ChunkedEncryption", "_crypt.SequentialEncryption"]: + if isinstance(encryption_method, _crypt_info.XSalsa20ParallelChunk): + if encryption_method.num_macs == 1: + return _crypt.SequentialEncryption( + key=key, + buffer=buffer, + nonce=encryption_method.nonce, + mac=encryption_method.macs[0], + intent=_crypt.SequentialEncryption.INTENT.DECRYPTION, + ) + else: + nonces = _crypt.ChunkedEncryption.sequential_nonces( + initial_nonce=encryption_method.nonce, + count=encryption_method.num_macs, + ) + return _crypt.ChunkedEncryption( + key=key, + buffer=buffer, + chunk_size=encryption_method.chunk_size, + nonces=nonces, + macs=encryption_method.macs, + executor=self._decryption_pool, + intent=_crypt.ChunkedEncryption.INTENT.DECRYPTION, + ) + elif isinstance(encryption_method, _crypt_info.XSalsa20SequentialChunk): + return _crypt.SequentialEncryption( + key=key, + buffer=buffer, + nonce=encryption_method.nonce, + mac=encryption_method.mac, + intent=_crypt.SequentialEncryption.INTENT.DECRYPTION, + ) + else: + raise CryptographyError("Unknown encryption method") + + def _stream_decrypt( + self, encryption_method: _crypt_info.CryptInfoChunk, key: bytes, buffer + ): + try: + with self._get_decryption_manager( + encryption_method, key, buffer + ) as crypto: + if isinstance(crypto, _crypt.ChunkedEncryption): + fs = [] + for chunk in range(crypto.num_chunks): + with crypto.chunk_view(chunk) as view: + self._file.readinto(view) + fs.append(crypto.decrypt_chunk(chunk)) + crypto.wait_or_raise(fs, timeout=_TIMEOUT) + else: + self._file.readinto(buffer) + crypto.decrypt() + except _crypt.CryptographyError as e: + raise CryptographyError("Tensor decryption failed") from e + finally: + del crypto + + @staticmethod + @contextlib.contextmanager + def _release_on_exc(mv: memoryview): + try: + yield mv + except GeneratorExit: + del mv + raise + except BaseException: + mv.release() + del mv + raise + def _read_numpytensors( self, filter_func: Optional[Callable[[str], Union[bool, Any]]] = None, @@ -1208,7 +1755,9 @@ def _read_numpytensors( tensors_read = 0 while num_tensors == -1 or tensors_read < num_tensors: header = _TensorHeaderDeserializer.from_io( - self._file, zero_hashes=True + self._file, + zero_hashes=True, + check_crypt_info=self._has_crypt_info, ) if header is None: @@ -1239,34 +1788,30 @@ def _read_numpytensors( self._metadata[header.name].hashes = header.hashes - # Encryption handling - salt, nonce, block_size = self._get_salt_nonce(header.hashes) - if self._passphrase is None and salt is not None: - # We are an encrypted tensor and we were not provided - # a passphrase to decrypt with - raise ValueError( - "Tensor is encrypted, but no passphrase was provided" + header_hashes = header.compute_hashes() + self._metadata[header.name].header_hashes = header_hashes + + is_encrypted: bool = ( + header.crypt_info is not None + and header.crypt_info.num_chunks != 0 + ) + if self._encrypted and not is_encrypted: + raise CryptographyError( + "Tensor is not encrypted, but decryption was requested" ) - elif self._passphrase is not None and salt is None: - raise ValueError( - "Tensor is not encrypted, but a passphrase was provided" + elif is_encrypted and not self._encrypted: + raise CryptographyError( + "Tensor is encrypted, but decryption was not requested" ) - elif self._passphrase is not None: - if isinstance(self._passphrase, str): - passphrase = self._passphrase.encode("utf-8") - else: - passphrase = self._passphrase - key = hashlib.sha256(passphrase + salt).digest() - - wrapped_io = DecryptedStream( - self._file, key, nonce, block_size + elif self._encrypted or is_encrypted: + assert self._encrypted and is_encrypted + encryption_method = self._get_encryption_method( + header.crypt_info ) + key = self._derive_encryption_key(header.crypt_info) else: - wrapped_io = self._file - - # Store our raw headers with hashes zeroed out - # for model verification - self._metadata[header.name].raw_headers = header.buffer + key = None + encryption_method = None # We use memoryview to avoid copying the data. mv: memoryview @@ -1276,7 +1821,6 @@ def _read_numpytensors( # all the tensors. We just need to slice out the # memoryview for the current tensor. mv = self._buffers[header.name] - wrapped_io.readinto(mv) self._allocated += header.data_length elif self._plaid_mode: # In plaid_mode, we don't allocate a buffer, we just @@ -1286,7 +1830,6 @@ def _read_numpytensors( # the buffer contents after we yield the tensor, which # is loaded straight into the GPU memory. mv = self._buffers[header.name] - wrapped_io.readinto(mv) else: # In lazy_load mode, we allocate a new buffer for each # tensor. This is a bit slower, but it's the only way @@ -1295,17 +1838,20 @@ def _read_numpytensors( if len(buffer) != header.data_length: raise RuntimeError("Header data length mismatch") mv = memoryview(buffer) - wrapped_io.readinto(mv) + + if not self._encrypted or mv.nbytes == 0: + self._file.readinto(mv) + elif self._encrypted and mv.nbytes > 0: + with self._release_on_exc(mv): + self._stream_decrypt(encryption_method, key, mv) if verify_hash: - try: + with self._release_on_exc(mv): + # Releasing on an exception is necessary to prevent + # a BufferError on close() self._verify_hashes( - header.name, header.hashes, header.buffer, mv + header.name, header.hashes, header_hashes, mv ) - except HashMismatchError: - # Necessary to prevent a BufferError on close() - mv.release() - raise if raw: tensor = mv @@ -1578,6 +2124,8 @@ def _optimize_plaid_mode_buffers( self._buffers = unoptimized_buffers class _AtomicCountdown: + __slots__ = ("_count", "_condition", "_cancelled", "_initial") + def __init__(self, count: int, initial: Optional[int] = None): if count <= 0: raise ValueError("Invalid count.") @@ -1610,10 +2158,12 @@ def trigger(self) -> None: if self._count == 0: self._condition.notify_all() - def reset(self) -> None: + def reset(self, count: Optional[int] = None) -> None: """Resets the internal counter.""" + if count is not None and count <= 0: + raise ValueError("Invalid count.") with self._condition: - self._count = self._initial + self._count = self._initial if count is None else count self._cancelled = False def cancel(self) -> None: @@ -1678,17 +2228,17 @@ def _bulk_load( transfer_out_queue = queue.SimpleQueue() sentinel = object() - num_hash_tasks = 2 + max_num_hash_tasks = 2 tasks_per_tensor = 1 if verify_hash: - tasks_per_tensor += num_hash_tasks + tasks_per_tensor += max_num_hash_tasks atomic_countdown: typing.Type = TensorDeserializer._AtomicCountdown if self._plaid_mode: - countdowns = [ - atomic_countdown(tasks_per_tensor, 0) + countdowns = tuple( + atomic_countdown(tasks_per_tensor, initial=0) for _ in range(self._plaid_mode_buffer_count) - ] + ) countdown_cycle = itertools.cycle(countdowns) else: countdown_cycle = itertools.cycle((None,)) @@ -1705,7 +2255,7 @@ def cancel_thread(thread: threading.Thread): ctypes.pythonapi.PyThreadState_SetAsyncExc(thread_id, exc) == 1 ) - def receive_and_check(timeout: int = 3600) -> torch.nn.Parameter: + def receive_and_check(timeout: int) -> torch.nn.Parameter: outcome = transfer_out_queue.get(timeout=timeout) if outcome is sentinel: raise RuntimeError("Loading failed") @@ -1723,7 +2273,7 @@ def transfer() -> None: ): while True: next_tensor, countdown = transfer_in_queue.get( - timeout=3600 + timeout=_TIMEOUT ) next_tensor: torch.Tensor countdown: Optional[atomic_countdown] @@ -1732,7 +2282,7 @@ def transfer() -> None: try: transfer_out_queue.put( self._to_torch_parameter(next_tensor), - timeout=3600, + timeout=_TIMEOUT, ) finally: if countdown is not None: @@ -1777,7 +2327,7 @@ def ready_buffers( if verify_hash: computation_threads = concurrent.futures.ThreadPoolExecutor( - max_workers=min(len(keys) * 2, os.cpu_count()), + max_workers=min(len(keys) * max_num_hash_tasks, cpu_count), thread_name_prefix="TensorizerComputation", ) else: @@ -1792,7 +2342,7 @@ def check_hash( with memoryview(data).cast("B") as mv: try: self._verify_hashes( - metadata.name, hashes, metadata.raw_headers, mv + metadata.name, hashes, metadata.header_hashes, mv ) finally: if countdown is not None: @@ -1828,7 +2378,9 @@ def read_into_buffers() -> None: break metadata: TensorEntry = self._metadata[name] self._file.seek(metadata.offset) - if countdown is not None and not countdown.wait(): + if countdown is not None and not countdown.wait( + _TIMEOUT + ): break for *_, tensor in self.read_tensors( num_tensors=1, verify_hash=False @@ -1837,9 +2389,12 @@ def read_into_buffers() -> None: if stop: break tensor: torch.Tensor + hashing_required = computation_threads is not None if countdown is not None: - countdown.reset() - if computation_threads is not None: + countdown.reset( + 1 + len(metadata.hashes) * hashing_required + ) + if hashing_required: check_hashes(name, tensor, countdown) transfer_in_queue.put_nowait((tensor, countdown)) except (Cancelled, Exception) as e: @@ -1860,22 +2415,22 @@ def read_into_buffers() -> None: try: for key in keys[:-1]: - self._cache[key] = receive_and_check() + self._cache[key] = receive_and_check(_TIMEOUT) yield self._cache[key] # Stop before yielding the final tensor # to catch up on hash verification - read_thread.join(timeout=3600) + read_thread.join(timeout=_TIMEOUT) if computation_threads is not None: # At this point, all checks have been added to `checks` computation_threads.shutdown(wait=True) # At this point, all checks have finished for check in checks: # This will raise if any of the checks failed - check.result(timeout=3600) + check.result(timeout=_TIMEOUT) checks.clear() - self._cache[keys[-1]] = receive_and_check() + self._cache[keys[-1]] = receive_and_check(_TIMEOUT) transfer_in_queue.put_nowait((sentinel, None)) - transfer_thread.join(timeout=3600) + transfer_thread.join(timeout=_TIMEOUT) yield self._cache[keys[-1]] except Exception: stop = True @@ -1888,13 +2443,13 @@ def read_into_buffers() -> None: transfer_thread.join(timeout=4) if transfer_thread.is_alive(): cancel_thread(transfer_thread) - transfer_thread.join(timeout=3600) + transfer_thread.join(timeout=_TIMEOUT) if read_thread.is_alive(): # A graceful exit is again preferred, but not necessary read_thread.join(timeout=2) if read_thread.is_alive(): cancel_thread(read_thread) - read_thread.join(timeout=3600) + read_thread.join(timeout=_TIMEOUT) for check in checks: check.cancel() raise @@ -2032,7 +2587,7 @@ def verify_module( self._verify_hashes( name, entry.hashes, - entry.raw_headers, + entry.header_hashes, mv, ) results.append((name, True)) @@ -2091,11 +2646,10 @@ class TensorSerializer: Args: file_obj: A file-like object or path to a file to write to. The path can be a S3 URI. - passphrase: A passphrase to use for encryption. If None, no encryption - will be used. - salt: A salt to use for encryption. If None, a random salt will be - generated. - compress_tensors: If True, compress the tensors using lz4. This + encryption: An `EncryptionParams` object holding a password or key + to use for encryption. If None, no encryption will be used. + compress_tensors: Not implemented. Specifying this option does nothing. + Previously, if True, compress the tensors using lz4. This exists as an internal curiosity as it doesn't seem to make much of a difference in practice. @@ -2139,9 +2693,8 @@ def __init__( int, ], compress_tensors: bool = False, - passphrase: Optional[Union[str, bytes]] = None, - salt: Optional[Union[str, bytes]] = None, - crypt_chunk_size: int = 2048, + *, + encryption: Optional[EncryptionParams] = None, ) -> None: if isinstance(file_obj, (str, bytes, os.PathLike, int)): self._file = stream_io.open_stream(file_obj, "wb+") @@ -2149,23 +2702,23 @@ def __init__( self._mode_check(file_obj) self._file = file_obj - self._cleartext_chunk_size = crypt_chunk_size - self._crypt_chunk_size = ( - crypt_chunk_size + nacl.secret.SecretBox.MACBYTES - ) - if passphrase is not None: - if salt is None: - salt = os.urandom(32) - elif isinstance(salt, str): - salt = salt.encode("utf-8") - self._salt = salt - # Convert our passphrase to bytes - if isinstance(passphrase, str): - passphrase = passphrase.encode("utf-8") - self._crypto_key = hashlib.sha256(passphrase + salt).digest() - self._lockbox = nacl.secret.SecretBox(self._crypto_key) + if encryption is not None and not isinstance( + encryption, EncryptionParams + ): + raise TypeError( + "encryption parameter: expected EncryptionParams instance" + f" or None, {encryption.__class__.__name__} found" + ) + self._encryption = encryption + self._encrypted = encryption is not None + self._used_nonces: Optional[Set[bytes]] + if self._encrypted: + _require_libsodium() + self._crypt_chunk_size = 2 << 20 + self._used_nonces = set() else: - self._lockbox = None + self._crypt_chunk_size = None + self._used_nonces = None # Get information about the file object's capabilities _fd_getter = getattr(self._file, "fileno", None) @@ -2201,7 +2754,7 @@ def __init__( # multithreading in spite of the GIL because CPython's hash function # implementations release the GIL during longer hash computations. self._computation_pool = concurrent.futures.ThreadPoolExecutor( - max_workers=os.cpu_count(), + max_workers=cpu_count, thread_name_prefix="TensorizerComputation", ) @@ -2229,6 +2782,19 @@ def __init__( thread_name_prefix="TensorizerHeaderWriter", ) + if self._encrypted: + self._encryption_pool = concurrent.futures.ThreadPoolExecutor( + max_workers=max_concurrent_writers, + thread_name_prefix="TensorizerEncryption", + ) + + self._decryption_pool = concurrent.futures.ThreadPoolExecutor( + max_workers=max_concurrent_writers, + thread_name_prefix="TensorizerDecryption", + ) + else: + self._encryption_pool = self._decryption_pool = None + # Implementation detail for CPython: ThreadPoolExecutor objects # use an instance of queue.SimpleQueue as a FIFO work queue, # so the order that tasks are started (but not necessarily finished) @@ -2253,15 +2819,23 @@ def __init__( self._file.write(TENSORIZER_MAGIC) # Write file header metadata + if not self._encrypted: + # Can't tell if OPAQUE_TENSORIZER_VERSION is needed + # until a tensor is written later with an opaque dtype, + # so assume it is non-opaque until then. + version_number = NON_OPAQUE_TENSORIZER_VERSION + else: + # File encryption requires a newer tensorizer version + version_number = TENSORIZER_VERSION self._file_header_loc = self._file.tell() self._file_header = _FileHeader( - version_number=NON_OPAQUE_TENSORIZER_VERSION, + version_number=version_number, tensor_size=0, tensor_count=0, ) self._file.write(self._file_header.to_bytes()) - # Reserve 256kb for metadata. + # Reserve 256 KiB for metadata. metadata_size = 256 * 1024 self._file.write(struct.pack(" None: @@ -2415,6 +2989,28 @@ def close(self) -> None: # logger.info(f"Comp'd bytes: {self.total_compressed_tensor_bytes}") # logger.info(f"Ratio: {compression_ratio:.2f}") + def _new_nonces(self, count: int) -> Tuple[bytes, ...]: + if count < 0: + raise ValueError("Invalid nonce count") + elif count == 0: + return () + elif self._used_nonces is None: + raise RuntimeError( + "Tried to create cryptographic nonces while" + " encryption is disabled" + ) + nonces = tuple( + _crypt.ChunkedEncryption.sequential_nonces( + initial_nonce=_crypt.ChunkedEncryption.random_nonce(), + count=count, + ) + ) + + if self._used_nonces.intersection(nonces): + raise RuntimeError("Illegal nonce reuse") + self._used_nonces.update(nonces) + return nonces + def write_tensor( self, idx, @@ -2468,6 +3064,7 @@ def _write_tensor( *, _synchronize: bool = True, _start_pos: Optional[int] = None, + _temporary_buffer: bool = False, ) -> int: """ Underlying implementation for `write_tensor()`, @@ -2488,16 +3085,31 @@ def _write_tensor( writes starting at the current file offset. """ if isinstance(tensor, torch.Tensor): - numpy_tensor = _NumpyTensor.from_tensor(tensor) + if not tensor.is_contiguous(): + _temporary_buffer = True + numpy_tensor = _NumpyTensor.from_tensor(tensor.contiguous()) else: - numpy_tensor = _NumpyTensor.from_array(tensor) + if ( + isinstance(tensor, numpy.ndarray) + and not tensor.flags.c_contiguous + and hasattr(numpy, "ascontiguousarray") + ): + numpy_tensor = _NumpyTensor.from_array( + numpy.ascontiguousarray(tensor) + ) + _temporary_buffer = True + else: + numpy_tensor = _NumpyTensor.from_array(tensor) dtype_name = numpy_tensor.numpy_dtype if numpy_tensor.is_opaque: # The datatype name needs to contain both the numpy dtype that the # data is serialized as and the original torch dtype. dtype_name += OPAQUE_DTYPE_SEP + numpy_tensor.torch_dtype - self._file_header.version_number = TENSORIZER_VERSION + self._file_header.version_number = max( + OPAQUE_TENSORIZER_VERSION, + self._file_header.version_number, + ) tensor = numpy_tensor.data tensor_memory = numpy_tensor.data.data @@ -2508,6 +3120,41 @@ def _write_tensor( raise ValueError("dtype name length should be less than 256") shape = tensor.shape header_pos = self._file.tell() if _start_pos is None else _start_pos + + if self._encrypted: + chunks = _Chunked( + total_size=tensor_memory.nbytes, + chunk_size=self._crypt_chunk_size, + ) + nonces = self._new_nonces(chunks.count) + encryptor = _crypt.ChunkedEncryption( + key=self._encryption.key, + buffer=tensor_memory, + chunk_size=self._crypt_chunk_size, + nonces=nonces, + executor=self._computation_pool, + ) + + key_derivation_chunk = self._encryption._crypt_info_chunk() + encryption_algorithm_chunk = _crypt_info.XSalsa20ParallelChunk( + chunk_size=self._crypt_chunk_size, + nonce=nonces[0], + macs=encryptor.macs, + ) + if key_derivation_chunk is not None: + chunks = (key_derivation_chunk, encryption_algorithm_chunk) + else: + chunks = (encryption_algorithm_chunk,) + crypt_info = _crypt_info.CryptInfo(chunks) + else: + encryptor = None + if self._file_header.version_number == TENSORIZER_VERSION: + crypt_info = _crypt_info.CryptInfo() + else: + crypt_info = None + + include_crc32: bool = not self._encrypted + header = _TensorHeaderSerializer( idx, tensor_type, @@ -2516,7 +3163,9 @@ def _write_tensor( shape, tensor_size, header_pos, - self._lockbox is not None, + include_crc32=include_crc32, + include_sha256=True, + crypt_info=crypt_info, ) tensor_pos = header_pos + header.data_offset @@ -2542,113 +3191,114 @@ def write_metadata(): # These two tasks are CPU-bound and don't block the GIL, # so they go into the computation thread pool. def compute_crc32(): - crc32 = zlib.crc32(header.buffer) + crc32 = header.compute_crc32() return zlib.crc32(tensor_memory, crc32) def compute_sha256(): - sha256 = hashlib.sha256(header.buffer) + sha256 = header.compute_sha256() sha256.update(tensor_memory) return sha256.digest() - def encrypt_tensor() -> Tuple[Optional[bytes], Optional[bytes]]: - start = time.monotonic() - if self._lockbox is None: - return None, None - nonce = nacl.utils.random(nacl.secret.SecretBox.NONCE_SIZE) - nonce_int = int.from_bytes( - nonce, - "big", - signed=False, - ) - - num_chunks = math.ceil( - tensor_memory.nbytes / self._cleartext_chunk_size - ) - cryptotext_size = num_chunks * self._crypt_chunk_size - cryptotext = bytearray(cryptotext_size) - cryptotext_end = 0 - tensor_bytes = memoryview(tensor_memory.tobytes()) - setup_end = time.monotonic() - - chunk_repr = "" - - for i in range(num_chunks): - # We XOR the nonce with the chunk index to avoid nonce reuse. - step_nonce = nonce_int ^ i - step_nonce_bytes = step_nonce.to_bytes( - nacl.secret.SecretBox.NONCE_SIZE, "big", signed=False - ) - plaintext_begin = i * self._cleartext_chunk_size - plaintext_end = plaintext_begin + self._cleartext_chunk_size - if plaintext_end > tensor_memory.nbytes: - plaintext_end = tensor_memory.nbytes - to_encrypt = tensor_bytes[plaintext_begin:plaintext_end] - chunk = self._lockbox.encrypt( - to_encrypt, - step_nonce_bytes, - ).ciphertext - cryptotext_begin = i * self._crypt_chunk_size - cryptotext_end = cryptotext_begin + len(chunk) - cryptotext[cryptotext_begin:cryptotext_end] = chunk - encryption_header_size = len(chunk) - len(to_encrypt) - end = time.monotonic() - duration_setup_ms = (setup_end - start) * 1000 - duration_ms = (end - start) * 1000 - print( - f"Pos: {tensor_pos} - Size: {tensor_size} - Encryption time:" - f" {duration_ms:.2f}ms, setup: {duration_setup_ms:.2f}ms," - f" {cryptotext_end} bytes, {encryption_header_size} header" - f" size, {chunk_repr}" - ) - - return nonce, cryptotext[:cryptotext_end] - # This task is I/O-bound and dependent on the previous two tasks, # so it goes into the header writer pool. def commit_header( - crc32_future: concurrent.futures.Future, - sha256_future: concurrent.futures.Future, - encrypt_future: concurrent.futures.Future, + crc32_future: Optional[concurrent.futures.Future], + sha256_future: Optional[concurrent.futures.Future], + encrypt_future: Optional[concurrent.futures.Future], ): - crc32 = crc32_future.result(3600) - sha256 = sha256_future.result(3600) - nonce, encrypted = encrypt_future.result(3600) - header.add_crc32(crc32) - header.add_sha256(sha256) - if encrypted is not None: - header.add_xsalsa20(self._salt, nonce, self._crypt_chunk_size) - header.update_data_length(len(encrypted)) + crc32 = sha256 = None + if crc32_future is not None: + crc32 = crc32_future.result(_TIMEOUT) + if sha256_future is not None: + sha256 = sha256_future.result(_TIMEOUT) + if encrypt_future is not None: + encrypt_future.result(_TIMEOUT) + # These must be written only after all other futures complete + # to prevent a race condition from other threads hashing + # a partially-filled-in hash section + if crc32_future is not None: + header.add_crc32(crc32) + if sha256_future is not None: + header.add_sha256(sha256) + if encrypt_future is not None: + header.update_crypt_info() self._pwrite(header.buffer, header_pos) - crc32_task = self._computation_pool.submit(compute_crc32) + hash_tasks = [] + if include_crc32: + crc32_task = self._computation_pool.submit(compute_crc32) + hash_tasks.append(crc32_task) + else: + crc32_task = None sha256_task = self._computation_pool.submit(compute_sha256) - encrypt_task = self._computation_pool.submit(encrypt_tensor) + hash_tasks.append(sha256_task) + self._jobs.extend(hash_tasks) + + def encrypt(prerequisites: Iterable[concurrent.futures.Future]): + fs = concurrent.futures.wait(prerequisites, timeout=_TIMEOUT) + for f in fs.done: + # Raise exceptions + f.result() + for f in fs.not_done: + # Raise timeouts + f.result(0) + try: + encryptor.encrypt_all( + wait=True, + timeout=_TIMEOUT, + ) + except _crypt.CryptographyError as e: + raise CryptographyError("Tensor encryption failed") from e + + # This task is I/O-bound, so it goes into the regular writer pool. + def write_tensor_data( + prerequisite: Optional[concurrent.futures.Future], + ): + if prerequisite is not None: + prerequisite.result(_TIMEOUT) + bytes_written = self._pwrite(tensor_memory, tensor_pos) + with self._tensor_count_update_lock: + self._file_header.tensor_count += 1 + self._file_header.tensor_size += bytes_written + + def decrypt(prerequisite: concurrent.futures.Future): + try: + prerequisite.result(_TIMEOUT) + finally: + # Try to decrypt again even if writing to disk failed + # to avoid exiting with the tensor memory in a modified state + fs = encryptor.decrypt_all(wait=False) + try: + _crypt.ChunkedEncryption.wait_or_raise( + fs, + timeout=_TIMEOUT, + return_when=concurrent.futures.ALL_COMPLETED, + ) + except _crypt.CryptographyError as e: + raise CryptographyError( + "Restoring encrypted tensor data in memory failed" + ) from e + + # Encrypt the tensor memory in-place before writing + if self._encrypted: + encrypt_task = self._encryption_pool.submit(encrypt, hash_tasks) + self._jobs.append(encrypt_task) + else: + encrypt_task = None commit_header_task = self._header_writer_pool.submit( commit_header, crc32_task, sha256_task, encrypt_task ) - self._jobs.extend( - (encrypt_task, crc32_task, sha256_task, commit_header_task) - ) + self._jobs.append(commit_header_task) - # This task is I/O-bound and has no prerequisites, - # so it goes into the regular writer pool. - def write_tensor_data(): - _, encrypted = encrypt_task.result(3600) - if encrypted is not None: - self._pwrite(encrypted, tensor_pos) - else: - self._pwrite(tensor_memory, tensor_pos) - with self._tensor_count_update_lock: - self._file_header.tensor_count += 1 - self._file_header.tensor_size += tensor_memory.nbytes + # Write the potentially-encrypted tensor memory to the file + write_task = self._writer_pool.submit(write_tensor_data, encrypt_task) + self._jobs.append(write_task) + # Decrypt the memory after writing is finished, if it was encrypted + if self._encrypted and not _temporary_buffer: + self._jobs.append(self._decryption_pool.submit(decrypt, write_task)) - self._jobs.append(self._writer_pool.submit(write_tensor_data)) - tensor_encrypted_payload = encrypt_task.result(3600)[1] - if tensor_encrypted_payload is not None: - tensor_endpos = tensor_pos + len(tensor_encrypted_payload) - else: - tensor_endpos = tensor_pos + tensor_size + tensor_endpos = tensor_pos + tensor_size # Update our prologue. if _synchronize: @@ -2713,7 +3363,7 @@ def _transfer(): for t in tensors: if transfer_finished: break - transferred.put(t.cpu().detach(), timeout=3600) + transferred.put(t.cpu().detach(), timeout=_TIMEOUT) else: # Sentinel transferred.put(None) @@ -2736,7 +3386,7 @@ def _interrupt_transfer(): pass return ( - iter(lambda: transferred.get(timeout=3600), None), + iter(lambda: transferred.get(timeout=_TIMEOUT), None), _interrupt_transfer, ) @@ -2754,7 +3404,9 @@ def _bulk_write(self, tensors: Iterable[_WriteSpec]): fallocate = getattr(os, "posix_fallocate", None) if fallocate and self._fd: size = sum(len(t.name) for t in tensors) - size += sum(t.tensor.untyped_storage().size() for t in tensors) + size += sum( + t.tensor.element_size() * t.tensor.nelement() for t in tensors + ) # Rough underestimate of header size header_min_size = 24 size += header_min_size * len(tensors) @@ -2780,6 +3432,9 @@ def _bulk_write(self, tensors: Iterable[_WriteSpec]): idx, name, tensor_type, tensor, callback = tensors.popleft() if tensor.device.type == "cuda": tensor = next(transferred) + temp_tensor = True + else: + temp_tensor = False next_pos = self._write_tensor( self._idx, name, @@ -2787,6 +3442,7 @@ def _bulk_write(self, tensors: Iterable[_WriteSpec]): tensor, _synchronize=False, _start_pos=next_pos, + _temporary_buffer=temp_tensor, ) if callback is not None: callback() diff --git a/tensorizer/stream_io.py b/tensorizer/stream_io.py index 80240df0..dbac71a8 100644 --- a/tensorizer/stream_io.py +++ b/tensorizer/stream_io.py @@ -18,7 +18,6 @@ import boto3 import botocore -import nacl.secret import redis import tensorizer._version as _version @@ -187,104 +186,6 @@ def __hash__(self): return hash(self._curl_flags) -class DecryptedStream(io.RawIOBase): - """ - This class is a file-like object that wraps a mixed stream of encrypted and - decrypted data. It is intended to be called when it is known that the next - read is going to be a decryption operation. - """ - - def __init__( - self, - stream: io.RawIOBase, - key: bytes, - nonce: bytes, - chunk_size: int = 1024 << 8, - ): - self._stream = stream - self._lockbox = nacl.secret.SecretBox(key) - self._nonce_int = int.from_bytes(nonce, "big", signed=False) - self._chunk_size = chunk_size - self._ciphertext_chunk_sz = chunk_size + self._lockbox.MACBYTES - self._ciphertext_buffer = bytearray(self._ciphertext_chunk_sz) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - - def __del__(self): - self.close() - - def tell(self) -> int: - return self._stream.tell() - - def readinto(self, ba: bytearray) -> int: - goal = len(ba) - if goal == 0: - return 0 - # Read in chunks of self._chunk_size and decrypt them into ba - # until we have enough bytes. - ciphertext_offset = 0 - plaintext_offset = 0 - num_chunks = goal // self._chunk_size - if goal % self._chunk_size: - num_chunks += 1 - ciphertext_goal = goal + (num_chunks * self._lockbox.MACBYTES) - - step = 0 - while ciphertext_offset < ciphertext_goal: - step_nonce = self._nonce_int ^ step - step_nonce_bytes = step_nonce.to_bytes( - nacl.secret.SecretBox.NONCE_SIZE, "big", signed=False - ) - if ciphertext_offset + self._ciphertext_chunk_sz > ciphertext_goal: - ciphertext_read_sz = ciphertext_goal - ciphertext_offset - ciphertext = memoryview(self._ciphertext_buffer)[ - :ciphertext_read_sz - ] - else: - ciphertext_read_sz = self._ciphertext_chunk_sz - ciphertext = self._ciphertext_buffer - if ciphertext_read_sz == 0: - break - plaintext_sz = self._chunk_size - if goal - plaintext_offset < plaintext_sz: - plaintext_sz = goal - plaintext_offset - self._stream.readinto(ciphertext) - ba[plaintext_offset : plaintext_offset + plaintext_sz] = ( - self._lockbox.decrypt(ciphertext, step_nonce_bytes) - ) - step += 1 - ciphertext_offset += ciphertext_read_sz - plaintext_offset += plaintext_sz - - def read(self, size=-1) -> bytes: - buf = bytearray(size) - bytes_read = self.readinto(buf) - return bytes(buf[:bytes_read]) - - def writable(self) -> bool: - return False - - def fileno(self) -> int: - return self._stream.fileno() - - def close(self): - # We are a passive wrapper, so we don't close the underlying stream. - pass - - def closed(self): - return self._stream.closed - - def readline(self, size=-1) -> bytes: - return self._stream.readline(size) - - def seek(self, position, whence=SEEK_SET): - self._stream.seek(position, whence) - - class CURLStreamFile(io.RawIOBase): """ CURLStreamFile implements a file-like object around an HTTP download, the @@ -589,6 +490,7 @@ def close(self): self._curl.stdout.close() self._curl = None + @property def closed(self): return self._closed diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 31e4cf56..085c910c 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -11,7 +11,7 @@ import tempfile import time import unittest -from typing import Mapping, NamedTuple, Tuple +from typing import Mapping, NamedTuple, Optional from unittest.mock import patch import torch @@ -29,8 +29,13 @@ stream_io, utils, ) +from tensorizer._crypt import available as encryption_available from tensorizer.serialization import TensorHash, TensorType -from test_stream_io import start_redis, teardown_redis + +try: + from test_stream_io import start_redis, teardown_redis +except ImportError: + from .test_stream_io import start_redis, teardown_redis model_name = "EleutherAI/gpt-neo-125M" num_hellos = 400 @@ -43,30 +48,28 @@ class SerializeMethod(enum.Enum): Module = 1 StateDict = 2 - EncryptedModule = 3 - EncryptedStateDict = 4 + + +class SerializationResult(NamedTuple): + filename: str + orig_sd: dict def serialize_model( model_name: str, device: str, method: SerializeMethod = SerializeMethod.Module, -) -> Tuple[str, dict]: + encryption: Optional[serialization.EncryptionParams] = None, +) -> SerializationResult: model = AutoModelForCausalLM.from_pretrained(model_name).to(device) sd = model.state_dict() out_file = tempfile.NamedTemporaryFile("wb+", delete=False) try: start_time = time.monotonic() - if method is SerializeMethod.EncryptedModule: - serializer = TensorSerializer(out_file, passphrase="test") - else: - serializer = TensorSerializer(out_file) - if method in (SerializeMethod.Module, SerializeMethod.EncryptedModule): + serializer = TensorSerializer(out_file, encryption=encryption) + if method is SerializeMethod.Module: serializer.write_module(model) - elif method in ( - SerializeMethod.StateDict, - SerializeMethod.EncryptedStateDict, - ): + elif method is SerializeMethod.StateDict: serializer.write_state_dict(sd) else: raise ValueError("Invalid serialization method") @@ -76,7 +79,17 @@ def serialize_model( except Exception: os.unlink(out_file.name) raise - return out_file.name, sd + return SerializationResult(out_file.name, sd) + + +@contextlib.contextmanager +@functools.wraps(serialize_model) +def serialize_model_temp(*args, **kwargs): + filename = serialize_model(*args, **kwargs).filename + try: + yield filename + finally: + os.unlink(filename) # Reducing a tensor to a hash makes it faster to compare against the reference @@ -117,10 +130,7 @@ def check_deserialized( allow_subset: bool = False, include_non_persistent_buffers: bool = True, ): - orig_sd = model_digest( - model_name, - include_non_persistent_buffers, - ) + orig_sd = model_digest(model_name, include_non_persistent_buffers) if not allow_subset: test_case.assertEqual( @@ -239,29 +249,109 @@ def test_serialization(self): finally: os.unlink(serialized_model) + @unittest.skipUnless( + encryption_available, + reason="libsodium must be installed to test encryption", + ) def test_encryption(self): - unencrypted_model, orig_sd = serialize_model( - model_name, "cpu", method=SerializeMethod.Module + fixed_salt = bytes(32) + encryption = serialization.EncryptionParams.from_passphrase_fast( + passphrase="test", salt=fixed_salt ) - encrypted_model, orig_sd = serialize_model( - model_name, "cpu", method=SerializeMethod.EncryptedModule + decryption = serialization.DecryptionParams.from_passphrase( + passphrase="test" ) - try: - with open(encrypted_model, "rb") as in_file: - deserialized = TensorDeserializer( - in_file, device="cpu", passphrase="test" - ) - check_deserialized( - self, - deserialized, - model_name, - include_non_persistent_buffers=(True), - ) - deserialized.close() + incorrect_decryption = serialization.DecryptionParams.from_passphrase( + passphrase="tset" + ) + + def _serialize(enc: Optional[serialization.EncryptionParams]): + return serialize_model_temp( + model_name, + default_device, + method=SerializeMethod.Module, + encryption=enc, + ) + + def _test_first_key(obj): + k = next(iter(deserialized.keys())) + if obj[k] is not None: + raise RuntimeError() + + with _serialize(encryption) as encrypted_model: + for lazy_load, plaid_mode in ( + (False, False), + (False, True), + (True, True), + ): + # Ensure that it works when given a passphrase + with self.subTest( + msg="Deserializing with a correct passphrase", + device=default_device, + lazy_load=lazy_load, + plaid_mode=plaid_mode, + ), open(encrypted_model, "rb") as in_file, TensorDeserializer( + in_file, + device=default_device, + lazy_load=lazy_load, + plaid_mode=plaid_mode, + verify_hash=True, + encryption=decryption, + ) as deserialized: + check_deserialized( + self, + deserialized, + model_name, + ) + del deserialized + gc.collect() + + # Ensure that it fails to load when not given a passphrase + with self.subTest( + msg="Deserializing with a missing passphrase" + ), self.assertRaises(serialization.CryptographyError), open( + encrypted_model, "rb" + ) as in_file, TensorDeserializer( + in_file, + device=default_device, + lazy_load=True, + encryption=None, + ) as deserialized: + _test_first_key(deserialized) del deserialized - finally: - os.unlink(unencrypted_model) - os.unlink(encrypted_model) + gc.collect() + + # Ensure that it fails to load when given the wrong passphrase + with self.subTest( + msg="Deserializing with an incorrect passphrase" + ), self.assertRaises(serialization.CryptographyError), open( + encrypted_model, "rb" + ) as in_file, TensorDeserializer( + in_file, + device=default_device, + lazy_load=True, + encryption=incorrect_decryption, + ) as deserialized: + _test_first_key(deserialized) + del deserialized + gc.collect() + + with _serialize(None) as unencrypted_model: + # Ensure that it fails to load an unencrypted model + # when expecting encryption + with self.subTest( + msg="Deserializing an unencrypted model with a passphrase" + ), self.assertRaises(serialization.CryptographyError), open( + unencrypted_model, "rb" + ) as in_file, TensorDeserializer( + in_file, + device=default_device, + lazy_load=True, + encryption=decryption, + ) as deserialized: + _test_first_key(deserialized) + del deserialized + gc.collect() def test_bfloat16(self): shape = (50, 50) @@ -342,8 +432,7 @@ class TestDeserialization(unittest.TestCase): @classmethod def setUpClass(cls): - serialized_model_path, sd = serialize_model(model_name, "cpu") - del sd + serialized_model_path = serialize_model(model_name, "cpu").filename cls._serialized_model_path = serialized_model_path gc.collect() @@ -562,7 +651,7 @@ class TestVerification(unittest.TestCase): @classmethod def setUpClass(cls): - serialized_model_path = serialize_model(model_name, "cpu")[0] + serialized_model_path = serialize_model(model_name, "cpu").filename cls._serialized_model_path = serialized_model_path gc.collect()