Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uintx ops - Slice etc... #1026

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 202 additions & 0 deletions test/dtypes/test_uintx_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import pytest
import torch
from torchao.dtypes.uintx.uintx import UintxTensor, to_uintx, _DTYPE_TO_BIT_WIDTH

# Define the dtypes to test
if torch.__version__ >= "2.3":
dtypes = (torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7)
else:
dtypes = ()

devices = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])

def get_bit_width_from_tensor(tensor):
max_value = tensor.max().item()
return max(2, (max_value + 1).bit_length())

def quantize_for_dtype(value, dtype):
if dtype == torch.uint8:
return value # No quantization needed for uint8
bit_width = _DTYPE_TO_BIT_WIDTH[dtype]
return min(value, 2**bit_width - 1)

@pytest.fixture(params=dtypes)
def dtype(request):
return request.param

@pytest.fixture(params=devices)
def device(request):
return request.param

@pytest.fixture(params=list(_DTYPE_TO_BIT_WIDTH.keys()))
def uintx_tensor_and_dtype(request):
dtype = request.param
original_data = torch.tensor([10, 25, 40, 55, 5, 20, 35, 50], dtype=torch.uint8)
quantized_data = torch.tensor([quantize_for_dtype(v.item(), dtype) for v in original_data], dtype=torch.uint8)
uintx_tensor = to_uintx(quantized_data, dtype)
return uintx_tensor, dtype


def test_basic_slicing(uintx_tensor_and_dtype):
uintx_tensor, dtype = uintx_tensor_and_dtype
sliced_uintx = uintx_tensor[2:6]
sliced_data = sliced_uintx.get_plain()
bit_width = get_bit_width_from_tensor(sliced_data)
assert torch.all(sliced_data == sliced_data), f"Sanity check failed for {bit_width}-bit tensor"

def test_step_slicing(uintx_tensor_and_dtype):
uintx_tensor, dtype = uintx_tensor_and_dtype
step_sliced_uintx = uintx_tensor[1::2]
step_sliced_data = step_sliced_uintx.get_plain()

original_data = uintx_tensor.get_plain()
expected_step_slice = original_data[1::2]

expected_step_slice = expected_step_slice.to(step_sliced_data.dtype)

assert torch.all(step_sliced_data == expected_step_slice), (
f"Step slicing failed for {uintx_tensor.dtype} on {uintx_tensor.device}\n"
f"Original tensor: {original_data}\n"
f"Expected step slice: {expected_step_slice}\n"
f"Actual step slice: {step_sliced_data}"
)
assert step_sliced_data.shape == expected_step_slice.shape, (
f"Shape mismatch for {uintx_tensor.dtype} on {uintx_tensor.device}\n"
f"Expected shape: {expected_step_slice.shape}\n"
f"Actual shape: {step_sliced_data.shape}"
)

def test_negative_indexing(uintx_tensor_and_dtype):
uintx_tensor, dtype = uintx_tensor_and_dtype
negative_sliced_uintx = uintx_tensor[-3:]
negative_sliced_data = negative_sliced_uintx.get_plain()

original_data = uintx_tensor.get_plain()
expected_negative_slice = original_data[-3:]

expected_negative_slice = expected_negative_slice.to(negative_sliced_data.dtype)

assert torch.all(negative_sliced_data == expected_negative_slice), (
f"Negative indexing failed for {uintx_tensor.dtype} on {uintx_tensor.device}\n"
f"Original tensor: {original_data}\n"
f"Expected negative slice: {expected_negative_slice}\n"
f"Actual negative slice: {negative_sliced_data}"
)

assert negative_sliced_data.shape == expected_negative_slice.shape, (
f"Shape mismatch for {uintx_tensor.dtype} on {uintx_tensor.device}\n"
f"Expected shape: {expected_negative_slice.shape}\n"
f"Actual shape: {negative_sliced_data.shape}"
)

assert torch.all(negative_sliced_data == original_data[-3:]), (
f"Negative indexing did not select the correct elements for {uintx_tensor.dtype} on {uintx_tensor.device}\n"
f"Expected last three elements: {original_data[-3:]}\n"
f"Actual selected elements: {negative_sliced_data}"
)

def test_slice_assignment(uintx_tensor_and_dtype):
uintx_tensor, original_dtype = uintx_tensor_and_dtype
assert original_dtype in _DTYPE_TO_BIT_WIDTH.keys(), f"Unexpected dtype: {original_dtype}"

#original data
original_data = uintx_tensor.get_plain()
print(f"Original data: {original_data}")

# data to assign
new_data = torch.tensor([1, 2], dtype=torch.uint8, device=uintx_tensor.device)
print(f"New data: {new_data}")

# quantize the new data to the original dtype
quantized_new_data = torch.tensor([quantize_for_dtype(v.item(), original_dtype) for v in new_data],
dtype=torch.uint8, device=uintx_tensor.device)

# assign the quantized data to the slice
uintx_tensor[3:5] = to_uintx(quantized_new_data, original_dtype)

# Get the modified data
modified_data = uintx_tensor.get_plain()
print(f"Modified data: {modified_data}")

# Check if the assigned slice has been updated
assert torch.all(modified_data[3:5] == quantized_new_data), (
f"Slice assignment failed for {original_dtype} on {uintx_tensor.device}\n"
f"Assigned slice: {quantized_new_data}\n"
f"Expected quantized slice: {quantized_new_data}\n"
f"Actual slice after assignment: {modified_data[3:5]}"
)

# Check if the rest of the tensor remained unchanged
assert torch.all(modified_data[:3] == original_data[:3]) and torch.all(modified_data[5:] == original_data[5:]), (
f"Unassigned parts of the tensor changed after slice assignment for {original_dtype} on {uintx_tensor.device}"
)

# Test assigning a regular tensor (not UintxTensor)
regular_tensor = torch.tensor([3, 1], dtype=torch.uint8, device=uintx_tensor.device)
quantized_regular_tensor = torch.tensor([quantize_for_dtype(v.item(), original_dtype) for v in regular_tensor],
dtype=torch.uint8, device=uintx_tensor.device)
uintx_tensor[5:7] = to_uintx(quantized_regular_tensor, original_dtype)

modified_data_2 = uintx_tensor.get_plain()

assert torch.all(modified_data_2[5:7] == quantized_regular_tensor), (
f"Slice assignment with regular tensor failed for {original_dtype} on {uintx_tensor.device}\n"
f"Assigned slice: {quantized_regular_tensor}\n"
f"Expected quantized slice: {quantized_regular_tensor}\n"
f"Actual slice after assignment: {modified_data_2[5:7]}"
)

# Test assigning a scalar value
scalar_value = 2
quantized_scalar = quantize_for_dtype(scalar_value, original_dtype)
uintx_tensor[7] = quantized_scalar

modified_data_3 = uintx_tensor.get_plain()

assert modified_data_3[7] == quantized_scalar, (
f"Scalar assignment failed for {original_dtype} on {uintx_tensor.device}\n"
f"Assigned scalar: {quantized_scalar}\n"
f"Expected quantized scalar: {quantized_scalar}\n"
f"Actual value after assignment: {modified_data_3[7]}"
)

print(f"Slice and scalar assignment tests passed for {original_dtype} on {uintx_tensor.device}")

def test_out_of_bounds_slicing(uintx_tensor_and_dtype):
uintx_tensor, original_dtype = uintx_tensor_and_dtype
out_of_bounds_uintx = uintx_tensor[5:10]
out_of_bounds_data = out_of_bounds_uintx.get_plain()

original_data = uintx_tensor.get_plain()
expected_out_of_bounds = original_data[5:]

assert torch.all(out_of_bounds_data == expected_out_of_bounds), (
f"Out of bounds slicing failed for {original_dtype} on {uintx_tensor.device}\n"
f"Original tensor: {original_data}\n"
f"Expected out of bounds slice: {expected_out_of_bounds}\n"
f"Actual out of bounds slice: {out_of_bounds_data}"
)

assert out_of_bounds_data.shape == expected_out_of_bounds.shape, (
f"Shape mismatch for out of bounds slicing with {original_dtype} on {uintx_tensor.device}\n"
f"Expected shape: {expected_out_of_bounds.shape}\n"
f"Actual shape: {out_of_bounds_data.shape}"
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_device_transfer(uintx_tensor_and_dtype):
uintx_tensor_cpu, original_dtype = uintx_tensor_and_dtype
uintx_tensor_cuda = uintx_tensor_cpu.to("cuda")

assert uintx_tensor_cuda.device.type == "cuda", (
f"Failed to transfer {original_dtype} tensor to CUDA"
)

cpu_data = uintx_tensor_cpu.get_plain()
cuda_data = uintx_tensor_cuda.cpu().get_plain()

assert torch.all(cpu_data == cuda_data), (
f"Data mismatch after device transfer for {original_dtype}\n"
f"CPU data: {cpu_data}\n"
f"CUDA data (transferred back to CPU): {cuda_data}"
)
150 changes: 150 additions & 0 deletions test/dtypes/test_uintx_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import os
import torch
import torch.distributed as dist
from typing import Sequence
from torch.distributed import DeviceMesh
from torch.distributed.tensor import DTensor, Replicate, Shard, Placement
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.dtypes.uintx.uintx import UintxTensor, to_uintx
from torchao.quantization.quant_api import quantize_, uintx_weight_only
from torchao.utils import fill_defaults

class M(torch.nn.Module):
def __init__(self, in_features, out_features, **kwargs):
super().__init__(**kwargs)
self.linear = torch.nn.Linear(in_features=in_features, out_features=out_features, bias=False)

def forward(self, x):
return self.linear(x)

def quantize(m: torch.nn.Module, dtype, group_size=32)-> torch.nn.Module:
"""
Quantize the model
"""
quantize_(m, uintx_weight_only(dtype, group_size=group_size))
return m

def shard(
full_tensor: torch.tensor,
device_mesh: DeviceMesh,
placements: Sequence[Placement],
)-> DTensor:
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset

shape, offset = compute_local_shape_and_global_offset(
full_tensor.shape, device_mesh, placements
)
slices = [
slice(cur_offset, cur_offset + cur_shape)
for cur_shape, cur_offset in zip(shape, offset)
]
local_tensor = full_tensor[slices]
return DTensor.from_local(
local_tensor, device_mesh, placements
)


def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
"""
Shard linear layer of the model in column-wise fashion
"""
# Column-wise is wrt to A^T, so for A it is row-wise.
orig_weight = m.linear.weight
# Construct DTensor from local shard
dtensor = shard(orig_weight, mesh, [Shard(0)])
# Replace parameter in module
m.linear.weight = torch.nn.Parameter(
dtensor, requires_grad=False
)
return m

def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
"""
Shard linear layer of the model in row-wise fashion
"""
# Row-wise is wrt to A^T, so for A it is column-wise.
orig_weight = m.linear.weight
# Construct DTensor from local shard
dtensor = shard(orig_weight, mesh, [Shard(1)])
# Replace parameter in module
m.linear.weight = torch.nn.Parameter(
dtensor, requires_grad=False
)
return m

class Linear16(torch.nn.Module):
def __init__(self, scale, device):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Linear(scale * 2, scale, bias=False, dtype=torch.float16, device=device),
torch.nn.Linear(scale, scale, bias=False, dtype=torch.float16, device=device),
torch.nn.Linear(scale, scale//2, bias=False, dtype=torch.float16, device=device),
)

def forward(self, x):
return self.net(x)
########
# Test #
########
def main():
#run on cpu
device = torch.device("cpu")
proj_up = M(1024, 2048).to(device)
proj_dn = M(2048, 1024).to(device)
example_input = 100 * torch.randn(128, 1024, device=device)
y = proj_dn(proj_up(example_input))

# Quantize the model
up_quant = quantize(proj_up, torch.uint6)
dn_quant = quantize(proj_dn, torch.uint6)
y_q = dn_quant(up_quant(example_input))
print("Quantization works!")

# To make sure different ranks create the same module
torch.manual_seed(5)

# Get rank and device
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")

# Original model
proj_up = M(1024, 2048).to(device)
proj_dn = M(2048, 1024).to(device)
example_input = 100 * torch.randn(128, 1024, device=device)
y = proj_dn(proj_up(example_input))

# Quantize the model
up_quant = quantize(proj_up)
dn_quant = quantize(proj_dn)
y_q = dn_quant(up_quant(example_input))
print("Quantization works!")

# Create a device mesh
dist.init_process_group(backend="nccl")
mesh = dist.init_device_mesh("cuda", (world_size,))

# Shard the models
up_dist = colwise_shard(up_quant, mesh)
dn_dist = rowwise_shard(dn_quant, mesh)

# We need to turn inputs into DTensor form as well -- just a format change
input_dtensor = DTensor.from_local(
example_input, mesh, [Replicate()]
)

y_d = dn_dist(up_dist(input_dtensor))
print("Distributed result:", y_d)
print("Distributed works!")

up_compiled = torch.compile(up_dist)
y_up = up_compiled(input_dtensor)
dn_compiled = torch.compile(dn_dist)
y_dn = dn_compiled(y_up)
print("compiled result:", y_dn)
print("torch.compile works!")

dist.destroy_process_group()

if __name__ == "__main__":
main()
Loading