diff --git a/hivemind/compression/floating.py b/hivemind/compression/floating.py index ee0c62898..73c37522a 100644 --- a/hivemind/compression/floating.py +++ b/hivemind/compression/floating.py @@ -12,7 +12,8 @@ class Float16Compression(CompressionBase): FP16_MIN, FP16_MAX = torch.finfo(torch.float16).min, torch.finfo(torch.float16).max def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor: - assert torch.is_floating_point(tensor) and tensor.dtype != torch.bfloat16 + if not torch.is_floating_point(tensor) or tensor.dtype == torch.bfloat16: + raise ValueError(f"{self.__class__.__name__} does not support {tensor.dtype} tensors") requires_grad = tensor.requires_grad tensor = tensor.detach().cpu() dtype_name = tensor.numpy().dtype.name @@ -47,7 +48,8 @@ class ScaledFloat16Compression(Float16Compression): FP32_EPS = torch.finfo(torch.float32).eps def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor: - assert torch.is_floating_point(tensor) and tensor.dtype != torch.bfloat16 + if not torch.is_floating_point(tensor) or tensor.dtype == torch.bfloat16: + raise ValueError(f"{self.__class__.__name__} does not support {tensor.dtype} tensors") requires_grad = tensor.requires_grad tensor = tensor.detach().cpu() dtype_name = tensor.numpy().dtype.name diff --git a/hivemind/compression/quantization.py b/hivemind/compression/quantization.py index acb3bc804..257d09bca 100644 --- a/hivemind/compression/quantization.py +++ b/hivemind/compression/quantization.py @@ -25,7 +25,8 @@ def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[n ... def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor: - assert torch.is_floating_point(tensor) + if not torch.is_floating_point(tensor): + raise ValueError(f"{self.__class__.__name__} does not support {tensor.dtype} tensors") quantized, codebook = self.quantize(tensor.detach(), allow_inplace=allow_inplace) return runtime_pb2.Tensor( compression=self.compression_type, @@ -128,6 +129,7 @@ def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_siz class BlockwiseQuantization(Quantization): compression_type = runtime_pb2.BLOCKWISE_8BIT codebook_dtype, indices_dtype = np.float32, np.uint8 + EXTRA_PARAMS = (4096, False, torch.float32, None, None) def quantize( self, tensor: torch.Tensor, allow_inplace: bool = False @@ -139,7 +141,7 @@ def quantize( raise ImportError(BNB_MISSING_MESSAGE) quantized, (absmax, codebook, *extra_params) = quantize_blockwise(tensor, blocksize=4096, nested=False) - assert tuple(extra_params) == (4096, False, tensor.dtype, None, None) # blocksize, nested, dtype, offset, s2 + assert tuple(extra_params) == self.EXTRA_PARAMS # blocksize, nested, dtype, offset, state2 return quantized.numpy(), (absmax.numpy(), codebook.numpy()) def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor: @@ -185,5 +187,5 @@ def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor: absmax = torch.as_tensor(absmax) codebook = torch.as_tensor(codebook) quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size)) - result = dequantize_blockwise(quantized, (absmax, codebook, 4096, False, torch.float32, None, None)) + result = dequantize_blockwise(quantized, (absmax, codebook, *self.EXTRA_PARAMS)) return result.to(getattr(torch, serialized_tensor.dtype)).requires_grad_(serialized_tensor.requires_grad)