Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What should .dtype for tensor subclass return? #442

Open
gau-nernst opened this issue Jun 26, 2024 · 15 comments
Open

What should .dtype for tensor subclass return? #442

gau-nernst opened this issue Jun 26, 2024 · 15 comments
Assignees

Comments

@gau-nernst
Copy link
Collaborator

What is the recommended way to show the dtype that the tensor appears to be? i.e. when call subclass_tensor.dtype

I see that the current AffineQuantizedTensor and NF4Tensor will show the original dtype. I understand that this helps with compatibility for existing code (e.g. in gpt-fast, KVCache dtype is taken from weight dtype)

dtype = self.output.weight.dtype

However, personally I feel that it is a bit unintuitive, because the weight is actually not FP32/BF16 anymore (but it appears to be so for compatibility reason I suppose)

@msaroufim also mentions that

This is unfortunately a big limitation with subclasses mostly because of limitations with autograd that are very difficult to get rid of

@jerryzh168

@jerryzh168
Copy link
Contributor

jerryzh168 commented Jun 26, 2024

yeah good question, so as @msaroufim mentioned the reason that we are returning a original (floating point) dtype is to make autograd engine happy, autograd engine today only works with floating point dtypes. nf4 tensor need to work with autograd I think so that's why it has to return a float dtype.

To really support this properly, we'd need to have generalized torch.dtype (e.g. "nf4", "affine_quantized (with some extra information)"), and basically enable open registration of torch.dtype. And I remember the blocker according to @ezyang is that the change will touch a lot of pytorch C++ core code and it's not high pri enough for people to work on it. I'm not exactly sure about all the options we have here though. @cpuhrsch also has been pushing for having nf4_tensor.dtype to return something like "nf4" instead of a floating point dtype as well. In summary, I think there are two things if we need to make it work:

  1. how do we add a new torch.dtype? does this has to touch C++ or can it stay in python?
  2. how does autograd engine work with new torch.dtypes? and how does other subsystems (export, dispatch etc.) work with it?

maybe @ezyang or @albanD can you talk about the work that might be involved here? and thoughts on the benefits of enabling this?

@albanD
Copy link

albanD commented Jun 27, 2024

The main dilemma in my head for this is that we don't have a single thing we expect from this dtype field:

  • We want it to sometimes show a pretend dtype, for example for autograd, but I'm sure also other Module methods (like .to(bfloat16) or any other subsystem (autocast, optimizer, distributed) that expects consistent dtype would wrongfully complain here.
  • We want it to show the dtype used to store the data as a way to differentiate with other ways to store the data.

I don't think there is a right or wrong answer here and only a tradeoff depending on what you want to use. But that makes the large amount of work for custom dtypes a lot less attractive unfortunately.

Maybe another option here could be to introduce a new concept of storage type which will represent how the data is stored for a particular Tensor.
You can then query the dtype (precision semantic of each element) or the stype (how this Tensor is represented in memory) depending on what your intent is?

@msaroufim
Copy link
Member

msaroufim commented Jun 27, 2024

You can then query the dtype (precision semantic of each element) or the stype (how this Tensor is represented in memory) depending on what your intent is?

This sounds pretty reasonable to me, separating out the concepts of stype and dtype seem likes an important concept to start explaining to end users because we don't want to abstract away the concept of bitpacking completely. Stypes would be something that changes rarely because it relies on hardware supports but dtypes is something that's changing quite rapidly. For example with this work we're support intX and fpX

I also had a recent example here where I'm quantizing to int8 weight only and just inspecting the state dict is not enough for me to know that the quantization happened correctly since at no point does int8 show up when I print(m) or print(m.state_dict)

Granted we can workaround this by having a better __repr__ for AffineQuantizedTensor

import torch 
import torchao
import copy
import tempfile

class ToyLinearModel(torch.nn.Module):
    def __init__(self, m=64, n=32, k=64):
        super().__init__()
        self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
        self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)

    def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"):
        return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x

m = ToyLinearModel().eval().to(torch.bfloat16)
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16)

m = torchao.quantize(m, torchao.quantization.quant_api.int8_weight_only())

print(m.state_dict())
ref = m(*example_inputs)
with tempfile.NamedTemporaryFile() as f:
    torch.save(m.state_dict(), f)
    f.seek(0)
    state_dict = torch.load(f)

m_copy.load_state_dict(state_dict, assign=True)

print(m_copy.state_dict())

