Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix edge cases in (de)serialize_torch_tensor #591

Merged
merged 7 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run-benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
pip install -r requirements-dev.txt
- name: Build bitsandbytes
run: |
pip install bitsandbytes==0.37.0
pip install bitsandbytes==0.41.1
- name: Build hivemind
run: |
pip install .
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
pip install -r requirements-dev.txt
- name: Build bitsandbytes
run: |
pip install bitsandbytes==0.37.0
pip install bitsandbytes==0.41.1
- name: Build hivemind
run: |
pip install .
Expand Down Expand Up @@ -93,7 +93,7 @@ jobs:
pip install -r requirements-dev.txt
- name: Build bitsandbytes
run: |
pip install bitsandbytes==0.37.0
pip install bitsandbytes==0.41.1
- name: Build hivemind
run: |
pip install -e . --no-use-pep517
Expand Down
3 changes: 2 additions & 1 deletion hivemind/compression/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class NoCompression(CompressionBase):
compression_type = runtime_pb2.CompressionType.NONE

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
requires_grad = tensor.requires_grad
tensor = tensor.detach()
shape = tensor.shape
dtype_name = str(tensor.dtype).replace("torch.", "")
Expand All @@ -98,7 +99,7 @@ def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: b
buffer=raw_data.numpy().tobytes(),
size=shape,
dtype=dtype_name,
requires_grad=tensor.requires_grad,
requires_grad=requires_grad,
)

def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
Expand Down
27 changes: 19 additions & 8 deletions hivemind/compression/floating.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,29 @@
FP16_MIN, FP16_MAX = torch.finfo(torch.float16).min, torch.finfo(torch.float16).max

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
if not torch.is_floating_point(tensor) or tensor.dtype == torch.bfloat16:
raise ValueError(f"{self.__class__.__name__} does not support {tensor.dtype} tensors")

Check warning on line 16 in hivemind/compression/floating.py

View check run for this annotation

Codecov / codecov/patch

hivemind/compression/floating.py#L16

Added line #L16 was not covered by tests
requires_grad = tensor.requires_grad
tensor = tensor.detach().cpu()
dtype_name = tensor.numpy().dtype.name
tensor = tensor.detach().cpu().float()
tensor = tensor if allow_inplace else tensor.clone()
tensor = tensor.to(torch.float32, copy=not allow_inplace)
tensor = tensor.clamp_(self.FP16_MIN, self.FP16_MAX).to(torch.float16)
return runtime_pb2.Tensor(
compression=self.compression_type,
buffer=tensor.numpy().tobytes(),
size=tensor.shape,
dtype=dtype_name,
requires_grad=tensor.requires_grad,
requires_grad=requires_grad,
)

def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
original_dtype = np.dtype(serialized_tensor.dtype)
array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16)
return torch.as_tensor(np.asarray(array, dtype=original_dtype)).reshape(tuple(serialized_tensor.size))
return (
torch.as_tensor(np.asarray(array, dtype=original_dtype))
.reshape(tuple(serialized_tensor.size))
.requires_grad_(serialized_tensor.requires_grad)
)

def estimate_compression_ratio(self, info: CompressionInfo) -> float:
return 16.0 / get_num_bits(info.descriptor.dtype)
Expand All @@ -41,9 +48,12 @@
FP32_EPS = torch.finfo(torch.float32).eps

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
if not torch.is_floating_point(tensor) or tensor.dtype == torch.bfloat16:
raise ValueError(f"{self.__class__.__name__} does not support {tensor.dtype} tensors")

Check warning on line 52 in hivemind/compression/floating.py

View check run for this annotation

Codecov / codecov/patch

hivemind/compression/floating.py#L52

Added line #L52 was not covered by tests
requires_grad = tensor.requires_grad
tensor = tensor.detach().cpu()
dtype_name = tensor.numpy().dtype.name
tensor = tensor.detach().cpu().float()
tensor = tensor if allow_inplace else tensor.clone()
tensor = tensor.to(dtype=torch.float32, copy=not allow_inplace)
means = torch.mean(tensor, dim=-1, keepdim=True)
tensor.sub_(means)
stds = tensor.norm(dim=-1, keepdim=True) / math.sqrt(tensor.shape[-1])
Expand All @@ -58,7 +68,7 @@
buffer=data,
size=tensor.shape,
dtype=dtype_name,
requires_grad=tensor.requires_grad,
requires_grad=requires_grad,
)

def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
Expand All @@ -77,7 +87,8 @@
tensor = torch.as_tensor(np.asarray(array, dtype=serialized_tensor.dtype)).reshape(
list(serialized_tensor.size)
)
return tensor.mul_(stds).add_(means)
dtype = getattr(torch, serialized_tensor.dtype)
return tensor.mul_(stds).add_(means).to(dtype).requires_grad_(serialized_tensor.requires_grad)


def get_num_bits(dtype: torch.dtype) -> int:
Expand Down
25 changes: 15 additions & 10 deletions hivemind/compression/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
...

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
if not torch.is_floating_point(tensor):
raise ValueError(f"{self.__class__.__name__} does not support {tensor.dtype} tensors")

Check warning on line 29 in hivemind/compression/quantization.py

View check run for this annotation

Codecov / codecov/patch

hivemind/compression/quantization.py#L29

Added line #L29 was not covered by tests
quantized, codebook = self.quantize(tensor.detach(), allow_inplace=allow_inplace)
return runtime_pb2.Tensor(
compression=self.compression_type,
buffer=b"".join((np.int64(len(codebook)).tobytes(), codebook.tobytes(), quantized.tobytes())),
size=tensor.shape,
dtype=tensor.numpy().dtype.name,
dtype=tensor.data.numpy().dtype.name if tensor.dtype != torch.bfloat16 else "bfloat16",
requires_grad=tensor.requires_grad,
)

Expand All @@ -39,8 +41,8 @@
codebook = np.frombuffer(serialized_tensor.buffer, offset=8, count=codebook_size, dtype=self.codebook_dtype)
quantized = np.frombuffer(serialized_tensor.buffer, offset=8 + codebook.nbytes, dtype=self.indices_dtype)
quantized = torch.as_tensor(quantized, dtype=torch.int64).reshape(tuple(serialized_tensor.size))
codebook = torch.as_tensor(np.asarray(codebook, dtype=serialized_tensor.dtype))
return codebook[quantized]
codebook = torch.as_tensor(codebook).to(dtype=getattr(torch, serialized_tensor.dtype))
return codebook[quantized].requires_grad_(serialized_tensor.requires_grad)

def estimate_compression_ratio(self, info: CompressionInfo) -> float:
return self.n_bits / torch.finfo(info.descriptor.dtype).bits
Expand All @@ -59,8 +61,10 @@
compression_type = runtime_pb2.UNIFORM_8BIT

def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[np.ndarray, np.ndarray]:
assert torch.is_floating_point(tensor)
offset = self.n_bins // 2
shift = tensor.mean()
tensor = tensor.to(dtype=torch.float32, copy=not allow_inplace)
centered_tensor = tensor.sub_(shift) if allow_inplace else tensor - shift
std_unbiased = centered_tensor.norm() / math.sqrt(centered_tensor.numel() - 1)
scale = self.RANGE_IN_SIGMAS * std_unbiased / self.n_bins
Expand Down Expand Up @@ -125,6 +129,7 @@
class BlockwiseQuantization(Quantization):
compression_type = runtime_pb2.BLOCKWISE_8BIT
codebook_dtype, indices_dtype = np.float32, np.uint8
EXTRA_PARAMS = (4096, False, torch.float32, None, None)

