-
Notifications
You must be signed in to change notification settings - Fork 178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
What should .dtype
for tensor subclass return?
#442
Comments
yeah good question, so as @msaroufim mentioned the reason that we are returning a original (floating point) dtype is to make autograd engine happy, autograd engine today only works with floating point dtypes. nf4 tensor need to work with autograd I think so that's why it has to return a float dtype. To really support this properly, we'd need to have generalized torch.dtype (e.g. "nf4", "affine_quantized (with some extra information)"), and basically enable open registration of torch.dtype. And I remember the blocker according to @ezyang is that the change will touch a lot of pytorch C++ core code and it's not high pri enough for people to work on it. I'm not exactly sure about all the options we have here though. @cpuhrsch also has been pushing for having nf4_tensor.dtype to return something like "nf4" instead of a floating point dtype as well. In summary, I think there are two things if we need to make it work:
maybe @ezyang or @albanD can you talk about the work that might be involved here? and thoughts on the benefits of enabling this? |
The main dilemma in my head for this is that we don't have a single thing we expect from this dtype field:
I don't think there is a right or wrong answer here and only a tradeoff depending on what you want to use. But that makes the large amount of work for custom dtypes a lot less attractive unfortunately. Maybe another option here could be to introduce a new concept of storage type which will represent how the data is stored for a particular Tensor. |
This sounds pretty reasonable to me, separating out the concepts of stype and dtype seem likes an important concept to start explaining to end users because we don't want to abstract away the concept of bitpacking completely. Stypes would be something that changes rarely because it relies on hardware supports but dtypes is something that's changing quite rapidly. For example with this work we're support intX and fpX
I also had a recent example here where I'm quantizing to int8 weight only and just inspecting the state dict is not enough for me to know that the quantization happened correctly since at no point does int8 show up when I Granted we can workaround this by having a better import torch
import torchao
import copy
import tempfile
class ToyLinearModel(torch.nn.Module):
def __init__(self, m=64, n=32, k=64):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)
def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"):
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
m = ToyLinearModel().eval().to(torch.bfloat16)
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16)
m = torchao.quantize(m, torchao.quantization.quant_api.int8_weight_only())
print(m.state_dict())
ref = m(*example_inputs)
with tempfile.NamedTemporaryFile() as f:
torch.save(m.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)
m_copy.load_state_dict(state_dict, assign=True)
print(m_copy.state_dict())
res = m_copy(*example_inputs)
|
I think dtype and stype concepts will suffice for INT8, but they are not quite satisfactory for sub-byte dtypes. Pretty much all sub-byte dtypes will use "untyped" uint8 as storage. Even if there is hardware support in the future, I think it will come in the form of specialized instruction instead of storage (e.g. matmul for uint4x2) - correct me if I'm wrong. At least for now, is it correct to say that |
@gau-nernst That's right. It is kind of unintuitive but it's what integrates best with how the autograd system works right now. In fact, you are best to pick the dtype that you want the gradients to be. |
My implementation abstracts away the details of whether the Intx Tensor is int2,3,4 etc. because from torch's perspective, that number is only important for internal bit packing. As long as the user knows what tensor it is (i print it out in repr) I think its fine that dtype returns the container dtype |
I think main question I have for I have a slightly different proposal, for the use cases we have:
Would this work? |
I like what @jerryzh168 proposed. But I think if we are to make changes in PyTorch core, we can afford to be even more specific. Even phrases like "semantic dtype" or "pretend dtype" are somewhat ambiguous. At least for now, we know it is used by autograd, I think we can use something like For other PyTorch subsystems, like autocast, optimizer, distributed as @albanD mentioned, we should investigate how it interacts with tensor subclass dtype and other subsystems. For optimizer, I'm experimenting with implementing 8-bit Adam using tensor subclass, and it seems like it doesn't care about dtype, as it only involves math calculations (at least for 1 GPU training, no AMP). There are some checks for complex and sparse tensor, but for now that shouldn't be a concern I think. For For converting from tensor subclass to standard dtype via ao/torchao/dtypes/nf4tensor.py Lines 284 to 289 in dee13e1
But I'm not entirely sure if that is sufficient for |
We could add such a field to Tensor, but I'm not sure it's going to help much for a few reasons:
From there, I always come back to a simpler solution being: have a subclass with a pretend-dtype and it's cls that define the behavior, then you can "hide" from the systems that can "just work" and change the behavior of the ones you want via the subclass overrides that all these systems have (or that we're building in FSDP2 for example). I'm not sure what's the best way forward here though, curious what you think! |
I can't comment much on the implementation and internal sides of PyTorch as I'm not qualified for it and I'm not knowledgeable about it. I just want to re-emphasize that from a user's perspective, having a On the implementation side, just a noob question. Can't we introduce a concept of "legacy dtype" that basically means what dtype is today (always a native PyTorch dtype, the name is up to discussion) and a "(new) dtype" which subclasses can override to provide some useful metadata for subclass consumers (can't find a better word for "new", but you get the idea). And then 2 options
When writing this out like this, option 1 seems more attractive as it is less intrusive. Hope to hear opinions from the PyTorch team. |
Agreed. @gau-nernst, what do you expect to see when you call |
@vkuzo Having "Dtype of the raw stored elements" is not helpful, as I pointed out previously, that most custom dtype impl will just use
Tbh I can't think of any at the moment. Tensor subclass have varied applications other than custom dtype, such as torchvision subclasses and DTensor. So enforcing (or recommending) all subclasses to return something consistent also feels restrictive. People with more experience with tensor subclass can chime in. I guess having custom dtype implemented as tensor subclass will undoubtedly create conflicts with |
are these referring to autograd? I feel for autograd we may not necessarily wants the same dtype as the forward dtype, e.g. if I'm using int4 for forward, I may still want higher precision like bfloat16 for gradient, but I feel it would be better for that to be an explicit decision from user instead of not being transparent.
I feel these are a bit orthogonal to the problem we are discussing here actually, even if we have a generalized dtype that returns "nf4", "affine_quantized(dtype=..., ...)", we still need to make sure our new tensor subclass/dtype works with all subsystems if we want to use it in these subsystems. |
Is there any way we can inspect the dtype that is mimicked at this moment? I am trying to provide checks in my code to ensure we are quantizing properly, however, even the repr() function does not give us important information such as integer type, number of bits, etc: linear_1.weight.__repr__()
|
since number of bits is represented by |
What is the recommended way to show the dtype that the tensor appears to be? i.e. when call
subclass_tensor.dtype
I see that the current
AffineQuantizedTensor
andNF4Tensor
will show the original dtype. I understand that this helps with compatibility for existing code (e.g. in gpt-fast, KVCache dtype is taken from weight dtype)ao/torchao/_models/llama/model.py
Line 122 in f172c47
However, personally I feel that it is a bit unintuitive, because the weight is actually not FP32/BF16 anymore (but it appears to be so for compatibility reason I suppose)
@msaroufim also mentions that
@jerryzh168
The text was updated successfully, but these errors were encountered: