Skip to content

Commit

Permalink
review
Browse files Browse the repository at this point in the history
  • Loading branch information
mryab authored and Your Name committed Sep 5, 2023
1 parent d92ca9f commit 12e8fc2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
6 changes: 4 additions & 2 deletions hivemind/compression/floating.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ class Float16Compression(CompressionBase):
FP16_MIN, FP16_MAX = torch.finfo(torch.float16).min, torch.finfo(torch.float16).max

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
assert torch.is_floating_point(tensor) and tensor.dtype != torch.bfloat16
if not torch.is_floating_point(tensor) or tensor.dtype == torch.bfloat16:
raise ValueError(f"{self.__class__.__name__} does not support {tensor.dtype} tensors")

Check warning on line 16 in hivemind/compression/floating.py

View check run for this annotation

Codecov / codecov/patch

hivemind/compression/floating.py#L16

Added line #L16 was not covered by tests
requires_grad = tensor.requires_grad
tensor = tensor.detach().cpu()
dtype_name = tensor.numpy().dtype.name
Expand Down Expand Up @@ -47,7 +48,8 @@ class ScaledFloat16Compression(Float16Compression):
FP32_EPS = torch.finfo(torch.float32).eps

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
assert torch.is_floating_point(tensor) and tensor.dtype != torch.bfloat16
if not torch.is_floating_point(tensor) or tensor.dtype == torch.bfloat16:
raise ValueError(f"{self.__class__.__name__} does not support {tensor.dtype} tensors")

Check warning on line 52 in hivemind/compression/floating.py

View check run for this annotation

Codecov / codecov/patch

hivemind/compression/floating.py#L52

Added line #L52 was not covered by tests
requires_grad = tensor.requires_grad
tensor = tensor.detach().cpu()
dtype_name = tensor.numpy().dtype.name
Expand Down
8 changes: 5 additions & 3 deletions hivemind/compression/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[n
...

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
assert torch.is_floating_point(tensor)
if not torch.is_floating_point(tensor):
raise ValueError(f"{self.__class__.__name__} does not support {tensor.dtype} tensors")

Check warning on line 29 in hivemind/compression/quantization.py

View check run for this annotation

Codecov / codecov/patch

hivemind/compression/quantization.py#L29

Added line #L29 was not covered by tests
quantized, codebook = self.quantize(tensor.detach(), allow_inplace=allow_inplace)
return runtime_pb2.Tensor(
compression=self.compression_type,
Expand Down Expand Up @@ -128,6 +129,7 @@ def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_siz
class BlockwiseQuantization(Quantization):
compression_type = runtime_pb2.BLOCKWISE_8BIT
codebook_dtype, indices_dtype = np.float32, np.uint8
EXTRA_PARAMS = (4096, False, torch.float32, None, None)

def quantize(
self, tensor: torch.Tensor, allow_inplace: bool = False
Expand All @@ -139,7 +141,7 @@ def quantize(
raise ImportError(BNB_MISSING_MESSAGE)

quantized, (absmax, codebook, *extra_params) = quantize_blockwise(tensor, blocksize=4096, nested=False)
assert tuple(extra_params) == (4096, False, tensor.dtype, None, None) # blocksize, nested, dtype, offset, s2
assert tuple(extra_params) == self.EXTRA_PARAMS # blocksize, nested, dtype, offset, state2
return quantized.numpy(), (absmax.numpy(), codebook.numpy())

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
Expand Down Expand Up @@ -185,5 +187,5 @@ def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
absmax = torch.as_tensor(absmax)
codebook = torch.as_tensor(codebook)
quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size))
result = dequantize_blockwise(quantized, (absmax, codebook, 4096, False, torch.float32, None, None))
result = dequantize_blockwise(quantized, (absmax, codebook, *self.EXTRA_PARAMS))
return result.to(getattr(torch, serialized_tensor.dtype)).requires_grad_(serialized_tensor.requires_grad)

0 comments on commit 12e8fc2

Please sign in to comment.