res = m_copy(*example_inputs)
(new) [marksaroufim@devgpu003.cco3 ~/test]$ python aa.py 
OrderedDict([('linear1.weight', AffineQuantizedTensor(data=tensor([[ 0.0815, -0.0544,  0.0952,  ..., -0.1216,  0.0320,  0.0623],
        [ 0.0408, -0.0776,  0.0302,  ...,  0.0564,  0.1226, -0.0505],
        [ 0.0693,  0.1226,  0.0029,  ..., -0.1040,  0.0811,  0.0212],
        ...,
        [ 0.1157,  0.0233, -0.0339,  ..., -0.0991,  0.1147,  0.0933],
        [-0.0894,  0.0391, -0.0266,  ...,  0.0238,  0.0161,  0.0894],
        [-0.1230,  0.0039, -0.0378,  ...,  0.0019, -0.1240,  0.0232]],
       dtype=torch.bfloat16), shape=torch.Size([32, 64]), device=cpu, dtype=torch.bfloat16, requires_grad=False)), ('linear2.weight', AffineQuantizedTensor(data=tensor([[-0.1445, -0.0103,  0.0593,  ..., -0.1484,  0.0708,  0.0078],
        [ 0.0454, -0.1396,  0.1455,  ...,  0.1621,  0.0344, -0.0757],
        [ 0.0776,  0.1416,  0.0845,  ...,  0.0055, -0.1475, -0.1064],
        ...,
        [-0.0591,  0.1235,  0.1729,  ..., -0.1445, -0.1484, -0.0796],
        [-0.1250, -0.0693, -0.0068,  ...,  0.0732, -0.0354,  0.1553],
        [ 0.1025, -0.1299, -0.0432,  ..., -0.0297, -0.1367, -0.1377]],
       dtype=torch.bfloat16), shape=torch.Size([64, 32]), device=cpu, dtype=torch.bfloat16, requires_grad=False))])
/home/marksaroufim/test/aa.py:31: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(f)
OrderedDict([('linear1.weight', AffineQuantizedTensor(data=tensor([[ 0.0815, -0.0544,  0.0952,  ..., -0.1216,  0.0320,  0.0623],
        [ 0.0408, -0.0776,  0.0302,  ...,  0.0564,  0.1226, -0.0505],
        [ 0.0693,  0.1226,  0.0029,  ..., -0.1040,  0.0811,  0.0212],
        ...,
        [ 0.1157,  0.0233, -0.0339,  ..., -0.0991,  0.1147,  0.0933],
        [-0.0894,  0.0391, -0.0266,  ...,  0.0238,  0.0161,  0.0894],
        [-0.1230,  0.0039, -0.0378,  ...,  0.0019, -0.1240,  0.0232]],
       dtype=torch.bfloat16), shape=torch.Size([32, 64]), device=cpu, dtype=torch.bfloat16, requires_grad=False)), ('linear2.weight', AffineQuantizedTensor(data=tensor([[-0.1445, -0.0103,  0.0593,  ..., -0.1484,  0.0708,  0.0078],
        [ 0.0454, -0.1396,  0.1455,  ...,  0.1621,  0.0344, -0.0757],
        [ 0.0776,  0.1416,  0.0845,  ...,  0.0055, -0.1475, -0.1064],
        ...,
        [-0.0591,  0.1235,  0.1729,  ..., -0.1445, -0.1484, -0.0796],
        [-0.1250, -0.0693, -0.0068,  ...,  0.0732, -0.0354,  0.1553],
        [ 0.1025, -0.1299, -0.0432,  ..., -0.0297, -0.1367, -0.1377]],
       dtype=torch.bfloat16), shape=torch.Size([64, 32]), device=cpu, dtype=torch.bfloat16, requires_grad=False))])
(new) [marksaroufim@devgpu003.cco3 ~/test]$ 

@gau-nernst
Copy link
Collaborator Author

I think dtype and stype concepts will suffice for INT8, but they are not quite satisfactory for sub-byte dtypes. Pretty much all sub-byte dtypes will use "untyped" uint8 as storage. Even if there is hardware support in the future, I think it will come in the form of specialized instruction instead of storage (e.g. matmul for uint4x2) - correct me if I'm wrong.
So introducing the stype concept doesn't seem to solve much problems. Taking the simplest example, there should be an intuitive way to tell INT4 subclass is an int4, but behaves like float32/bfloat16, and stored as uint8 (in that case, is storage type really important anymore?).

At least for now, is it correct to say that .dtype should return the type that it should behave like? For integration with existing systems.

