-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix edge cases in (de)serialize_torch_tensor #591
Changes from 5 commits
c69feed
058690e
c621685
3363f11
ee4164f
d92ca9f
12e8fc2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,12 +25,13 @@ def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[n | |
... | ||
|
||
def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor: | ||
assert torch.is_floating_point(tensor) | ||
quantized, codebook = self.quantize(tensor.detach(), allow_inplace=allow_inplace) | ||
return runtime_pb2.Tensor( | ||
compression=self.compression_type, | ||
buffer=b"".join((np.int64(len(codebook)).tobytes(), codebook.tobytes(), quantized.tobytes())), | ||
size=tensor.shape, | ||
dtype=tensor.numpy().dtype.name, | ||
dtype=tensor.data.numpy().dtype.name if tensor.dtype != torch.bfloat16 else "bfloat16", | ||
requires_grad=tensor.requires_grad, | ||
) | ||
|
||
|
@@ -39,8 +40,8 @@ def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor: | |
codebook = np.frombuffer(serialized_tensor.buffer, offset=8, count=codebook_size, dtype=self.codebook_dtype) | ||
quantized = np.frombuffer(serialized_tensor.buffer, offset=8 + codebook.nbytes, dtype=self.indices_dtype) | ||
quantized = torch.as_tensor(quantized, dtype=torch.int64).reshape(tuple(serialized_tensor.size)) | ||
codebook = torch.as_tensor(np.asarray(codebook, dtype=serialized_tensor.dtype)) | ||
return codebook[quantized] | ||
codebook = torch.as_tensor(codebook).to(dtype=getattr(torch, serialized_tensor.dtype)) | ||
return codebook[quantized].requires_grad_(serialized_tensor.requires_grad) | ||
|
||
def estimate_compression_ratio(self, info: CompressionInfo) -> float: | ||
return self.n_bits / torch.finfo(info.descriptor.dtype).bits | ||
|
@@ -59,8 +60,10 @@ class Uniform8BitQuantization(Quantization): | |
compression_type = runtime_pb2.UNIFORM_8BIT | ||
|
||
def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[np.ndarray, np.ndarray]: | ||
assert torch.is_floating_point(tensor) | ||
offset = self.n_bins // 2 | ||
shift = tensor.mean() | ||
tensor = tensor.to(dtype=torch.float32, copy=not allow_inplace) | ||
centered_tensor = tensor.sub_(shift) if allow_inplace else tensor - shift | ||
std_unbiased = centered_tensor.norm() / math.sqrt(centered_tensor.numel() - 1) | ||
scale = self.RANGE_IN_SIGMAS * std_unbiased / self.n_bins | ||
|
@@ -135,14 +138,15 @@ def quantize( | |
except ImportError: | ||
raise ImportError(BNB_MISSING_MESSAGE) | ||
|
||
quantized, (absmax, codebook) = quantize_blockwise(tensor) | ||
quantized, (absmax, codebook, *extra_params) = quantize_blockwise(tensor, blocksize=4096, nested=False) | ||
assert tuple(extra_params) == (4096, False, tensor.dtype, None, None) # blocksize, nested, dtype, offset, s2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can make that tuple on the right a module-level constant? It's used twice in the code, better to make it clear we're using some predefined values There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done, thanks for the suggestion |
||
return quantized.numpy(), (absmax.numpy(), codebook.numpy()) | ||
|
||
def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor: | ||
requires_grad = tensor.requires_grad | ||
tensor = tensor.detach() | ||
dtype_name = str(tensor.dtype).replace("torch.", "") | ||
if tensor.dtype == torch.bfloat16: | ||
tensor = tensor.to(torch.float32) | ||
tensor = tensor.to(torch.float32) | ||
|
||
quantized, (absmax, codebook) = self.quantize(tensor, allow_inplace=allow_inplace) | ||
|
||
|
@@ -157,7 +161,7 @@ def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: b | |
return runtime_pb2.Tensor( | ||
buffer=b"".join(serialized_data), | ||
size=tensor.shape, | ||
requires_grad=tensor.requires_grad, | ||
requires_grad=requires_grad, | ||
dtype=dtype_name, | ||
compression=self.compression_type, | ||
) | ||
|
@@ -181,6 +185,5 @@ def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor: | |
absmax = torch.as_tensor(absmax) | ||
codebook = torch.as_tensor(codebook) | ||
quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size)) | ||
result = dequantize_blockwise(quantized, (absmax, codebook)) # Always returns a float32 tensor | ||
result = result.to(dtype=getattr(torch, serialized_tensor.dtype)) | ||
return result | ||
result = dequantize_blockwise(quantized, (absmax, codebook, 4096, False, torch.float32, None, None)) | ||
return result.to(getattr(torch, serialized_tensor.dtype)).requires_grad_(serialized_tensor.requires_grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason why we should fail with an error in case of bf16 inputs? It is indeed not sensible, but if the user wants to do so, it's probably better to issue a warning instead of flat out refusing to pass that through quantization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added ValueError with a more user-legible reason