Skip to content

Commit

Permalink
Fix edge cases in (de)serialize_torch_tensor (#591)
Browse files Browse the repository at this point in the history
* serialize with requires_grad
* ensure that all compression methods return tensor of the original dtype
* test that all compression methods preserve dtype and requires_grad


---------

Co-authored-by: Your Name <you@example.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
  • Loading branch information
3 people authored Sep 5, 2023
1 parent 64f1f1e commit 2873252
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run-benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
pip install -r requirements-dev.txt
- name: Build bitsandbytes
run: |
pip install bitsandbytes==0.37.0
pip install bitsandbytes==0.41.1
- name: Build hivemind
run: |
pip install .
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
pip install -r requirements-dev.txt
- name: Build bitsandbytes
run: |
pip install bitsandbytes==0.37.0
pip install bitsandbytes==0.41.1
- name: Build hivemind
run: |
pip install .
Expand Down Expand Up @@ -93,7 +93,7 @@ jobs:
pip install -r requirements-dev.txt
- name: Build bitsandbytes
run: |
pip install bitsandbytes==0.37.0
pip install bitsandbytes==0.41.1
- name: Build hivemind
run: |
pip install -e . --no-use-pep517
Expand Down
3 changes: 2 additions & 1 deletion hivemind/compression/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class NoCompression(CompressionBase):
compression_type = runtime_pb2.CompressionType.NONE

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
requires_grad = tensor.requires_grad
tensor = tensor.detach()
shape = tensor.shape
dtype_name = str(tensor.dtype).replace("torch.", "")
Expand All @@ -98,7 +99,7 @@ def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: b
buffer=raw_data.numpy().tobytes(),
size=shape,
dtype=dtype_name,
requires_grad=tensor.requires_grad,
requires_grad=requires_grad,
)

def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
Expand Down
27 changes: 19 additions & 8 deletions hivemind/compression/floating.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,29 @@ class Float16Compression(CompressionBase):
FP16_MIN, FP16_MAX = torch.finfo(torch.float16).min, torch.finfo(torch.float16).max

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
if not torch.is_floating_point(tensor) or tensor.dtype == torch.bfloat16:
raise ValueError(f"{self.__class__.__name__} does not support {tensor.dtype} tensors")
requires_grad = tensor.requires_grad
tensor = tensor.detach().cpu()
dtype_name = tensor.numpy().dtype.name
tensor = tensor.detach().cpu().float()
tensor = tensor if allow_inplace else tensor.clone()
tensor = tensor.to(torch.float32, copy=not allow_inplace)
tensor = tensor.clamp_(self.FP16_MIN, self.FP16_MAX).to(torch.float16)
return runtime_pb2.Tensor(
compression=self.compression_type,
buffer=tensor.numpy().tobytes(),
size=tensor.shape,
dtype=dtype_name,
requires_grad=tensor.requires_grad,
requires_grad=requires_grad,
)

def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
original_dtype = np.dtype(serialized_tensor.dtype)
array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16)
return torch.as_tensor(np.asarray(array, dtype=original_dtype)).reshape(tuple(serialized_tensor.size))
return (
torch.as_tensor(np.asarray(array, dtype=original_dtype))
.reshape(tuple(serialized_tensor.size))
.requires_grad_(serialized_tensor.requires_grad)
)

def estimate_compression_ratio(self, info: CompressionInfo) -> float:
return 16.0 / get_num_bits(info.descriptor.dtype)
Expand All @@ -41,9 +48,12 @@ class ScaledFloat16Compression(Float16Compression):
FP32_EPS = torch.finfo(torch.float32).eps

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
if not torch.is_floating_point(tensor) or tensor.dtype == torch.bfloat16:
raise ValueError(f"{self.__class__.__name__} does not support {tensor.dtype} tensors")
requires_grad = tensor.requires_grad
tensor = tensor.detach().cpu()
dtype_name = tensor.numpy().dtype.name
tensor = tensor.detach().cpu().float()
tensor = tensor if allow_inplace else tensor.clone()
tensor = tensor.to(dtype=torch.float32, copy=not allow_inplace)
means = torch.mean(tensor, dim=-1, keepdim=True)
tensor.sub_(means)
stds = tensor.norm(dim=-1, keepdim=True) / math.sqrt(tensor.shape[-1])
Expand All @@ -58,7 +68,7 @@ def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: b
buffer=data,
size=tensor.shape,
dtype=dtype_name,
requires_grad=tensor.requires_grad,
requires_grad=requires_grad,
)

