Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting tensor parallelism for int8 weight only quant #939

Merged
merged 14 commits into from
Sep 27, 2024
1 change: 1 addition & 0 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,5 +135,6 @@ def test_print_quantized_module(self, apply_quant):

common_utils.instantiate_parametrized_tests(TestAffineQuantized)


if __name__ == "__main__":
run_tests()
12 changes: 12 additions & 0 deletions test/dtypes/test_affine_quantized_tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from torchao.testing.utils import copy_tests, TorchAOTensorParallelTestCase
from torch.testing._internal.common_utils import run_tests
from torchao.quantization import int8_weight_only

class TestAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
pass


copy_tests(TorchAOTensorParallelTestCase, TestAffineQuantizedTensorParallel, "aqt_tp")

if __name__ == "__main__":
run_tests()
2 changes: 2 additions & 0 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
quantize_,
)
from . import dtypes
from . import testing

__all__ = [
"dtypes",
"autoquant",
"quantize_",
"testing",
]

# test-pytorchbot
Expand Down
52 changes: 49 additions & 3 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
find_multiple,
TorchAOBaseTensor,
TORCH_VERSION_AT_LEAST_2_5,
_is_float8_type
_is_float8_type,
fill_defaults,
)
import logging

Expand Down Expand Up @@ -599,13 +600,25 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)

if func is aten.t.default:
elif func is aten.t.default:
tensor = args[0]
new = tensor.__class__(
tensor.int_data.view(tensor.shape[::-1]), tensor.scale, tensor.zero_point, tensor.layout_type
tensor.int_data.t(), tensor.scale, tensor.zero_point, tensor.layout_type
)
return return_and_correct_aliasing(func, args, kwargs, new)

elif func is aten.slice.Tensor:
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
if dim == 0:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
)
elif dim == 1:
assert len(self.scale.shape) == 1, f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}"
return PlainAQTLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.zero_point.view(-1), self.layout_type)
else:
raise NotImplementedError(f"PlainAQTLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported")

raise NotImplementedError(
f"PlainAQTLayout dispatch: attempting to run {func}, this is not supported"
)
Expand Down Expand Up @@ -1595,6 +1608,39 @@ def _(func, types, args, kwargs):
)
return return_and_correct_aliasing(func, args, kwargs, new)

@implements(aten.slice.Tensor)
def _(func, types, args, kwargs):
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
assert step == 1
assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}"
if end >= self.shape[dim]:
end = self.shape[dim]
shape = list(self.shape)
shape[dim] = end - start
block_size = self.block_size
assert len(block_size) == 2, f"Slice only works for 2d block_size right now, got: {block_size}"
# with slice, some shape dimension might be smaller than block_size dimension, so
# we need to make sure there is no overflow
block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1]))
new = self.__class__(aten.slice.Tensor(self.layout_tensor, dim, start, end, step), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride())
return return_and_correct_aliasing(func, args, kwargs, new)

# this is needed for DTensor.from_local() and for flattening tensor
@implements(aten.view.default)
def _(func, types, args, kwargs):
self, shape = args

if tuple(self.shape) == tuple(shape):
return self.__class__(self.layout_tensor, self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride())

if len(shape) == 1 and shape[0] == -1:
assert len(self.block_size) == 2 and self.block_size[0] == 1
block_size = (self.block_size[1],)
return self.__class__(self.layout_tensor, block_size, (self.numel(),), self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride())

raise ValueError(f"{self.__class__.__name__} only supports .view() with same shape or shape=[-1]")


to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx
to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx
Expand Down
120 changes: 120 additions & 0 deletions torchao/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
import copy
import torch
import torchao
import os

from packaging import version

from torch.testing._internal import common_utils
from torchao.dtypes import AffineQuantizedTensor
from torchao.dtypes import to_affine_quantized_intx
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization import quantize_, int8_weight_only

"""
How to use:
Expand All @@ -33,6 +37,8 @@ class MyTestCase(TorchAOBasicTestCase):
unittest.main()
"""

torch_version = version.Version(torch.__version__)

# copied from https://github.com/pytorch/pytorch/blob/941d094dd1b507dacf06ddc6ed3485a9537e09b7/test/inductor/test_torchinductor.py#L11389
def copy_tests(
my_cls, other_cls, suffix, test_failures=None, xfail_prop=None
Expand Down Expand Up @@ -213,10 +219,124 @@ def test_linear_compile(self, device, dtype):
lp_res = torch.compile(l)(hp_act_tensor)
self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR)

import torch.distributed as dist
from torch.distributed._tensor import DTensor, Replicate, Shard, DeviceMesh
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
NUM_DEVICES,
)

COMPILED_TENSOR_PARALLEL_REQUIRED_VERSION = version.Version("2.5.0dev")

class TorchAOTensorParallelTestCase(DTensorTestBase):
"""Basic test case for tensor subclasses
"""
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

TENSOR_SUBCLASS = AffineQuantizedTensor
QUANT_METHOD_FN = staticmethod(int8_weight_only)
QUANT_METHOD_KWARGS = {}

@staticmethod
def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
"""
Shard linear layer of the model in column-wise fashion
"""
# Column-wise is wrt to A^T, so for A it is row-wise.
# Number of rows per rank
orig_weight = m.linear.weight
n_local_rows = orig_weight.size(0) // mesh.size()
rank = mesh.get_local_rank()
local_shard = orig_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :]
# Construct DTensor from local shard
dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)])
# Replace parameter in module
m.linear.weight = torch.nn.Parameter(
dtensor, requires_grad=False
)
return m

