Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix edge cases in (de)serialize_torch_tensor #591

Merged
merged 7 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion hivemind/compression/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class NoCompression(CompressionBase):
compression_type = runtime_pb2.CompressionType.NONE

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
requires_grad = tensor.requires_grad
tensor = tensor.detach()
shape = tensor.shape
dtype_name = str(tensor.dtype).replace("torch.", "")
Expand All @@ -98,7 +99,7 @@ def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: b
buffer=raw_data.numpy().tobytes(),
size=shape,
dtype=dtype_name,
requires_grad=tensor.requires_grad,
requires_grad=requires_grad,
)

def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
Expand Down
25 changes: 17 additions & 8 deletions hivemind/compression/floating.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,28 @@ class Float16Compression(CompressionBase):
FP16_MIN, FP16_MAX = torch.finfo(torch.float16).min, torch.finfo(torch.float16).max

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
assert torch.is_floating_point(tensor) and tensor.dtype != torch.bfloat16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why we should fail with an error in case of bf16 inputs? It is indeed not sensible, but if the user wants to do so, it's probably better to issue a warning instead of flat out refusing to pass that through quantization

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added ValueError with a more user-legible reason

requires_grad = tensor.requires_grad
tensor = tensor.detach().cpu()
dtype_name = tensor.numpy().dtype.name
tensor = tensor.detach().cpu().float()
tensor = tensor if allow_inplace else tensor.clone()
tensor = tensor.to(torch.float32, copy=not allow_inplace)
tensor = tensor.clamp_(self.FP16_MIN, self.FP16_MAX).to(torch.float16)
return runtime_pb2.Tensor(
compression=self.compression_type,
buffer=tensor.numpy().tobytes(),
size=tensor.shape,
dtype=dtype_name,
requires_grad=tensor.requires_grad,
requires_grad=requires_grad,
)

def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
original_dtype = np.dtype(serialized_tensor.dtype)
array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16)
return torch.as_tensor(np.asarray(array, dtype=original_dtype)).reshape(tuple(serialized_tensor.size))
return (
torch.as_tensor(np.asarray(array, dtype=original_dtype))
.reshape(tuple(serialized_tensor.size))
.requires_grad_(serialized_tensor.requires_grad)
)

def estimate_compression_ratio(self, info: CompressionInfo) -> float:
return 16.0 / get_num_bits(info.descriptor.dtype)
Expand All @@ -41,9 +47,11 @@ class ScaledFloat16Compression(Float16Compression):
FP32_EPS = torch.finfo(torch.float32).eps

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
assert torch.is_floating_point(tensor) and tensor.dtype != torch.bfloat16
requires_grad = tensor.requires_grad
tensor = tensor.detach().cpu()
dtype_name = tensor.numpy().dtype.name
tensor = tensor.detach().cpu().float()
tensor = tensor if allow_inplace else tensor.clone()
tensor = tensor.to(dtype=torch.float32, copy=not allow_inplace)
means = torch.mean(tensor, dim=-1, keepdim=True)
tensor.sub_(means)
stds = tensor.norm(dim=-1, keepdim=True) / math.sqrt(tensor.shape[-1])
Expand All @@ -58,7 +66,7 @@ def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: b
buffer=data,
size=tensor.shape,
dtype=dtype_name,
requires_grad=tensor.requires_grad,
requires_grad=requires_grad,
)

def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
Expand All @@ -77,7 +85,8 @@ def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
tensor = torch.as_tensor(np.asarray(array, dtype=serialized_tensor.dtype)).reshape(
list(serialized_tensor.size)
)
return tensor.mul_(stds).add_(means)
dtype = getattr(torch, serialized_tensor.dtype)
return tensor.mul_(stds).add_(means).to(dtype).requires_grad_(serialized_tensor.requires_grad)


def get_num_bits(dtype: torch.dtype) -> int:
Expand Down
23 changes: 13 additions & 10 deletions hivemind/compression/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[n
...

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
assert torch.is_floating_point(tensor)
quantized, codebook = self.quantize(tensor.detach(), allow_inplace=allow_inplace)
return runtime_pb2.Tensor(
compression=self.compression_type,
buffer=b"".join((np.int64(len(codebook)).tobytes(), codebook.tobytes(), quantized.tobytes())),
size=tensor.shape,
dtype=tensor.numpy().dtype.name,
dtype=tensor.data.numpy().dtype.name if tensor.dtype != torch.bfloat16 else "bfloat16",
requires_grad=tensor.requires_grad,
)

Expand All @@ -39,8 +40,8 @@ def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
codebook = np.frombuffer(serialized_tensor.buffer, offset=8, count=codebook_size, dtype=self.codebook_dtype)
quantized = np.frombuffer(serialized_tensor.buffer, offset=8 + codebook.nbytes, dtype=self.indices_dtype)
quantized = torch.as_tensor(quantized, dtype=torch.int64).reshape(tuple(serialized_tensor.size))
codebook = torch.as_tensor(np.asarray(codebook, dtype=serialized_tensor.dtype))
return codebook[quantized]
codebook = torch.as_tensor(codebook).to(dtype=getattr(torch, serialized_tensor.dtype))
return codebook[quantized].requires_grad_(serialized_tensor.requires_grad)

def estimate_compression_ratio(self, info: CompressionInfo) -> float:
return self.n_bits / torch.finfo(info.descriptor.dtype).bits
Expand All @@ -59,8 +60,10 @@ class Uniform8BitQuantization(Quantization):
compression_type = runtime_pb2.UNIFORM_8BIT

def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[np.ndarray, np.ndarray]:
assert torch.is_floating_point(tensor)
offset = self.n_bins // 2
shift = tensor.mean()
tensor = tensor.to(dtype=torch.float32, copy=not allow_inplace)
centered_tensor = tensor.sub_(shift) if allow_inplace else tensor - shift
std_unbiased = centered_tensor.norm() / math.sqrt(centered_tensor.numel() - 1)
scale = self.RANGE_IN_SIGMAS * std_unbiased / self.n_bins
Expand Down Expand Up @@ -135,14 +138,15 @@ def quantize(
except ImportError:
raise ImportError(BNB_MISSING_MESSAGE)

quantized, (absmax, codebook) = quantize_blockwise(tensor)
quantized, (absmax, codebook, *extra_params) = quantize_blockwise(tensor, blocksize=4096, nested=False)
assert tuple(extra_params) == (4096, False, tensor.dtype, None, None) # blocksize, nested, dtype, offset, s2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can make that tuple on the right a module-level constant? It's used twice in the code, better to make it clear we're using some predefined values

Copy link
Member Author

@justheuristic justheuristic Sep 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thanks for the suggestion

return quantized.numpy(), (absmax.numpy(), codebook.numpy())

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
requires_grad = tensor.requires_grad
tensor = tensor.detach()
dtype_name = str(tensor.dtype).replace("torch.", "")
if tensor.dtype == torch.bfloat16:
tensor = tensor.to(torch.float32)
tensor = tensor.to(torch.float32)

quantized, (absmax, codebook) = self.quantize(tensor, allow_inplace=allow_inplace)

Expand All @@ -157,7 +161,7 @@ def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: b
return runtime_pb2.Tensor(
buffer=b"".join(serialized_data),
size=tensor.shape,
requires_grad=tensor.requires_grad,
requires_grad=requires_grad,
dtype=dtype_name,
compression=self.compression_type,
)
Expand All @@ -181,6 +185,5 @@ def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
absmax = torch.as_tensor(absmax)
codebook = torch.as_tensor(codebook)
quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size))
result = dequantize_blockwise(quantized, (absmax, codebook)) # Always returns a float32 tensor
result = result.to(dtype=getattr(torch, serialized_tensor.dtype))
return result
result = dequantize_blockwise(quantized, (absmax, codebook, 4096, False, torch.float32, None, None))
return result.to(getattr(torch, serialized_tensor.dtype)).requires_grad_(serialized_tensor.requires_grad)
39 changes: 37 additions & 2 deletions tests/test_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,57 @@ def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
assert len(chunks) == max((len(serialized_tensor.buffer) - 1) // chunk_size + 1, 1)
restored = combine_from_streaming(chunks)
result = deserialize_torch_tensor(restored)
assert result.dtype == tensor.dtype, compression
assert result.requires_grad == tensor.requires_grad
assert torch.allclose(result, tensor, rtol=rtol, atol=atol)
assert result.dtype == tensor.dtype


@pytest.mark.forked
def test_serialize_tensor():
tensor = torch.randn(512, 12288)
tensor = torch.randn(512, 12288, requires_grad=True)
for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10**9]:
_check(tensor, CompressionType.NONE, chunk_size=chunk_size)

_check(tensor, CompressionType.FLOAT16, rtol=0.0, atol=1e-2)
_check(torch.randint(0, 100, (512, 1, 1)), CompressionType.NONE)
_check(torch.randn(10, 20), CompressionType.MEANSTD_16BIT, atol=0.1)
_check(torch.tensor(1.0), CompressionType.NONE)
_check(torch.tensor(1.0), CompressionType.FLOAT16)


@pytest.mark.parametrize(
"dtype",
[
torch.float32,
torch.float16,
torch.bfloat16,
torch.float64,
torch.complex64,
torch.int64,
torch.int32,
torch.uint8,
torch.bool,
],
)
@pytest.mark.parametrize("requires_grad", [False, True])
@pytest.mark.forked
def test_serialize_tensor_properties(dtype: torch.dtype, requires_grad: bool):
tensor = torch.randn(123, 45, requires_grad=requires_grad).to(dtype)
if dtype == torch.bfloat16:
compression_types = [
type
for type in CompressionType.values()
if type not in (CompressionType.FLOAT16, CompressionType.MEANSTD_16BIT)
]
elif torch.is_floating_point(tensor): # nb: complex and qint data types are not is_floating_point
compression_types = CompressionType.values()
else:
compression_types = [CompressionType.NONE]

for compression_type in compression_types:
_check(tensor, compression_type, atol=float("inf"))


@pytest.mark.parametrize("use_legacy_bfloat16", [True, False])
@pytest.mark.parametrize("tensor_size", [(4096, 16), (0, 0)])
@pytest.mark.forked
Expand Down
Loading