def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
Expand All @@ -77,7 +87,8 @@ def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
tensor = torch.as_tensor(np.asarray(array, dtype=serialized_tensor.dtype)).reshape(
list(serialized_tensor.size)
)
return tensor.mul_(stds).add_(means)
dtype = getattr(torch, serialized_tensor.dtype)
return tensor.mul_(stds).add_(means).to(dtype).requires_grad_(serialized_tensor.requires_grad)


def get_num_bits(dtype: torch.dtype) -> int:
Expand Down
25 changes: 15 additions & 10 deletions hivemind/compression/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[n
...

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
if not torch.is_floating_point(tensor):
raise ValueError(f"{self.__class__.__name__} does not support {tensor.dtype} tensors")
quantized, codebook = self.quantize(tensor.detach(), allow_inplace=allow_inplace)
return runtime_pb2.Tensor(
compression=self.compression_type,
buffer=b"".join((np.int64(len(codebook)).tobytes(), codebook.tobytes(), quantized.tobytes())),
size=tensor.shape,
dtype=tensor.numpy().dtype.name,
dtype=tensor.data.numpy().dtype.name if tensor.dtype != torch.bfloat16 else "bfloat16",
requires_grad=tensor.requires_grad,
)

Expand All @@ -39,8 +41,8 @@ def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
codebook = np.frombuffer(serialized_tensor.buffer, offset=8, count=codebook_size, dtype=self.codebook_dtype)
quantized = np.frombuffer(serialized_tensor.buffer, offset=8 + codebook.nbytes, dtype=self.indices_dtype)
quantized = torch.as_tensor(quantized, dtype=torch.int64).reshape(tuple(serialized_tensor.size))
codebook = torch.as_tensor(np.asarray(codebook, dtype=serialized_tensor.dtype))
return codebook[quantized]
codebook = torch.as_tensor(codebook).to(dtype=getattr(torch, serialized_tensor.dtype))
return codebook[quantized].requires_grad_(serialized_tensor.requires_grad)

def estimate_compression_ratio(self, info: CompressionInfo) -> float:
return self.n_bits / torch.finfo(info.descriptor.dtype).bits
Expand All @@ -59,8 +61,10 @@ class Uniform8BitQuantization(Quantization):
compression_type = runtime_pb2.UNIFORM_8BIT

def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[np.ndarray, np.ndarray]:
assert torch.is_floating_point(tensor)
offset = self.n_bins // 2
shift = tensor.mean()
tensor = tensor.to(dtype=torch.float32, copy=not allow_inplace)
centered_tensor = tensor.sub_(shift) if allow_inplace else tensor - shift
std_unbiased = centered_tensor.norm() / math.sqrt(centered_tensor.numel() - 1)
scale = self.RANGE_IN_SIGMAS * std_unbiased / self.n_bins
Expand Down Expand Up @@ -125,6 +129,7 @@ def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_siz
class BlockwiseQuantization(Quantization):
compression_type = runtime_pb2.BLOCKWISE_8BIT
codebook_dtype, indices_dtype = np.float32, np.uint8
EXTRA_PARAMS = (4096, False, torch.float32, None, None)

def quantize(
self, tensor: torch.Tensor, allow_inplace: bool = False
Expand All @@ -135,14 +140,15 @@ def quantize(
except ImportError:
raise ImportError(BNB_MISSING_MESSAGE)

quantized, (absmax, codebook) = quantize_blockwise(tensor)
quantized, (absmax, codebook, *extra_params) = quantize_blockwise(tensor, blocksize=4096, nested=False)
assert tuple(extra_params) == self.EXTRA_PARAMS # blocksize, nested, dtype, offset, state2
return quantized.numpy(), (absmax.numpy(), codebook.numpy())

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
requires_grad = tensor.requires_grad
tensor = tensor.detach()
dtype_name = str(tensor.dtype).replace("torch.", "")
if tensor.dtype == torch.bfloat16:
tensor = tensor.to(torch.float32)
tensor = tensor.to(torch.float32)

quantized, (absmax, codebook) = self.quantize(tensor, allow_inplace=allow_inplace)

Expand All @@ -157,7 +163,7 @@ def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: b
return runtime_pb2.Tensor(
buffer=b"".join(serialized_data),
size=tensor.shape,
requires_grad=tensor.requires_grad,
requires_grad=requires_grad,
dtype=dtype_name,
compression=self.compression_type,
)
Expand All @@ -181,6 +187,5 @@ def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
absmax = torch.as_tensor(absmax)
codebook = torch.as_tensor(codebook)
quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size))
result = dequantize_blockwise(quantized, (absmax, codebook)) # Always returns a float32 tensor
result = result.to(dtype=getattr(torch, serialized_tensor.dtype))
return result
result = dequantize_blockwise(quantized, (absmax, codebook, *self.EXTRA_PARAMS))
return result.to(getattr(torch, serialized_tensor.dtype)).requires_grad_(serialized_tensor.requires_grad)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def run(self):
with open("requirements-docs.txt") as docs_requirements_file:
extras["docs"] = list(map(str, parse_requirements(docs_requirements_file)))

extras["bitsandbytes"] = ["bitsandbytes~=0.37.0"]
extras["bitsandbytes"] = ["bitsandbytes~=0.41.1"]

extras["all"] = extras["dev"] + extras["docs"] + extras["bitsandbytes"]

Expand Down
39 changes: 37 additions & 2 deletions tests/test_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,57 @@ def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
assert len(chunks) == max((len(serialized_tensor.buffer) - 1) // chunk_size + 1, 1)
restored = combine_from_streaming(chunks)
result = deserialize_torch_tensor(restored)
assert result.dtype == tensor.dtype, compression
assert result.requires_grad == tensor.requires_grad
assert torch.allclose(result, tensor, rtol=rtol, atol=atol)
assert result.dtype == tensor.dtype


@pytest.mark.forked
def test_serialize_tensor():
tensor = torch.randn(512, 12288)
tensor = torch.randn(512, 12288, requires_grad=True)
for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10**9]:
_check(tensor, CompressionType.NONE, chunk_size=chunk_size)

_check(tensor, CompressionType.FLOAT16, rtol=0.0, atol=1e-2)
_check(torch.randint(0, 100, (512, 1, 1)), CompressionType.NONE)
_check(torch.randn(10, 20), CompressionType.MEANSTD_16BIT, atol=0.1)
_check(torch.tensor(1.0), CompressionType.NONE)
_check(torch.tensor(1.0), CompressionType.FLOAT16)


@pytest.mark.parametrize(
"dtype",
[
torch.float32,
torch.float16,
torch.bfloat16,
torch.float64,
torch.complex64,
torch.int64,
torch.int32,
torch.uint8,
torch.bool,
],
)
@pytest.mark.parametrize("requires_grad", [False, True])
@pytest.mark.forked
def test_serialize_tensor_properties(dtype: torch.dtype, requires_grad: bool):
tensor = torch.randn(123, 45, requires_grad=requires_grad).to(dtype)
if dtype == torch.bfloat16:
compression_types = [
type
for type in CompressionType.values()
if type not in (CompressionType.FLOAT16, CompressionType.MEANSTD_16BIT)
]
elif torch.is_floating_point(tensor): # nb: complex and qint data types are not is_floating_point
compression_types = CompressionType.values()
else:
compression_types = [CompressionType.NONE]

for compression_type in compression_types:
_check(tensor, compression_type, atol=float("inf"))


@pytest.mark.parametrize("use_legacy_bfloat16", [True, False])
@pytest.mark.parametrize("tensor_size", [(4096, 16), (0, 0)])
@pytest.mark.forked
Expand Down

0 comments on commit 2873252

Please sign in to comment.