@staticmethod
def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
"""
Shard linear layer of the model in row-wise fashion
"""
# Row-wise is wrt to A^T, so for A it is column-wise.
# Number of rows per rank
orig_weight = m.linear.weight
n_local_cols = orig_weight.size(1) // mesh.size()
rank = mesh.get_local_rank()
local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols]
# Construct DTensor from local shard
dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)])
# Replace parameter in module
m.linear.weight = torch.nn.Parameter(
dtensor, requires_grad=False
)
return m

def quantize(self, m: torch.nn.Module) -> torch.nn.Module:
"""
Quantize the model
"""
quantize_(m, self.QUANT_METHOD_FN(**self.QUANT_METHOD_KWARGS))
return m

@common_utils.parametrize("dtype", COMMON_DTYPES)
@with_comms
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_tp(self, dtype):
device = "cuda"
# To make sure different ranks create the same module
torch.manual_seed(5)

class M(torch.nn.Module):
def __init__(self, in_features, out_features, **kwargs) -> None:
super().__init__(**kwargs)
self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda")

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)

# Get rank and device
device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}")

# Original model
proj_up = M(1024, 2048).to(device).to(dtype)
proj_dn = M(2048, 1024).to(device).to(dtype)
example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype)
y = proj_dn(proj_up(example_input))

# Quantize the model
up_quant = self.quantize(proj_up)
dn_quant = self.quantize(proj_dn)
y_q = dn_quant(up_quant(example_input))

mesh = self.build_device_mesh()
# Shard the models
up_dist = self.colwise_shard(up_quant, mesh)
dn_dist = self.rowwise_shard(dn_quant, mesh)

# We need to turn inputs into DTensor form as well -- just a format change
input_dtensor = DTensor.from_local(
example_input, mesh, [Replicate()]
)

y_d = dn_dist(up_dist(input_dtensor))

if torch_version < COMPILED_TENSOR_PARALLEL_REQUIRED_VERSION:
# Need torch 2.5 to support compiled tensor parallelism
return

up_compiled = torch.compile(up_dist)
y_up = up_compiled(input_dtensor)
dn_compiled = torch.compile(dn_dist)
y_dn = dn_compiled(y_up)

common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)
common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase)
common_utils.instantiate_parametrized_tests(TorchAOTensorParallelTestCase)

if __name__ == "__main__":
unittest.main()
24 changes: 24 additions & 0 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,30 @@ def _get_to_kwargs(self, *args, **kwargs):
}
return kwargs

def fill_defaults(args, n, defaults_tail):
"""
__torch_dispatch__ doesn't guarantee the number of arguments you are
passed (e.g., defaulted arguments are not passed); but usually it is
convenient to pad out the arguments list with defaults. This function
helps you do that.
Args:
args: the list of positional arguments passed to __torch_dispatch__
n: the number of arguments you are expecting to get
defaults_tail: default values for the arguments, starting from the
end of the list
Example:
>>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
[1, 2, 3, 4, 5]
>>> fill_defaults([1, 2, 3], 5, [None, None, None])
[1, 2, 3, None, None]]
"""
if n - len(defaults_tail) > len(args):
raise RuntimeError("not enough defaults to fill arguments")
r = list(args)
for i in range(len(args), n):
r.append(defaults_tail[i - n + len(defaults_tail)])
return r


## Deprecated, will be deleted in the future
def _torch_version_at_least(min_version):
Expand Down
33 changes: 5 additions & 28 deletions tutorials/developer_api_guide/my_dtype_tensor_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,36 +25,13 @@
LayoutType,
PlainLayoutType,
)
from torchao.utils import TorchAOBaseTensor
from torchao.utils import (
TorchAOBaseTensor,
fill_defaults,
)

aten = torch.ops.aten

# TODO: move to torchao/utils.py
def fill_defaults(args, n, defaults_tail):
"""
__torch_dispatch__ doesn't guarantee the number of arguments you are
passed (e.g., defaulted arguments are not passed); but usually it is
convenient to pad out the arguments list with defaults. This function
helps you do that.
Args:
args: the list of positional arguments passed to __torch_dispatch__
n: the number of arguments you are expecting to get
defaults_tail: default values for the arguments, starting from the
end of the list
Example:
>>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
[1, 2, 3, 4, 5]
>>> fill_defaults([1, 2, 3], 5, [None, None, None])
[1, 2, 3, None, None]]
"""
if n - len(defaults_tail) > len(args):
raise RuntimeError("not enough defaults to fill arguments")
r = list(args)
for i in range(len(args), n):
r.append(defaults_tail[i - n + len(defaults_tail)])
return r


###############################
# Base Layout Tensor Subclass #
###############################
Expand Down Expand Up @@ -327,7 +304,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
)
elif dim == 1:
return PlainMyDTypeLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1, 1), self.transposed, self.layout_type)
return PlainMyDTypeLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.transposed, self.layout_type)
else:
raise NotImplementedError(f"PlainMyDTypeLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported")
elif func is aten.t.default:
Expand Down
3 changes: 2 additions & 1 deletion tutorials/developer_api_guide/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from torch.distributed import DeviceMesh
from torch.distributed.tensor import DTensor, Replicate, Shard, Placement
from torch.utils._python_dispatch import return_and_correct_aliasing
from my_dtype_tensor_subclass import MyDTypeTensor, fill_defaults
from my_dtype_tensor_subclass import MyDTypeTensor
from torchao.utils import fill_defaults

# a tensor subclass that supports tensor parallelism with DTensor
class MyDTypeTensorTP(MyDTypeTensor):
Expand Down
Loading