def quantize(
self, tensor: torch.Tensor, allow_inplace: bool = False
Expand All @@ -135,14 +140,15 @@
except ImportError:
raise ImportError(BNB_MISSING_MESSAGE)

quantized, (absmax, codebook) = quantize_blockwise(tensor)
quantized, (absmax, codebook, *extra_params) = quantize_blockwise(tensor, blocksize=4096, nested=False)
assert tuple(extra_params) == self.EXTRA_PARAMS # blocksize, nested, dtype, offset, state2
return quantized.numpy(), (absmax.numpy(), codebook.numpy())

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
requires_grad = tensor.requires_grad
tensor = tensor.detach()
dtype_name = str(tensor.dtype).replace("torch.", "")
if tensor.dtype == torch.bfloat16:
tensor = tensor.to(torch.float32)
tensor = tensor.to(torch.float32)

quantized, (absmax, codebook) = self.quantize(tensor, allow_inplace=allow_inplace)

Expand All @@ -157,7 +163,7 @@
return runtime_pb2.Tensor(
buffer=b"".join(serialized_data),
size=tensor.shape,
requires_grad=tensor.requires_grad,
requires_grad=requires_grad,
dtype=dtype_name,
compression=self.compression_type,
)
Expand All @@ -181,6 +187,5 @@
absmax = torch.as_tensor(absmax)
codebook = torch.as_tensor(codebook)
quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size))
result = dequantize_blockwise(quantized, (absmax, codebook)) # Always returns a float32 tensor
result = result.to(dtype=getattr(torch, serialized_tensor.dtype))
return result
result = dequantize_blockwise(quantized, (absmax, codebook, *self.EXTRA_PARAMS))
return result.to(getattr(torch, serialized_tensor.dtype)).requires_grad_(serialized_tensor.requires_grad)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def run(self):
with open("requirements-docs.txt") as docs_requirements_file:
extras["docs"] = list(map(str, parse_requirements(docs_requirements_file)))

extras["bitsandbytes"] = ["bitsandbytes~=0.37.0"]
extras["bitsandbytes"] = ["bitsandbytes~=0.41.1"]

extras["all"] = extras["dev"] + extras["docs"] + extras["bitsandbytes"]

Expand Down
39 changes: 37 additions & 2 deletions tests/test_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,57 @@ def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
assert len(chunks) == max((len(serialized_tensor.buffer) - 1) // chunk_size + 1, 1)
restored = combine_from_streaming(chunks)
result = deserialize_torch_tensor(restored)
assert result.dtype == tensor.dtype, compression
assert result.requires_grad == tensor.requires_grad
assert torch.allclose(result, tensor, rtol=rtol, atol=atol)
assert result.dtype == tensor.dtype


@pytest.mark.forked
def test_serialize_tensor():
tensor = torch.randn(512, 12288)
tensor = torch.randn(512, 12288, requires_grad=True)
for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10**9]:
_check(tensor, CompressionType.NONE, chunk_size=chunk_size)

_check(tensor, CompressionType.FLOAT16, rtol=0.0, atol=1e-2)
_check(torch.randint(0, 100, (512, 1, 1)), CompressionType.NONE)
_check(torch.randn(10, 20), CompressionType.MEANSTD_16BIT, atol=0.1)
_check(torch.tensor(1.0), CompressionType.NONE)
_check(torch.tensor(1.0), CompressionType.FLOAT16)


@pytest.mark.parametrize(
"dtype",
[
torch.float32,
torch.float16,
torch.bfloat16,
torch.float64,
torch.complex64,
torch.int64,
torch.int32,
torch.uint8,
torch.bool,
],
)
@pytest.mark.parametrize("requires_grad", [False, True])
@pytest.mark.forked
def test_serialize_tensor_properties(dtype: torch.dtype, requires_grad: bool):
tensor = torch.randn(123, 45, requires_grad=requires_grad).to(dtype)
if dtype == torch.bfloat16:
compression_types = [
type
for type in CompressionType.values()
if type not in (CompressionType.FLOAT16, CompressionType.MEANSTD_16BIT)
]
elif torch.is_floating_point(tensor): # nb: complex and qint data types are not is_floating_point
compression_types = CompressionType.values()
else:
compression_types = [CompressionType.NONE]

for compression_type in compression_types:
_check(tensor, compression_type, atol=float("inf"))


@pytest.mark.parametrize("use_legacy_bfloat16", [True, False])
@pytest.mark.parametrize("tensor_size", [(4096, 16), (0, 0)])
@pytest.mark.forked
Expand Down
Loading