@ezyang
Copy link

ezyang commented Jun 28, 2024

@gau-nernst That's right. It is kind of unintuitive but it's what integrates best with how the autograd system works right now. In fact, you are best to pick the dtype that you want the gradients to be.

cc @albanD @vkuzo

@vayuda
Copy link
Collaborator

vayuda commented Jun 28, 2024

My implementation abstracts away the details of whether the Intx Tensor is int2,3,4 etc. because from torch's perspective, that number is only important for internal bit packing. As long as the user knows what tensor it is (i print it out in repr) I think its fine that dtype returns the container dtype

@jerryzh168
Copy link
Contributor

jerryzh168 commented Jun 28, 2024

You can then query the dtype (precision semantic of each element) or the stype (how this Tensor is represented in memory) depending on what your intent is?

I think main question I have for stype is that it does not capture all the meaning a "dtype" that's trying to convey as mentioned by @gau-nernst

I have a slightly different proposal, for the use cases we have:

  1. to(dtype)
    This should just be a real (generalized) dtype that has all information, "affine_quantized(dtype=uint4, quant_min=.., quant_max=...)", "nf4", "mx", "fp6", "uint3", "int2"

  2. autograd that requires semantic dtype
    this can query the tensor.semantic_dtype or even can just be tensor.corresponding_floating_point_dtype

Would this work?

@gau-nernst
Copy link
Collaborator Author

I like what @jerryzh168 proposed. But I think if we are to make changes in PyTorch core, we can afford to be even more specific. Even phrases like "semantic dtype" or "pretend dtype" are somewhat ambiguous. At least for now, we know it is used by autograd, I think we can use something like Tensor.autograd_dtype or Tensor.grad_dtype.

For other PyTorch subsystems, like autocast, optimizer, distributed as @albanD mentioned, we should investigate how it interacts with tensor subclass dtype and other subsystems. For optimizer, I'm experimenting with implementing 8-bit Adam using tensor subclass, and it seems like it doesn't care about dtype, as it only involves math calculations (at least for 1 GPU training, no AMP). There are some checks for complex and sparse tensor, but for now that shouldn't be a concern I think.

For Tensor.to(dtype) and Module.to(dtype), I don't think we should support converting to tensor subclass directly (e.g. fp32_tensor.to("nf4") for example). Users should call the explicit method to convert to tensor subclass (e.g. NF4Tensor.from_float() or quantize(module, int4_weight_only()) API ). This would avoid having to register custom dtype to PyTorch's Tensor and nn.Module globally, as well as avoid strange combinations like converting from 1 tensor subclass to another.

For converting from tensor subclass to standard dtype via .to(dtype), it can already be done by implementing aten.to() for the tensor subclass (which basically calls .dequantize()) -> no need changes to PyTorch core. E.g. existing NF4

@implements([torch.ops.aten.to.dtype])
def to_dtype(func, *args, **kwargs):
if not args[0][0].is_contiguous():
assert args[0][0].t().is_contiguous()
return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t()
return args[0][0].get_original_weight().to(args[0][1])

But I'm not entirely sure if that is sufficient for nn.Module.to(dtype) and won't cause other problems (need to check).

@albanD
Copy link

albanD commented Jul 11, 2024

We could add such a field to Tensor, but I'm not sure it's going to help much for a few reasons:

  • Using it for non-leafs (activations) is going to be very tricky as you will have to access all (this is a challenge, not all of them exist in the python side depending on how op in question) of these Tensors. Also making this field retroactive (on a graph that is already created) is going to be a technical challenge.
  • Using it on leafs would work fine, but that will only help for the final .grad field value. So all the computation will be done in the original dtype and only the final result will be downcasted. You can get the same behavior as post-processing after autograd is done by having a Tensor.low_dtype_grad that you populate by hand.
  • Even if you manage to get all intermediates and have all the intermediates be of the dtype you want, we would need to fix the following for it to work:
    • Your method to capture activations will need to compose with all other features you use (all the hooks, reparametrization, AC, Distributed, compile are the highest risk I would guess)
    • The current autograd formulas that assume same tensor/grad dtype will need fixing and extensive testing
    • You will have to roll out your own Optimizer to handle these grads
    • You will have to roll out your own Distributed constructs to handle these mixed grads (FSDP/DDP won't work out of the box for sure, and FSDP1 won't ever work by design)
    • You won't be able to use other mixed dtype systems like Autocast
    • You won't be able to use the standard Module.to() and cie APIs.

From there, I always come back to a simpler solution being: have a subclass with a pretend-dtype and it's cls that define the behavior, then you can "hide" from the systems that can "just work" and change the behavior of the ones you want via the subclass overrides that all these systems have (or that we're building in FSDP2 for example).

I'm not sure what's the best way forward here though, curious what you think!

@gau-nernst
Copy link
Collaborator Author

I can't comment much on the implementation and internal sides of PyTorch as I'm not qualified for it and I'm not knowledgeable about it. I just want to re-emphasize that from a user's perspective, having a quant_subclass.dtype returning a "pretend-dtype" is unintuitive. I understand the dtype system was built without this in mind, and changing it would be an enormous task that no benefits can probably justify it. But if the PyTorch team wants to continue pushing for exotic dtype implementation via tensor subclass, some rethinking and works probably need to be done.

On the implementation side, just a noob question. Can't we introduce a concept of "legacy dtype" that basically means what dtype is today (always a native PyTorch dtype, the name is up to discussion) and a "(new) dtype" which subclasses can override to provide some useful metadata for subclass consumers (can't find a better word for "new", but you get the idea). And then 2 options

  1. Keep "legacy dtype" as .dtype, and make the "new dtype" another field. If the "new dtype" field is not implemented, it falls back to "legacy dtype". Subclass consumers should use the new field to inquire about dtype information (need to communicate to people about this). No changes in PyTorch core / C++ backend. No breaking changes. Everyone is happy. .dtype is still confusing/unintuitive.
  2. Move "legacy dtype" to another field, and make the "new dtype" as .dtype. If the "legacy dtype" field is not implemented, it falls back to "new dtype". Subclass devs, if override .dtype (to provide "new dtype" with custom dtype metadata), must also implement "legacy dtype" field. (Probably) only (some) renames in PyTorch core (likely underestimating the complexity here). Functions with dtype as an argument (e.g. .to(dtype=...)) should only accept "legacy dtype" (then it also opens up the question whether the argument name dtype should also be changed in this case). Old code should still work in newer versions of PyTorch. .dtype now reflects either native PyTorch dtype or subclass exotic dtype.

When writing this out like this, option 1 seems more attractive as it is less intrusive. Hope to hear opinions from the PyTorch team.

@vkuzo
Copy link
Contributor

vkuzo commented Jul 12, 2024

The main dilemma in my head for this is that we don't have a single thing we expect from this dtype field:

Agreed. @gau-nernst, what do you expect to see when you call YourTensor.dtype? A string such as "nf4"? Dtype of the raw stored elements with the corresponding metadata? Something else? Is there something here that makes sense to return which would be consistent across all of the possible tensor subclasses that people can build?

@gau-nernst
Copy link
Collaborator Author

@vkuzo Having NF4Tensor.dtype returns bfloat16 is unintuitive, and I hope many people can agree with me on that. I would expect NF4Tensor.dtype return something about NF4. Whether it is a string or some python object, I don't have a specific thing in mind. @jerryzh168 previously gave some suggestions.

"Dtype of the raw stored elements" is not helpful, as I pointed out previously, that most custom dtype impl will just use uint8 as storage (correct me if I misunderstand your example here). For tensor subclasses that only act as containers, dtype of stored elements make sense.

Is there something here that makes sense to return which would be consistent across all of the possible tensor subclasses that people can build?

Tbh I can't think of any at the moment. Tensor subclass have varied applications other than custom dtype, such as torchvision subclasses and DTensor. So enforcing (or recommending) all subclasses to return something consistent also feels restrictive. People with more experience with tensor subclass can chime in.

I guess having custom dtype implemented as tensor subclass will undoubtedly create conflicts with .dtype, and there is always a difference between "tensor subclass dtype" and "native PyTorch dtype".

@jerryzh168
Copy link
Contributor

jerryzh168 commented Jul 15, 2024

We could add such a field to Tensor, but I'm not sure it's going to help much for a few reasons:

  • Using it for non-leafs (activations) is going to be very tricky as you will have to access all (this is a challenge, not all of them exist in the python side depending on how op in question) of these Tensors. Also making this field retroactive (on a graph that is already created) is going to be a technical challenge.
  • Using it on leafs would work fine, but that will only help for the final .grad field value. So all the computation will be done in the original dtype and only the final result will be downcasted. You can get the same behavior as post-processing after autograd is done by having a Tensor.low_dtype_grad that you populate by hand.

are these referring to autograd? I feel for autograd we may not necessarily wants the same dtype as the forward dtype, e.g. if I'm using int4 for forward, I may still want higher precision like bfloat16 for gradient, but I feel it would be better for that to be an explicit decision from user instead of not being transparent.

  • Even if you manage to get all intermediates and have all the intermediates be of the dtype you want, we would need to fix the following for it to work:

    • Your method to capture activations will need to compose with all other features you use (all the hooks, reparametrization, AC, Distributed, compile are the highest risk I would guess)
    • The current autograd formulas that assume same tensor/grad dtype will need fixing and extensive testing
    • You will have to roll out your own Optimizer to handle these grads
    • You will have to roll out your own Distributed constructs to handle these mixed grads (FSDP/DDP won't work out of the box for sure, and FSDP1 won't ever work by design)
    • You won't be able to use other mixed dtype systems like Autocast
    • You won't be able to use the standard Module.to() and cie APIs.

I feel these are a bit orthogonal to the problem we are discussing here actually, even if we have a generalized dtype that returns "nf4", "affine_quantized(dtype=..., ...)", we still need to make sure our new tensor subclass/dtype works with all subsystems if we want to use it in these subsystems. you can "hide" from the systems that can "just work" at the cost of not being transparent doesn't feel like a good trade off. But I understand it might be the lowest effort solution we can have today. I think maybe we can wait until this issue becomes a blocker to a real use case and think of proper solutions. But I'd also like to understand what is the cost if we want to just support this properly (allowing a generalized dtype definition in core)

@nfrumkin
Copy link

nfrumkin commented Aug 6, 2024

Is there any way we can inspect the dtype that is mimicked at this moment? I am trying to provide checks in my code to ensure we are quantizing properly, however, even the repr() function does not give us important information such as integer type, number of bits, etc:

linear_1.weight.__repr__()
"AffineQuantizedTensor(data=tensor([[-0.0007,  0.0016,  0.0019,  ..., -0.0188,  0.0145, -0.0112],\n        [ 0.0005,  0.0017,  0.0000,  ..., -0.0111,  0.0097, -0.0035],\n        [-0.0012,  0.0012,  0.0022,  ..., -0.0157, -0.0065,  0.0033],\n        ...,\n        [-0.0021, -0.0034, -0.0015,  ..., -0.0108,  0.0099,  0.0095],\n        [-0.0030, -0.0025,  0.0012,  ...,  0.0097,  0.0012,  0.0046],\n        [-0.0014, -0.0022, -0.0003,  ...,  0.0088,  0.0088,  0.0188]],\n       device='cuda:0'), shape=torch.Size([1280, 320]), device=cuda:0, dtype=torch.float32, requires_grad=True)"

@jerryzh168
Copy link
Contributor

Is there any way we can inspect the dtype that is mimicked at this moment? I am trying to provide checks in my code to ensure we are quantizing properly, however, even the repr() function does not give us important information such as integer type, number of bits, etc:

linear_1.weight.__repr__()
"AffineQuantizedTensor(data=tensor([[-0.0007,  0.0016,  0.0019,  ..., -0.0188,  0.0145, -0.0112],\n        [ 0.0005,  0.0017,  0.0000,  ..., -0.0111,  0.0097, -0.0035],\n        [-0.0012,  0.0012,  0.0022,  ..., -0.0157, -0.0065,  0.0033],\n        ...,\n        [-0.0021, -0.0034, -0.0015,  ..., -0.0108,  0.0099,  0.0095],\n        [-0.0030, -0.0025,  0.0012,  ...,  0.0097,  0.0012,  0.0046],\n        [-0.0014, -0.0022, -0.0003,  ...,  0.0088,  0.0088,  0.0188]],\n       device='cuda:0'), shape=torch.Size([1280, 320]), device=cuda:0, dtype=torch.float32, requires_grad=True)"

since AffineQuantizedTensor has to support different packing formats, we now use LayoutType to distinguish between different packing format, and internal dtype may not always make sense for all layout types, the simplest layout type is PlainLayoutType, that will store int_data directly, so
if isinstance(tensor.layout_type, PlainLayoutType), tensor.layout_tensor.int_data.dtype will give you the internal dtype,
or you could just print linear_1.weight.layout_tensor.__repr__()

number of bits is represented by quantized_tensor.quant_min and quantized_tensor.quant_max, but in the future we do want to use the native uint1 to uint7 dtype to AffineQuantizedTensor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants