diff --git a/.github/workflows/build_wheels_linux.yml b/.github/workflows/build_wheels_linux.yml index 14623be8d5..3c37e0e1e0 100644 --- a/.github/workflows/build_wheels_linux.yml +++ b/.github/workflows/build_wheels_linux.yml @@ -43,6 +43,8 @@ jobs: # triggered daily from main with a schedule repository: pytorch/ao ref: "" + test-infra-repository: pytorch/test-infra + test-infra-ref: main build-matrix: ${{ needs.generate-matrix.outputs.matrix }} env-var-script: packaging/env_var_script_linux.sh pre-script: packaging/pre_build_script.sh diff --git a/.github/workflows/float8_test.yml b/.github/workflows/float8_test.yml index f90282011e..760beb6319 100644 --- a/.github/workflows/float8_test.yml +++ b/.github/workflows/float8_test.yml @@ -29,7 +29,7 @@ jobs: gpu-arch-type: "cuda" gpu-arch-version: "12.1" - uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: timeout: 60 runner: ${{ matrix.runs-on }} @@ -38,8 +38,6 @@ jobs: script: | conda create -n venv python=3.9 -y conda activate venv - echo "::group::Install newer objcopy that supports --set-section-alignment" - yum install -y devtoolset-10-binutils export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH python -m pip install --upgrade pip pip install ${{ matrix.torch-spec }} diff --git a/.github/workflows/nightly_smoke_test.yml b/.github/workflows/nightly_smoke_test.yml index 9e2d4fee82..9f3dc3c0fb 100644 --- a/.github/workflows/nightly_smoke_test.yml +++ b/.github/workflows/nightly_smoke_test.yml @@ -26,7 +26,7 @@ jobs: gpu-arch-version: "12.1" - uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: runner: ${{ matrix.runs-on }} gpu-arch-type: ${{ matrix.gpu-arch-type }} diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 975e0470f5..0488e6d922 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -18,6 +18,38 @@ env: HF_TOKEN: ${{ secrets.HF_TOKEN }} jobs: + test-nightly: + strategy: + fail-fast: false + matrix: + include: + - name: CUDA Nightly + runs-on: linux.g5.12xlarge.nvidia.gpu + torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu124' + gpu-arch-type: "cuda" + gpu-arch-version: "12.4" + - name: CPU Nightly + runs-on: linux.4xlarge + torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cpu' + gpu-arch-type: "cpu" + gpu-arch-version: "" + + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + with: + timeout: 120 + runner: ${{ matrix.runs-on }} + gpu-arch-type: ${{ matrix.gpu-arch-type }} + gpu-arch-version: ${{ matrix.gpu-arch-version }} + script: | + conda create -n venv python=3.9 -y + conda activate venv + python -m pip install --upgrade pip + pip install ${{ matrix.torch-spec }} + pip install -r dev-requirements.txt + pip install . + export CONDA=$(dirname $(dirname $(which conda))) + export LD_LIBRARY_PATH=$CONDA/lib/:$LD_LIBRARY_PATH + pytest test --verbose -s test: strategy: fail-fast: false @@ -38,11 +70,6 @@ jobs: torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121' gpu-arch-type: "cuda" gpu-arch-version: "12.1" - - name: CUDA Nightly - runs-on: linux.g5.12xlarge.nvidia.gpu - torch-spec: '--pre torch==2.6.0.dev20241101 --index-url https://download.pytorch.org/whl/nightly/cu121' - gpu-arch-type: "cuda" - gpu-arch-version: "12.1" - name: CPU 2.3 runs-on: linux.4xlarge @@ -59,11 +86,6 @@ jobs: torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cpu' gpu-arch-type: "cpu" gpu-arch-version: "" - - name: CPU Nightly - runs-on: linux.4xlarge - torch-spec: '--pre torch==2.6.0.dev20241101 --index-url https://download.pytorch.org/whl/nightly/cpu' - gpu-arch-type: "cpu" - gpu-arch-version: "" uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: diff --git a/.gitignore b/.gitignore index 5c201d1b34..5fa7064cbe 100644 --- a/.gitignore +++ b/.gitignore @@ -371,4 +371,7 @@ venv/ sweep/ # Model checkpoints -checkpoints/ \ No newline at end of file +checkpoints/ + +# Experimental +torchao/experimental/cmake-out diff --git a/README.md b/README.md index 158cfb7562..6ba0e3be4c 100644 --- a/README.md +++ b/README.md @@ -177,8 +177,8 @@ We're also fortunate to be integrated into some of the leading open-source libra 2. Hugging Face diffusers best practices with torch.compile and torchao in a standalone repo [diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao) 3. Mobius HQQ backend leveraged our int4 kernels to get [195 tok/s on a 4090](https://github.com/mobiusml/hqq#faster-inference) 4. [TorchTune](https://github.com/pytorch/torchtune) for our QLoRA and QAT recipes -5. [torchchat](https://github.com/pytorch/torchtune) for post training quantization -6. [SGLang](https://github.com/sgl-project/sglang/pull/1341) for LLM inference quantization +5. [torchchat](https://github.com/pytorch/torchchat) for post training quantization +6. SGLang for LLM serving: [usage](https://github.com/sgl-project/sglang/blob/4f2ee48ed1c66ee0e189daa4120581de324ee814/docs/backend/backend.md?plain=1#L83) and the major [PR](https://github.com/sgl-project/sglang/pull/1341). ## Videos * [Keynote talk at GPU MODE IRL](https://youtu.be/FH5wiwOyPX4?si=VZK22hHz25GRzBG1&t=1009) @@ -201,8 +201,9 @@ If you find the torchao library useful, please cite it in your work as below. @software{torchao, title = {torchao: PyTorch native quantization and sparsity for training and inference}, author = {torchao maintainers and contributors}, - url = {https//github.com/pytorch/torchao}, + url = {https://github.com/pytorch/torchao}, license = {BSD-3-Clause}, month = oct, year = {2024} +} ``` diff --git a/benchmarks/benchmark_low_bit_adam.py b/benchmarks/benchmark_low_bit_adam.py index bd31193892..986cc58b4f 100644 --- a/benchmarks/benchmark_low_bit_adam.py +++ b/benchmarks/benchmark_low_bit_adam.py @@ -4,7 +4,7 @@ # - lpmm (4-bit optim): pip install yacs git+https://github.com/thu-ml/low-bit-optimizers.git # - DeepSpeed (ZeRO-Offload): # sudo apt install libopenmpi-dev -# LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu pip install mpi4p +# LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu pip install mpi4py # DS_BUILD_CPU_ADAM=1 pip install deepspeed --no-cache-dir # # To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default AdamW optimizer from PyTorch core @@ -31,11 +31,15 @@ import torch.nn.functional as F import wandb from torch.utils.data import DataLoader +from torchao.utils import get_available_devices from torchvision.transforms import v2 from tqdm import tqdm from torchao.prototype import low_bit_optim +_DEVICE = get_available_devices()[-1] +assert _DEVICE in ["cuda", "xpu"], "Benchmark currently only supports CUDA & XPU(BF16)" + OPTIM_MAP = dict( AdamW=partial(torch.optim.AdamW, fused=True), AdamW8bitBnb=bnb.optim.AdamW8bit, @@ -49,7 +53,9 @@ OPTIM_MAP.update( AdamW4bitLpmm=partial(lpmm.optim.AdamW, fused=True), - AdamW4bitRank1Lpmm=partial(lpmm.optim.AdamW, qconfig=argparse.Namespace(scale_type="rank1")), + AdamW4bitRank1Lpmm=partial( + lpmm.optim.AdamW, qconfig=argparse.Namespace(scale_type="rank1") + ), ) except ImportError: @@ -67,8 +73,12 @@ def get_lr(self, step: int) -> float: if step < self.warmup_steps: return self.lr * step / self.warmup_steps if step < self.total_steps: - progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps) - return self.final_lr + 0.5 * (self.lr - self.final_lr) * (1 + math.cos(progress * math.pi)) + progress = (step - self.warmup_steps) / ( + self.total_steps - self.warmup_steps + ) + return self.final_lr + 0.5 * (self.lr - self.final_lr) * ( + 1 + math.cos(progress * math.pi) + ) return self.final_lr @@ -92,7 +102,9 @@ def get_parser(): parser.add_argument("--weight_decay", type=float, default=0) parser.add_argument("--optim_kwargs", type=json.loads, default=dict()) parser.add_argument("--cosine_lr_scheduler", action="store_true") - parser.add_argument("--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"]) + parser.add_argument( + "--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"] + ) parser.add_argument("--project") parser.add_argument("--run_name", default="debug") @@ -110,11 +122,15 @@ def get_dloader(args, training: bool): transforms.extend([v2.Resize(256), v2.CenterCrop(224)]) transforms.append(v2.ToDtype(torch.float32, scale=True)) - transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) + transforms.append( + v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ) transforms = v2.Compose(transforms) # use dataset from HF so download is fast - ds = datasets.load_dataset("timm/resisc45", split="train" if training else "validation") + ds = datasets.load_dataset( + "timm/resisc45", split="train" if training else "validation" + ) ds = ds.select_columns(["image", "label"]) ds.set_transform(lambda x: dict(image=transforms(x["image"]), label=x["label"])) @@ -128,9 +144,9 @@ def get_dloader(args, training: bool): ) -def get_amp_ctx(amp): +def get_amp_ctx(amp, device): dtype = dict(bf16=torch.bfloat16, fp16=torch.float16, none=None)[amp] - return torch.autocast("cuda", dtype=dtype, enabled=amp != "none") + return torch.autocast(device, dtype=dtype, enabled=amp != "none") @torch.no_grad() @@ -148,8 +164,8 @@ def evaluate_model(model, args): if args.channels_last: batch["image"] = batch["image"].to(memory_format=torch.channels_last) - with get_amp_ctx(args.amp): - all_preds.append(model(batch["image"].cuda()).argmax(1).cpu()) + with get_amp_ctx(args.amp, _DEVICE): + all_preds.append(model(batch["image"].to(_DEVICE)).argmax(1).cpu()) all_labels = torch.cat(all_labels, dim=0) all_preds = torch.cat(all_preds, dim=0) @@ -164,8 +180,12 @@ def evaluate_model(model, args): if args.full_bf16: assert args.amp == "none", "When --full_bf16 is set, --amp must be none" if args.optim_cpu_offload == "deepspeed": - assert args.amp == "none", "When using DeepSpeed ZeRO-Offload, --amp must be none" - assert args.optim == "AdamW", "When using DeepSpeed ZeRO-Offload, --optim must be AdamW" + assert ( + args.amp == "none" + ), "When using DeepSpeed ZeRO-Offload, --amp must be none" + assert ( + args.optim == "AdamW" + ), "When using DeepSpeed ZeRO-Offload, --optim must be AdamW" if args.profile: args.n_epochs = 1 if args.seed is not None: @@ -185,14 +205,16 @@ def evaluate_model(model, args): dloader = get_dloader(args, True) print(f"Train dataset: {len(dloader.dataset):,} images") - model = timm.create_model(args.model, pretrained=True, num_classes=45, **args.model_kwargs) + model = timm.create_model( + args.model, pretrained=True, num_classes=45, **args.model_kwargs + ) if args.checkpoint_activations: model.set_grad_checkpointing() if args.full_bf16: model.bfloat16() if args.channels_last: model.to(memory_format=torch.channels_last) - model.cuda() # move model to CUDA after optionally convert it to BF16 + model.to(_DEVICE) # move model to DEVICE after optionally convert it to BF16 if args.compile: model.compile(fullgraph=True) print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") @@ -227,9 +249,15 @@ def evaluate_model(model, args): optim_cls = OPTIM_MAP[args.optim] if args.optim_cpu_offload == "ao": - optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls) + optim_cls = partial( + low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls + ) elif args.optim_cpu_offload == "ao_offload_grads": - optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls, offload_gradients=True) + optim_cls = partial( + low_bit_optim.CPUOffloadOptimizer, + optimizer_class=optim_cls, + offload_gradients=True, + ) optim = optim_cls( model.parameters(), @@ -239,24 +267,30 @@ def evaluate_model(model, args): ) lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs) - grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16") + grad_scaler = torch.amp.GradScaler(_DEVICE, enabled=args.amp == "fp16") log_interval = 10 t0 = time.perf_counter() step = 0 for epoch_idx in range(args.n_epochs): model.train() - pbar = tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}") + pbar = tqdm( + dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}" + ) with torch.profiler.profile() if args.profile else nullcontext() as prof: for batch in pbar: if args.full_bf16: batch["image"] = batch["image"].bfloat16() if args.channels_last: - batch["image"] = batch["image"].to(memory_format=torch.channels_last) + batch["image"] = batch["image"].to( + memory_format=torch.channels_last + ) - with get_amp_ctx(args.amp): - loss = F.cross_entropy(model(batch["image"].cuda()), batch["label"].cuda()) + with get_amp_ctx(args.amp, _DEVICE): + loss = F.cross_entropy( + model(batch["image"].to(_DEVICE)), batch["label"].to(_DEVICE) + ) if args.optim_cpu_offload == "deepspeed": model.backward(loss) @@ -275,7 +309,9 @@ def evaluate_model(model, args): log_dict = dict(loss=loss.item(), lr=optim.param_groups[0]["lr"]) if step > 0: t1 = time.perf_counter() - log_dict["imgs_per_second"] = args.batch_size * log_interval / (t1 - t0) + log_dict["imgs_per_second"] = ( + args.batch_size * log_interval / (t1 - t0) + ) t0 = t1 logger.log(log_dict, step=step) @@ -296,9 +332,11 @@ def evaluate_model(model, args): else: val_acc = evaluate_model(model, args) - print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}") + print( + f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}" + ) logger.log(dict(val_acc=val_acc), step=step) - peak_mem = torch.cuda.max_memory_allocated() / 1e9 + peak_mem = getattr(torch, _DEVICE).max_memory_allocated() / 1e9 print(f"Max memory used: {peak_mem:.02f} GB") logger.log(dict(max_memory_allocated=peak_mem)) diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index f4f2813a37..e545ea4665 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -6,6 +6,7 @@ import copy import io +import functools import os import random from contextlib import nullcontext, redirect_stdout @@ -22,6 +23,11 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.utils.checkpoint import ( + checkpoint, + create_selective_checkpoint_contexts, + CheckpointPolicy, +) from torchao.float8.config import ( CastConfig, Float8LinearConfig, @@ -254,6 +260,22 @@ def profile_function( return prof +# set up AC for max(abs(tensor)) +# context: https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.create_selective_checkpoint_contexts +ops_to_save = [ + torch.ops.aten.abs.default, + torch.ops.aten.max.default, +] + +def policy_fn(ctx, op, *args, **kwargs): + if op in ops_to_save: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + +context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + + def main( profile_path_prefix: pathlib.Path, compile: bool = True, @@ -265,6 +287,7 @@ def main( dtype_filter: str = "both", add_inductor_metadata_to_trace: bool = True, enable_sync_amax_history: bool = True, + enable_activation_checkpointing: bool = False, ): assert model_type in ("linear", "ln_linear", "norm_ffn_norm", "norm_ffn_norm_small"), "unsupported" assert dtype_filter in ("both", "float8", "bfloat16") @@ -294,6 +317,7 @@ def main( print(f"Compile is set to | {compile}") print(f"model_type is set to | {model_type}") print(f"scaling_repr is set to | {scaling_repr}") + print(f"enable_activation_checkpointing is set to {enable_activation_checkpointing}") device = "cuda" ref_dtype = torch.bfloat16 @@ -338,11 +362,17 @@ def main( convert_to_float8_training(m_float8, config=config) def ref_forw_backward(x): - out = m_ref(x) + if enable_activation_checkpointing: + out = checkpoint(m_ref, x, use_reentrant=False, context_fn=context_fn) + else: + out = m_ref(x) out.sum().backward() def float8_forw(x): - out = m_float8(x) + if enable_activation_checkpointing: + out = checkpoint(m_float8, x, use_reentrant=False, context_fn=context_fn) + else: + out = m_float8(x) return out sync_amax_history = sync_float8_amax_and_scale_history diff --git a/examples/sam2_amg_server/README.md b/examples/sam2_amg_server/README.md index 5c6cf4fb06..43fc2b2528 100644 --- a/examples/sam2_amg_server/README.md +++ b/examples/sam2_amg_server/README.md @@ -8,7 +8,7 @@ curl -X POST http://127.0.0.1:5000/upload -F 'image=@/path/to/file.jpg' --output Start the server ``` -python server.py ~/checkpoints/sam2 --port --host --fast +python server.py ~/checkpoints/sam2 large --port --host --fast ``` Collect the rles @@ -24,11 +24,11 @@ Experiments run on H100 and with batch size 1 | mode | mIoU | mask count mismatch | avg. ms per request | max. memory (MiB (%)) | batch size | points per batch | | -------------- | ----------------- | ------------------- | ------------------- | --------------------- | ---------- | ---------------- | | baseline | 1.0 | 0 | 863 | 4013MiB (4%) | 1 | 64 | -| ao | 0.9999980926513672 | 6 | 586 | | 1 | 64 | -| fast | 0.9937329888343811 | 191 | 333 | | 1 | 1024 | -| fast | 0.9937219619750977 | 192 | 324 | | 16 | 1024 | -| fast + furious | 0.9804400205612183 | 292 | 131 | | 1 | 1024 | -| fast + furious | 0.9806423187255859 | 282 | 130 | | 16 | 1024 | +| ao | 0.9999980926513672 | 6 | 586 | 3257MiB (3%) | 1 | 64 | +| fast | 0.993732988834381 | 191 | 326 | 27197MiB (27%) | 1 | 1024 | +| fast | 0.9937511086463928 | 194 | 315 | 27488MiB (28%) | 16 | 1024 | +| fast + furious | 0.9817246198654175 | 266 | 120 | 13616MiB (13%) | 1 | 1024 | +| fast + furious | 0.9794579744338989 | 274 | 122 | 13808MiB (14%) | 16 | 1024 | mask count mismatch counts the number of requests where the number of masks differ from the baseline. For example, the baseline may have chosen to segment an image into 18 masks, but the fast variant produces 17 or 19. @@ -58,7 +58,7 @@ Make sure you've installed https://github.com/facebookresearch/sam2 Start server ``` -python server.py ~/checkpoints/sam2 --port --host --baseline +python server.py ~/checkpoints/sam2 large --port --host --baseline ``` Generate and save rles (one line per json via `-w "\n"`) @@ -73,7 +73,7 @@ sys 0m4.137s ### 3. Start server with torchao variant of SAM2 Start server ``` -python server.py ~/checkpoints/sam2 --port --host +python server.py ~/checkpoints/sam2 large --port --host ``` Generate and save rles (one line per json via `-w "\n"`) @@ -88,7 +88,7 @@ sys 0m4.350s ### 4. Start server with torchao variant of SAM2 and `--fast` optimizations Start server ``` -python server.py ~/checkpoints/sam2 --port --host --fast +python server.py ~/checkpoints/sam2 large --port --host --fast ``` Generate and save rles (one line per json via `-w "\n"`) @@ -103,7 +103,7 @@ sys 0m4.138s ### 5. Start server with torchao variant of SAM2 and `--fast` and `--furious` optimizations Start server ``` -python server.py ~/checkpoints/sam2 --port --host --fast --furious +python server.py ~/checkpoints/sam2 large --port --host --fast --furious ``` Generate and save rles (one line per json via `-w "\n"`) diff --git a/examples/sam2_amg_server/cli.py b/examples/sam2_amg_server/cli.py new file mode 100644 index 0000000000..9cf5bdc8f3 --- /dev/null +++ b/examples/sam2_amg_server/cli.py @@ -0,0 +1,74 @@ +import fire +import logging +import matplotlib.pyplot as plt +from server import file_bytes_to_image_tensor +from server import show_anns +from server import model_type_to_paths +from server import MODEL_TYPES_TO_MODEL +from server import set_fast +from server import set_aot_fast +from server import load_aot_fast +from server import set_furious +from torchao._models.sam2.build_sam import build_sam2 +from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator +from torchao._models.sam2.utils.amg import rle_to_mask +from io import BytesIO + +def main_docstring(): + return f""" + Args: + checkpoint_path (str): Path to folder containing checkpoints from https://github.com/facebookresearch/sam2?tab=readme-ov-file#download-checkpoints + model_type (str): Choose from one of {", ".join(MODEL_TYPES_TO_MODEL.keys())} + input_path (str): Path to input image + output_path (str): Path to output image + """ + + +def main_headless(checkpoint_path, model_type, input_bytes, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False, load_fast=""): + device = "cuda" + sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type) + if verbose: + print(f"Loading model {sam2_checkpoint} with config {model_cfg}") + sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False) + mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle") + if furious: + set_furious(mask_generator) + if load_fast: + load_aot_fast(mask_generator, load_fast) + if fast: + set_fast(mask_generator, load_fast) + + image_tensor = file_bytes_to_image_tensor(input_bytes) + if verbose: + print(f"Loaded image of size {tuple(image_tensor.shape)} and generating mask.") + masks = mask_generator.generate(image_tensor) + + if verbose: + print("Generating mask annotations for input image.") + plt.figure(figsize=(image_tensor.shape[1]/100., image_tensor.shape[0]/100.), dpi=100) + plt.imshow(image_tensor) + show_anns(masks, rle_to_mask) + plt.axis('off') + plt.tight_layout() + buf = BytesIO() + plt.savefig(buf, format=output_format) + buf.seek(0) + return buf.getvalue() + +def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False, load_fast=""): + input_bytes = bytearray(open(input_path, 'rb').read()) + output_bytes = main_headless(checkpoint_path, + model_type, + input_bytes, + points_per_batch=points_per_batch, + output_format=output_format, + verbose=verbose, + fast=fast, + furious=furious, + load_fast=load_fast) + with open(output_path, "wb") as file: + file.write(output_bytes) + +main.__doc__ = main_docstring() +if __name__ == "__main__": + fire.Fire(main) diff --git a/examples/sam2_amg_server/cli_on_modal.py b/examples/sam2_amg_server/cli_on_modal.py new file mode 100644 index 0000000000..3295ede842 --- /dev/null +++ b/examples/sam2_amg_server/cli_on_modal.py @@ -0,0 +1,171 @@ +from pathlib import Path +import json +import fire + +import modal + +TARGET = "/root/" +DOWNLOAD_URL_BASE = "https://raw.githubusercontent.com/pytorch/ao/refs/heads" + +image = ( + modal.Image.debian_slim(python_version="3.12.7") + .pip_install("numpy<3", "tqdm") + .pip_install( + "torch", + pre=True, + index_url="https://download.pytorch.org/whl/nightly/cu124", # tested with torch-2.6.0.dev20241120 + ) + .pip_install( + "torchvision", + pre=True, + index_url="https://download.pytorch.org/whl/nightly/cu124", # tested with torch-2.6.0.dev20241120 + ) + .apt_install("git") + .apt_install("libopencv-dev") + .apt_install("python3-opencv") + .run_commands(["git clone https://github.com/pytorch/ao.git /tmp/ao_src"]) + .run_commands(["cd /tmp/ao_src; python setup.py develop"]) + .pip_install( + "gitpython", + ) + .apt_install("wget") + .run_commands([f"wget https://raw.githubusercontent.com/pytorch/ao/refs/heads/main/examples/sam2_amg_server/requirements.txt"]) + .pip_install_from_requirements( + 'requirements.txt', + ) +) + +app = modal.App("torchao-sam-2-cli", image=image) + +checkpoints = modal.Volume.from_name("torchao-sam-2-cli-checkpoints", create_if_missing=True) +data = modal.Volume.from_name("torchao-sam-2-cli-data", create_if_missing=True) + + +@app.cls( + gpu="H100", + container_idle_timeout=20 * 60, + timeout=20 * 60, + volumes={ + TARGET + "checkpoints": checkpoints, + TARGET + "data": data, + }, +) +class Model: + model_type: str = modal.parameter(default="large") + points_per_batch: int = modal.parameter(default=1024) + fast: int = modal.parameter(default=0) + furious: int = modal.parameter(default=0) + + def calculate_file_hash(self, file_path, hash_algorithm='sha256'): + import hashlib + """Calculate the hash of a file.""" + hash_func = hashlib.new(hash_algorithm) + with open(file_path, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_func.update(chunk) + return hash_func.hexdigest() + + def download_file(self, url, filename): + import subprocess + command = f"wget -O {filename} {url}" + subprocess.run(command, shell=True, check=True) + + @modal.build() + @modal.enter() + def build(self): + import os + from torchao._models.sam2.build_sam import build_sam2 + from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator + + download_url_branch = "climodal2" + download_url = f"{DOWNLOAD_URL_BASE}/{download_url_branch}/" + download_url += "examples/sam2_amg_server/" + + h = self.calculate_file_hash(TARGET + "data/cli.py") + print("cli.py hash: ", h) + if h != "b38d60cb6fad555ad3c33081672ae981a5e4e744199355dfd24d395d20dfefda": + self.download_file(download_url + "cli.py", TARGET + "data/cli.py") + + h = self.calculate_file_hash(TARGET + "data/server.py") + print("server.py hash: ", h) + if h != "af33fdb9bcfe668b7764cb9c86f5fa9a799c999306e7c7e5b28c988b2616a0ae": + self.download_file(download_url + "server.py", TARGET + "data/server.py") + + os.chdir(Path(TARGET + "data")) + import sys + sys.path.append(".") + + from server import model_type_to_paths + from server import set_fast + from server import set_furious + + + device = "cuda" + checkpoint_path = Path(TARGET) / Path("checkpoints") + sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, self.model_type) + sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False) + mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=self.points_per_batch, output_mode="uncompressed_rle") + self.mask_generator = mask_generator + if self.fast: + set_fast(mask_generator) + if self.furious: + set_furious(mask_generator) + + @modal.method() + def inference_rle(self, input_bytes) -> dict: + import os + os.chdir(Path(TARGET + "data")) + import sys + sys.path.append(".") + from server import file_bytes_to_image_tensor + from server import masks_to_rle_dict + image_tensor = file_bytes_to_image_tensor(input_bytes) + masks = self.mask_generator.generate(image_tensor) + return masks_to_rle_dict(masks) + + @modal.method() + def inference(self, input_bytes, output_format='png'): + import os + os.chdir(Path(TARGET + "data")) + import sys + sys.path.append(".") + from server import file_bytes_to_image_tensor + from server import show_anns + image_tensor = file_bytes_to_image_tensor(input_bytes) + masks = self.mask_generator.generate(image_tensor) + + import matplotlib.pyplot as plt + from io import BytesIO + from torchao._models.sam2.utils.amg import rle_to_mask + plt.figure(figsize=(image_tensor.shape[1]/100., image_tensor.shape[0]/100.), dpi=100) + plt.imshow(image_tensor) + show_anns(masks, rle_to_mask) + plt.axis('off') + plt.tight_layout() + buf = BytesIO() + plt.savefig(buf, format=output_format) + buf.seek(0) + return buf.getvalue() + + +def main(input_path, output_path, fast=False, furious=False, model_type="large", output_rle=False): + input_bytes = bytearray(open(input_path, 'rb').read()) + try: + model = modal.Cls.lookup("torchao-sam-2-cli", "Model")() + except modal.exception.NotFoundError: + print("Can't find running app. To deploy the app run the following command. Note that this costs money! See https://modal.com/pricing") + print("modal deploy cli_on_modal.py") + return + + if output_rle: + output_dict = model.inference_rle.remote(input_bytes) + with open(output_path, "w") as file: + file.write(json.dumps(output_dict, indent=4)) + else: + output_bytes = model.inference.remote(input_bytes) + with open(output_path, "wb") as file: + file.write(output_bytes) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/examples/sam2_amg_server/requirements.txt b/examples/sam2_amg_server/requirements.txt index a77773d62b..e591e89100 100644 --- a/examples/sam2_amg_server/requirements.txt +++ b/examples/sam2_amg_server/requirements.txt @@ -7,3 +7,4 @@ hydra-core tqdm iopath python-multipart +requests diff --git a/examples/sam2_amg_server/server.py b/examples/sam2_amg_server/server.py index 516df71a9a..060a5ad5dd 100644 --- a/examples/sam2_amg_server/server.py +++ b/examples/sam2_amg_server/server.py @@ -1,4 +1,5 @@ import itertools +import requests import uvicorn import fire import tempfile @@ -37,6 +38,23 @@ # torch._dynamo.config.capture_dynamic_output_shape_ops = True torch._dynamo.config.capture_dynamic_output_shape_ops = True +def download_file(url, download_dir): + # Create the directory if it doesn't exist + download_dir = Path(download_dir) + download_dir.mkdir(parents=True, exist_ok=True) + # Extract the file name from the URL + file_name = url.split('/')[-1] + # Define the full path for the downloaded file + file_path = download_dir / file_name + # Download the file + response = requests.get(url, stream=True) + response.raise_for_status() # Raise an error for bad responses + # Write the file to the specified directory + print(f"Downloading '{file_name}' to '{download_dir}'") + with open(file_path, 'wb') as file: + for chunk in response.iter_content(chunk_size=8192): + file.write(chunk) + print(f"Downloaded '{file_name}' to '{download_dir}'") def example_shapes(): return [(848, 480, 3), @@ -149,6 +167,26 @@ def profiler_runner(path, fn, *args, **kwargs): return result +def memory_runner(path, fn, *args, **kwargs): + print("Start memory recording") + torch.cuda.synchronize() + torch.cuda.memory._record_memory_history( + True, + trace_alloc_max_entries=100000, + trace_alloc_record_context=True + ) + result = fn(*args, **kwargs) + torch.cuda.synchronize() + snapshot = torch.cuda.memory._snapshot() + print("Finish memory recording") + import pickle + with open(path, 'wb') as f: + pickle.dump(snapshot, f) + # Use to convert pickle file into html + # python torch/cuda/_memory_viz.py trace_plot .pickle -o .html + return result + + def image_tensor_to_masks(example_image, mask_generator): masks = mask_generator.generate(example_image) return masks @@ -187,7 +225,7 @@ def process_batch(batch, mask_generator): print(f"Processing batch of len {len(batch)} using generate_batch") masks = mask_generator.generate_batch(image_tensors) print(f"Took avg. {(time.time() - t) / len(batch)}s per batch entry") - # max_memory_allocated() + max_memory_allocated() return masks @@ -252,19 +290,273 @@ def unittest_fn(masks, ref_masks, order_by_area=False, verbose=False): print(f"mIoU is {miou} with equal count {equal_count} out of {len(masks)}") +MODEL_TYPES_TO_CONFIG = { + "tiny": "sam2.1_hiera_t.yaml", + "small": "sam2.1_hiera_s.yaml", + "plus": "sam2.1_hiera_b+.yaml", + "large": "sam2.1_hiera_l.yaml", + } + +MODEL_TYPES_TO_MODEL = { + "tiny": "sam2.1_hiera_tiny.pt", + "small": "sam2.1_hiera_small.pt", + "plus": "sam2.1_hiera_base_plus.pt", + "large": "sam2.1_hiera_large.pt", + } + + +MODEL_TYPES_TO_URL = { + "tiny": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt", + "small": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt", + "plus": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt", + "large": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt", + } + + +def main_docstring(): + return f""" + Args: + checkpoint_path (str): Path to folder containing checkpoints from https://github.com/facebookresearch/sam2?tab=readme-ov-file#download-checkpoints + model_type (str): Choose from one of {", ".join(MODEL_TYPES_TO_MODEL.keys())} + """ + + +def model_type_to_paths(checkpoint_path, model_type): + if model_type not in MODEL_TYPES_TO_CONFIG.keys(): + raise ValueError(f"Expected model_type to be one of {', '.join(MODEL_TYPES_TO_MODEL.keys())} but got {model_type}") + sam2_checkpoint = Path(checkpoint_path) / Path(MODEL_TYPES_TO_MODEL[model_type]) + if not sam2_checkpoint.exists(): + print(f"Can't find checkpoint {sam2_checkpoint} in folder {checkpoint_path}. Downloading.") + download_file(MODEL_TYPES_TO_URL[model_type], checkpoint_path) + assert sam2_checkpoint.exists(), "Can't find downloaded file. Please open an issue." + model_cfg = f"configs/sam2.1/{MODEL_TYPES_TO_CONFIG[model_type]}" + return sam2_checkpoint, model_cfg + + +def aot_compile(model_directory, name, fn, sample_args): + path = Path(model_directory) / Path(f"{name}.pt2") + print(f"Saving at {path=}") + options = { + "max_autotune": True, + "triton.cudagraphs": True, + } + + exported = torch.export.export_for_inference(fn, sample_args) + output_path = torch._inductor.aoti_compile_and_package( + exported, + package_path=str(path), + inductor_configs=options, + ) + return output_path + + +def aot_load(path): + return torch._export.aot_load(path, "cuda") + +class FunctionModel(torch.nn.Module): + + def __init__(self, module, fn_name): + super().__init__() + self.module = module + self.fn_name = fn_name + + def forward(self, *args): + return getattr(self.module, self.fn_name)(*args) + + +def set_aot_fast(mask_generator, model_directory): + example_input = torch.empty(1, 3, 1024, 1024) + example_input = example_input.to(mask_generator.predictor._image_dtype) + example_input = (example_input.to(mask_generator.predictor.device),) + aot_compile(model_directory, + "sam2_image_encoder", + mask_generator.predictor.model.image_encoder, + example_input) + + # NOTE: THIS DOESN'T WORK YET! + # example_input_0_0 = torch.empty(1, 32, 256, 256, dtype=torch.float16, device=mask_generator.predictor.device) + # example_input_0_1 = torch.empty(1, 64, 128, 128, dtype=torch.float16, device=mask_generator.predictor.device) + # example_input_1 = torch.empty(1, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device) + # example_input_2 = torch.empty(1024, 1, 2, dtype=torch.float32, device=mask_generator.predictor.device) + # example_input_3 = torch.empty(1024, 1, dtype=torch.int32, device=mask_generator.predictor.device) + # example_input = ([example_input_0_0, example_input_0_1], + # example_input_1, + # example_input_2, + # example_input_3, + # None, + # None, + # True, + # True, + # -1) + # mask_generator.forward = mask_generator.predictor._predict_masks_with_features + # mask_generator(*example_input) + # aot_compile("sam2__predict_masks_with_features", + # mask_generator, + # example_input) + + # example_input_2 = torch.empty(1024, 1, 2, dtype=torch.float32, device=mask_generator.predictor.device) + # example_input_3 = torch.empty(1024, 1, dtype=torch.int32, device=mask_generator.predictor.device) + # aot_compile("sam2_sam_prompt_encoder", + # mask_generator.predictor.model.sam_prompt_encoder, + # ((example_input_2, example_input_3), + # None, + # None)) + + # NOTE: THIS DOESN'T WORK YET! + # example_input_0 = torch.empty(1, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device) + # example_input_1 = torch.empty(1, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device) + # example_input_2 = torch.empty(1024, 2, 256, dtype=torch.float32, device=mask_generator.predictor.device) + # example_input_3 = torch.empty(1024, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device) + + # example_input_4_0 = torch.empty(1, 32, 256, 256, dtype=torch.float16, device=mask_generator.predictor.device) + # example_input_4_1 = torch.empty(1, 64, 128, 128, dtype=torch.float16, device=mask_generator.predictor.device) + + # example_input = (example_input_0, + # example_input_1, + # example_input_2, + # example_input_3, + # True, + # True, + # [example_input_4_0, example_input_4_1]) + # print("Example") + # mask_generator.predictor.model.sam_mask_decoder(*example_input) + # print("Example done") + # aot_compile("sam2_sam_mask_decoder", + # mask_generator.predictor.model.sam_mask_decoder, + # example_input) + + # example_input_0 = torch.empty(1024, 256, 64, 64, dtype=torch.float16, device=mask_generator.predictor.device) + # example_input_1 = torch.empty(1024, 256, 64, 64, dtype=torch.float16, device=mask_generator.predictor.device) + # example_input_2 = torch.empty(1024, 8, 256, dtype=torch.float16, device=mask_generator.predictor.device) + # example_input = (example_input_0, example_input_1, example_input_2) + + # mask_generator.predictor.model.sam_mask_decoder.transformer(*example_input) + # aot_compile("sam2_sam_mask_decoder_transformer", + # mask_generator.predictor.model.sam_mask_decoder.transformer, + # example_input) + + + + +class LoadedModel(torch.nn.Module): + + def __init__(self, aoti_compiled_model): + super().__init__() + self.aoti_compiled_model = aoti_compiled_model + + def forward(self, *args): + return self.aoti_compiled_model(*args) + +class LoadedDecoder(torch.nn.Module): + + def __init__(self, aoti_compiled_model, other): + super().__init__() + self.aoti_compiled_model = aoti_compiled_model + self.other = other + + def forward(self, *args): + return self.aoti_compiled_model(*args) + + def get_dense_pe(self, *args, **kwargs) -> torch.Tensor: + return self.other.get_dense_pe(*args, **kwargs) + +def load_aot_fast(mask_generator, model_directory): + t0 = time.time() + path = Path(model_directory) / Path(f"sam2_image_encoder.pt2") + assert path.exists(), f"Expected {path} to exist." + print(f"Start load from {path}") + pkg = torch._inductor.aoti_load_package(str(path)) + pkg_m = LoadedModel(pkg) + mask_generator.predictor.model.image_encoder = pkg_m + + # NOTE: This doesn't work yet! + # pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2__predict_masks_with_features.pt2")) + # pkg_m = LoadedModel(pkg) + # mask_generator.predictor._predict_masks_with_features = pkg_m.forward + + # pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2_sam_prompt_encoder.pt2")) + # pkg_m = LoadedDecoder(pkg, mask_generator.predictor.model.sam_prompt_encoder) + # mask_generator.predictor.model.sam_prompt_encoder = pkg_m + + # NOTE: This doesn't work yet! + # pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2_sam_mask_decoder.pt2")) + # pkg_m = LoadedModel(pkg) + # pkg_m.conv_s0 = mask_generator.predictor.model.sam_mask_decoder.conv_s0 + # pkg_m.conv_s1 = mask_generator.predictor.model.sam_mask_decoder.conv_s1 + # mask_generator.predictor.model.sam_mask_decoder = pkg_m + + # pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2_sam_mask_decoder_transformer.pt2")) + # pkg_m = LoadedModel(pkg) + # mask_generator.predictor.model.sam_mask_decoder.transformer = pkg_m + + print(f"End load. Took {time.time() - t0}s") + + +def set_fast(mask_generator, load_fast=""): + if load_fast == "": + # TODO: Using CUDA graphs can cause numerical differences? + mask_generator.predictor.model.image_encoder = torch.compile( + mask_generator.predictor.model.image_encoder, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + mask_generator.predictor._predict_masks = torch.compile( + mask_generator.predictor._predict_masks, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + # mask_generator.predictor._predict_masks_postprocess = torch.compile( + # mask_generator.predictor._predict_masks_postprocess, + # fullgraph=True, + # dynamic=True, + # ) + + +def set_furious(mask_generator): + mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16) + # NOTE: Not baseline feature + mask_generator.predictor._image_dtype = torch.float16 + mask_generator.predictor._transforms_device = mask_generator.predictor.device + torch.set_float32_matmul_precision('high') + mask_generator.predictor.model.sam_mask_decoder = mask_generator.predictor.model.sam_mask_decoder.to(torch.float16) + # NOTE: Not baseline feature + mask_generator.predictor.model.sam_mask_decoder._src_dtype = torch.float16 + +def set_autoquant(mask_generator): + from torchao import autoquant + from torchao.quantization import DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + # NOTE: Not baseline feature + mask_generator.predictor.model.image_encoder = autoquant(mask_generator.predictor.model.image_encoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40) + mask_generator.predictor._transforms_device = mask_generator.predictor.device + torch.set_float32_matmul_precision('high') + # NOTE: this fails when we run + # python server.py ~/checkpoints/sam2 large --port 8000 --host localhost --fast --use_autoquant --unittest + # https://gist.github.com/jerryzh168/d337cb5de0a1dec306069fe48ac8225e + # mask_generator.predictor.model.sam_mask_decoder = autoquant(mask_generator.predictor.model.sam_mask_decoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40) + + def main(checkpoint_path, + model_type, baseline=False, fast=False, furious=False, + use_autoquant=False, unittest=False, benchmark=False, profile=None, + memory_profile=None, verbose=False, points_per_batch=64, port=5000, host="127.0.0.1", dry=False, - batch_size=1): + batch_size=1, + load_fast="", + save_fast=""): if verbose: logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', @@ -283,56 +575,34 @@ def main(checkpoint_path, from torchao._models.sam2.build_sam import build_sam2 from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator from torchao._models.sam2.utils.amg import rle_to_mask - + device = "cuda" - from pathlib import Path - sam2_checkpoint = Path(checkpoint_path) / Path("sam2.1_hiera_large.pt") - model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" - + sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type) + logging.info(f"Loading model {sam2_checkpoint} with config {model_cfg}") sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False) - + logging.info(f"Using {points_per_batch} points_per_batch") mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle") - if fast: - assert not baseline, "--fast cannot be combined with baseline. code to be torch.compile(fullgraph=True) compatible." - # TODO: Using CUDA graphs can cause numerical differences? - mask_generator.predictor.model.image_encoder = torch.compile( - mask_generator.predictor.model.image_encoder, - mode="max-autotune", - fullgraph=True, - dynamic=False, - ) - - mask_generator.predictor.model.sam_prompt_encoder.forward = torch.compile( - mask_generator.predictor.model.sam_prompt_encoder.forward, - mode="max-autotune", - fullgraph=True, - dynamic=False, - ) + if load_fast != "": + load_aot_fast(mask_generator, load_fast) - mask_generator.predictor._predict_masks = torch.compile( - mask_generator.predictor._predict_masks, - mode="max-autotune", - fullgraph=True, - dynamic=False, - ) + if furious: + set_furious(mask_generator) + # since autoquant is replicating what furious mode is doing, don't use these two together + elif use_autoquant: + set_autoquant(mask_generator) - # mask_generator.predictor._predict_masks_postprocess = torch.compile( - # mask_generator.predictor._predict_masks_postprocess, - # fullgraph=True, - # dynamic=True, - # ) + if save_fast != "": + assert load_fast == "", "Can't save compiled models while loading them with --load-fast." + assert not baseline, "--fast cannot be combined with baseline. code to be torch.compile(fullgraph=True) compatible." + print(f"Saving compiled models under directory {save_fast}") + set_aot_fast(mask_generator, save_fast) - if furious: - mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16) - # NOTE: Not baseline feature - mask_generator.predictor._image_dtype = torch.float16 - torch.set_float32_matmul_precision('high') - mask_generator.predictor.model.sam_mask_decoder = mask_generator.predictor.model.sam_mask_decoder.to(torch.float16) - # NOTE: Not baseline feature - mask_generator.predictor.model.sam_mask_decoder._src_dtype = torch.float16 + if fast: + assert not baseline, "--fast cannot be combined with baseline. code to be torch.compile(fullgraph=True) compatible." + set_fast(mask_generator, load_fast) with open('dog.jpg', 'rb') as f: image_tensor = file_bytes_to_image_tensor(bytearray(f.read())) @@ -363,11 +633,15 @@ def main(checkpoint_path, for i, shapes in enumerate([example_shapes(), example_shapes_2()]): print(f"batch size {batch_size} example shapes {i} benchmark") random_images = [np.random.randint(0, 256, size=size, dtype=np.uint8) for size in shapes] + if batch_size > len(random_images): + num_repeat = (len(random_images) + batch_size) // batch_size + random_images = num_repeat * random_images if batch_size == 1: [benchmark_fn(image_tensor_to_masks, r, mask_generator) for r in random_images] else: random_images = random_images[:batch_size] + print("len(random_images): ", len(random_images)) benchmark_fn(image_tensors_to_masks, random_images, mask_generator) if profile is not None: @@ -377,6 +651,13 @@ def main(checkpoint_path, else: profiler_runner(profile, image_tensors_to_masks, [image_tensor] * batch_size, mask_generator) + if memory_profile is not None: + print(f"Saving memory profile under {memory_profile}") + if batch_size == 1: + memory_runner(memory_profile, image_tensor_to_masks, image_tensor, mask_generator) + else: + memory_runner(memory_profile, image_tensors_to_masks, [image_tensor] * batch_size, mask_generator) + if dry: return @@ -401,7 +682,7 @@ async def upload_rle(image: UploadFile = File(...)): await request_queue.put((image_tensor, response_future)) masks = await response_future return masks_to_rle_dict(masks) - + @app.post("/upload") async def upload_image(image: UploadFile = File(...)): image_tensor = file_bytes_to_image_tensor(bytearray(await image.read())) @@ -419,10 +700,11 @@ async def upload_image(image: UploadFile = File(...)): plt.savefig(buf, format='png') buf.seek(0) return StreamingResponse(buf, media_type="image/png") - + # uvicorn.run(app, host=host, port=port, log_level="info") uvicorn.run(app, host=host, port=port) +main.__doc__ = main_docstring() if __name__ == "__main__": fire.Fire(main) diff --git a/packaging/post_build_script.sh b/packaging/post_build_script.sh index 70e8d83392..e6cfc8adfe 100644 --- a/packaging/post_build_script.sh +++ b/packaging/post_build_script.sh @@ -13,7 +13,7 @@ if [[ "$CU_VERSION" == cu* ]]; then WHEEL_NAME=$(ls dist/) pushd dist - manylinux_plat=manylinux2014_x86_64 + manylinux_plat=manylinux_2_28_x86_64 auditwheel repair --plat "$manylinux_plat" -w . \ --exclude libtorch.so \ --exclude libtorch_python.so \ diff --git a/ruff.toml b/ruff.toml index 1aed057d25..b20cab030c 100644 --- a/ruff.toml +++ b/ruff.toml @@ -7,13 +7,17 @@ include = [ "torchao/quantization/**/*.py", "torchao/dtypes/**/*.py", "torchao/sparsity/**/*.py", + "torchao/profiler/**/*.py", + "torchao/testing/**/*.py", "torchao/prototype/low_bit_optim/**.py", - "test/quantization/test_observer.py", - "test/dtypes/test_affine_quantized_float.py", - "test/dtypes/test_nf4.py", - "test/prototype/low_bit_optim/**.py", "torchao/utils.py", - + "torchao/ops.py", + "torchao/_executorch_ops.py", + "test/float8/**/*.py", + "test/quantization/**/*.py", + "test/dtypes/**/*.py", + "test/sparsity/**/*.py", + "test/prototype/low_bit_optim/**.py", ] lint.ignore = ["E731"] diff --git a/scripts/prepare.sh b/scripts/prepare.sh index db426e3b11..9cbc8295ee 100644 --- a/scripts/prepare.sh +++ b/scripts/prepare.sh @@ -2,7 +2,11 @@ python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf python scripts/download.py --repo_id meta-llama/Meta-Llama-3-8B python scripts/download.py --repo_id meta-llama/Meta-Llama-3.1-8B python scripts/download.py --repo_id meta-llama/Llama-3.2-3B +python scripts/download.py --repo_id nm-testing/SparseLlama-3-8B-pruned_50.2of4 python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3-8B python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-3.2-3B +# neuralmagic doesn't come with tokenizer, so we need to copy it over +mkdir -p checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original && cp checkpoints/meta-llama/Meta-Llama-3-8B/original/tokenizer.model checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original/tokenizer.model +python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4 diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index dd7e679e56..43d57b7d12 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -1,28 +1,30 @@ +import tempfile +import unittest + +import torch +from torch.testing._internal import common_utils from torch.testing._internal.common_utils import ( TestCase, run_tests, ) + +from torchao.dtypes import Int4CPULayout, SemiSparseLayout from torchao.quantization import ( + float8_weight_only, int4_weight_only, - int8_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, - int8_dynamic_activation_int8_semi_sparse_weight, - float8_weight_only, + int8_weight_only, ) from torchao.quantization.quant_primitives import MappingType -from torchao.dtypes import SemiSparseLayout -from torch.testing._internal import common_utils -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -import torch -import unittest -import tempfile - -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_6, + is_sm_at_least_89, +) -def get_quantization_functions(do_sparse: bool, do_int4: bool): +def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"): base_functions = [ int8_weight_only(), int8_dynamic_activation_int4_weight(), @@ -30,12 +32,19 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool): int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC), ] if do_int4: - base_functions.append(int4_weight_only(group_size=32)) + if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6: + base_functions.append( + int4_weight_only(group_size=32, layout=Int4CPULayout()) + ) + else: + base_functions.append(int4_weight_only(group_size=32)) if do_sparse: - base_functions.append(int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())) + base_functions.append( + int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) + ) - if is_cuda_8_9: + if is_sm_at_least_89(): base_functions.append(float8_weight_only()) return base_functions @@ -44,11 +53,11 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool): class TestAffineQuantized(TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_tensor_core_layout_transpose(self): - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - t = l.weight + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + t = linear.weight shape = t.shape apply_int4_weight_only_quant = int4_weight_only(group_size=32) - ql = apply_int4_weight_only_quant(l) + ql = apply_int4_weight_only_quant(linear) aqt = ql.weight aqt_shape = aqt.shape self.assertEqual(aqt_shape, shape) @@ -64,8 +73,8 @@ def test_tensor_core_layout_transpose(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("apply_quant", get_quantization_functions(True, True)) def test_weights_only(self, apply_quant): - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - ql = apply_quant(l) + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + ql = apply_quant(linear) with tempfile.NamedTemporaryFile() as f: torch.save(ql.state_dict(), f) f.seek(0) @@ -78,33 +87,32 @@ def test_weights_only(self, apply_quant): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("apply_quant", get_quantization_functions(False, False)) def test_to_device(self, apply_quant): - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = apply_quant(l) + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + ql = apply_quant(linear) ql.to("cuda") - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = apply_quant(l) + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + ql = apply_quant(linear) ql.to(device="cuda") - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = apply_quant(l) + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + ql = apply_quant(linear) ql.cuda() @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_register_new_dispatch(self): - from torchao.dtypes.affine_quantized_tensor import ( - register_aqt_quantized_linear_dispatch, + from torchao.dtypes import AffineQuantizedTensor, to_affine_quantized_intx + from torchao.dtypes.affine_quantized_tensor_ops import ( deregister_aqt_quantized_linear_dispatch, + register_aqt_quantized_linear_dispatch, ) - from torchao.dtypes import to_affine_quantized_intx - from torchao.dtypes import AffineQuantizedTensor from torchao.quantization.quant_primitives import MappingType def dispatch_condition(input_tensor, weight_tensor, bias): return ( - isinstance(weight_tensor, AffineQuantizedTensor) and - weight_tensor.quant_min == 0 and - weight_tensor.quant_max == 2**6-1 + isinstance(weight_tensor, AffineQuantizedTensor) + and weight_tensor.quant_min == 0 + and weight_tensor.quant_max == 2**6 - 1 ) def impl(input_tensor, weight_tensor, bias): @@ -115,23 +123,35 @@ def impl(input_tensor, weight_tensor, bias): register_aqt_quantized_linear_dispatch(dispatch_condition, impl) def apply_uint6_weight_only_quant(linear): - linear.weight = torch.nn.Parameter(to_affine_quantized_intx(linear.weight, MappingType.ASYMMETRIC, (1, linear.weight.shape[-1]), torch.uint8, 0, 2**6-1), requires_grad=False) + linear.weight = torch.nn.Parameter( + to_affine_quantized_intx( + linear.weight, + MappingType.ASYMMETRIC, + (1, linear.weight.shape[-1]), + torch.uint8, + 0, + 2**6 - 1, + ), + requires_grad=False, + ) return linear - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - apply_uint6_weight_only_quant(l) + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + apply_uint6_weight_only_quant(linear) example_input = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda") - with self.assertRaisesRegex(AssertionError, "dispatching to my impl for uint6 weight only quant"): - l(example_input) + with self.assertRaisesRegex( + AssertionError, "dispatching to my impl for uint6 weight only quant" + ): + linear(example_input) deregister_aqt_quantized_linear_dispatch(dispatch_condition) @common_utils.parametrize("apply_quant", get_quantization_functions(True, True)) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_print_quantized_module(self, apply_quant): - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - ql = apply_quant(l) + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + ql = apply_quant(linear) assert "AffineQuantizedTensor" in str(ql) @@ -139,23 +159,29 @@ class TestAffineQuantizedBasic(TestCase): COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) COMMON_DTYPES = [torch.bfloat16] - @common_utils.parametrize("apply_quant", get_quantization_functions(False, True)) @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) - def test_flatten_unflatten(self, apply_quant, device, dtype): - l = torch.nn.Linear(128, 256, dtype=dtype, device=device) - ql = apply_quant(l) - lp_tensor = ql.weight - tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() - tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict} - outer_size = lp_tensor.size() - outer_stride = lp_tensor.stride() - reconstructed = type(lp_tensor).__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride) - example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),) - ref = ql(*example_inputs) - ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False) - reconstruct_res = ql(*example_inputs) - self.assertEqual(reconstruct_res, ref) + def test_flatten_unflatten(self, device, dtype): + apply_quant_list = get_quantization_functions(False, True, device) + for apply_quant in apply_quant_list: + linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + ql = apply_quant(linear) + lp_tensor = ql.weight + tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() + tensor_data_dict = { + name: getattr(lp_tensor, name) for name in tensor_data_name_dict + } + outer_size = lp_tensor.size() + outer_stride = lp_tensor.stride() + reconstructed = type(lp_tensor).__tensor_unflatten__( + tensor_data_dict, tensor_attributes, outer_size, outer_stride + ) + example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),) + ref = ql(*example_inputs) + ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False) + reconstruct_res = ql(*example_inputs) + self.assertEqual(reconstruct_res, ref) + common_utils.instantiate_parametrized_tests(TestAffineQuantized) common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic) diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 1e2ff29796..4d8312b427 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -37,13 +37,14 @@ MappingType, choose_qparams_affine, ) +from torchao.utils import ( + is_sm_at_least_89, + is_sm_at_least_90, +) random.seed(0) torch.manual_seed(0) -is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) - class ToyLinearModel(torch.nn.Module): def __init__(self, in_features, out_features): @@ -59,12 +60,14 @@ def forward(self, x): class TestAffineQuantizedFloat8Compile(InductorTestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) @common_utils.parametrize("mode", ["dynamic", "weight-only", "static"]) @common_utils.parametrize("compile", [True, False]) @common_utils.parametrize( - "granularity", [PerTensor(), PerRow()] if is_H100 else [PerTensor()] + "granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()] ) # Inputs are (M,..), K, N @common_utils.parametrize( @@ -134,10 +137,16 @@ def test_fp8_linear_variants( compute_error(output_original, output_quantized) > 20 ), f"Quantization error is too high got a SQNR of {error}" + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_invalid_granularity(self): with pytest.raises(ValueError, match="Invalid granularity specification"): float8_dynamic_activation_float8_weight(granularity="invalid") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_mismatched_granularity(self): with pytest.raises( ValueError, @@ -145,6 +154,9 @@ def test_mismatched_granularity(self): ): float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow())) + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_unsupported_granularity(self): class UnsupportedGranularity: pass @@ -155,7 +167,9 @@ class UnsupportedGranularity: ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_per_row_with_float32(self): with pytest.raises( AssertionError, @@ -167,7 +181,9 @@ def test_per_row_with_float32(self): ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) @common_utils.parametrize("mode", ["dynamic", "weight-only", "static"]) def test_serialization(self, mode: str): # Create and quantize the model @@ -237,7 +253,9 @@ def test_serialization(self, mode: str): ), f"Scales do not match for {layer_name}" @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_fp8_weight_dimension_warning(self): # Create model with incompatible dimensions (not multiples of 16) model = ToyLinearModel(10, 25).cuda() # 10x25 and 25x10 weights diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index af07328407..82d3d2501d 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -1,28 +1,28 @@ -import torch import unittest -from torch.testing._internal.common_utils import run_tests + +import torch +from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) + from torchao.quantization import ( + float8_dynamic_activation_float8_weight, + float8_weight_only, int4_weight_only, int8_weight_only, - float8_weight_only, - float8_dynamic_activation_float8_weight, ) from torchao.quantization.observer import PerRow, PerTensor -import torch.distributed as dist -from torch.distributed._tensor import DTensor, Replicate, Shard, DeviceMesh -from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, - with_comms, - NUM_DEVICES, -) from torchao.quantization.quant_api import quantize_ -from torchao.dtypes import AffineQuantizedTensor from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 + class TestAffineQuantizedTensorParallel(DTensorTestBase): - """Basic test case for tensor subclasses - """ + """Basic test case for tensor subclasses""" + QUANT_METHOD_FN = staticmethod(int8_weight_only) QUANT_METHOD_KWARGS = {} @@ -40,9 +40,7 @@ def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: # Construct DTensor from local shard dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)]) # Replace parameter in module - m.linear.weight = torch.nn.Parameter( - dtensor, requires_grad=False - ) + m.linear.weight = torch.nn.Parameter(dtensor, requires_grad=False) return m @staticmethod @@ -59,9 +57,7 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: # Construct DTensor from local shard dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)], run_check=True) # Replace parameter in module - m.linear.weight = torch.nn.Parameter( - dtensor, requires_grad=False - ) + m.linear.weight = torch.nn.Parameter(dtensor, requires_grad=False) return m def quantize(self, m: torch.nn.Module) -> torch.nn.Module: @@ -79,7 +75,9 @@ def _test_tp(self, dtype): class M(torch.nn.Module): def __init__(self, in_features, out_features, **kwargs) -> None: super().__init__(**kwargs) - self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda") + self.linear = torch.nn.Linear( + in_features, out_features, bias=False, device="cuda" + ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) @@ -91,11 +89,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: proj_up = M(1024, 2048).to(device).to(dtype) proj_dn = M(2048, 1024).to(device).to(dtype) example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype) - y = proj_dn(proj_up(example_input)) + proj_dn(proj_up(example_input)) # Quantize the model up_quant = self.quantize(proj_up) dn_quant = self.quantize(proj_dn) - y_q = dn_quant(up_quant(example_input)) + dn_quant(up_quant(example_input)) mesh = self.build_device_mesh() mesh.device_type = "cuda" @@ -105,11 +103,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dn_dist = self.rowwise_shard(dn_quant, mesh) # We need to turn inputs into DTensor form as well -- just a format change - input_dtensor = DTensor.from_local( - example_input, mesh, [Replicate()] - ) + input_dtensor = DTensor.from_local(example_input, mesh, [Replicate()]) - y_d = dn_dist(up_dist(input_dtensor)) + dn_dist(up_dist(input_dtensor)) if not TORCH_VERSION_AT_LEAST_2_6: # Need torch 2.6 to support compiled tensor parallelism @@ -118,7 +114,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: up_compiled = torch.compile(up_dist) y_up = up_compiled(input_dtensor) dn_compiled = torch.compile(dn_dist) - y_dn = dn_compiled(y_up) + dn_compiled(y_up) class TestInt8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): @@ -142,11 +138,13 @@ class TestInt4woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel) def test_tp(self, dtype): return self._test_tp(dtype) + common_utils.instantiate_parametrized_tests(TestInt8woAffineQuantizedTensorParallel) common_utils.instantiate_parametrized_tests(TestInt4woAffineQuantizedTensorParallel) # Run only on H100 if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0): + class TestFloat8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): QUANT_METHOD_FN = staticmethod(float8_weight_only) COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32] @@ -157,7 +155,9 @@ class TestFloat8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParalle def test_tp(self, dtype): return self._test_tp(dtype) - class TestFloat8dqTensorAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): + class TestFloat8dqTensorAffineQuantizedTensorParallel( + TestAffineQuantizedTensorParallel + ): QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight) QUANT_METHOD_KWARGS = {"granularity": PerTensor()} COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32] @@ -168,7 +168,9 @@ class TestFloat8dqTensorAffineQuantizedTensorParallel(TestAffineQuantizedTensorP def test_tp(self, dtype): return self._test_tp(dtype) - class TestFloat8dqRowAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): + class TestFloat8dqRowAffineQuantizedTensorParallel( + TestAffineQuantizedTensorParallel + ): QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight) QUANT_METHOD_KWARGS = {"granularity": PerRow()} COMMON_DTYPES = [torch.bfloat16] @@ -179,7 +181,11 @@ class TestFloat8dqRowAffineQuantizedTensorParallel(TestAffineQuantizedTensorPara def test_tp(self, dtype): return self._test_tp(dtype) - common_utils.instantiate_parametrized_tests(TestFloat8dqTensorAffineQuantizedTensorParallel) - common_utils.instantiate_parametrized_tests(TestFloat8dqRowAffineQuantizedTensorParallel) + common_utils.instantiate_parametrized_tests( + TestFloat8dqTensorAffineQuantizedTensorParallel + ) + common_utils.instantiate_parametrized_tests( + TestFloat8dqRowAffineQuantizedTensorParallel + ) if __name__ == "__main__": run_tests() diff --git a/test/dtypes/test_bitnet.py b/test/dtypes/test_bitnet.py index 70153cf5ba..e248b04b05 100644 --- a/test/dtypes/test_bitnet.py +++ b/test/dtypes/test_bitnet.py @@ -1,6 +1,7 @@ import pytest import torch import torch.nn as nn + from torchao.prototype.dtypes import BitnetTensor from torchao.prototype.dtypes.uint2 import unpack_uint2 from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter @@ -9,6 +10,7 @@ if not TORCH_VERSION_AT_LEAST_2_4: pytest.skip("Unsupported PyTorch version", allow_module_level=True) + @pytest.fixture(autouse=True) def run_before_and_after_tests(): # source: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test # noqa: E501 @@ -22,34 +24,45 @@ def run_before_and_after_tests(): # avoid dynamo cache limit issues torch._dynamo.reset() + @pytest.fixture def bitnet_tensor(): - input_tensor = torch.randint(0, 15, (4,4), dtype=torch.uint8) + input_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8) return BitnetTensor.from_unpacked(input_tensor) + def test_copy(bitnet_tensor): copied_tensor = bitnet_tensor.clone() assert torch.equal(bitnet_tensor.elem, copied_tensor.elem) + def test_transpose(bitnet_tensor): transposed_tensor = bitnet_tensor.t() expected_tensor = unpack_uint2(bitnet_tensor.elem).t() assert torch.equal(unpack_uint2(transposed_tensor.elem), expected_tensor) + def test_multiply(bitnet_tensor): w_t = torch.randint(0, 15, (4, 16), dtype=torch.uint8) w = BitnetTensor.from_unpacked(w_t) - y = torch.addmm(torch.Tensor([1]), bitnet_tensor, w) + torch.addmm(torch.Tensor([1]), bitnet_tensor, w) + -@pytest.mark.parametrize("dtype", [torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64]) +@pytest.mark.parametrize( + "dtype", + [torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64], +) def test_conversion(bitnet_tensor, dtype): converted_tensor = bitnet_tensor.to(dtype) expected_tensor = unpack_uint2(bitnet_tensor.elem).to(dtype) assert torch.allclose(converted_tensor, expected_tensor, atol=1e-5) + def _apply_weight_only_uint2_quant(model): def fn(mod): - mod.weight = torch.nn.Parameter(BitnetTensor.from_float(mod.weight), requires_grad=False) + mod.weight = torch.nn.Parameter( + BitnetTensor.from_float(mod.weight), requires_grad=False + ) return mod _replace_with_custom_fn_if_matches_filter( @@ -58,19 +71,21 @@ def fn(mod): lambda mod, fqn: isinstance(mod, torch.nn.Linear), ) -@pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="Regression introdued in nightlies") + +@pytest.mark.skipif( + TORCH_VERSION_AT_LEAST_2_5, reason="Regression introdued in nightlies" +) @pytest.mark.parametrize("input_shape", [[2, 4], [5, 5, 5, 4], [1, 4, 4]]) def test_uint2_quant(input_shape): - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" x = torch.randn(*input_shape).to(device) m = nn.Sequential(nn.Linear(4, 16)).to(device) y_ref = m(x) _apply_weight_only_uint2_quant(m) y_wo = m(x) assert y_ref.shape == y_wo.shape - y_compiled = torch.compile(m, fullgraph=True)(x) + torch.compile(m, fullgraph=True)(x) if __name__ == "__main__": pytest.main(__file__) - diff --git a/test/dtypes/test_bitpacking.py b/test/dtypes/test_bitpacking.py index 647ead8fd8..262a4d6ca6 100644 --- a/test/dtypes/test_bitpacking.py +++ b/test/dtypes/test_bitpacking.py @@ -1,33 +1,38 @@ -import torch -from torchao.dtypes.uintx.bitpacking import pack, unpack, pack_cpu, unpack_cpu import pytest +import torch from torch.utils._triton import has_triton -bit_widths = (1,2,3,4,5,6,7) +from torchao.dtypes.uintx.bitpacking import pack, pack_cpu, unpack, unpack_cpu + +bit_widths = (1, 2, 3, 4, 5, 6, 7) dimensions = (0, -1, 1) + @pytest.fixture(autouse=True) def run_before_and_after_tests(): yield - torch._dynamo.reset() # reset cache between tests + torch._dynamo.reset() # reset cache between tests + @pytest.mark.parametrize("bit_width", bit_widths) @pytest.mark.parametrize("dim", dimensions) def test_CPU(bit_width, dim): - test_tensor = torch.randint(0, 2**bit_width, (32,32,32), dtype=torch.uint8, device='cpu') - packed = pack_cpu(test_tensor, bit_width, dim = dim) - unpacked = unpack_cpu(packed, bit_width, dim = dim) - assert(unpacked.allclose(test_tensor)) + test_tensor = torch.randint( + 0, 2**bit_width, (32, 32, 32), dtype=torch.uint8, device="cpu" + ) + packed = pack_cpu(test_tensor, bit_width, dim=dim) + unpacked = unpack_cpu(packed, bit_width, dim=dim) + assert unpacked.allclose(test_tensor) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("bit_width", bit_widths) @pytest.mark.parametrize("dim", dimensions) def test_GPU(bit_width, dim): - test_tensor = torch.randint(0, 2**bit_width, (32,32,32), dtype=torch.uint8).cuda() - packed = pack(test_tensor, bit_width, dim = dim) - unpacked = unpack(packed, bit_width, dim = dim) - assert(unpacked.allclose(test_tensor)) + test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).cuda() + packed = pack(test_tensor, bit_width, dim=dim) + unpacked = unpack(packed, bit_width, dim=dim) + assert unpacked.allclose(test_tensor) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -36,27 +41,33 @@ def test_GPU(bit_width, dim): @pytest.mark.parametrize("dim", dimensions) def test_compile(bit_width, dim): torch._dynamo.config.specialize_int = True - pack_compile = torch.compile(pack, fullgraph=True) - unpack_compile = torch.compile(unpack, fullgraph=True) - test_tensor = torch.randint(0, 2**bit_width, (32,32,32), dtype=torch.uint8).cuda() - packed = pack(test_tensor, bit_width, dim = dim) - unpacked = unpack(packed, bit_width, dim = dim) - assert(unpacked.allclose(test_tensor)) + torch.compile(pack, fullgraph=True) + torch.compile(unpack, fullgraph=True) + test_tensor = torch.randint(0, 2**bit_width, (32, 32, 32), dtype=torch.uint8).cuda() + packed = pack(test_tensor, bit_width, dim=dim) + unpacked = unpack(packed, bit_width, dim=dim) + assert unpacked.allclose(test_tensor) + # these test cases are for the example pack walk through in the bitpacking.py file @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_pack_example(): - test_tensor = torch.tensor([0x30,0x29,0x17,0x5,0x20,0x16,0x9,0x22], dtype=torch.uint8).cuda() - shard_4,shard_2 = pack(test_tensor, 6) + test_tensor = torch.tensor( + [0x30, 0x29, 0x17, 0x5, 0x20, 0x16, 0x9, 0x22], dtype=torch.uint8 + ).cuda() + shard_4, shard_2 = pack(test_tensor, 6) print(shard_4, shard_2) assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).cuda().allclose(shard_4) assert torch.tensor([39, 146], dtype=torch.uint8).cuda().allclose(shard_2) unpacked = unpack([shard_4, shard_2], 6) assert unpacked.allclose(test_tensor) + def test_pack_example_CPU(): - test_tensor = torch.tensor([0x30,0x29,0x17,0x5,0x20,0x16,0x9,0x22], dtype=torch.uint8) - shard_4,shard_2 = pack(test_tensor, 6) + test_tensor = torch.tensor( + [0x30, 0x29, 0x17, 0x5, 0x20, 0x16, 0x9, 0x22], dtype=torch.uint8 + ) + shard_4, shard_2 = pack(test_tensor, 6) print(shard_4, shard_2) assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).allclose(shard_4) assert torch.tensor([39, 146], dtype=torch.uint8).allclose(shard_2) diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index 3e65ea6ab8..8bb39b2cc8 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -1,7 +1,6 @@ import copy - -import pytest import unittest + import torch from torch.testing._internal.common_utils import ( TestCase, @@ -9,22 +8,27 @@ parametrize, run_tests, ) + from torchao.dtypes.floatx import ( - FloatxTensorCoreAQTTensorImpl, FloatxTensorCoreLayout, - to_scaled_tc_floatx, from_scaled_tc_floatx, + to_scaled_tc_floatx, +) +from torchao.dtypes.floatx.floatx_tensor_core_layout import ( + FloatxTensorCoreAQTTensorImpl, + _pack_tc_floatx, + _pack_tc_fp6, +) +from torchao.prototype.custom_fp_utils import ( + _f32_to_floatx_unpacked, + _floatx_unpacked_to_f32, ) -from torchao.dtypes.floatx.floatx import _pack_tc_floatx, _pack_tc_fp6 -from torchao.prototype.custom_fp_utils import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32 from torchao.quantization import ( - quantize_, fpx_weight_only, + quantize_, ) - from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode - _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) _Floatx_DTYPES = [(3, 2), (2, 2)] @@ -53,7 +57,9 @@ def test_from_tc_floatx_correctness(self, ebits, mbits, device): x = torch.randn(256, 64, device=device) * 100 # quantize and dequantize so that the values are exactly representable in Floatx - x = _floatx_unpacked_to_f32(_f32_to_floatx_unpacked(x, ebits, mbits), ebits, mbits) + x = _floatx_unpacked_to_f32( + _f32_to_floatx_unpacked(x, ebits, mbits), ebits, mbits + ) tc_floatx, scale = to_scaled_tc_floatx(x, ebits, mbits) actual = from_scaled_tc_floatx(tc_floatx, ebits, mbits, scale=scale) @@ -64,11 +70,15 @@ def test_from_tc_floatx_correctness(self, ebits, mbits, device): def test_from_scaled_tc_floatx_compile(self, ebits, mbits, device): M, N = 256, 64 nbits = 1 + ebits + mbits - x = torch.randint(256, size=(M, N // 8 * nbits), dtype=torch.uint8, device=device) + x = torch.randint( + 256, size=(M, N // 8 * nbits), dtype=torch.uint8, device=device + ) scale = torch.randn(M, device=device) expected = from_scaled_tc_floatx(x, ebits, mbits, scale) - actual = torch.compile(from_scaled_tc_floatx, fullgraph=True)(x, ebits, mbits, scale) + actual = torch.compile(from_scaled_tc_floatx, fullgraph=True)( + x, ebits, mbits, scale + ) torch.testing.assert_close(actual, expected) @unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available") @@ -83,13 +93,18 @@ def test_to_copy_device(self, ebits, mbits): scale = choose_qparams_affine_floatx(x, ebits, mbits) x = quantize_affine_floatx(x, scale, ebits, mbits) _layout = FloatxTensorCoreLayout(ebits, mbits) - floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain(x, scale, None, _layout).cuda() + floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain( + x, scale, None, _layout + ).cuda() assert floatx_tensor_impl.device.type == "cuda" floatx_tensor_impl = floatx_tensor_impl.cpu() assert floatx_tensor_impl.device.type == "cpu" @unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, reason="quantization only works with torch.compile for 2.5+") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, + reason="quantization only works with torch.compile for 2.5+", + ) @parametrize("ebits,mbits", _Floatx_DTYPES) @parametrize("bias", [False, True]) @parametrize("dtype", [torch.half, torch.bfloat16]) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index 4da6b95391..caa1a6c7bd 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -504,17 +504,22 @@ def test_to_cuda(self): self.assertEqual(nf4_tensor.device.type, "cpu") nf4_tensor = nf4_tensor.to("cuda", non_blocking=True) self.assertEqual(nf4_tensor.device.type, "cuda") + self.assertEqual(type(nf4_tensor), NF4Tensor) + nf4_tensor.get_original_weight() # make sure we can dequantize nf4_tensor = to_nf4(torch.randn(512 * 512)) self.assertEqual(nf4_tensor.device.type, "cpu") nf4_tensor = nf4_tensor.to("cuda") self.assertEqual(nf4_tensor.device.type, "cuda") + self.assertEqual(type(nf4_tensor), NF4Tensor) + nf4_tensor.get_original_weight() nf4_tensor = to_nf4(torch.randn(512 * 512)) self.assertEqual(nf4_tensor.device.type, "cpu") nf4_tensor = nf4_tensor.to("cuda", torch.bfloat16) self.assertEqual(nf4_tensor.device.type, "cuda") self.assertEqual(nf4_tensor.dtype, torch.bfloat16) + self.assertEqual(type(nf4_tensor), torch.Tensor) # dequantized @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_to_cpu(self): @@ -524,6 +529,37 @@ def test_to_cpu(self): for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: inner_tensor = getattr(nf4_tensor, attr) self.assertEqual(inner_tensor.device.type, "cpu") + nf4_tensor.get_original_weight() # make sure we can dequantize + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_to_module(self): + linear = nn.Linear(512, 512, bias=False) + linear.weight = nn.Parameter( + to_nf4(linear.weight.detach()), requires_grad=False + ) + linear.cuda() + self.assertEqual(linear.weight.device.type, "cuda") + weight = linear.weight.get_original_weight() + self.assertEqual(weight.device.type, "cuda") + + linear.cpu() + self.assertEqual(linear.weight.device.type, "cpu") + weight = linear.weight.get_original_weight() + self.assertEqual(weight.device.type, "cpu") + + linear = nn.Linear(512, 512, bias=False) + linear.weight = nn.Parameter( + to_nf4(linear.weight.detach()), requires_grad=False + ) + linear.to("cuda") + self.assertEqual(linear.weight.device.type, "cuda") + weight = linear.weight.get_original_weight() + self.assertEqual(weight.device.type, "cuda") + + linear.to("cpu") + self.assertEqual(linear.weight.device.type, "cpu") + weight = linear.weight.get_original_weight() + self.assertEqual(weight.device.type, "cpu") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) diff --git a/test/dtypes/test_uint2.py b/test/dtypes/test_uint2.py index b017c47dd4..f6faaea10d 100644 --- a/test/dtypes/test_uint2.py +++ b/test/dtypes/test_uint2.py @@ -1,6 +1,6 @@ import pytest import torch -import torch.nn as nn + from torchao.prototype.dtypes import UInt2Tensor from torchao.prototype.dtypes.uint2 import unpack_uint2 from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 @@ -8,26 +8,33 @@ if not TORCH_VERSION_AT_LEAST_2_4: pytest.skip("Unsupported PyTorch version", allow_module_level=True) + @pytest.fixture def uint2_tensor(): - input_tensor = torch.randint(0, 15, (4,4), dtype = torch.uint8) + input_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8) return UInt2Tensor(input_tensor) + def test_copy(uint2_tensor): copied_tensor = uint2_tensor.clone() assert torch.equal(uint2_tensor.elem, copied_tensor.elem) + def test_transpose(uint2_tensor): transposed_tensor = uint2_tensor.t() expected_tensor = unpack_uint2(uint2_tensor.elem).t() assert torch.equal(unpack_uint2(transposed_tensor.elem), expected_tensor) -@pytest.mark.parametrize("dtype", [torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64]) + +@pytest.mark.parametrize( + "dtype", + [torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64], +) def test_conversion(uint2_tensor, dtype): converted_tensor = uint2_tensor.to(dtype) expected_tensor = unpack_uint2(uint2_tensor.elem).to(dtype) assert torch.allclose(converted_tensor, expected_tensor, atol=1e-5) -if __name__ == '__main__': + +if __name__ == "__main__": pytest.main(__file__) - diff --git a/test/dtypes/test_uint4.py b/test/dtypes/test_uint4.py index 98fb523d33..e148d68abb 100644 --- a/test/dtypes/test_uint4.py +++ b/test/dtypes/test_uint4.py @@ -1,35 +1,42 @@ -import torch -from torchao.dtypes.uint4 import ( - UInt4Tensor, - PerChannelSymmetricWeightUInt4Tensor, -) +import copy import unittest -from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e -from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer +import torch +from torch import nn +from torch.ao.quantization.observer import ObserverBase +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.quantizer import ( + QuantizationAnnotation, + QuantizationSpec, + Quantizer, +) +from torch.fx import ( + GraphModule, + Node, +) from torch.testing._internal.common_quantization import ( NodeSpec as ns, +) +from torch.testing._internal.common_quantization import ( QuantizationTestCase, ) + +from torchao.dtypes.uintx.uint4_layout import ( + PerChannelSymmetricWeightUInt4Tensor, + UInt4Tensor, +) from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, ) -from torch.ao.quantization.observer import ObserverBase -from torch import nn -from torch.fx import ( - Node, - GraphModule, -) -from torch.ao.quantization.quantizer import ( - QuantizationAnnotation, -) -import copy from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 def _apply_weight_only_uint4_quant(model): def fn(mod): - mod.weight = torch.nn.Parameter(PerChannelSymmetricWeightUInt4Tensor.from_float(mod.weight), requires_grad=False) + mod.weight = torch.nn.Parameter( + PerChannelSymmetricWeightUInt4Tensor.from_float(mod.weight), + requires_grad=False, + ) return mod _replace_with_custom_fn_if_matches_filter( @@ -38,28 +45,46 @@ def fn(mod): lambda mod, fqn: isinstance(mod, torch.nn.Linear), ) -@unittest.skip("FAILED test/dtypes/test_uint4.py::TestUInt4::test_basic_tensor_ops - AttributeError: module 'torch' has no attribute 'uint4'") + +@unittest.skip( + "FAILED test/dtypes/test_uint4.py::TestUInt4::test_basic_tensor_ops - AttributeError: module 'torch' has no attribute 'uint4'" +) class TestUInt4(QuantizationTestCase): def test_basic_tensor_ops(self): - x = UInt4Tensor(torch.tensor([ - [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], - [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], - [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], - ], dtype=torch.uint8)) + x = UInt4Tensor( + torch.tensor( + [ + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + ], + dtype=torch.uint8, + ) + ) self.assertEqual(x.shape, (3, 16)) # TODO: make sure this returns torch.uint4 self.assertIs(x.dtype, torch.uint4) # making sure these works x.to(torch.uint8) - expected = UInt4Tensor(torch.tensor([ - [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], - ], dtype=torch.uint8)) + expected = UInt4Tensor( + torch.tensor( + [ + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + ], + dtype=torch.uint8, + ) + ) self.assertEqual(x[0:1, :], expected) - expected = UInt4Tensor(torch.tensor([ - [0x23, 0x45], - [0x23, 0x45], - [0x23, 0x45], - ], dtype=torch.uint8)) + expected = UInt4Tensor( + torch.tensor( + [ + [0x23, 0x45], + [0x23, 0x45], + [0x23, 0x45], + ], + dtype=torch.uint8, + ) + ) self.assertEqual(x[:, 2:6], expected) torch.save(x, "uint4_tensor.pt") x = torch.load("uint4_tensor.pt") @@ -71,9 +96,9 @@ def test_gpu_quant(self): for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: x = torch.randn(*x_shape) m = nn.Sequential(nn.Linear(4, 16)) - y_ref = m(x) + m(x) # checking if it runs _apply_weight_only_uint4_quant(m) - y_wo = m(x) + m(x) # checking if it runs # sqnr = compute_error(y_ref, y_wo) opt = torch.compile(m, fullgraph=True, mode="max-autotune") # make sure it runs @@ -81,9 +106,9 @@ def test_gpu_quant(self): def test_pt2e_quant(self): from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( - OP_TO_ANNOTATOR, QuantizationConfig, ) + class Uint4Observer(ObserverBase): def __init__(self, *args, **kwargs): # just faking a dtype here @@ -99,9 +124,15 @@ def calculate_qparams(self, **kwargs): def convert(self, model: GraphModule, observer_node: Node): with model.graph.inserting_before(observer_node): q_node = model.graph.call_function( - torch.ops.qtensors.quantize_per_tensor_uint4, (observer_node.args[0], 1.0, 0), {}) + torch.ops.qtensors.quantize_per_tensor_uint4, + (observer_node.args[0], 1.0, 0), + {}, + ) dq_node = model.graph.call_function( - torch.ops.qtensors.dequantize_per_tensor_uint4, (q_node, 1.0, 0), {}) + torch.ops.qtensors.dequantize_per_tensor_uint4, + (q_node, 1.0, 0), + {}, + ) observer_node.replace_all_uses_with(dq_node) model.graph.erase_node(observer_node) @@ -160,10 +191,12 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: if _is_annotated(partition): continue - linear_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=quantization_config.output_activation, - _annotated=True, + linear_node.meta["quantization_annotation"] = ( + QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) ) _mark_nodes_as_annotated(partition) @@ -197,7 +230,6 @@ def forward(self, x): # _test_quantizer in PT2EQuantizationTestCase # resetting dynamo cache - export_with_dynamic_shape = False torch._dynamo.reset() m_eager = M().eval() @@ -210,23 +242,22 @@ def forward(self, x): ).module() else: m = torch._export.capture_pre_autograd_graph( - m, - example_inputs, - ).module() + m, + example_inputs, + ).module() m = prepare_pt2e(m, quantizer) # Calibrate m(*example_inputs) m = convert_pt2e(m, fold_quantize=False) - pt2_quant_output = m(*example_inputs) + m(*example_inputs) - node_occurrence = { - ns.call_function(k): v for k, v in node_occurrence.items() - } + node_occurrence = {ns.call_function(k): v for k, v in node_occurrence.items()} node_list = [ns.call_function(n) for n in node_list] self.checkGraphModuleNodes( m, expected_node_occurrence=node_occurrence, expected_node_list=node_list ) + if __name__ == "__main__": unittest.main() diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index bb754135db..da43253678 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -1,85 +1,109 @@ -from math import log -from copy import deepcopy import pytest - import torch -from torchao.dtypes.uintx import to_uintx +from torchao.dtypes.uintx.uintx_layout import to_uintx from torchao.quantization.quant_api import quantize_, uintx_weight_only -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_5, -) - from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, choose_qparams_affine, - quantize_affine, dequantize_affine, + quantize_affine, +) +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_5, ) # torch.uintx dtypes are introduced in 2.3 if TORCH_VERSION_AT_LEAST_2_3: - dtypes = (torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7) + dtypes = ( + torch.uint1, + torch.uint2, + torch.uint3, + torch.uint4, + torch.uint5, + torch.uint6, + torch.uint7, + ) else: dtypes = () group_sizes = [32, 64, 128] devices = ["cpu", "cuda"] + + @pytest.fixture(autouse=True) def run_before_and_after_tests(): yield - torch._dynamo.reset() # reset cache between tests + torch._dynamo.reset() # reset cache between tests + class Linear16(torch.nn.Module): def __init__(self, scale, device): super().__init__() self.net = torch.nn.Sequential( - torch.nn.Linear(scale * 2, scale, bias=False, dtype=torch.float16, device=device), - torch.nn.Linear(scale, scale, bias=False, dtype=torch.float16, device=device), - torch.nn.Linear(scale, scale//2, bias=False, dtype=torch.float16, device=device), + torch.nn.Linear( + scale * 2, scale, bias=False, dtype=torch.float16, device=device + ), + torch.nn.Linear( + scale, scale, bias=False, dtype=torch.float16, device=device + ), + torch.nn.Linear( + scale, scale // 2, bias=False, dtype=torch.float16, device=device + ), ) def forward(self, x): return self.net(x) + @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("group_size", group_sizes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build") +@pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build" +) def test_uintx_quant_on_cpu_then_move_to_cuda(dtype, group_size): scale = 512 fp16_mod_on_cpu = Linear16(scale, "cpu") quantize_(fp16_mod_on_cpu, uintx_weight_only(dtype, group_size=group_size)) - test_input_on_cpu = torch.randn(scale*2, dtype=torch.float16, device="cpu") + test_input_on_cpu = torch.randn(scale * 2, dtype=torch.float16, device="cpu") output_on_cpu = fp16_mod_on_cpu(test_input_on_cpu) fp16_mod_on_cuda = fp16_mod_on_cpu.to("cuda") test_input_on_cuda = test_input_on_cpu.to("cuda") output_on_cuda = fp16_mod_on_cuda(test_input_on_cuda) - assert torch.allclose(output_on_cpu, output_on_cuda.cpu(), atol=1.0e-3), "The output of the model on CPU and CUDA should be close" + assert torch.allclose( + output_on_cpu, output_on_cuda.cpu(), atol=1.0e-3 + ), "The output of the model on CPU and CUDA should be close" + @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("group_size", group_sizes) @pytest.mark.parametrize("device", devices) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build") +@pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build" +) def test_uintx_weight_only_model_quant(dtype, group_size, device): scale = 512 fp16 = Linear16(scale, device) quantize_(fp16, uintx_weight_only(dtype, group_size=group_size)) uintx = torch.compile(fp16, fullgraph=True) - test_input = torch.randn(scale*2, dtype=torch.float16, device=device) + test_input = torch.randn(scale * 2, dtype=torch.float16, device=device) output = uintx.forward(test_input) - assert output != None, "model quantization failed" + assert output is not None, "model quantization failed" + @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("group_size", group_sizes) @pytest.mark.parametrize("device", devices) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build") +@pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_5, reason="only works with fix in the nightly build" +) def test_uintx_weight_only_quant(dtype, group_size, device): - input_float = torch.randn((1, 256), dtype=torch.float16, device = device) + input_float = torch.randn((1, 256), dtype=torch.float16, device=device) mapping_type = MappingType.SYMMETRIC eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int32 @@ -87,73 +111,91 @@ def test_uintx_weight_only_quant(dtype, group_size, device): block_size = (1, group_size) scale, zero_point = choose_qparams_affine( - input_float, mapping_type, block_size, - dtype, eps=eps, scale_dtype=torch.float32, - zero_point_dtype=zero_point_dtype, preserve_zero=True, zero_point_domain=zero_point_domain + input_float, + mapping_type, + block_size, + dtype, + eps=eps, + scale_dtype=torch.float32, + zero_point_dtype=zero_point_dtype, + preserve_zero=True, + zero_point_domain=zero_point_domain, ) aqt = quantize_affine( - input_float, block_size, scale, - zero_point, dtype, - zero_point_domain=zero_point_domain + input_float, + block_size, + scale, + zero_point, + dtype, + zero_point_domain=zero_point_domain, ) # Note: output will be uint8 tensor for sub byte tensors for now - q = to_uintx(aqt, dtype, -1) - assert q != None, "quantization failed" + q = to_uintx(aqt, dtype, -1) + assert q is not None, "quantization failed" deqaunt = dequantize_affine( - q, block_size, scale, - zero_point, dtype, - zero_point_domain=zero_point_domain + q, block_size, scale, zero_point, dtype, zero_point_domain=zero_point_domain ) - assert deqaunt != None, "deqauntization failed" + assert deqaunt is not None, "deqauntization failed" @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") -@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="sub byte dtype requires torch 2.3+") +@pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_3, reason="sub byte dtype requires torch 2.3+" +) def test_uintx_target_dtype(dtype): from torchao.quantization.quant_api import uintx_weight_only - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") # make sure it runs - uintx_weight_only(dtype)(l) - l(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")) + uintx_weight_only(dtype)(linear) + linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")) + @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") -@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="torch.compile without unwrap_tensor_subclass requires torch 2.5+") +@pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_5, + reason="torch.compile without unwrap_tensor_subclass requires torch 2.5+", +) def test_uintx_target_dtype_compile(dtype): from torchao.quantization.quant_api import uintx_weight_only - l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") # make sure it runs - uintx_weight_only(dtype)(l) - l = torch.compile(l) - l(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")) + uintx_weight_only(dtype)(linear) + linear = torch.compile(linear) + linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") -@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="sub byte dtype requires torch 2.3+") +@pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_3, reason="sub byte dtype requires torch 2.3+" +) def test_uintx_model_size(dtype): from torchao.quantization.quant_api import uintx_weight_only from torchao.utils import get_model_size_in_bytes + # scale size = 1/64 * 2 bytes = 1/32 bytes # zero_point size = 1/64 * 4 bytes = 1/16 bytes # dtype data size = 1 * bit_width/8 = bit_width/8 bytes _dtype_to_ratio = { - torch.uint1: (1/8 + 1/16 + 1/32) / 2, - torch.uint2: (2/8 + 1/16 + 1/32) / 2, - torch.uint3: (3/8 + 1/16 + 1/32) / 2, - torch.uint4: (4/8 + 1/16 + 1/32) / 2, - torch.uint5: (5/8 + 1/16 + 1/32) / 2, - torch.uint6: (6/8 + 1/16 + 1/32) / 2, - torch.uint7: (7/8 + 1/16 + 1/32) / 2, + torch.uint1: (1 / 8 + 1 / 16 + 1 / 32) / 2, + torch.uint2: (2 / 8 + 1 / 16 + 1 / 32) / 2, + torch.uint3: (3 / 8 + 1 / 16 + 1 / 32) / 2, + torch.uint4: (4 / 8 + 1 / 16 + 1 / 32) / 2, + torch.uint5: (5 / 8 + 1 / 16 + 1 / 32) / 2, + torch.uint6: (6 / 8 + 1 / 16 + 1 / 32) / 2, + torch.uint7: (7 / 8 + 1 / 16 + 1 / 32) / 2, } - l = torch.nn.Sequential( + linear = torch.nn.Sequential( torch.nn.Linear(128, 256, bias=False, dtype=torch.bfloat16, device="cuda") ) - bf16_size = get_model_size_in_bytes(l) + bf16_size = get_model_size_in_bytes(linear) # make sure it runs - uintx_weight_only(dtype)(l[0]) - quantized_size = get_model_size_in_bytes(l) + uintx_weight_only(dtype)(linear[0]) + quantized_size = get_model_size_in_bytes(linear) assert bf16_size * _dtype_to_ratio[dtype] == quantized_size diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 0c55c9c38a..e5f64abf57 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -4,20 +4,21 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import copy -import io import itertools import random import re import unittest import warnings -from typing import List, Tuple import pytest - import torch import torch.nn as nn -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + is_sm_at_least_89, + is_sm_at_least_90, +) if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -29,9 +30,9 @@ e5m2_dtype, Float8LinearConfig, Float8LinearRecipeName, - recipe_name_to_linear_config, ScalingGranularity, ScalingType, + recipe_name_to_linear_config, ) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( @@ -47,14 +48,14 @@ from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, - hp_tensor_and_scale_to_float8, LinearMMConfig, ScaledMMConfig, + hp_tensor_and_scale_to_float8, ) from torchao.float8.float8_utils import ( + FP8_TYPES, compute_error, fp8_tensor_statistics, - FP8_TYPES, tensor_to_scale, ) from torchao.testing.float8.test_utils import get_test_float8_linear_config @@ -63,10 +64,6 @@ torch.manual_seed(0) -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) - - def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: assert torch.all(a._scale == b._scale).item(), "scales are not identical" assert torch.all(a._data == b._data).item(), "data is not identical" @@ -144,6 +141,25 @@ def test_copy_(self): fp8_b.copy_(fp8_a) torch.testing.assert_close(fp8_a._data, fp8_b._data) + def test_transpose(self): + a = torch.rand((16, 16), dtype=torch.bfloat16) + for axiswise_dim in (None, 0, -1): + scale_a = tensor_to_scale(a, e4m3_dtype) + fp8_a = hp_tensor_and_scale_to_float8( + a, scale_a, e4m3_dtype, axiswise_dim=axiswise_dim + ) + fp8_b = hp_tensor_and_scale_to_float8( + a, scale_a, e4m3_dtype, axiswise_dim=axiswise_dim + ) + + fp8_a_transposed = fp8_a.transpose(0, 1) + fp8_b_t = fp8_b.t() + + torch.testing.assert_close( + (fp8_a_transposed._data, fp8_a_transposed._scale), + (fp8_b_t._data, fp8_b_t._scale), + ) + @pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)]) @pytest.mark.parametrize("axiswise_dim", [0, -1]) def test_axiswise_dynamic_cast(self, shape, axiswise_dim): @@ -186,7 +202,7 @@ def test_axiswise_reshape(self): rtol=0, ) with pytest.raises(RuntimeError): - a_fp8_d0_r2 = a_fp8_d0.reshape(-1, 7) + a_fp8_d0.reshape(-1, 7) # if we scale across dim2, we can only reshape to [-1, 7] a_fp8_d2 = hp_tensor_to_float8_dynamic( @@ -210,7 +226,7 @@ def test_axiswise_reshape(self): rtol=0, ) with pytest.raises(RuntimeError): - a_fp8_d2_r2 = a_fp8_d2.reshape(3, -1) + a_fp8_d2.reshape(3, -1) @pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)]) @pytest.mark.parametrize( @@ -222,7 +238,7 @@ def test_axiswise_reshape(self): ], ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0") + @unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0") def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity): a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda") b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") @@ -261,17 +277,25 @@ def _test_linear_impl( x, m_ref, config: Float8LinearConfig, + use_ac: bool = False, ): m_fp8 = Float8Linear.from_float( copy.deepcopy(m_ref), config, ) for _ in range(2): + if use_ac: + y_fp8 = torch.utils.checkpoint.checkpoint(m_fp8, x, use_reentrant=False) + else: + y_fp8 = m_fp8(x) + y_fp8.sum().backward() if linear_requires_sync(config): sync_float8_amax_and_scale_history(m_fp8) - y_fp8 = m_fp8(x) - y_fp8.sum().backward() - y_ref = m_ref(x) + + if use_ac: + y_ref = torch.utils.checkpoint.checkpoint(m_ref, x, use_reentrant=False) + else: + y_ref = m_ref(x) y_ref.sum().backward() assert y_ref.shape == y_fp8.shape @@ -328,7 +352,9 @@ def _test_linear_impl( # verify initialization flags got updated assert m_fp8.is_amax_initialized, "Amax was not properly initialized" - @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) + @pytest.mark.parametrize( + "emulate", [True, False] if is_sm_at_least_89() else [True] + ) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize( "scaling_type_input", @@ -344,6 +370,7 @@ def _test_linear_impl( ) @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("linear_bias", [False, True]) + @pytest.mark.parametrize("use_ac", [False, True]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_linear_from_config_params( self, @@ -354,6 +381,7 @@ def test_linear_from_config_params( scaling_type_grad_output: ScalingType, linear_dtype: torch.dtype, linear_bias: bool, + use_ac: bool, ): x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) @@ -369,6 +397,7 @@ def test_linear_from_config_params( x, m_ref, config, + use_ac, ) # Note: there are now too many config combinations to test all of @@ -407,7 +436,9 @@ def test_linear_from_recipe( config, ) - @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) + @pytest.mark.parametrize( + "emulate", [True, False] if is_sm_at_least_89() else [True] + ) @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) @@ -417,33 +448,36 @@ def test_autocast_outputs( emulate: bool, linear_dtype: torch.dtype, ): - m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) + m_ref = nn.Sequential( + nn.Linear(32, 32, device="cuda", dtype=linear_dtype), + nn.Linear(32, 32, device="cuda", dtype=linear_dtype), + ) config = Float8LinearConfig( cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), emulate=emulate, ) - m = Float8Linear.from_float(copy.deepcopy(m_ref), config) + m = convert_to_float8_training(copy.deepcopy(m_ref), config=config) # autocast off x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) + y = m(x) if linear_requires_sync(config): sync_float8_amax_and_scale_history(m) - y = m(x) assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}" # autocast on with torch.autocast("cuda"): + y = m(x) if linear_requires_sync(config): sync_float8_amax_and_scale_history(m) - y = m(x) assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}" with torch.autocast("cuda", dtype=torch.bfloat16): + y = m(x) if linear_requires_sync(config): sync_float8_amax_and_scale_history(m) - y = m(x) assert ( y.dtype == torch.bfloat16 ), f"y.dtype is {y.dtype}, expected {torch.bfloat16}" @@ -451,7 +485,9 @@ def test_autocast_outputs( @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) - @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) + @pytest.mark.parametrize( + "emulate", [True, False] if is_sm_at_least_89() else [True] + ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) @@ -512,18 +548,33 @@ def test_repr(self): s = m.__repr__() assert "i:dyn_ten,w:del_ten,go:dyn_ten" in s - @unittest.skipIf(not is_cuda_8_9, "CUDA 8.9 not available") + @unittest.skipIf(not is_sm_at_least_89(), "CUDA 8.9 not available") def test_inference_mode(self): x = torch.randn(32, 32, device="cuda") m = nn.Sequential(nn.Linear(32, 32)).cuda() m = convert_to_float8_training(m) with torch.inference_mode(mode=True): - y = m(x) + m(x) + + @unittest.skipIf(not is_sm_at_least_89(), "CUDA arch 8.9 not available") + def test_quantize(self): + x = torch.randn(32, 32, device="cuda") + m = nn.Sequential(nn.Linear(32, 32)).cuda() + m = convert_to_float8_training(m) + assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear" + from torchao.quantization.quant_api import float8_weight_only, quantize_ + + quantize_(m, float8_weight_only()) + assert ( + m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn + ), "Post quantization dtype should be torch.float8_e4m3fn" + with torch.no_grad(): + m(x) class TestScaledMM: @unittest.skipIf( - not is_cuda_8_9, + not is_sm_at_least_89(), "CUDA not available", ) @pytest.mark.parametrize( @@ -565,10 +616,10 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): if base_dtype in {torch.bfloat16, torch.float16}: atol, rtol = 7e-2, 7e-2 else: - atol, rtol = 2e-3, 2e-3 + atol, rtol = 3e-3, 3e-3 torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - @unittest.skipIf(not is_cuda_8_9, "CUDA not available") + @unittest.skipIf(not is_sm_at_least_89(), "CUDA not available") def test_different_configs_error(self): x_fp32 = torch.randn(16, 16, device="cuda") x_scale = torch.tensor(1.0, device="cuda") @@ -604,7 +655,7 @@ def test_different_configs_error(self): a @ b @unittest.skipIf( - not is_cuda_8_9, + not is_sm_at_least_89(), "CUDA not available", ) @pytest.mark.parametrize( diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index ce9935ca79..57362d6990 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -5,26 +5,32 @@ # LICENSE file in the root directory of this source tree. import copy import random -from typing import List, Tuple import sys import unittest from io import StringIO import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + is_sm_at_least_89, + is_sm_at_least_90, +) if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) import torch import torch.nn as nn +from torch._dynamo.test_case import TestCase as DynamoTestCase +from torch._dynamo.testing import CompileCounterWithBackend + from torchao.float8.config import ( CastConfig, e4m3_dtype, Float8LinearConfig, - ScalingType, Float8LinearRecipeName, + ScalingType, recipe_name_to_linear_config, ) from torchao.float8.float8_linear import Float8Linear @@ -38,18 +44,12 @@ hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import ( - LinearMMConfig, GemmInputRole, + LinearMMConfig, ScaledMMConfig, ) from torchao.testing.float8.test_utils import get_test_float8_linear_config -from torch._dynamo.test_case import TestCase as DynamoTestCase -from torch._dynamo.testing import CompileCounterWithBackend - -# TODO(future PR): standardize IS_H100 with the rest of the codebase -is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) def _test_compile_base( backend: str, @@ -92,12 +92,14 @@ def _test_compile_base( "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) @pytest.mark.parametrize( - "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_weight", + [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) @pytest.mark.parametrize( - "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_grad_output", + [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) -@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) +@pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_eager_only( @@ -124,15 +126,17 @@ def test_eager_only( @pytest.mark.parametrize("fullgraph", [True]) -@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) +@pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True]) @pytest.mark.parametrize( "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) @pytest.mark.parametrize( - "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_weight", + [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) @pytest.mark.parametrize( - "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_grad_output", + [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @@ -165,12 +169,17 @@ def test_aot_eager( "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) @pytest.mark.parametrize( - "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_weight", + [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) @pytest.mark.parametrize( - "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] + "scaling_type_grad_output", + [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], +) +@unittest.skipIf( + not torch.cuda.is_available() or not is_sm_at_least_89(), + "CUDA with float8 support not available", ) -@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available") @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) def test_inductor_from_config_params( fullgraph, @@ -194,15 +203,21 @@ def test_inductor_from_config_params( dtype, ) + # Note: there are now too many config combinations to test all of # them, so this function factors out some of the recipes which are annoying # to combine with the main testing function. # TODO(future PR): make this cleaner. @pytest.mark.parametrize( "recipe_name", - [Float8LinearRecipeName.ALL_AXISWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP], + [ + Float8LinearRecipeName.ALL_AXISWISE, + Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, + ], +) +@unittest.skipIf( + not is_sm_at_least_90(), "CUDA with capability 9.0 or greater not available" ) -@unittest.skipIf(not is_H100, "CUDA with capability 9.0 or greater not available") def test_inductor_from_recipe(recipe_name): torch._dynamo.reset() config = recipe_name_to_linear_config(recipe_name) @@ -239,7 +254,10 @@ def forward(self, x): return x_fp8 # TODO(future): figure out why the test below fails on CUDA capability 8.9 - @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with capability 9.0 or greater not available") + @unittest.skipIf( + not torch.cuda.is_available() or not is_sm_at_least_90(), + "CUDA with capability 9.0 or greater not available", + ) def test_float8_with_graph_break_in_the_middle(self): """Test that having Float8Tensor object at the boundary of a subgraph""" cnts = CompileCounterWithBackend("inductor") @@ -252,7 +270,10 @@ def test_float8_with_graph_break_in_the_middle(self): self.assertEqual(cnts.frame_count, 2, "Compiled graph should have 2 frames!") torch.testing.assert_close(y_eager, y_compiled) - @unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available") + @unittest.skipIf( + not torch.cuda.is_available() or not is_sm_at_least_89(), + "CUDA with float8 support not available", + ) def test_float8_graph_input(self): """Test that having Float8Tensor object as a graph input""" @@ -273,7 +294,10 @@ def to_float(x): ) torch.testing.assert_close(y2_eager, y2_compiled) - @unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available") + @unittest.skipIf( + not torch.cuda.is_available() or not is_sm_at_least_89(), + "CUDA with float8 support not available", + ) def test_float8_graph_output(self): """Test that having Float8Tensor object as a graph output works""" cnts = CompileCounterWithBackend("inductor") @@ -300,7 +324,10 @@ def test_float8_graph_output(self): ) -@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available") +@unittest.skipIf( + not torch.cuda.is_available() or not is_sm_at_least_89(), + "CUDA with float8 support not available", +) def test_sync_amax_func(): torch._dynamo.reset() cnts = CompileCounterWithBackend("inductor") @@ -338,7 +365,10 @@ def __exit__(self, *args): sys.stderr = self.sys_stderr -@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available") +@unittest.skipIf( + not torch.cuda.is_available() or not is_sm_at_least_89(), + "CUDA with float8 support not available", +) def test_sync_amax_func_cuda_graph_success(): torch._dynamo.reset() with capture_stderr() as stderr: @@ -368,9 +398,9 @@ def test_sync_amax_func_cuda_graph_success(): @unittest.skipIf( - not is_cuda_8_9, - "CUDA not available", - ) + not is_sm_at_least_89(), + "CUDA not available", +) @pytest.mark.parametrize( "dtype", [ diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index 92143e62b3..e0de749d0b 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -13,27 +13,32 @@ import copy import os -import torch -import torch.nn as nn -import torch.nn.functional as F - import pytest +import torch from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) -from torchao.float8 import Float8LinearConfig -from torchao.float8.float8_linear_utils import convert_to_float8_training +from torch.distributed._tensor import DTensor, Replicate, Shard, distribute_tensor +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.tensor.parallel import parallelize_module +from torch.testing._internal.distributed._tensor.common_dtensor import ( + ModelArgs, + Transformer, +) +from tqdm import tqdm +from torchao.float8 import Float8LinearConfig from torchao.float8.config import CastConfig, e4m3_dtype, ScalingType +from torchao.float8.float8_linear_utils import convert_to_float8_training from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, - hp_tensor_and_scale_to_float8, LinearMMConfig, + hp_tensor_and_scale_to_float8, ) from torchao.float8.float8_tensor_parallel import ( Float8ColwiseParallel, @@ -41,15 +46,8 @@ PrepareFloat8ModuleInput, ) from torchao.float8.float8_utils import tensor_to_scale -from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard -from torch.distributed.device_mesh import DeviceMesh, init_device_mesh -from torch.distributed.tensor.parallel import parallelize_module from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor -from torch.testing._internal.distributed._tensor.common_dtensor import ( - ModelArgs, - Transformer, -) -from tqdm import tqdm +from torchao.testing.float8.dtensor_utils import ToyModel def setup_distributed(): @@ -60,28 +58,6 @@ def setup_distributed(): return device_mesh -class FeedForward(nn.Module): - """MLP based model""" - - def __init__(self): - super(FeedForward, self).__init__() - self.w1 = nn.Linear(16, 32, bias=False) - self.w2 = nn.Linear(16, 32, bias=False) - self.out_proj = nn.Linear(32, 16, bias=False) - - def forward(self, x): - return self.out_proj(F.silu(self.w1(x)) * self.w2(x)) - - -class ToyModel(nn.Module): - def __init__(self): - super(ToyModel, self).__init__() - self.ffn = FeedForward() - - def forward(self, x): - return self.ffn(x) - - def _test_scaled_mm(mesh: DeviceMesh, size=16): device = mesh.device_type fp8_dtype = e4m3_dtype @@ -325,9 +301,7 @@ def _test_distribute_fsdp_tensor_subclass(tp_mesh: DeviceMesh): ) assert ( isinstance(colwise_param, DTensor) - and isinstance( - colwise_param._local_tensor, WeightWithDynamicFloat8CastTensor - ) + and isinstance(colwise_param._local_tensor, WeightWithDynamicFloat8CastTensor) ), f"expect DTensor(local_tensor={WeightWithDynamicFloat8CastTensor}) but got {colwise_param}" # test Float8RowwiseParallel rowwise_param = distribute_tensor( @@ -335,9 +309,7 @@ def _test_distribute_fsdp_tensor_subclass(tp_mesh: DeviceMesh): ) assert ( isinstance(rowwise_param, DTensor) - and isinstance( - rowwise_param._local_tensor, WeightWithDynamicFloat8CastTensor - ) + and isinstance(rowwise_param._local_tensor, WeightWithDynamicFloat8CastTensor) ), f"expect DTensor(local_tensor={WeightWithDynamicFloat8CastTensor}) but got {colwise_param}" diff --git a/test/float8/test_dtensor.sh b/test/float8/test_dtensor.sh index 2e38feffec..585a9014b1 100755 --- a/test/float8/test_dtensor.sh +++ b/test/float8/test_dtensor.sh @@ -8,4 +8,8 @@ if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; exit fi +# integration tests for TP/SP NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/float8/test_dtensor.py + +# integration smoke tests for FSDP2 + TP +NCCL_DEBUG=WARN torchrun --nproc_per_node 4 test/float8/test_fsdp2_tp.py diff --git a/test/float8/test_fsdp.py b/test/float8/test_fsdp.py index 2ba33bba08..863256dc35 100644 --- a/test/float8/test_fsdp.py +++ b/test/float8/test_fsdp.py @@ -13,10 +13,10 @@ import copy import os -import pytest import warnings import fire +import pytest from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 @@ -27,6 +27,14 @@ import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn +from torch.distributed.fsdp import ( + FullStateDictConfig, + StateDictType, +) +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, +) + from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType from torchao.float8.float8_linear_utils import ( convert_to_float8_training, @@ -34,11 +42,6 @@ sync_float8_amax_and_scale_history, ) from torchao.float8.float8_utils import compute_error -from torch.distributed.fsdp import ( - FullStateDictConfig, - FullyShardedDataParallel as FSDP, - StateDictType, -) torch.manual_seed(0) diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index 27336252c3..fbe5c9b508 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -1,11 +1,12 @@ import copy import itertools -import pytest import threading import unittest from typing import Any, List, Optional -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +import pytest + +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89 if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -15,20 +16,14 @@ import torch._dynamo.testing import torch.distributed as dist import torch.nn as nn -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType -from torchao.float8.float8_linear_utils import convert_to_float8_training -from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic -from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor -from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp -from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy +from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard from torch.distributed._tensor import DTensor, init_device_mesh -from torchao.float8.float8_tensor import GemmInputRole from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( + MLP, FSDPTest, FSDPTestMultiThread, - MLP, patch_all_gather, ) from torch.testing._internal.common_utils import run_tests @@ -38,10 +33,17 @@ TransformerBlock, ) -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -if not is_cuda_8_9: +from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.float8_linear_utils import convert_to_float8_training +from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic +from torchao.float8.float8_tensor import GemmInputRole +from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor +from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp + +if not is_sm_at_least_89(): pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) + class TestFloat8Common: def broadcast_module(self, module: nn.Module) -> None: # Broadcast for multi-threaded process group tests since seed is per @@ -61,7 +63,9 @@ def init_multi_module(self) -> nn.Module: self.broadcast_module(module) return module - def init_transformer(self, weight_tying: bool, dtype: Optional[torch.dtype] = None) -> nn.Module: + def init_transformer( + self, weight_tying: bool, dtype: Optional[torch.dtype] = None + ) -> nn.Module: torch.manual_seed(42) args = ModelArgs( n_layers=3, @@ -303,16 +307,17 @@ def world_size(self) -> int: def test_amax_allreduce_device_mesh(self): dp_size = 2 pp_size = self.world_size // dp_size - global_mesh = init_device_mesh("cuda", (pp_size, dp_size), mesh_dim_names=("pp", "dp")) - dp_mesh = global_mesh["dp"] - pp_mesh = global_mesh["pp"] + global_mesh = init_device_mesh( + "cuda", (pp_size, dp_size), mesh_dim_names=("pp", "dp") + ) + dp_mesh = global_mesh["dp"] if self.rank in [0, 1]: # rank 0 and 1 are the 1st stage in the pipeline # rank 2 and 4 are doing nothing but waiting for the 1st stage torch.manual_seed(42 + self.rank) hp_tensor = torch.randn(768, 32, device="cuda") - float8_tensor = hp_tensor_to_float8_dynamic( + hp_tensor_to_float8_dynamic( hp_tensor, torch.float8_e4m3fn, Float8LinearConfig( @@ -320,9 +325,10 @@ def test_amax_allreduce_device_mesh(self): ), gemm_input_role=GemmInputRole.WEIGHT, reduce_amax=True, - device_mesh=dp_mesh + device_mesh=dp_mesh, ) + class TestFloat8MultiThread(FSDPTestMultiThread, TestFloat8Common): @property def world_size(self) -> int: @@ -459,7 +465,6 @@ def test_fp32_fp8_single_module_parity(self): [ScalingType.DYNAMIC, ScalingType.DELAYED, ScalingType.STATIC], ) for enable_fsdp_float8_all_gather, scaling_type_weight in choices: - if scaling_type_weight is ScalingType.STATIC: cast_config_weight = CastConfig( scaling_type=scaling_type_weight, diff --git a/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py b/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py index ed79c65b6c..d2e9a51c7f 100644 --- a/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py +++ b/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py @@ -1,8 +1,9 @@ import copy -import pytest from typing import Optional -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +import pytest + +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89 if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -11,15 +12,6 @@ import torch._dynamo.testing import torch.distributed as dist import torch.nn as nn -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType -from torchao.float8.float8_linear_utils import convert_to_float8_training -from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic -from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor -from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp -from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy -from torch.distributed._tensor import DTensor, init_device_mesh -from torchao.float8.float8_tensor import GemmInputRole -from torch.testing._internal.common_cuda import TEST_CUDA from torch.distributed._composable.fsdp import fully_shard from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest @@ -28,6 +20,7 @@ ModelArgs, Transformer, ) + from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType from torchao.float8.float8_linear_utils import ( convert_to_float8_training, @@ -37,8 +30,7 @@ from torchao.float8.float8_tensor import GemmInputRole from torchao.testing.float8.fsdp2_utils import check_parity_fp8_comm_only -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -if not is_cuda_8_9: +if not is_sm_at_least_89(): pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) @@ -47,7 +39,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: fp8_param = hp_tensor_to_float8_dynamic( self.weight, torch.float8_e4m3fn, - None, # mm_linear_config, + None, # mm_linear_config, reduce_amax=False, gemm_input_role=GemmInputRole.WEIGHT, ) @@ -56,7 +48,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.bias is not None: output = output + self.bias.to(output.dtype) return output - + @classmethod def from_float( cls, @@ -92,7 +84,9 @@ def broadcast_module(self, module: nn.Module) -> None: for param in module.parameters(): dist.broadcast(param, src=0) - def init_transformer(self, weight_tying: bool, dtype: Optional[torch.dtype] = None) -> nn.Module: + def init_transformer( + self, weight_tying: bool, dtype: Optional[torch.dtype] = None + ) -> nn.Module: torch.manual_seed(42) args = ModelArgs( n_layers=3, @@ -114,7 +108,6 @@ class TestFloat8MultiProcess(FSDPTest, TestFloat8Common): def world_size(self) -> int: return min(torch.cuda.device_count(), 2) - @skip_if_lt_x_gpu(2) def test_transformer_parity(self): self.run_subtests( @@ -166,7 +159,7 @@ def _test_transformer_parity( fully_shard(transformer_block) module.layers.register_module(layer_id, transformer_block) fully_shard(module) - + ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2) optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True) @@ -183,6 +176,5 @@ def _test_transformer_parity( ) - if __name__ == "__main__": run_tests() diff --git a/test/float8/test_fsdp2_tp.py b/test/float8/test_fsdp2_tp.py new file mode 100644 index 0000000000..fa3d30410b --- /dev/null +++ b/test/float8/test_fsdp2_tp.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +""" +Test numerics of manually defined float16 TP vs float8 TP of toy models + +Note: for now, this does not run in CI. +TODO(future): make this run in CI +""" + +import copy +import os + +import pytest +import torch + +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 + +if not TORCH_VERSION_AT_LEAST_2_5: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.tensor.parallel import parallelize_module +from tqdm import tqdm + +from torchao.float8 import Float8LinearConfig +from torchao.float8.float8_linear_utils import convert_to_float8_training +from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, +) +from torchao.testing.float8.dtensor_utils import ToyModel + + +def setup_distributed(): + world_size = int(os.environ.get("WORLD_SIZE", -1)) + + # https://pytorch.org/tutorials/recipes/distributed_device_mesh.html + device_mesh = init_device_mesh( + "cuda", + (world_size // 2, 2), + mesh_dim_names=("dp", "tp"), + ) + # seed must be the same in all processes + torch.manual_seed(1) + return device_mesh + + +def _test_fp8_mlp_tensor_parallelism_base( + mesh: DeviceMesh, size=16, compile: bool = False +): + device = mesh.device_type + + config = Float8LinearConfig( + emulate=True, + enable_fsdp_float8_all_gather=True, + ) + + toy_model = ToyModel().to(device) + + tp_model = copy.deepcopy(toy_model) + tp_model = convert_to_float8_training(tp_model, config=config) + + # apply TP + tp_model = parallelize_module( + tp_model, + mesh["tp"], + { + "ffn.w1": Float8ColwiseParallel(), + "ffn.w2": Float8ColwiseParallel(), + "ffn.out_proj": Float8RowwiseParallel(), + }, + ) + + if compile: + tp_model = torch.compile(tp_model) + + # apply FSDP + fsdp_config = {"mesh": mesh["dp"]} + tp_model = fully_shard(tp_model, **fsdp_config) + + x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) + x_fp32_tp_input = x_fp32.clone() + + tp_out = tp_model(x_fp32_tp_input) + tp_out.sum().backward() + torch.cuda.synchronize() + + # TODO(future PR): test numerics, and add more cases + + +def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): + _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False) + + +def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): + _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True) + + +if __name__ == "__main__": + # float8 only works on CUDA H100 so we only test cuda and we follow + # other test files to not use TestCase but instead just add the test + # cases in the main func. + device_mesh = setup_distributed() + + tests = [ + _test_fp8_mlp_tensor_parallelism_eager, + _test_fp8_mlp_tensor_parallelism_compile, + ] + + for test in tqdm(tests, desc="Running tests"): + try: + test(device_mesh) + except Exception as e: + print(f"Test {test.__name__} failed with error: {e}") + raise e + + torch.distributed.destroy_process_group() diff --git a/test/float8/test_fsdp_compile.py b/test/float8/test_fsdp_compile.py index b481c14e30..1d95801f67 100644 --- a/test/float8/test_fsdp_compile.py +++ b/test/float8/test_fsdp_compile.py @@ -12,7 +12,6 @@ import warnings import fire - import pytest from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 @@ -24,13 +23,14 @@ import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torchao.float8 import Float8LinearConfig from torchao.float8.config import CastConfig, ScalingType from torchao.float8.float8_linear_utils import ( convert_to_float8_training, sync_float8_amax_and_scale_history, ) -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP torch.manual_seed(0) diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index a91b784c85..311964d831 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -11,7 +11,11 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + is_sm_at_least_89, + is_sm_at_least_90, +) if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -19,12 +23,11 @@ import torch import torch.nn as nn import torch.nn.functional as F + from torchao.float8.config import ( - CastConfig, - Float8LinearConfig, - ScalingType, - ScalingGranularity, + Float8LinearConfig, Float8LinearRecipeName, + ScalingType, recipe_name_to_linear_config, ) from torchao.float8.float8_linear_utils import ( @@ -32,12 +35,9 @@ linear_requires_sync, sync_float8_amax_and_scale_history, ) -from torchao.float8.float8_utils import compute_error, IS_ROCM +from torchao.float8.float8_utils import IS_ROCM, compute_error from torchao.testing.float8.test_utils import get_test_float8_linear_config -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) - torch.manual_seed(0) @@ -87,7 +87,6 @@ def init_weights(self, init_std: float): class TestFloat8NumericsIntegrationTest: - def _test_impl(self, config: Float8LinearConfig) -> None: data_dtype = torch.bfloat16 # LLaMa 3 70B shapes @@ -167,18 +166,20 @@ def _test_impl(self, config: Float8LinearConfig) -> None: assert sqnr > grad_sqnr_threshold @pytest.mark.parametrize( - "scaling_type_input", + "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) @pytest.mark.parametrize( - "scaling_type_weight", + "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) @pytest.mark.parametrize( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) - @pytest.mark.skipif(not is_cuda_8_9, reason="requires SM89 compatible machine") + @pytest.mark.skipif( + not is_sm_at_least_89(), reason="requires SM89 compatible machine" + ) @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") def test_encoder_fw_bw_from_config_params( self, @@ -196,9 +197,14 @@ def test_encoder_fw_bw_from_config_params( @pytest.mark.parametrize( "recipe_name", - [Float8LinearRecipeName.ALL_AXISWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP], + [ + Float8LinearRecipeName.ALL_AXISWISE, + Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, + ], + ) + @pytest.mark.skipif( + not is_sm_at_least_90(), reason="requires SM90 compatible machine" ) - @pytest.mark.skipif(not is_cuda_9_0, reason="requires SM90 compatible machine") @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") def test_encoder_fw_bw_from_recipe( self, diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index 7eda0ab5de..2f231fbb31 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -1,12 +1,7 @@ import unittest import torch -from torchao.dtypes.affine_quantized_tensor import ( - to_affine_quantized_intx, +from torchao.quantization import ( ZeroPointDomain, - PlainAQTTensorImpl, - PlainLayout, - TensorCoreTiledAQTTensorImpl, - TensorCoreTiledLayout, MappingType, ) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 92d2dcd5c2..10f2d157f9 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -19,7 +19,7 @@ from torchao.quantization.dynamic_quant import ( DynamicallyPerAxisQuantizedLinear, ) -from torchao.dtypes import TensorCoreTiledLayout +from torchao.dtypes import TensorCoreTiledLayout, Int4CPULayout from torchao.quantization.quant_api import ( int4_weight_only, int8_weight_only, @@ -91,8 +91,10 @@ TORCH_VERSION_AT_LEAST_2_6, unwrap_tensor_subclass, is_fbcode, - benchmark_model + benchmark_model, + is_sm_at_least_90, ) +from torchao.dtypes.utils import is_device logger = logging.getLogger("INFO") @@ -104,7 +106,6 @@ COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy() -is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) def _int8wo_api(mod): if TORCH_VERSION_AT_LEAST_2_4: @@ -133,7 +134,10 @@ def _int8da_int8w_api(mod): change_linear_weights_to_int8_dqtensors(mod) def _int4wo_api(mod): - if TORCH_VERSION_AT_LEAST_2_4: + if is_device(next(mod.parameters()).device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + quantize_(mod, int4_weight_only(layout=Int4CPULayout()), set_inductor_config=False) + unwrap_tensor_subclass(mod) + elif TORCH_VERSION_AT_LEAST_2_4: quantize_(mod, int4_weight_only(), set_inductor_config=False) if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(mod) @@ -662,6 +666,8 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): + if device == "cpu": + self.skipTest(f"Temporarily skipping for {device}") if dtype != torch.bfloat16: self.skipTest("Currently only supports bfloat16.") for test_shape in ([(16, 1024, 16)] + ([(1, 1024, 8)] if device=='cuda' else [])): @@ -673,6 +679,8 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype): + if device == "cpu": + self.skipTest(f"Temporarily skipping for {device}") if dtype != torch.bfloat16: self.skipTest("Currently only supports bfloat16.") m_shapes = [16, 256] + ([1] if device=="cuda" else []) @@ -771,15 +779,27 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") - @unittest.skipIf(not is_H100, "Need H100 to run") + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") def test_aq_float8_weight_only_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype ) + def test_autoquantizable_flatten_unflatten(self): + from torchao.quantization import DEFAULT_AUTOQUANT_CLASS_LIST + weight = torch.randn(16, 32) + qtensor_class_list = DEFAULT_AUTOQUANT_CLASS_LIST + aqw = AutoQuantizableLinearWeight.from_float(weight, qtensor_class_list) + tensor_data_name_dict, tensor_attributes = aqw.__tensor_flatten__() + tensor_data_dict = {name: getattr(aqw, name) for name in tensor_data_name_dict} + outer_size = aqw.size() + outer_stride = aqw.stride() + reconstructed = type(aqw).__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride) + + @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") - @unittest.skipIf(not is_H100, "Need H100 to run") + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype): if dtype != torch.bfloat16: with self.assertRaisesRegex(AssertionError, "PerRow quantization only works for bfloat16 precision"): @@ -793,7 +813,7 @@ def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") - @unittest.skipIf(not is_H100, "Need H100 to run") + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype @@ -803,6 +823,8 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") def test_int4_weight_only_quant_subclass(self, device, dtype): + if device == "cpu": + self.skipTest(f"Temporarily skipping for {device}") if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") for test_shape in ([(16, 1024, 16)] + ([(1, 1024, 8)] if device=='cuda' else [])): @@ -896,6 +918,8 @@ def test_int8_weight_only_quant_with_freeze(self, device, dtype): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") def test_int4_weight_only_quant_subclass_api(self, device, dtype): + if device == "cpu": + self.skipTest(f"Temporarily skipping for {device}") if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") for test_shape in ([(16, 1024, 16)] + ([(1, 1024, 256)] if device=='cuda' else [])): @@ -911,12 +935,20 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): + if device == "cpu": + self.skipTest(f"Temporarily skipping for {device}") if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") + layout_list = [] + if device == 'cpu' and TORCH_VERSION_AT_LEAST_2_6: + layout_list.append(Int4CPULayout()) + else: + for inner_k_tiles in [4, 2]: + layout_list.append(TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)) for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])): for groupsize in [64, 32]: - for inner_k_tiles in [4, 2]: - kwargs = {"groupsize": groupsize, "layout": TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)} + for layout in layout_list: + kwargs = {"groupsize": groupsize, "layout": layout} def api(mod): kwargs_copy = kwargs.copy() @@ -1492,6 +1524,23 @@ def forward(self, x): assert not isinstance(model.lin1.weight.weight, AutoQuantizableLinearWeight) model(x_in) + @parameterized.expand(list(itertools.product(["cuda"], COMMON_DTYPES))) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_autoquant_min_sqnr(self, device, dtype): + m, k, n = 128, 128, 128 + example_input = torch.randn(m, k, device=device, dtype=dtype) + model = torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k,n), + torch.nn.ReLU(), + ).to(device).to(dtype) + out = model(example_input) + torchao.autoquant(model, min_sqnr=60) + out2 = model(example_input) + sqnr = SQNR(out, out2) + # without setting min_sqnr to 60, we get around 45-50 final sqnr + # setting min_sqnr for individual linear to be 60 allows us to achieve >= 50 final sqnr + self.assertTrue(sqnr >= 50, f"sqnr: {sqnr}") diff --git a/test/kernel/test_autotuner.py b/test/kernel/test_autotuner.py index 4ed0974172..3e8c9b0a04 100644 --- a/test/kernel/test_autotuner.py +++ b/test/kernel/test_autotuner.py @@ -13,10 +13,10 @@ import pytest import torch from parameterized import parameterized +from torchao.utils import is_sm_at_least_90 logging.basicConfig(level=logging.INFO) -is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) class TestQuantFlow(unittest.TestCase): @@ -56,7 +56,7 @@ def test_int_mm(self, device, dtype): ("cuda", torch.float16), ] ) - @unittest.skipIf(not is_H100, "Needs H100") + @unittest.skipIf(not is_sm_at_least_90(), "Needs H100") def test_int_mm_float8(self, device, dtype): from torchao.kernel import intmm diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index bc9b02deb5..4cac940313 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -20,11 +20,8 @@ ) from torchao.quantization.utils import compute_error -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89 -# trying to outsmart flake8 -__has_cuda = torch.cuda.is_available() -IS_CUDA_GE_89 = __has_cuda and torch.cuda.get_device_capability() >= (8, 9) torch.manual_seed(2) @@ -102,7 +99,7 @@ def test_linear_compile(elem_dtype, bias): Verify that compile does not change numerics of MX linear fw + bw """ if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not IS_CUDA_GE_89: + if not is_sm_at_least_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") input_shape = (2, 4) grad_shape = (2, 6) @@ -173,7 +170,7 @@ def test_inference_compile_simple(elem_dtype): Smoke test for inference compile """ if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not IS_CUDA_GE_89: + if not is_sm_at_least_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") m = nn.Sequential(nn.Linear(4, 6, bias=False, dtype=torch.bfloat16)) m = m.cuda() diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 964a575411..522785ae6f 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -24,11 +24,8 @@ ) from torchao.quantization.utils import compute_error -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89 -# trying to outsmart flake8 -__has_cuda = torch.cuda.is_available() -IS_CUDA_GE_89 = __has_cuda and torch.cuda.get_device_capability() >= (8, 9) torch.manual_seed(2) @@ -225,7 +222,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): Verifies that compile does not change numerics of MX casts """ if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not IS_CUDA_GE_89: + if not is_sm_at_least_89(): # separate ifs because flake8 is outsmarting me pytest.skip("CUDA capability >= 8.9 required for float8 in triton") diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index eccf8db8f6..3663e027c7 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -40,6 +40,7 @@ def run_before_and_after_tests(): @pytest.mark.parametrize("qdtype", qdtypes) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5,reason="requires nightly pytorch") +@pytest.mark.skip("Temporarily skipping to unpin nightiles") def test_awq_loading(device, qdtype): if qdtype == torch.uint4 and device == "cpu": pytest.skip("uint4 not supported on cpu") @@ -126,4 +127,4 @@ def test_save_weights_only(): assert awq_out is not None assert awq_save_load_out is not None - assert torch.allclose(awq_out, awq_save_load_out, atol = 1e-2) \ No newline at end of file + assert torch.allclose(awq_out, awq_save_load_out, atol = 1e-2) diff --git a/test/profiler/test_device_spec.py b/test/prototype/test_device_spec.py similarity index 97% rename from test/profiler/test_device_spec.py rename to test/prototype/test_device_spec.py index 1ede428fe0..dd159f5336 100644 --- a/test/profiler/test_device_spec.py +++ b/test/prototype/test_device_spec.py @@ -8,7 +8,7 @@ import torch from utils import patch_device -from torchao.profiler.device_spec import ( +from torchao.prototype.profiler.device_spec import ( _AVAILABLE_GPU_SPECS, CUDADeviceSpec, get_chip_name, diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index a97d1cffdd..4ba13db1fb 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -7,6 +7,7 @@ import torch from packaging.version import Version from torch import nn +from torch.distributed._composable.fsdp import fully_shard from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import ( @@ -26,6 +27,7 @@ from torchao.prototype.low_bit_optim.subclass_8bit import OptimState8bit from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8 from torchao.utils import ( + get_available_devices, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, @@ -42,7 +44,7 @@ lpmm = None -_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) +_DEVICES = get_available_devices() class TestQuantize(TestCase): @@ -94,7 +96,9 @@ def test_bf16_stochastic_round(self, device, compile): x = torch.rand(32, device=device) * 100 x_rep = x.view(-1, 1).repeat(1, 100_000) - func = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False, disable=not compile) + func = torch.compile( + _fp32_to_bf16_sr, fullgraph=True, dynamic=False, disable=not compile + ) x_rep_bf16 = func(x_rep) assert x_rep_bf16.dtype is torch.bfloat16 @@ -169,8 +173,13 @@ def test_subclass_slice(self, subclass, shape, device): tensor = subclass.zeros(shape, device=device) offset = shape[0] // 2 - torch.testing.assert_close(tensor.dequantize()[:offset], tensor[:offset].dequantize()) - torch.testing.assert_close(tensor.dequantize()[offset:offset*2], tensor[offset:offset*2].dequantize()) + torch.testing.assert_close( + tensor.dequantize()[:offset], tensor[:offset].dequantize() + ) + torch.testing.assert_close( + tensor.dequantize()[offset : offset * 2], + tensor[offset : offset * 2].dequantize(), + ) @pytest.mark.skipif(bnb is None, reason="bitsandbytes is not available") @pytest.mark.skipif( @@ -188,7 +197,9 @@ def test_optim_8bit_correctness(self, optim_name): block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048 optim1 = getattr(bnb.optim, optim_name)(model1.parameters()) - optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size) + optim2 = getattr(low_bit_optim, optim_name)( + model2.parameters(), block_size=block_size + ) for _ in range(2): x = torch.randn(4, 32, device=device) @@ -244,11 +255,12 @@ def test_optim_4bit_correctness(self, optim_name): torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) @pytest.mark.skipif( - not torch.cuda.is_available(), reason="optim CPU offload requires CUDA" + not torch.cuda.is_available() and not torch.xpu.is_available(), + reason="optim CPU offload requires CUDA or XPU", ) @parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)]) def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum): - device = "cuda" + device = _DEVICES[-1] model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)) model1.to(device) @@ -279,13 +291,16 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum): torch.testing.assert_close(p2, p1) @pytest.mark.skipif( - not torch.cuda.is_available(), reason="optim CPU offload requires CUDA" + not torch.cuda.is_available() and not torch.xpu.is_available(), + reason="optim CPU offload requires CUDA or XPU", ) def test_optim_cpu_offload_save_load(self): - device = "cuda" + device = _DEVICES[-1] model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)) model1.to(device) - optim1 = low_bit_optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW) + optim1 = low_bit_optim.CPUOffloadOptimizer( + model1.parameters(), torch.optim.AdamW + ) for _ in range(2): x = torch.randn(4, 32, device=device) @@ -300,7 +315,9 @@ def test_optim_cpu_offload_save_load(self): # resume training model2 = copy.deepcopy(model1) - optim2 = low_bit_optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW) + optim2 = low_bit_optim.CPUOffloadOptimizer( + model2.parameters(), torch.optim.AdamW + ) optim2.load_state_dict(state_dict) for _ in range(2): @@ -365,9 +382,9 @@ def world_size(self) -> int: return _FSDP_WORLD_SIZE @pytest.mark.skipif( - not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required." + not TORCH_VERSION_AT_LEAST_2_5, reason="PyTorch>=2.5 is required." ) - @skip_if_lt_x_gpu(2) + @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) def test_fsdp2(self): optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit] if torch.cuda.get_device_capability() >= (8, 9): @@ -382,9 +399,12 @@ def _test_fsdp2(self, optim_cls): import torch.distributed as dist import torch.distributed.checkpoint as dcp import torch.utils._pytree as pytree - from torch.distributed._composable.fsdp import fully_shard from torch.distributed.tensor import DTensor - from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer, TransformerBlock + from torch.testing._internal.distributed._tensor.common_dtensor import ( + ModelArgs, + Transformer, + TransformerBlock, + ) batch_size = 3 vocab_size = 1024 @@ -392,7 +412,7 @@ def _test_fsdp2(self, optim_cls): model_args = ModelArgs( n_layers=3, n_heads=4, - dim=1024, + dim=512, vocab_size=vocab_size, max_seq_len=seq_len, dropout_p=0, @@ -457,7 +477,10 @@ def _test_fsdp2(self, optim_cls): subclasses = (OptimState4bit, OptimState8bit, OptimStateFp8) - for v1, v2 in zip(pytree.tree_iter(resumed_fsdp_optim.state_dict()), pytree.tree_iter(fsdp_optim.state_dict())): + for v1, v2 in zip( + pytree.tree_iter(resumed_fsdp_optim.state_dict()), + pytree.tree_iter(fsdp_optim.state_dict()), + ): assert v1.__class__ == v2.__class__, (v1.__class__, v2.__class__) if isinstance(v1, DTensor): v1 = v1.to_local() @@ -468,6 +491,29 @@ def _test_fsdp2(self, optim_cls): v2 = v2.dequantize() self.assertEqual(v1, v2) + @pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_5, reason="PyTorch>=2.5 is required." + ) + @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) + def test_uneven_shard(self): + in_dim = 512 + out_dim = _FSDP_WORLD_SIZE * 16 + 1 + + # 1st dim of linear weight will not be divisible by WORLD_SIZE + model = nn.Linear(in_dim, out_dim, device="cuda") + assert model.weight.shape[0] % _FSDP_WORLD_SIZE != 0 + fully_shard(model) + + # currently all of our low-bit Adam/AdamW share the same implementation. + # thus, we only need to test for 1 optimizer class. + optim = low_bit_optim.AdamW8bit(model.parameters()) + + for _ in range(2): + inputs = torch.randn(2, in_dim, device="cuda") + model(inputs).sum().backward() + optim.step() + optim.zero_grad() + instantiate_parametrized_tests(TestQuantize) instantiate_parametrized_tests(TestOptim) diff --git a/test/profiler/test_performance_counter.py b/test/prototype/test_performance_counter.py similarity index 99% rename from test/profiler/test_performance_counter.py rename to test/prototype/test_performance_counter.py index 2cd1a33581..6ece2c6398 100644 --- a/test/profiler/test_performance_counter.py +++ b/test/prototype/test_performance_counter.py @@ -30,8 +30,8 @@ qkv_proj_io_check, ) -from torchao.profiler.device_spec import CUDADeviceSpec, DeviceSpec -from torchao.profiler.performance_counter import ( +from torchao.prototype.profiler.device_spec import CUDADeviceSpec, DeviceSpec +from torchao.prototype.profiler.performance_counter import ( CUDAPerformanceTimer, PerformanceCounterMode, PerformanceStats, diff --git a/test/prototype/test_sparse_api.py b/test/prototype/test_sparse_api.py index baf224e169..f3cdbe8386 100644 --- a/test/prototype/test_sparse_api.py +++ b/test/prototype/test_sparse_api.py @@ -31,6 +31,7 @@ class TestSemiStructuredSparse(common_utils.TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skip("Temporarily skipping to unpin nightlies") def test_sparse(self): input = torch.rand((128, 128)).half().cuda() model = ( @@ -49,6 +50,9 @@ def test_sparse(self): sparsify_(model, semi_sparse_weight()) sparse_result = model(input) + if compile: + model = torch.compile(model) + torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) @@ -56,11 +60,14 @@ class TestQuantSemiSparse(common_utils.TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "pytorch 2.5+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @common_utils.parametrize("compile", [True, False]) + @common_utils.parametrize("compile", [False]) def test_quant_semi_sparse(self, compile): if not torch.backends.cusparselt.is_available(): self.skipTest("Need cuSPARSELt") + # compile True failed with CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit(... + # https://github.com/pytorch/ao/actions/runs/11978863581/job/33402892517?pr=1330 + torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False input = torch.rand((128, 128)).half().cuda() @@ -194,7 +201,7 @@ def test_sparse(self, compile): quantize_(model_copy, int8_dynamic_activation_int8_weight()) reference = model_copy(input) - from torchao.dtypes.affine_quantized_tensor import BlockSparseLayout + from torchao.dtypes import BlockSparseLayout quantize_( model, diff --git a/test/profiler/utils.py b/test/prototype/utils.py similarity index 98% rename from test/profiler/utils.py rename to test/prototype/utils.py index 7b2b999809..8c402b8114 100644 --- a/test/profiler/utils.py +++ b/test/prototype/utils.py @@ -5,7 +5,7 @@ import torch -from torchao.profiler import PerformanceTimer +from torchao.prototype.profiler import PerformanceTimer @contextmanager diff --git a/test/quantization/test_galore_quant.py b/test/quantization/test_galore_quant.py index 1eabf479ce..3eb9b0a2c5 100644 --- a/test/quantization/test_galore_quant.py +++ b/test/quantization/test_galore_quant.py @@ -3,13 +3,16 @@ import pytest # Skip entire test if triton is not available, otherwise CI failure -try: - import triton -except ImportError: - pytest.skip("triton is not installed", allow_module_level=True) - -import bitsandbytes.functional as F +try: # noqa: F401 + import triton # noqa: F401 +except ImportError: # noqa: F401 + pytest.skip("triton is not installed", allow_module_level=True) # noqa: F401 import torch +from bitsandbytes.functional import ( + create_dynamic_map, + dequantize_blockwise, + quantize_blockwise, +) from torchao.prototype.galore.kernels import ( triton_dequant_blockwise, @@ -36,9 +39,9 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize): g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01 - qmap = F.create_dynamic_map(signed).to(g.device) + qmap = create_dynamic_map(signed).to(g.device) - ref_bnb, qstate = F.quantize_blockwise(g, code=qmap, blocksize=blocksize) + ref_bnb, qstate = quantize_blockwise(g, code=qmap, blocksize=blocksize) bnb_norm = (g.reshape(-1, blocksize) / qstate.absmax[:, None]).reshape(g.shape) tt_q, tt_norm, tt_absmax = triton_quantize_blockwise( @@ -82,10 +85,10 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize): def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize): g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01 - qmap = F.create_dynamic_map(signed).to(g.device) + qmap = create_dynamic_map(signed).to(g.device) - q, qstate = F.quantize_blockwise(g, code=qmap, blocksize=blocksize) + q, qstate = quantize_blockwise(g, code=qmap, blocksize=blocksize) - dq_ref = F.dequantize_blockwise(q, qstate) + dq_ref = dequantize_blockwise(q, qstate) dq = triton_dequant_blockwise(q, qmap, qstate.absmax, group_size=blocksize) assert torch.allclose(dq, dq_ref) diff --git a/test/quantization/test_marlin_qqq.py b/test/quantization/test_marlin_qqq.py index c020b958f1..ebdf2281e0 100644 --- a/test/quantization/test_marlin_qqq.py +++ b/test/quantization/test_marlin_qqq.py @@ -1,4 +1,5 @@ import copy +import unittest import pytest import torch @@ -18,10 +19,14 @@ MappingType, choose_qparams_and_quantize_affine_qqq, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode -class MarlinQQQ(TestCase): +@unittest.skipIf( + is_fbcode(), + "Skipping the test in fbcode since we don't have TARGET file for kernels", +) +class TestMarlinQQQ(TestCase): def setUp(self): super().setUp() torch.manual_seed(0) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 29f833c9ab..3a998635aa 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -13,9 +13,8 @@ import torch import torch.nn.functional as F from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 -from torchao.dtypes import ( - TensorCoreTiledLayout, -) + +from torchao.quantization.GPTQ import _replace_linear_8da4w, _replace_linear_int4 from torchao.quantization.granularity import ( PerAxis, PerGroup, @@ -26,33 +25,26 @@ ComposableQATQuantizer, FakeQuantizeConfig, ) -from torchao.quantization.qat.fake_quantizer import ( - FakeQuantizer, -) from torchao.quantization.qat.embedding import ( FakeQuantizedEmbedding, ) from torchao.quantization.qat.linear import ( FakeQuantizedLinear, + Int4WeightOnlyQATLinear, Int8DynActInt4WeightQATLinear, - Int4WeightOnlyQATLinear ) from torchao.quantization.qat.utils import ( _choose_qparams_per_token_asymmetric, _fake_quantize_per_channel_group, _fake_quantize_per_token, - _get_qmin_qmax, _GenericFakeQuantize, -) -from torchao.quantization.quant_api import ( - int4_weight_only, - quantize_, + _get_qmin_qmax, ) from torchao.quantization.quant_primitives import ( - fake_quantize_affine, MappingType, TorchAODType, ZeroPointDomain, + fake_quantize_affine, ) from torchao.quantization.unified import ( TwoStepQuantizer, @@ -65,17 +57,12 @@ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, -) - -from torchao.quantization.GPTQ import ( - _replace_linear_8da4w, - _replace_linear_int4 ) # TODO: put this in a common test utils file _CUDA_IS_AVAILABLE = torch.cuda.is_available() + class Sub(torch.nn.Module): def __init__(self): super().__init__() @@ -87,6 +74,7 @@ def example_inputs(self): def forward(self, x): return self.linear(x) + class M(torch.nn.Module): def __init__(self): super().__init__() @@ -103,6 +91,7 @@ def forward(self, x): x = self.linear2(x) return x + class M2(torch.nn.Module): def __init__(self): super().__init__() @@ -118,7 +107,9 @@ def forward(self, x): class TestQAT(unittest.TestCase): SEED = 123 - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantize_per_channel_group(self): n_bit = 4 (qmin, qmax) = _get_qmin_qmax(n_bit) @@ -132,20 +123,40 @@ def test_fake_quantize_per_channel_group(self): # fake quant op out = _fake_quantize_per_channel_group( - x, s, zp, qmin, qmax, group_size, + x, + s, + zp, + qmin, + qmax, + group_size, ) out.sum().backward() # compare against PTQ ops out_ptq = torch.ops.quantized_decomposed.quantize_per_channel_group( - x2, s, zp, qmin, qmax, torch.int8, group_size, + x2, + s, + zp, + qmin, + qmax, + torch.int8, + group_size, ) out_ptq = torch.ops.quantized_decomposed.dequantize_per_channel_group( - out_ptq, s, zp, qmin, qmax, torch.int8, group_size, torch.float32, + out_ptq, + s, + zp, + qmin, + qmax, + torch.int8, + group_size, + torch.float32, ) torch.testing.assert_close(out, out_ptq, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantize_per_token(self): (qmin, qmax) = _get_qmin_qmax(8) @@ -161,10 +172,21 @@ def test_fake_quantize_per_token(self): # compare against PTQ ops out_ptq = torch.ops.quantized_decomposed.quantize_per_token( - x2, s, zp, qmin, qmax, torch.int8, + x2, + s, + zp, + qmin, + qmax, + torch.int8, ) out_ptq = torch.ops.quantized_decomposed.dequantize_per_token( - out_ptq, s, zp, qmin, qmax, torch.int8, torch.float32, + out_ptq, + s, + zp, + qmin, + qmax, + torch.int8, + torch.float32, ) torch.testing.assert_close(out, out_ptq, atol=0, rtol=0) @@ -182,9 +204,10 @@ def _set_ptq_weight( WeightOnlyInt4Linear, ) from torchao.quantization.qat.linear import ( - Int8DynActInt4WeightQATLinear, Int4WeightOnlyQATLinear, + Int8DynActInt4WeightQATLinear, ) + n_bit = 4 (qmin, qmax) = _get_qmin_qmax(n_bit) group_size = qat_linear.weight_fake_quantizer.config.group_size @@ -193,7 +216,13 @@ def _set_ptq_weight( fp32_weight = qat_linear.weight (s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size) q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group( - fp32_weight, s, zp, qmin, qmax, torch.int8, group_size, + fp32_weight, + s, + zp, + qmin, + qmax, + torch.int8, + group_size, ) ptq_linear.weight = q_weight ptq_linear.scales = s @@ -201,28 +230,39 @@ def _set_ptq_weight( elif isinstance(ptq_linear, WeightOnlyInt4Linear): assert isinstance(qat_linear, Int4WeightOnlyQATLinear) (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( - qat_linear.weight, n_bit, group_size, + qat_linear.weight, + n_bit, + group_size, ) q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to("cuda"), qat_linear.inner_k_tiles, + q_weight.to("cuda"), + qat_linear.inner_k_tiles, ) ptq_linear.weight = q_weight ptq_linear.scales_and_zeros = scales_and_zeros else: raise ValueError("Unknown ptq_linear type: %s" % type(ptq_linear)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_linear(self): - from torchao.quantization.qat.linear import Int8DynActInt4WeightQATLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear + from torchao.quantization.qat.linear import Int8DynActInt4WeightQATLinear group_size = 128 torch.manual_seed(self.SEED) qat_linear = Int8DynActInt4WeightQATLinear( - 256, 688, bias=False, groupsize=group_size, + 256, + 688, + bias=False, + groupsize=group_size, ) ptq_linear = Int8DynActInt4WeightLinear( - 256, 688, bias=False, groupsize=group_size, + 256, + 688, + bias=False, + groupsize=group_size, ) # Force the weights to be the same @@ -236,10 +276,12 @@ def test_qat_8da4w_linear(self): ptq_out = ptq_linear(x2) torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_quantizer(self): - from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer + from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer group_size = 16 torch.manual_seed(self.SEED) @@ -268,9 +310,13 @@ def test_qat_8da4w_quantizer(self): converted_state_dict = converted_model.state_dict() self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) for k in ptq_state_dict.keys(): - torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) + torch.testing.assert_close( + ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0 + ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_quantizer_meta_weights(self): from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer @@ -282,7 +328,9 @@ def test_qat_8da4w_quantizer_meta_weights(self): qat_model = qat_quantizer.prepare(m) self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values())) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_quantizer_disable_fake_quant(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward. @@ -341,7 +389,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): qat_out2 = qat_model2(*x2) torch.testing.assert_close(qat_out, qat_out2, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward. @@ -363,8 +413,12 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): nn_model.sub.linear.weight = torch.nn.Parameter(qat_model.sub.linear.weight) # Simulate training for both models - optimizer1 = torch.optim.SGD(nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) - optimizer2 = torch.optim.SGD(qat_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) + optimizer1 = torch.optim.SGD( + nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5 + ) + optimizer2 = torch.optim.SGD( + qat_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5 + ) loss_fn1 = torch.nn.CrossEntropyLoss() loss_fn2 = torch.nn.CrossEntropyLoss() example_inputs = nn_model.example_inputs() @@ -382,9 +436,15 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): optimizer2.step() # After 1 training step, weights should match exactly - torch.testing.assert_close(nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0) - torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0) - torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0) + torch.testing.assert_close( + nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0 + ) + torch.testing.assert_close( + nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0 + ) + torch.testing.assert_close( + nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0 + ) def _test_qat_quantized_gradients(self, quantizer): """ @@ -394,7 +454,9 @@ def _test_qat_quantized_gradients(self, quantizer): torch.manual_seed(self.SEED) m = M() model = quantizer.prepare(m) - optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) + optimizer = torch.optim.SGD( + model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5 + ) loss_fn = torch.nn.CrossEntropyLoss() # Simulate training @@ -426,13 +488,18 @@ def _test_qat_quantized_gradients(self, quantizer): optimizer.step() current_step += 1 - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_quantizer_gradients(self): from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer + quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16) self._test_qat_quantized_gradients(quantizer) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_generic_fake_quantize(self): """ Test that the generic fake quantize used in 8da4w QAT matches @@ -443,7 +510,9 @@ def test_qat_generic_fake_quantize(self): py_input = torch.randn(16, 64).float().requires_grad_() py_s = torch.randn(16).float() py_zp = torch.randint(qmax, size=(16,), dtype=torch.int32) - py_out = torch.fake_quantize_per_channel_affine(py_input, py_s, py_zp, 0, qmin, qmax) + py_out = torch.fake_quantize_per_channel_affine( + py_input, py_s, py_zp, 0, qmin, qmax + ) py_out.sum().backward() ao_input = copy.deepcopy(py_input) @@ -451,7 +520,9 @@ def test_qat_generic_fake_quantize(self): block_size = (1, ao_input.shape[-1]) ao_s = copy.deepcopy(py_s) ao_zp = copy.deepcopy(py_zp) - ao_out = _GenericFakeQuantize.apply(ao_input, block_size, ao_s, ao_zp, qmin, qmax) + ao_out = _GenericFakeQuantize.apply( + ao_input, block_size, ao_s, ao_zp, qmin, qmax + ) ao_out.sum().backward() torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0) @@ -485,10 +556,14 @@ def test_qat_4w_primitives(self): # PTQ (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( - weight, n_bit, group_size, scales_precision, + weight, + n_bit, + group_size, + scales_precision, ) q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to(device), inner_k_tiles, + q_weight.to(device), + inner_k_tiles, ) ptq_out = torch.ops.aten._weight_int4pack_mm( x, q_weight, group_size, scales_and_zeros @@ -497,9 +572,12 @@ def test_qat_4w_primitives(self): # QAT block_size = (1, group_size) quant_min = 0 - quant_max = 2 ** n_bit - 1 + quant_max = 2**n_bit - 1 scales, zero_points = get_groupwise_affine_qparams( - weight, n_bit, group_size, scales_precision, + weight, + n_bit, + group_size, + scales_precision, ) w_fq = fake_quantize_affine( weight, @@ -509,27 +587,37 @@ def test_qat_4w_primitives(self): torch.int32, quant_min, quant_max, - zero_point_domain = ZeroPointDomain.FLOAT, + zero_point_domain=ZeroPointDomain.FLOAT, ) qat_out = torch.nn.functional.linear(x, w_fq) self._assert_close_4w(qat_out, ptq_out) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_linear(self): - from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear from torchao.quantization.GPTQ import WeightOnlyInt4Linear + from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear group_size = 128 device = torch.device("cuda") dtype = torch.bfloat16 torch.manual_seed(self.SEED) qat_linear = Int4WeightOnlyQATLinear( - 256, 688, bias=False, groupsize=group_size, device=device, + 256, + 688, + bias=False, + groupsize=group_size, + device=device, ) ptq_linear = WeightOnlyInt4Linear( - 256, 688, bias=False, groupsize=group_size, device=device, + 256, + 688, + bias=False, + groupsize=group_size, + device=device, ) # Force the weights to be the same @@ -543,17 +631,22 @@ def test_qat_4w_linear(self): ptq_out = ptq_linear(x2) self._assert_close_4w(qat_out, ptq_out) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_4w_quantizer_gradients(self): from torchao.quantization.qat import Int4WeightOnlyQATQuantizer + quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8) self._test_qat_quantized_gradients(quantizer) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_quantizer(self): - from torchao.quantization.qat import Int4WeightOnlyQATQuantizer from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer + from torchao.quantization.qat import Int4WeightOnlyQATQuantizer group_size = 32 inner_k_tiles = 8 @@ -563,10 +656,12 @@ def test_qat_4w_quantizer(self): m = M().to(device).to(dtype) m2 = copy.deepcopy(m) qat_quantizer = Int4WeightOnlyQATQuantizer( - groupsize=group_size, inner_k_tiles=inner_k_tiles, + groupsize=group_size, + inner_k_tiles=inner_k_tiles, ) ptq_quantizer = Int4WeightOnlyQuantizer( - groupsize=group_size, inner_k_tiles=inner_k_tiles, + groupsize=group_size, + inner_k_tiles=inner_k_tiles, ) qat_model = qat_quantizer.prepare(m) ptq_model = ptq_quantizer.quantize(m2) @@ -589,13 +684,16 @@ def test_qat_4w_quantizer(self): converted_state_dict = converted_model.state_dict() self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) for k in ptq_state_dict.keys(): - torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) + torch.testing.assert_close( + ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0 + ) class _MyQATQuantizer(TwoStepQuantizer): """ Dummy quantizer that attaches a certain value to each nn.Linear's `_temp_quantizer_values` attribute. """ + ATTR_NAME = "_temp_quantizer_values" def __init__(self, value: str): @@ -626,19 +724,24 @@ def test_composable_qat_quantizer(self): self.assertEqual(values_list, ["quantizer1", "quantizer2"]) composable_quantizer.convert(model) values_list = getattr(model.linear1, self._MyQATQuantizer.ATTR_NAME) - self.assertEqual(values_list, ["quantizer1", "quantizer2", "quantizer1", "quantizer2"]) + self.assertEqual( + values_list, ["quantizer1", "quantizer2", "quantizer1", "quantizer2"] + ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_4w_embedding(self): from torchao.quantization.qat import Int4WeightOnlyEmbeddingQATQuantizer + model = M2() x = model.example_inputs() - out = model(*x) + model(*x) quantizer = Int4WeightOnlyEmbeddingQATQuantizer() prepared = quantizer.prepare(model) - prepared_out = prepared(*x) + prepared(*x) converted = quantizer.convert(model) - converted_out = converted(*x) + converted(*x) def test_fake_quantize_config_granularity(self): """ @@ -685,7 +788,9 @@ def test_fake_quantize_config_granularity_error_cases(self): Test incorrect settings of `FakeQuantizeConfig`'s granularity. """ # no granularity provided - with self.assertRaisesRegex(ValueError, "`granularity` or `group_size` must be set"): + with self.assertRaisesRegex( + ValueError, "`granularity` or `group_size` must be set" + ): FakeQuantizeConfig(torch.int8) # group_size with conflicting granularity @@ -718,8 +823,12 @@ def test_fake_quantize_config_mapping_type(self): """ # symmetric symmetric_config1 = FakeQuantizeConfig(torch.int8, "per_token") - symmetric_config2 = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=True) - symmetric_config3 = FakeQuantizeConfig(torch.int8, "per_token", MappingType.SYMMETRIC) + symmetric_config2 = FakeQuantizeConfig( + torch.int8, "per_token", is_symmetric=True + ) + symmetric_config3 = FakeQuantizeConfig( + torch.int8, "per_token", MappingType.SYMMETRIC + ) self.assertEqual(symmetric_config1.mapping_type, MappingType.SYMMETRIC) self.assertEqual(symmetric_config2.mapping_type, MappingType.SYMMETRIC) self.assertEqual(symmetric_config3.mapping_type, MappingType.SYMMETRIC) @@ -728,8 +837,12 @@ def test_fake_quantize_config_mapping_type(self): self.assertTrue(symmetric_config3.is_symmetric) # asymmetric - asymmetric_config1 = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) - asymmetric_config2 = FakeQuantizeConfig(torch.int8, "per_token", MappingType.ASYMMETRIC) + asymmetric_config1 = FakeQuantizeConfig( + torch.int8, "per_token", is_symmetric=False + ) + asymmetric_config2 = FakeQuantizeConfig( + torch.int8, "per_token", MappingType.ASYMMETRIC + ) self.assertEqual(asymmetric_config1.mapping_type, MappingType.ASYMMETRIC) self.assertEqual(asymmetric_config2.mapping_type, MappingType.ASYMMETRIC) self.assertFalse(asymmetric_config1.is_symmetric) @@ -743,11 +856,15 @@ def test_fake_quantize_config_mapping_type(self): # bad config1: both mapping_type and is_symmetric are set msg = "Cannot set both `mapping_type` and `is_symmetric`" with self.assertRaisesRegex(ValueError, msg): - FakeQuantizeConfig(torch.int8, "per_token", MappingType.SYMMETRIC, is_symmetric=False) + FakeQuantizeConfig( + torch.int8, "per_token", MappingType.SYMMETRIC, is_symmetric=False + ) # bad config2: not supported with self.assertRaisesRegex(ValueError, "not supported"): - FakeQuantizeConfig(torch.int8, "per_token", MappingType.SYMMETRIC_NO_CLIPPING_ERR) + FakeQuantizeConfig( + torch.int8, "per_token", MappingType.SYMMETRIC_NO_CLIPPING_ERR + ) def test_fake_quantize_config_dtype(self): """ @@ -781,7 +898,9 @@ def test_fake_quantize_config_dtype(self): FakeQuantizeConfig(TorchAODType.INT7, "per_token") FakeQuantizeConfig(torch.int8, "per_token") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantized_linear_8da4w(self): """ Test that we can express int8 dynamic activations + int4 weights with `FakeQuantizedLinear`. @@ -792,7 +911,9 @@ def test_fake_quantized_linear_8da4w(self): 256, 688, bias=False, - activation_config=FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False), + activation_config=FakeQuantizeConfig( + torch.int8, "per_token", is_symmetric=False + ), weight_config=FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size), ) @@ -801,7 +922,9 @@ def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: Baseline for int8 dynamic per token asymmetric + int4 per group symmetric quant. """ # activations - (s, zp) = _choose_qparams_per_token_asymmetric(x, torch.float32, torch.int32) + (s, zp) = _choose_qparams_per_token_asymmetric( + x, torch.float32, torch.int32 + ) (qmin, qmax) = _get_qmin_qmax(8) x_fq = _fake_quantize_per_token(x, s, zp, qmin, qmax) @@ -809,7 +932,9 @@ def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: (s, zp) = get_group_qparams_symmetric(weight, 4, group_size, torch.float32) zp = zp.to(torch.int32) (qmin, qmax) = _get_qmin_qmax(4) - w_fq = _fake_quantize_per_channel_group(weight, s, zp, qmin, qmax, group_size) + w_fq = _fake_quantize_per_channel_group( + weight, s, zp, qmin, qmax, group_size + ) return F.linear(x_fq, w_fq) # Compare linear values @@ -820,7 +945,9 @@ def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: baseline_out = linear_forward_8da4w(x2, fq_linear.weight) torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantized_linear_4w(self): """ Test that we can express int4 weight only (tinygemm) with `FakeQuantizedLinear`. @@ -849,7 +976,13 @@ def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: (s, zp) = get_groupwise_affine_qparams(weight, 4, group_size, torch.float32) zp = zp.to(torch.int32) w_fq = _fake_quantize_per_channel_group( - weight, s, zp, qmin, qmax, group_size, zero_point_domain=ZeroPointDomain.FLOAT, + weight, + s, + zp, + qmin, + qmax, + group_size, + zero_point_domain=ZeroPointDomain.FLOAT, ) return F.linear(x, w_fq) @@ -860,50 +993,78 @@ def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: fq_out = fq_linear(x) baseline_out = linear_forward_4w(x2, fq_linear.weight) torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_replace_linear_8da4w(self): - module = torch.nn.ModuleList([ - torch.nn.Linear(in_features=256, out_features=50, bias=True) - ]) - _replace_linear_8da4w(module, 256, False, torch.float32, torch.float32, Int8DynActInt4WeightQATLinear, copy_weights=True) - assert(not isinstance(module[0], Int8DynActInt4WeightQATLinear) and isinstance(module[0], torch.nn.Linear)) - module = torch.nn.ModuleList([ - torch.nn.Linear(in_features=256, out_features=50, bias=False) - ]) - _replace_linear_8da4w(module, 256, False, torch.float32, torch.float32, Int8DynActInt4WeightQATLinear, copy_weights=True) - assert(isinstance(module[0], Int8DynActInt4WeightQATLinear)) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + module = torch.nn.ModuleList( + [torch.nn.Linear(in_features=256, out_features=50, bias=True)] + ) + _replace_linear_8da4w( + module, + 256, + False, + torch.float32, + torch.float32, + Int8DynActInt4WeightQATLinear, + copy_weights=True, + ) + assert not isinstance(module[0], Int8DynActInt4WeightQATLinear) and isinstance( + module[0], torch.nn.Linear + ) + module = torch.nn.ModuleList( + [torch.nn.Linear(in_features=256, out_features=50, bias=False)] + ) + _replace_linear_8da4w( + module, + 256, + False, + torch.float32, + torch.float32, + Int8DynActInt4WeightQATLinear, + copy_weights=True, + ) + assert isinstance(module[0], Int8DynActInt4WeightQATLinear) + + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_replace_linear_int4(self): - module = torch.nn.ModuleList([ - torch.nn.Linear(in_features=256, out_features=50, bias=True) - ]) + module = torch.nn.ModuleList( + [torch.nn.Linear(in_features=256, out_features=50, bias=True)] + ) _replace_linear_int4( - module, - 256, + module, + 256, 8, - padding_allowed=True, - precision=torch.bfloat16, - scales_precision=torch.bfloat16, - linear_class=Int4WeightOnlyQATLinear, - copy_weights=True) - assert(not isinstance(module[0], Int4WeightOnlyQATLinear) and isinstance(module[0], torch.nn.Linear)) - module = torch.nn.ModuleList([ - torch.nn.Linear(in_features=256, out_features=50, bias=False) - ]) + padding_allowed=True, + precision=torch.bfloat16, + scales_precision=torch.bfloat16, + linear_class=Int4WeightOnlyQATLinear, + copy_weights=True, + ) + assert not isinstance(module[0], Int4WeightOnlyQATLinear) and isinstance( + module[0], torch.nn.Linear + ) + module = torch.nn.ModuleList( + [torch.nn.Linear(in_features=256, out_features=50, bias=False)] + ) _replace_linear_int4( - module, - 256, + module, + 256, 8, - padding_allowed=True, - precision=torch.bfloat16, - scales_precision=torch.bfloat16, - linear_class=Int4WeightOnlyQATLinear, - copy_weights=True) - assert(isinstance(module[0], Int4WeightOnlyQATLinear)) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + padding_allowed=True, + precision=torch.bfloat16, + scales_precision=torch.bfloat16, + linear_class=Int4WeightOnlyQATLinear, + copy_weights=True, + ) + assert isinstance(module[0], Int4WeightOnlyQATLinear) + + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantized_embedding_4w(self): """ Test that we can express int4 per group symmetric weight only fake quantization @@ -926,7 +1087,9 @@ def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: (s, zp) = get_group_qparams_symmetric(weight, 4, group_size, torch.float32) zp = zp.to(torch.int32) (qmin, qmax) = _get_qmin_qmax(4) - w_fq = _fake_quantize_per_channel_group(weight, s, zp, qmin, qmax, group_size) + w_fq = _fake_quantize_per_channel_group( + weight, s, zp, qmin, qmax, group_size + ) return F.embedding(x, w_fq) # Compare embedding values @@ -937,59 +1100,15 @@ def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: baseline_out = embedding_forward_4w(x2, fq_embedding.weight) torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_prototype_bc(self): """ Just to make sure we can import all the old prototype paths. We will remove this test in the near future when we actually break BC. """ - from torchao.quantization.prototype.qat import ( - disable_4w_fake_quant, - disable_8da4w_fake_quant, - enable_4w_fake_quant, - enable_8da4w_fake_quant, - ComposableQATQuantizer, - Int8DynActInt4WeightQATLinear, - Int4WeightOnlyEmbeddingQATQuantizer, - Int4WeightOnlyQATQuantizer, - Int8DynActInt4WeightQATQuantizer, - ) - from torchao.quantization.prototype.qat._module_swap_api import ( - disable_4w_fake_quant_module_swap, - enable_4w_fake_quant_module_swap, - disable_8da4w_fake_quant_module_swap, - enable_8da4w_fake_quant_module_swap, - Int4WeightOnlyQATQuantizerModuleSwap, - Int8DynActInt4WeightQATQuantizerModuleSwap, - ) - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( - AffineFakeQuantizedTensor, - to_affine_fake_quantized, - ) - from torchao.quantization.prototype.qat.api import ( - ComposableQATQuantizer, - FakeQuantizeConfig, - ) - from torchao.quantization.prototype.qat.embedding import ( - FakeQuantizedEmbedding, - Int4WeightOnlyEmbeddingQATQuantizer, - Int4WeightOnlyEmbedding, - Int4WeightOnlyQATEmbedding, - ) - from torchao.quantization.prototype.qat.fake_quantizer import ( - FakeQuantizer, - ) - from torchao.quantization.prototype.qat.linear import ( - disable_4w_fake_quant, - disable_8da4w_fake_quant, - enable_4w_fake_quant, - enable_8da4w_fake_quant, - FakeQuantizedLinear, - Int4WeightOnlyQATLinear, - Int4WeightOnlyQATQuantizer, - Int8DynActInt4WeightQATLinear, - Int8DynActInt4WeightQATQuantizer, - ) + if __name__ == "__main__": unittest.main() diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 458cd07810..eb5f1337d1 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -6,81 +6,86 @@ # mypy: ignore-errors # This test takes a long time to run +import copy +import gc +import tempfile import unittest +from pathlib import Path + import torch -import os from torch.ao.quantization.quantize_pt2e import ( - prepare_pt2e, convert_pt2e, + prepare_pt2e, ) from torch.ao.quantization.quantizer.xnnpack_quantizer import ( XNNPACKQuantizer, get_symmetric_quantization_config, ) +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import TestCase -import torchao +from torchao import quantize_ +from torchao._models.llama.model import Transformer, prepare_inputs_for_model +from torchao._models.llama.tokenizer import get_tokenizer from torchao.dtypes import ( AffineQuantizedTensor, ) from torchao.quantization import ( LinearActivationQuantizedTensor, ) -from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, -) -from torchao.quantization.subclass import ( - Int8WeightOnlyQuantizedLinearWeight, - Int4WeightOnlyQuantizedLinearWeight, -) -from torchao import quantize_ from torchao.quantization.quant_api import ( - _replace_with_custom_fn_if_matches_filter, Quantizer, TwoStepQuantizer, - int8_dynamic_activation_int4_weight, + _replace_with_custom_fn_if_matches_filter, int4_weight_only, - int8_weight_only, + int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, + int8_weight_only, +) +from torchao.quantization.quant_primitives import ( + MappingType, +) +from torchao.quantization.subclass import ( + Int4WeightOnlyQuantizedLinearWeight, + Int8WeightOnlyQuantizedLinearWeight, ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + unwrap_tensor_subclass, ) -from pathlib import Path -from torchao._models.llama.tokenizer import get_tokenizer -from torchao._models.llama.model import Transformer, prepare_inputs_for_model -from torchao.utils import unwrap_tensor_subclass -import copy -import tempfile -import gc -from torch.testing._internal.common_utils import TestCase -from torch.testing._internal import common_utils def dynamic_quant(model, example_inputs): m = torch.export.export(model, example_inputs).module() - quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True)) + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_dynamic=True) + ) m = prepare_pt2e(m, quantizer) m = convert_pt2e(m) return m + def capture_and_prepare(model, example_inputs): m = torch.export.export(model, example_inputs) - quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True)) + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_dynamic=True) + ) m = prepare_pt2e(m, quantizer) # TODO: we can run the weight observer in convert_pt2e so that user don't need to run this m(*example_inputs) return m -class XNNPackDynamicQuantizer(TwoStepQuantizer): +class XNNPackDynamicQuantizer(TwoStepQuantizer): def prepare(self, model: torch.nn.Module) -> torch.nn.Module: _replace_with_custom_fn_if_matches_filter( model, - lambda linear_mod: capture_and_prepare(linear_mod, (torch.randn(1, linear_mod.in_features))), + lambda linear_mod: capture_and_prepare( + linear_mod, (torch.randn(1, linear_mod.in_features)) + ), lambda mod, fqn: isinstance(mod, torch.nn.Linear), ) return model @@ -93,11 +98,13 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module: ) return model + class TorchCompileDynamicQuantizer(Quantizer): def quantize(self, model: torch.nn.Module) -> torch.nn.Module: quantize_(model, int8_dynamic_activation_int8_weight()) return model + class ToyLinearModel(torch.nn.Module): def __init__(self, m=64, n=32, k=64): super().__init__() @@ -105,7 +112,11 @@ def __init__(self, m=64, n=32, k=64): self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float) def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"): - return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),) + return ( + torch.randn( + batch_size, self.linear1.in_features, dtype=dtype, device=device + ), + ) def forward(self, x): x = self.linear1(x) @@ -118,9 +129,11 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs The deprecated implementation for int8 dynamic quant API, used as a reference for numerics and performance """ - from torchao.quantization.quant_api import _in_features_greater_than_16 - from torchao.quantization.quant_api import _is_linear - from torchao.quantization.quant_api import _get_subclass_inserter + from torchao.quantization.quant_api import ( + _get_subclass_inserter, + _in_features_greater_than_16, + _is_linear, + ) from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight if filter_fn is None: @@ -129,37 +142,49 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs ) _replace_with_custom_fn_if_matches_filter( - model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn + model, + _get_subclass_inserter( + Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs + ), + filter_fn, ) + def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass): def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs): """ The deprecated implementation for weight only quant API, used as a reference for numerics and performance """ - from torchao.quantization.quant_api import _is_linear - from torchao.quantization.quant_api import _get_subclass_inserter + from torchao.quantization.quant_api import _get_subclass_inserter, _is_linear filter_fn = kwargs.pop("filter_fn", _is_linear) _replace_with_custom_fn_if_matches_filter( model, - _get_subclass_inserter(deprecated_tenosr_subclass, enable_parametrization=True, **kwargs), + _get_subclass_inserter( + deprecated_tenosr_subclass, enable_parametrization=True, **kwargs + ), filter_fn, ) return _ref_change_linear_weights_to_woqtensors -_ref_change_linear_weights_to_int8_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight) -_ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight) + +_ref_change_linear_weights_to_int8_woqtensors = ( + _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight) +) +_ref_change_linear_weights_to_int4_woqtensors = ( + _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight) +) + class TestQuantFlow(TestCase): def test_dynamic_quant_gpu_singleline(self): m = ToyLinearModel().eval() example_inputs = m.example_inputs() quantize_(m, int8_dynamic_activation_int8_weight()) - quantized = m(*example_inputs) + m(*example_inputs) # AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64 # While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {}) # m = torch.compile(m, mode="max-autotune") @@ -182,7 +207,9 @@ def test_dynamic_quant_gpu_unified_api_unified_impl(self): compiled = m(*example_inputs) torch.testing.assert_close(quantized, compiled, atol=0, rtol=0) - @unittest.skip("FAILED test/quantization/test_quant_api.py::TestQuantFlow::test_dynamic_quant_gpu_unified_api_eager_mode_impl - AssertionError: Tensor-likes are not equal!") + @unittest.skip( + "FAILED test/quantization/test_quant_api.py::TestQuantFlow::test_dynamic_quant_gpu_unified_api_eager_mode_impl - AssertionError: Tensor-likes are not equal!" + ) def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): quantizer = TorchCompileDynamicQuantizer() m = ToyLinearModel().eval() @@ -196,10 +223,8 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "only works for torch 2.4+") def test_int8_wo_quant_save_load(self): - from torchao.quantization.quant_api import ( - change_linear_weights_to_int8_woqtensors, - ) m = ToyLinearModel().eval().cpu() + def api(model): quantize_(model, int8_weight_only()) unwrap_tensor_subclass(model) @@ -223,10 +248,12 @@ def api(model): torch.testing.assert_close(ref, res.cpu()) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch verion is 2.3 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch verion is 2.3 or lower" + ) def test_8da4w_quantizer(self): - from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer quantizer = Int8DynActInt4WeightQuantizer(groupsize=32) m = ToyLinearModel().eval() @@ -242,8 +269,9 @@ def test_8da4w_quantizer(self): # https://github.com/pytorch-labs/gpt-fast/blob/6253c6bb054e658d67566150f87329b87815ae63/scripts/convert_hf_checkpoint.py @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_8da4w_gptq_quantizer(self): - from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer from torchao._models._eval import InputRecorder, TransformerEvalWrapper + from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer + # should be similar to TorchCompileDynamicQuantizer precision = torch.bfloat16 device = "cpu" @@ -268,16 +296,20 @@ def test_8da4w_gptq_quantizer(self): input_prep_func = prepare_inputs_for_model pad_calibration_inputs = False - inputs = InputRecorder( - tokenizer, - calibration_seq_length, - input_prep_func, - pad_calibration_inputs, - model.config.vocab_size, - ).record_inputs( - calibration_tasks, - calibration_limit, - ).get_inputs() + inputs = ( + InputRecorder( + tokenizer, + calibration_seq_length, + input_prep_func, + pad_calibration_inputs, + model.config.vocab_size, + ) + .record_inputs( + calibration_tasks, + calibration_limit, + ) + .get_inputs() + ) quantizer = Int8DynActInt4WeightGPTQQuantizer( blocksize, @@ -287,7 +319,7 @@ def test_8da4w_gptq_quantizer(self): ) model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) model = quantizer.quantize(model, inputs) - result=TransformerEvalWrapper( + result = TransformerEvalWrapper( model, tokenizer, model.config.block_size, @@ -298,15 +330,17 @@ def test_8da4w_gptq_quantizer(self): 1, ) - assert result['results']['wikitext']['word_perplexity,none'] < 7.88, ( - f"accuracy regressed from 7.87 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 7.88 + ), f"accuracy regressed from 7.87 to {result['results']['wikitext']['word_perplexity,none']}" @unittest.skip("skipping until we get checkpoints for gpt-fast") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch verion is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch verion is 2.4 or lower" + ) def test_8da4w_quantizer_eval(self): - from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer from torchao._models._eval import TransformerEvalWrapper + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer precision = torch.bfloat16 device = "cpu" @@ -325,7 +359,7 @@ def test_8da4w_quantizer_eval(self): quantizer = Int8DynActInt4WeightQuantizer(groupsize=128, precision=precision) q_model = quantizer.quantize(model) - result=TransformerEvalWrapper( + result = TransformerEvalWrapper( q_model, tokenizer, q_model.config.block_size, @@ -335,14 +369,18 @@ def test_8da4w_quantizer_eval(self): ["wikitext"], 1, ) - assert result['results']['wikitext']['word_perplexity,none'] < 8.24, ( - f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 8.24 + ), f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_gptq_quantizer_int4_weight_only(self): + from torchao._models._eval import ( + MultiTensorInputRecorder, + TransformerEvalWrapper, + ) from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer - from torchao._models._eval import MultiTensorInputRecorder, TransformerEvalWrapper + precision = torch.bfloat16 device = "cuda" checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") @@ -367,18 +405,21 @@ def test_gptq_quantizer_int4_weight_only(self): calibration_seq_length = 100 input_prep_func = prepare_inputs_for_model pad_calibration_inputs = False - inputs = MultiTensorInputRecorder( - tokenizer, - calibration_seq_length, - input_prep_func, - pad_calibration_inputs, - model.config.vocab_size, - device="cpu", - ).record_inputs( - calibration_tasks, - calibration_limit, - ).get_inputs() - + inputs = ( + MultiTensorInputRecorder( + tokenizer, + calibration_seq_length, + input_prep_func, + pad_calibration_inputs, + model.config.vocab_size, + device="cpu", + ) + .record_inputs( + calibration_tasks, + calibration_limit, + ) + .get_inputs() + ) quantizer = Int4WeightOnlyGPTQQuantizer( blocksize, @@ -398,14 +439,15 @@ def test_gptq_quantizer_int4_weight_only(self): ["wikitext"], None, ) - assert result['results']['wikitext']['word_perplexity,none'] < 7.77, ( - f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 7.77 + ), f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_quantizer_int4_weight_only(self): - from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer from torchao._models._eval import TransformerEvalWrapper + from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer + precision = torch.bfloat16 device = "cuda" checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") @@ -435,13 +477,14 @@ def test_quantizer_int4_weight_only(self): ["wikitext"], 1, ) - assert result['results']['wikitext']['word_perplexity,none'] < 8.24, ( - f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 8.24 + ), f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_eval_wrapper(self): from torchao._models._eval import TransformerEvalWrapper + precision = torch.bfloat16 device = "cuda" checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") @@ -456,7 +499,7 @@ def test_eval_wrapper(self): tokenizer_path, "Llama-2-7b-chat-hf", ) - result=TransformerEvalWrapper( + result = TransformerEvalWrapper( model, tokenizer, model.config.block_size, @@ -466,17 +509,20 @@ def test_eval_wrapper(self): ["wikitext"], 1, ) - assert result['results']['wikitext']['word_perplexity,none']<7.77, ( - f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 7.77 + ), f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" # EVAL IS CURRENTLY BROKEN FOR LLAMA 3, VERY LOW ACCURACY @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_eval_wrapper_llama3(self): from torchao._models._eval import TransformerEvalWrapper + precision = torch.bfloat16 device = "cuda" - checkpoint_path = Path(".../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth") + checkpoint_path = Path( + ".../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth" + ) model = Transformer.from_name(checkpoint_path.parent.name) checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) model.load_state_dict(checkpoint, assign=True) @@ -498,30 +544,43 @@ def test_eval_wrapper_llama3(self): ["wikitext"], 1, ) - assert result['results']['wikitext']['word_perplexity,none'] < 8.24, ( - f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 8.24 + ), f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" # TODO: move to a separate test file @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @common_utils.parametrize("mapping_type", [MappingType.SYMMETRIC, MappingType.SYMMETRIC_NO_CLIPPING_ERR]) + @common_utils.parametrize( + "mapping_type", [MappingType.SYMMETRIC, MappingType.SYMMETRIC_NO_CLIPPING_ERR] + ) def test_quantized_tensor_subclass_8da4w(self, mapping_type): group_size = 32 m = ToyLinearModel().eval() m_copy = copy.deepcopy(m) example_inputs = m.example_inputs() - quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size, mapping_type=mapping_type)) + quantize_( + m, + int8_dynamic_activation_int4_weight( + group_size=group_size, mapping_type=mapping_type + ), + ) assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor) assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor) - assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor) - assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor) + assert isinstance( + m.linear1.weight.original_weight_tensor, AffineQuantizedTensor + ) + assert isinstance( + m.linear2.weight.original_weight_tensor, AffineQuantizedTensor + ) # reference - from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer - quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size, mapping_type=mapping_type) + quantizer = Int8DynActInt4WeightQuantizer( + groupsize=group_size, mapping_type=mapping_type + ) m_copy = quantizer.quantize(m_copy) assert isinstance(m_copy.linear1, Int8DynActInt4WeightLinear) assert isinstance(m_copy.linear2, Int8DynActInt4WeightLinear) @@ -552,7 +611,6 @@ def test_quantized_tensor_subclass_int4(self): self.assertTrue(torch.equal(res, ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_int8_wo(self): @@ -568,13 +626,11 @@ def test_quantized_tensor_subclass_int8_wo(self): # reference _ref_change_linear_weights_to_int8_woqtensors(m_copy) - res = m(*example_inputs) ref = m_copy(*example_inputs) self.assertTrue(torch.equal(res, ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.5 and below") @@ -583,13 +639,19 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): m = ToyLinearModel(1024, 1024, 2048).eval().to(torch.bfloat16).to("cuda") m_copy = copy.deepcopy(m) # setting batch_size to 20 to be compatible with the kernel - example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda") + example_inputs = m.example_inputs( + batch_size=20, dtype=torch.bfloat16, device="cuda" + ) quantize_(m, int8_dynamic_activation_int8_weight()) assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor) assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor) - assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor) - assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor) + assert isinstance( + m.linear1.weight.original_weight_tensor, AffineQuantizedTensor + ) + assert isinstance( + m.linear2.weight.original_weight_tensor, AffineQuantizedTensor + ) # reference _ref_change_linear_weights_to_int8_dqtensors(m_copy) @@ -601,6 +663,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): # workaround for export path from torchao.utils import unwrap_tensor_subclass + m_unwrapped = unwrap_tensor_subclass(m) m = torch.export.export(m_unwrapped, example_inputs).module() @@ -630,12 +693,10 @@ def test_quantized_tensor_subclass_save_load(self): res = m_copy(*example_inputs) self.assertEqual(res, ref) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_int8wo_quantized_model_to_device(self): m = ToyLinearModel().eval().to(torch.bfloat16) - m_copy = copy.deepcopy(m) example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cpu") quantize_(m, int8_weight_only()) @@ -654,7 +715,6 @@ def test_int4wo_quantized_model_to_device(self): devices = ["cuda", "cuda:0"] for device in devices: m = ToyLinearModel().eval().to(torch.bfloat16).to(device) - m_copy = copy.deepcopy(m) example_inputs = m.example_inputs(dtype=torch.bfloat16, device=device) quantize_(m, int4_weight_only()) @@ -678,7 +738,7 @@ def test_quantized_tensor_subclass_save_load_map_location(self): f.seek(0) state_dict = torch.load(f.name, map_location="cpu", mmap=True) - with torch.device('meta'): + with torch.device("meta"): m_copy = ToyLinearModel().eval() m_copy.load_state_dict(state_dict, assign=True) @@ -710,12 +770,13 @@ def reset_memory(): assert param.is_cuda self.assertLess(memory_streaming, memory_baseline) -class TestMultiTensorFlow(TestCase): +class TestMultiTensorFlow(TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_add_tensors(self): from torchao.quantization.GPTQ_MT import MultiTensor + tensor1 = torch.randn(3, 3) tensor2 = torch.randn(3, 3) mt = MultiTensor(tensor1) @@ -728,6 +789,7 @@ def test_multitensor_add_tensors(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_pad_unpad(self): from torchao.quantization.GPTQ_MT import MultiTensor + tensor1 = torch.randn(3, 3) mt = MultiTensor(tensor1) mt.pad_to_length(3) @@ -739,14 +801,13 @@ def test_multitensor_pad_unpad(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_inplace_operation(self): from torchao.quantization.GPTQ_MT import MultiTensor + tensor1 = torch.ones(3, 3) mt = MultiTensor(tensor1) mt += 1 # In-place addition self.assertTrue(torch.equal(mt.values[0], torch.full((3, 3), 2))) - - common_utils.instantiate_parametrized_tests(TestQuantFlow) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 4e0663eb87..a3fef29fea 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -7,25 +7,27 @@ # mypy: ignore-errors # This test takes a long time to run import unittest + import torch + +from torchao.dtypes.utils import is_device from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, + choose_qparams_affine, + dequantize_affine, fake_quantize_affine, fake_quantize_affine_cachemask, quantize_affine, - dequantize_affine, - choose_qparams_affine, - MappingType, - ZeroPointDomain, ) + # TODO: remove test for utils? from torchao.quantization.utils import ( get_group_qparams_symmetric, - get_groupwise_affine_qparams, - groupwise_affine_quantize_tensor_from_qparams, groupwise_affine_dequantize_tensor_from_qparams, + groupwise_affine_quantize_tensor_from_qparams, quantize_activation_per_token_absmax, ) - from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, @@ -37,6 +39,7 @@ _SEED = 1234 torch.manual_seed(_SEED) + # Helper function to run a function twice # and verify that the result is the same. # Adds some verification to avoid side effects. @@ -47,9 +50,12 @@ def check_idempotent(self, fn, *args, **kwargs): output0 = fn(*args, **kwargs) assert torch.is_tensor(output0) output1 = fn(*args, **kwargs) - self.assertTrue(torch.equal(output0, output1), f"Expected given function {fn} to be idempotent.") + self.assertTrue( + torch.equal(output0, output1), f"Expected given function {fn} to be idempotent." + ) return output1 + # Legacy tinygemm ops def _get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16): if groupsize > w.shape[-1]: @@ -70,6 +76,7 @@ def _get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat1 dtype=dtype ).reshape(w.shape[0], -1) + def _groupwise_affine_quantize_tensor_from_qparams( w, scales, @@ -102,10 +109,12 @@ def _groupwise_affine_quantize_tensor_from_qparams( .reshape_as(w) ) if TORCH_VERSION_AT_LEAST_2_5: - w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) + if not (is_device(w.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6): + w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) return w_int4x8 + def _groupwise_affine_dequantize_tensor_from_qparams( w_int4x8, scales, @@ -136,7 +145,9 @@ def _groupwise_affine_dequantize_tensor_from_qparams( class TestQuantPrimitives(unittest.TestCase): SEED = 123 - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower" + ) def test_get_group_qparams_symmetric(self): """ Test that `get_group_qparams_symmetric` produces the exact same scales as @@ -145,7 +156,6 @@ def test_get_group_qparams_symmetric(self): n_bit = 4 qmin = -(2 ** (n_bit - 1)) qmax = 2 ** (n_bit - 1) - 1 - eps = torch.finfo(torch.float32).eps groupsize = 256 torch.manual_seed(self.SEED) weight = torch.randn(100, 256).to(torch.float16) @@ -158,14 +168,16 @@ def test_get_group_qparams_symmetric(self): quant_max=qmax, # This is needed to ensure `min_val` and `max_val` are fp16, # otherwise they default to fp32 and the qparams will be slightly off - factory_kwargs={"dtype": torch.float16} + factory_kwargs={"dtype": torch.float16}, ) obs(weight) (scale_obs, _) = obs.calculate_qparams() scale_obs = scale_obs.reshape(weight.shape[0], -1) # assert that scales are identical - (scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize, precision=torch.float16) + (scale_ao, _) = get_group_qparams_symmetric( + weight, n_bit, groupsize, precision=torch.float16 + ) torch.testing.assert_close(scale_obs, scale_ao, rtol=0, atol=0) def test_choose_qparams_group_sym(self): @@ -178,9 +190,19 @@ def test_choose_qparams_group_sym(self): block_size = (1, 2) eps = torch.finfo(torch.float32).eps precision = torch.float32 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps, scale_dtype=precision, zero_point_dtype=precision) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + eps=eps, + scale_dtype=precision, + zero_point_dtype=precision, + ) - scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2, precision=precision, mapping_type=mapping_type) + scale_ref, zp_ref = get_group_qparams_symmetric( + input, n_bit=8, groupsize=2, precision=precision, mapping_type=mapping_type + ) self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) @@ -195,13 +217,26 @@ def test_choose_qparams_group_sym_no_clipping_err(self): block_size = (1, 2) eps = torch.finfo(torch.float32).eps precision = torch.float32 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps, scale_dtype=precision, zero_point_dtype=precision) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + eps=eps, + scale_dtype=precision, + zero_point_dtype=precision, + ) - scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2, precision=precision, mapping_type=mapping_type) + scale_ref, zp_ref = get_group_qparams_symmetric( + input, n_bit=8, groupsize=2, precision=precision, mapping_type=mapping_type + ) self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower") + + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower" + ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_choose_qparams_token_asym(self): input = torch.randn(10, 10) @@ -209,11 +244,29 @@ def test_choose_qparams_token_asym(self): dtype = torch.int8 block_size = (1, 10) if TORCH_VERSION_AT_LEAST_2_6: - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float64, zero_point_dtype=torch.int64) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float64, + zero_point_dtype=torch.int64, + ) else: - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + eps=torch.finfo(torch.float32).eps, + ) - scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(input, dtype) + scale_ref, zp_ref = ( + torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric( + input, dtype + ) + ) scale_ref = scale_ref.squeeze() zp_ref = zp_ref.squeeze() @@ -227,12 +280,15 @@ def test_choose_qparams_tensor_asym(self): dtype = torch.int8 block_size = (10, 10) eps = torch.finfo(torch.float32).eps - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps) - + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=eps + ) quant_min = -128 quant_max = 127 - scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams(input, quant_min, quant_max, eps, dtype) + scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams( + input, quant_min, quant_max, eps, dtype + ) scale_ref = scale_ref.squeeze() zp_ref = zp_ref.squeeze() @@ -246,18 +302,24 @@ def test_choose_qparams_tensor_sym(self): dtype = torch.int8 block_size = (10, 10) eps = torch.finfo(torch.float32).eps - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=eps + ) quant_min = -128 quant_max = 127 - scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_symmetric(input, quant_min, quant_max, eps, dtype) + scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_symmetric( + input, quant_min, quant_max, eps, dtype + ) scale_ref = scale_ref.squeeze() zp_ref = zp_ref.squeeze() self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_quantize_activation_per_token_abs_max(self): input = torch.randn(10, 10) quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) @@ -270,21 +332,35 @@ def test_quantize_activation_per_token_abs_max(self): eps = 1e-5 quant_min = -127 quant_max = 127 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, quant_min, quant_max, eps=eps, scale_dtype=torch.float) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps=eps, + scale_dtype=torch.float, + ) - quantized = quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max) + quantized = quantize_affine( + input, block_size, scale, zero_point, dtype, quant_min, quant_max + ) self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(scale, scale_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_quantize_activation_per_token_abs_max_zero_input(self): input = torch.zeros(10, 10) # make sure it still works quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_quantize_activation_per_token_abs_max_dtype(self): input = torch.zeros(10, 10, dtype=torch.bfloat16) quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) @@ -298,18 +374,30 @@ def test_quantize_activation_per_token_abs_max_dtype(self): quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) self.assertTrue(scale_ref.dtype, torch.float32) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_group_sym(self): input = torch.randn(10, 10) mapping_type = MappingType.SYMMETRIC dtype = torch.int8 block_size = (1, 2) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps + ) quantized = quantize_affine(input, block_size, scale, zero_point, dtype) - dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) + dequantized = check_idempotent( + self, + dequantize_affine, + quantized, + block_size, + scale, + zero_point, + dtype, + output_dtype=torch.float32, + ) group_size = 2 quant_min = -128 @@ -318,23 +406,43 @@ def test_quantize_dequantize_group_sym(self): input, scale, zero_point, quant_min, quant_max, torch.int8, group_size ) dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel_group( - quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, group_size, output_dtype=torch.float32 + quantized_ref, + scale, + zero_point, + quant_min, + quant_max, + torch.int8, + group_size, + output_dtype=torch.float32, ) self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_channel_asym(self): input = torch.randn(10, 10) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (10, 1) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps + ) output_dtype = torch.float32 quantized = quantize_affine(input, block_size, scale, zero_point, dtype) - dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype) + dequantized = check_idempotent( + self, + dequantize_affine, + quantized, + block_size, + scale, + zero_point, + dtype, + output_dtype=output_dtype, + ) axis = 1 quant_min = -128 @@ -343,12 +451,21 @@ def test_quantize_dequantize_channel_asym(self): input, scale, zero_point, axis, quant_min, quant_max, torch.int8 ) dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel( - quantized_ref, scale, zero_point, axis, quant_min, quant_max, torch.int8, out_dtype=output_dtype + quantized_ref, + scale, + zero_point, + axis, + quant_min, + quant_max, + torch.int8, + out_dtype=output_dtype, ) self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_tensor_asym(self): input = torch.randn(10, 10) @@ -356,32 +473,61 @@ def test_quantize_dequantize_tensor_asym(self): dtype = torch.int8 block_size = (10, 10) output_dtype = torch.float32 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps + ) quantized = quantize_affine(input, block_size, scale, zero_point, dtype) - dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype) + dequantized = check_idempotent( + self, + dequantize_affine, + quantized, + block_size, + scale, + zero_point, + dtype, + output_dtype=output_dtype, + ) - axis = 1 quant_min = -128 quant_max = 127 quantized_ref = torch.ops.quantized_decomposed.quantize_per_tensor( input, scale, zero_point, quant_min, quant_max, torch.int8 ) dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_tensor( - quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, out_dtype=output_dtype + quantized_ref, + scale, + zero_point, + quant_min, + quant_max, + torch.int8, + out_dtype=output_dtype, ) self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_channel_asym_4d(self): input = torch.randn(3, 3, 10, 10) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (3, 3, 1, 10) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps + ) quantized = quantize_affine(input, block_size, scale, zero_point, dtype) - dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) + dequantized = check_idempotent( + self, + dequantize_affine, + quantized, + block_size, + scale, + zero_point, + dtype, + output_dtype=torch.float32, + ) axis = 2 quant_min = -128 @@ -390,20 +536,40 @@ def test_quantize_dequantize_channel_asym_4d(self): input, scale, zero_point, axis, quant_min, quant_max, torch.int8 ) dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel( - quantized_ref, scale, zero_point, axis, quant_min, quant_max, torch.int8, out_dtype=torch.float32 + quantized_ref, + scale, + zero_point, + axis, + quant_min, + quant_max, + torch.int8, + out_dtype=torch.float32, ) self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower" + ) def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self): input = torch.randn(3, 3, 10, 10) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (3, 3, 2, 2) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps + ) quantized = quantize_affine(input, block_size, scale, zero_point, dtype) - dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) + dequantized = check_idempotent( + self, + dequantize_affine, + quantized, + block_size, + scale, + zero_point, + dtype, + output_dtype=torch.float32, + ) # we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float torch.testing.assert_close(dequantized, input, rtol=2, atol=0.02) @@ -412,11 +578,15 @@ def test_choose_qparams_tensor_asym_eps(self): mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (10, 10) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype + ) eps = torch.finfo(torch.float32).eps self.assertEqual(scale, eps) - @unittest.skipIf(not torch.cuda.is_available(), "skipping when cuda is not available") + @unittest.skipIf( + not torch.cuda.is_available(), "skipping when cuda is not available" + ) def test_get_group_qparams_symmetric_memory(self): """Check the memory usage of the op""" weight = torch.randn(1024, 1024).to(device="cuda") @@ -428,18 +598,20 @@ def test_get_group_qparams_symmetric_memory(self): self.assertTrue(after_choose_qparams_mem_use < 1.2 * original_mem_use) def test_raises(self): - """Make sure some errors are raised when user requested an unsupported type of quantization - """ + """Make sure some errors are raised when user requested an unsupported type of quantization""" input = torch.randn(10, 10) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (10, 10) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype) - + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype + ) # make sure we can't quantize int32 tensors: with self.assertRaisesRegex(AssertionError, "Unsupported input dtype:"): - _ = quantize_affine(input.to(torch.int32), block_size, scale, zero_point, dtype) + _ = quantize_affine( + input.to(torch.int32), block_size, scale, zero_point, dtype + ) # block_size and scale/zero_point shape mismatch block_size = (1, 1) @@ -458,7 +630,10 @@ def test_not_preserve_zero_not_supported(self): eps = 1e-6 scale_dtype = torch.bfloat16 zero_point_dtype = torch.bfloat16 - with self.assertRaisesRegex(ValueError, "preserve_zero == False is not supported for symmetric quantization"): + with self.assertRaisesRegex( + ValueError, + "preserve_zero == False is not supported for symmetric quantization", + ): choose_qparams_affine( input, mapping_type, @@ -472,11 +647,12 @@ def test_not_preserve_zero_not_supported(self): preserve_zero=False, ) - def test_get_groupwise_affine_qparams(self): input = torch.randn(10, 256) n_bit = 4 - scale_ref, zero_point_ref = _get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16) + scale_ref, zero_point_ref = _get_groupwise_affine_qparams( + input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16 + ) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 @@ -486,20 +662,19 @@ def test_get_groupwise_affine_qparams(self): eps = 1e-6 scale_dtype = torch.bfloat16 zero_point_dtype = torch.bfloat16 - scale, zero_point = \ - choose_qparams_affine( - input, - mapping_type, - block_size, - dtype, - quant_min, - quant_max, - eps, - scale_dtype=scale_dtype, - zero_point_dtype=zero_point_dtype, - preserve_zero=False, - zero_point_domain=ZeroPointDomain.FLOAT, - ) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=False, + zero_point_domain=ZeroPointDomain.FLOAT, + ) self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zero_point_ref)) @@ -511,8 +686,12 @@ def test_groupwise_affine_quantize_tensor_from_qparams(self): n_bit = 4 groupsize = 128 - w_int4x8 = groupwise_affine_quantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) - w_int4x8_ref = _groupwise_affine_quantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) + w_int4x8 = groupwise_affine_quantize_tensor_from_qparams( + input, scales, zeros, n_bit, groupsize + ) + w_int4x8_ref = _groupwise_affine_quantize_tensor_from_qparams( + input, scales, zeros, n_bit, groupsize + ) self.assertTrue(torch.equal(w_int4x8, w_int4x8_ref)) @@ -524,15 +703,25 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self): groupsize = 128 if TORCH_VERSION_AT_LEAST_2_5: - input_uint8 = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) - w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_uint8, scales, zeros, n_bit, groupsize) + input_tmp = input + if not (is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6): + input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) + w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( + input_tmp, scales, zeros, n_bit, groupsize + ) else: - w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) - w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) + w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( + input, scales, zeros, n_bit, groupsize + ) + w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams( + input, scales, zeros, n_bit, groupsize + ) self.assertTrue(torch.equal(w_bf16, w_bf16_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantize_affine(self): input = torch.randn(10, 10) @@ -544,14 +733,31 @@ def test_fake_quantize_affine(self): eps = 1e-5 quant_min = -127 quant_max = 127 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, quant_min, quant_max, eps=eps, scale_dtype=torch.float) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps=eps, + scale_dtype=torch.float, + ) - quantized = quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max) - dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, quant_min, quant_max) - fake_quantized = fake_quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max) + quantized = quantize_affine( + input, block_size, scale, zero_point, dtype, quant_min, quant_max + ) + dequantized = dequantize_affine( + quantized, block_size, scale, zero_point, dtype, quant_min, quant_max + ) + fake_quantized = fake_quantize_affine( + input, block_size, scale, zero_point, dtype, quant_min, quant_max + ) torch.testing.assert_close(dequantized, fake_quantized) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantize_affine_cachemask(self): input = torch.randn(10, 10) @@ -563,16 +769,36 @@ def test_fake_quantize_affine_cachemask(self): eps = 1e-5 quant_min = -127 quant_max = 127 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, quant_min, quant_max, eps=eps, scale_dtype=torch.float) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps=eps, + scale_dtype=torch.float, + ) - quantized = quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max) - dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, quant_min, quant_max) + quantized = quantize_affine( + input, block_size, scale, zero_point, dtype, quant_min, quant_max + ) + dequantized = dequantize_affine( + quantized, block_size, scale, zero_point, dtype, quant_min, quant_max + ) (fake_quantized, mask) = fake_quantize_affine_cachemask( - input, block_size, scale, zero_point, dtype, quant_min, quant_max, + input, + block_size, + scale, + zero_point, + dtype, + quant_min, + quant_max, ) expected_mask = torch.full(input.shape, True) torch.testing.assert_close(dequantized, fake_quantized) torch.testing.assert_close(expected_mask, mask) + if __name__ == "__main__": unittest.main() diff --git a/test/sparsity/test_fast_sparse_training.py b/test/sparsity/test_fast_sparse_training.py index 2779d37293..e3f5626d49 100644 --- a/test/sparsity/test_fast_sparse_training.py +++ b/test/sparsity/test_fast_sparse_training.py @@ -1,19 +1,18 @@ -import logging -import unittest import copy +import unittest import torch -import torch.nn.functional as F from torch import nn from torch.testing._internal.common_utils import TestCase from torchao.sparsity.training import ( + SemiSparseLinear, swap_linear_with_semi_sparse_linear, swap_semi_sparse_linear_with_linear, - SemiSparseLinear ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_fbcode + class ToyModel(nn.Module): def __init__(self): super().__init__() @@ -26,14 +25,16 @@ def forward(self, x): x = self.linear2(x) return x -class TestRuntimeSemiStructuredSparsity(TestCase): +class TestRuntimeSemiStructuredSparsity(TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(is_fbcode(), "broken in fbcode") + @unittest.skip("Temporarily skipping to unpin nightlies") def test_runtime_weight_sparsification(self): # need this import inside to not break 2.2 tests from torch.sparse import SparseSemiStructuredTensorCUSPARSELT + input = torch.rand((128, 128)).half().cuda() grad = torch.rand((128, 128)).half().cuda() model = ToyModel().half().cuda() @@ -41,7 +42,9 @@ def test_runtime_weight_sparsification(self): for name, mod in model.named_modules(): if isinstance(mod, torch.nn.Linear): - sparse = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(mod.weight.detach()).to_dense() + sparse = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort( + mod.weight.detach() + ).to_dense() mod.weight = nn.Parameter(sparse) dense_result = model(input) @@ -61,8 +64,12 @@ def test_runtime_weight_sparsification(self): sparse_result.backward(grad) # check grad - assert torch.allclose(model.linear1.weight.grad, model_c.linear1.weight.grad, rtol=1e-1, atol=1e-1) - assert torch.allclose(model.linear2.weight.grad, model_c.linear2.weight.grad, rtol=1e-1, atol=1e-1) + assert torch.allclose( + model.linear1.weight.grad, model_c.linear1.weight.grad, rtol=1e-1, atol=1e-1 + ) + assert torch.allclose( + model.linear2.weight.grad, model_c.linear2.weight.grad, rtol=1e-1, atol=1e-1 + ) # check that swap back works swap_semi_sparse_linear_with_linear(model_c) @@ -72,9 +79,11 @@ def test_runtime_weight_sparsification(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(is_fbcode(), "broken in fbcode") + @unittest.skip("Temporarily skipping to unpin nightlies") def test_runtime_weight_sparsification_compile(self): # need this import inside to not break 2.2 tests from torch.sparse import SparseSemiStructuredTensorCUSPARSELT + input = torch.rand((128, 128)).half().cuda() grad = torch.rand((128, 128)).half().cuda() model = ToyModel().half().cuda() @@ -82,7 +91,9 @@ def test_runtime_weight_sparsification_compile(self): for name, mod in model.named_modules(): if isinstance(mod, torch.nn.Linear): - sparse = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(mod.weight.detach()).to_dense() + sparse = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort( + mod.weight.detach() + ).to_dense() mod.weight = nn.Parameter(sparse) model = torch.compile(model, fullgraph=True) @@ -104,8 +115,12 @@ def test_runtime_weight_sparsification_compile(self): sparse_result.backward(grad) # check grad - assert torch.allclose(model.linear1.weight.grad, model_c.linear1.weight.grad, rtol=1e-1, atol=1e-1) - assert torch.allclose(model.linear2.weight.grad, model_c.linear2.weight.grad, rtol=1e-1, atol=1e-1) + assert torch.allclose( + model.linear1.weight.grad, model_c.linear1.weight.grad, rtol=1e-1, atol=1e-1 + ) + assert torch.allclose( + model.linear2.weight.grad, model_c.linear2.weight.grad, rtol=1e-1, atol=1e-1 + ) # check that swap back works swap_semi_sparse_linear_with_linear(model_c) diff --git a/test/sparsity/test_marlin.py b/test/sparsity/test_marlin.py index 173afd7dab..4da7304a24 100644 --- a/test/sparsity/test_marlin.py +++ b/test/sparsity/test_marlin.py @@ -1,28 +1,24 @@ -import torch import copy -import pytest +import pytest +import torch from torch import nn from torch.testing._internal.common_utils import TestCase, run_tests -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 + from torchao.dtypes import MarlinSparseLayout -from torchao.sparsity.sparse_api import apply_fake_sparsity from torchao.quantization.quant_api import int4_weight_only, quantize_ -from torchao.sparsity.marlin import ( - pack_to_marlin_24, - unpack_from_marlin_24, - inject_24 -) from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, choose_qparams_affine, quantize_affine, - ZeroPointDomain, - MappingType, ) +from torchao.sparsity.marlin import inject_24, pack_to_marlin_24, unpack_from_marlin_24 +from torchao.sparsity.sparse_api import apply_fake_sparsity +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 class SparseMarlin24(TestCase): - def setUp(self): super().setUp() torch.manual_seed(0) @@ -53,7 +49,9 @@ def test_quant_sparse_marlin_layout_eager(self): quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout())) sparse_result = self.model(self.input) - assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close" + assert torch.allclose( + dense_result, sparse_result, atol=3e-1 + ), "Results are not close" @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") @@ -71,7 +69,9 @@ def test_quant_sparse_marlin_layout_compile(self): self.model.forward = torch.compile(self.model.forward, fullgraph=True) sparse_result = self.model(self.input) - assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close" + assert torch.allclose( + dense_result, sparse_result, atol=3e-1 + ), "Results are not close" @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") def test_pack_unpack_equivalence(self): @@ -94,9 +94,30 @@ def test_pack_unpack_equivalence(self): # Inject 2:4 sparsity mask w_24, _ = inject_24(w, *w.shape) - # Quantize weights - scales, zeros = choose_qparams_affine(w_24, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) - w_q_24 = quantize_affine(w_24, block_size, scales, zeros, target_dtype, quant_min, quant_max, zero_point_domain) + # Quantize weights + scales, zeros = choose_qparams_affine( + w_24, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain, + ) + w_q_24 = quantize_affine( + w_24, + block_size, + scales, + zeros, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) scales = scales.reshape(-1, w_q_24.shape[1]) # Test pack/unpack equivalence @@ -107,8 +128,12 @@ def test_pack_unpack_equivalence(self): q_w_comp, packed_scales, meta, shape, group_size, num_bits ) - assert torch.equal(w_q_24, unpacked_q_w), "Unpacked weights do not match original weights" - assert torch.equal(scales, unpacked_scales), "Unpacked scales do not match original scales" + assert torch.equal( + w_q_24, unpacked_q_w + ), "Unpacked weights do not match original weights" + assert torch.equal( + scales, unpacked_scales + ), "Unpacked scales do not match original scales" if __name__ == "__main__": diff --git a/test/sparsity/test_wanda.py b/test/sparsity/test_wanda.py index fcb94036aa..e02ea9822a 100644 --- a/test/sparsity/test_wanda.py +++ b/test/sparsity/test_wanda.py @@ -3,12 +3,13 @@ import torch from torch import nn -from torchao.sparsity import WandaSparsifier from torch.ao.pruning import FakeSparsity from torch.nn.utils.parametrize import is_parametrized from torch.testing._internal.common_pruning import SimpleLinear from torch.testing._internal.common_utils import TestCase +from torchao.sparsity import WandaSparsifier + logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) @@ -29,7 +30,9 @@ def test_prepare(self): assert hasattr(module.parametrizations["weight"][0], "mask") # Check parametrization exists and is correct assert is_parametrized(module, "weight") - assert type(module.parametrizations.weight[0]) == FakeSparsity + assert isinstance( + module.parametrizations.weight[0], FakeSparsity + ), "FakeSparsity not found" # check activation observer is present assert hasattr(module, "activation_post_process") @@ -110,5 +113,6 @@ def test_two_layer_mlp_unstructured(self): sparsifier.squash_mask() + if __name__ == "__main__": unittest.main() diff --git a/test/test_ops.py b/test/test_ops.py index 4d8104c25b..c5821eed44 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -463,6 +463,7 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto MARLIN_TEST_PARAMS, ids=str, ) +@pytest.mark.skip(reason="test outputs nan after cuda is upgraded to 12.4") def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors): int8_traits = torch.iinfo(torch.int8) m_factor, n_factor, k_factor = mnk_factors diff --git a/torchao/__init__.py b/torchao/__init__.py index b910af3d7e..dd3ddbc813 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -28,7 +28,7 @@ torch.ops.load_library(so_files[0]) from . import ops except: - logging.info("Skipping import of cpp extensions") + logging.debug("Skipping import of cpp extensions") from torchao.quantization import ( autoquant, diff --git a/torchao/_executorch_ops.py b/torchao/_executorch_ops.py index 6a1a66ab77..3cf94ee53d 100644 --- a/torchao/_executorch_ops.py +++ b/torchao/_executorch_ops.py @@ -10,9 +10,14 @@ def _quantized_decomposed_quantize_per_channel_group_wrapper(*args, **kwargs): in PyTorch 2.3+ and recently changed signatures. """ from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.quantize_per_channel_group(*args, **kwargs) - raise ImportError("Need torch.ops.quantized_decomposed.quantize_per_channel_group, which is only available with PyTorch 2.3 or later.") + return torch.ops.quantized_decomposed.quantize_per_channel_group( + *args, **kwargs + ) + raise ImportError( + "Need torch.ops.quantized_decomposed.quantize_per_channel_group, which is only available with PyTorch 2.3 or later." + ) def _quantized_decomposed_choose_qparams_per_token_asymmetric_wrapper(*args, **kwargs): @@ -24,9 +29,14 @@ def _quantized_decomposed_choose_qparams_per_token_asymmetric_wrapper(*args, **k in PyTorch 2.3+ and recently changed signatures. """ from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(*args, **kwargs) - raise ImportError("Need torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric, which is only available with PyTorch 2.3 or later.") + return torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric( + *args, **kwargs + ) + raise ImportError( + "Need torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric, which is only available with PyTorch 2.3 or later." + ) def _quantized_decomposed_dequantize_per_channel_group_wrapper(*args, **kwargs): @@ -38,9 +48,14 @@ def _quantized_decomposed_dequantize_per_channel_group_wrapper(*args, **kwargs): in PyTorch 2.3+ and recently changed signatures. """ from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.dequantize_per_channel_group(*args, **kwargs) - raise ImportError("Need torch.ops.quantized_decomposed.dequantize_per_channel_group, which is only available with PyTorch 2.3 or later.") + return torch.ops.quantized_decomposed.dequantize_per_channel_group( + *args, **kwargs + ) + raise ImportError( + "Need torch.ops.quantized_decomposed.dequantize_per_channel_group, which is only available with PyTorch 2.3 or later." + ) def _quantized_decomposed_quantize_per_token_wrapper(*args, **kwargs): @@ -52,9 +67,12 @@ def _quantized_decomposed_quantize_per_token_wrapper(*args, **kwargs): in PyTorch 2.3+ and recently changed signatures. """ from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + if TORCH_VERSION_AT_LEAST_2_3: return torch.ops.quantized_decomposed.quantize_per_token(*args, **kwargs) - raise ImportError("Need torch.ops.quantized_decomposed.quantize_per_token, which is only available with PyTorch 2.3 or later.") + raise ImportError( + "Need torch.ops.quantized_decomposed.quantize_per_token, which is only available with PyTorch 2.3 or later." + ) def _quantized_decomposed_dequantize_per_token_wrapper(*args, **kwargs): @@ -66,6 +84,9 @@ def _quantized_decomposed_dequantize_per_token_wrapper(*args, **kwargs): in PyTorch 2.3+ and recently changed signatures. """ from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + if TORCH_VERSION_AT_LEAST_2_3: return torch.ops.quantized_decomposed.dequantize_per_token(*args, **kwargs) - raise ImportError("Need torch.ops.quantized_decomposed.dequantize_per_token, which is only available with PyTorch 2.3 or later.") + raise ImportError( + "Need torch.ops.quantized_decomposed.dequantize_per_token, which is only available with PyTorch 2.3 or later." + ) diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index 63733c736d..c8cd4bf39c 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -52,7 +52,7 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --wr python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt @@ -62,7 +62,7 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --wr python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt @@ -79,3 +79,20 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 1 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 32 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 128 + +# TTFT benchmarks +export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8wo --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --sparsity semi-structured --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization float8dq --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization float8wo --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int4wo-64 --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization sparse-marlin --write_result benchmark_results.txt --prefill_size 8000 --precision float16 --sparsity semi-structured + +# 2:4 sparse model +export MODEL_REPO=nm-testing/SparseLlama-3-8B-pruned_50.2of4 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --sparsity semi-structured --precision float16 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 25b65cd1ec..43667487d8 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -10,27 +10,22 @@ from generate import ( _load_model, device_sync, - ) -from torchao.quantization.quant_api import ( +from torchao.quantization import ( quantize_, int4_weight_only, int8_weight_only, int8_dynamic_activation_int8_weight, fpx_weight_only, uintx_weight_only, - unwrap_tensor_subclass, float8_weight_only, float8_dynamic_activation_float8_weight, - float8_static_activation_float8_weight, ) -from torchao._models._eval import TransformerEvalWrapper, InputRecorder from torchao._models.llama.model import prepare_inputs_for_model -from torchao.quantization.granularity import PerRow, PerTensor - +from torchao.quantization import PerRow, PerTensor from tokenizer import get_tokenizer import time -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass def run_evaluation( checkpoint_path: Path, diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 1efa6b04b3..065cc9c56d 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -17,15 +17,40 @@ from torchao.quantization.quant_primitives import MappingType from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False + +class HostEvent: + def __init__(self): + self.event_time = None + + def record(self): + self.event_time = time.perf_counter() + + def elapsed_time(self, other_event): + if self.event_time is None: + raise ValueError("Event not recorded!") + # return ms to match cuda event + return abs(other_event.event_time - self.event_time) * 1000 + +def device_timer(device): + if "cuda" in device: + return torch.cuda.Event(enable_timing=True) + elif ("cpu" in device) or ("mps" in device): + return HostEvent() + else: + print(f"device={device} is not yet suppported") + def device_sync(device): if "cuda" in device: torch.cuda.synchronize(device) + elif "xpu" in device: + torch.xpu.synchronize(device) elif ("cpu" in device) or ("mps" in device): pass else: print(f"device={device} is not yet suppported") -default_device = 'cuda' if torch.cuda.is_available() else 'cpu' +default_device = 'cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu' # support running without installing as a package wd = Path(__file__).parent.parent.resolve() @@ -67,7 +92,7 @@ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tenso def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): new_tokens, new_probs = [], [] for i in range(num_new_tokens): - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): next_token, next_prob = decode_one_token( model, cur_token, input_pos, **sampling_kwargs ) @@ -96,6 +121,10 @@ def generate( kv_cache_quantization: bool = False, cache_size: Optional[int] = None, linear_causal_mask: bool=False, + prefill_start_event: Optional[torch.cuda.Event]=None, + prefill_end_event: Optional[torch.cuda.Event]=None, + decode_start_event: Optional[torch.cuda.Event]=None, + decode_end_event: Optional[torch.cuda.Event]=None, **sampling_kwargs ) -> torch.Tensor: """ @@ -126,12 +155,21 @@ def generate( model.setup_caches(max_batch_size=batch_size, max_seq_length=cache_size, kv_cache_quantization=kv_cache_quantization, linear_causal_mask=linear_causal_mask, prompt_length=T) # execute prefill + if prefill_start_event is not None: + prefill_start_event.record() next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs).clone() seq[:, T] = next_token.squeeze() + if prefill_end_event is not None: + prefill_end_event.record() + # execute token generation + if decode_start_event is not None: + decode_start_event.record() input_pos = torch.tensor([T], device=device, dtype=torch.int) generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs) seq = torch.cat((seq[:, :T+1], *generated_tokens), dim=-1) + if decode_end_event is not None: + decode_end_event.record() return seq @@ -155,6 +193,7 @@ def _load_model(checkpoint_path, device, precision): B_INST, E_INST = "[INST]", "[/INST]" def main( + prefill_size: Optional[int] = None, prompt: str = "Hello, my name is", interactive: bool = False, num_samples: int = 5, @@ -164,8 +203,7 @@ def main( temperature: float = 0.8, checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), quantization: Optional[str] = None, - calibration_limit: int = 10, - calibration_seq_length: int = 256, + sparsity: Optional[str] = None, kv_cache_quantization: bool = False, cache_size: Optional[int] = None, linear_causal_mask: bool=False, @@ -181,6 +219,10 @@ def main( """Generates text samples based on a pre-trained Transformer model and tokenizer. """ + if prefill_size is not None and prefill_size > 0: + # create prompt of prefill size + prompt = "prompt " * (int(prefill_size)-3) + torchao.quantization.utils.recommended_inductor_config_setter() assert checkpoint_path.is_file(), checkpoint_path @@ -205,29 +247,44 @@ def main( torch.manual_seed(1234) + def ffn_only(mod, fqn): + return isinstance(mod, torch.nn.Linear) and "feed_forward" in fqn + + def not_ffn_only(mod, fqn): + return isinstance(mod, torch.nn.Linear) and not ffn_only(mod, fqn) + + def ffn_or_attn_only(mod, fqn): + return isinstance(mod, torch.nn.Linear) and ("feed_forward" in fqn or "attention" in fqn) if quantization: - from torchao.quantization.quant_api import ( + from torchao.quantization import ( quantize_, + autoquant, int8_weight_only, int8_dynamic_activation_int8_weight, int4_weight_only, int8_dynamic_activation_int4_weight, fpx_weight_only, uintx_weight_only, - autoquant, - unwrap_tensor_subclass, float8_weight_only, float8_dynamic_activation_float8_weight, ) + from torchao.utils import unwrap_tensor_subclass + from torchao.quantization.granularity import PerTensor, PerRow + from torchao.utils import unwrap_tensor_subclass if "spinquant" in quantization: from torchao.prototype.spinquant import apply_spinquant apply_spinquant(model) if "int8wo" in quantization: quantize_(model, int8_weight_only()) if "int8dq" in quantization: - quantize_(model, int8_dynamic_activation_int8_weight()) + if sparsity and "semi" in sparsity: + from torchao.dtypes import SemiSparseLayout + quantize_(model, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), filter_fn=ffn_only) + quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=not_ffn_only) + else: + quantize_(model, int8_dynamic_activation_int8_weight()) if "int4wo" in quantization: if "hqq" in quantization: use_hqq=True @@ -248,14 +305,14 @@ def main( layout=MarlinQQQLayout(), ), ) - else: + elif "semi" in sparsity: from torchao.dtypes import MarlinSparseLayout - quantize_(model, int4_weight_only(layout=MarlinSparseLayout())) + quantize_(model, int4_weight_only(layout=MarlinSparseLayout()), filter_fn=ffn_or_attn_only) if "fp6" in quantization: quantize_(model, fpx_weight_only(3, 2)) - if "embed-int8wo" in quantization: + elif "embed-int8wo" in quantization: quantize_(model, int8_weight_only(group_size=64), filter_fn=lambda x, *args: isinstance(x, torch.nn.Embedding)) - if quantization.startswith("awq"): + elif quantization.startswith("awq"): from torchao._models._eval import TransformerEvalWrapper from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 from torchao.prototype.awq.example import get_calib_dataset @@ -268,21 +325,21 @@ def main( quant_dtype = getattr(torch, quant_dtype, torch.uint8) model=model.to(device) # get calibration data - insert_awq_observer_(model, calibration_limit, calibration_seq_length, quant_dtype=quant_dtype, group_size=group_size) + insert_awq_observer_(model, 1, 256, quant_dtype=quant_dtype, group_size=group_size) TransformerEvalWrapper( model=model.to(device), tokenizer=tokenizer, - max_seq_length=calibration_seq_length, + max_seq_length=256, input_prep_func=prepare_inputs_for_model, device=device, ).run_eval( - tasks=['wikitext'], - limit=calibration_limit, + tasks=['wikitext'], + limit=1, ) is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) use_hqq = "hqq" in quantization quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, use_hqq=use_hqq), is_observed_linear) - if "uintx" in quantization: + elif "uintx" in quantization: # uintx-nbits-group_size, e.g. "uintx-2-64" if "hqq" in quantization: # uintx-nbits-group_size-hqq @@ -296,9 +353,32 @@ def main( dtype = _NBITS_TO_DTYPE[nbits] group_size = int(_quant_args[2]) quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq)) - if "float8wo" in quantization: + elif "int8_dynamic_activation_intx_weight" in quantization: + from torchao.experimental.quant_api import int8_dynamic_activation_intx_weight + assert precision == torch.float32, "int8_dynamic_activation_intx_weight requires fp32 precision" + + # Build kernels in temp location, and load them in torch + # This requires an ARM CPU + from torchao.experimental.temp_build import temp_build_and_load_torchao_ops + temp_build_and_load_torchao_ops(cmake_lists_path=os.path.dirname(os.path.realpath(__file__)) + "/../../experimental") + + # Quantize model + _quant_args = quantization.split("-") + nbit = int(_quant_args[1]) + assert nbit >= 1 and nbit <= 8, "nbits must be 1 to 8" + group_size = int(_quant_args[2]) + has_weight_zeros = bool(_quant_args[3]) + quantize_( + model, + int8_dynamic_activation_intx_weight( + group_size=group_size, + nbit=nbit, + has_weight_zeros=has_weight_zeros, + ), + ) + elif "float8wo" in quantization: quantize_(model, float8_weight_only()) - if "float8dq" in quantization: + elif "float8dq" in quantization: granularity = str(quantization.split("-")[-1]) if granularity=="tensor": granularity = PerTensor() @@ -307,13 +387,96 @@ def main( else: granularity = PerTensor() quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity)) - if "autoquant" in quantization: + elif "autoquant_v2" in quantization: + from torchao.prototype.quantization.autoquant_v2 import autoquant_v2 + from torchao._models._eval import InputRecorder + from torchao._models.llama.model import prepare_inputs_for_model + + calibration_seq_length = 256 + calibration_limit = 1 + inputs = InputRecorder( + tokenizer, + calibration_seq_length, + prepare_inputs_for_model, + False, # pad_calibration_inputs + model.config.vocab_size, + device="cuda" + ).record_inputs( + ["wikitext"], + 1, + ).get_inputs()[0].values[0] + inputs = prepare_inputs_for_model(inputs) + with torch.device("cuda"): + model.setup_caches( + max_batch_size=1, max_seq_length=calibration_seq_length + ) + + if "autoquant_v2-int4" == quantization: + model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs, batch_size=calibration_seq_length) + elif "autoquant_v2-float8" == quantization: + model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs, batch_size=calibration_seq_length) + elif "autoquant_v2-fp" == quantization: + model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, example_input=inputs, batch_size=calibration_seq_length) + elif "autoquant_v2-all" == quantization: + all_qtensor_classes = torchao.prototype.quantization.autoquant_v2.DEFAULT_AUTOQUANT_CLASS_LIST + torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST + torchao.prototype.quantization.autoquant_v2.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + if torchao.utils.is_sm_89(): + # this is fp8 related subclasses, should rename + all_qtensor_classes += torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST + model = autoquant_v2(model, manual=True, qtensor_class_list = all_qtensor_classes, example_input=inputs, batch_size=calibration_seq_length) + else: + model = autoquant_v2(model, manual=True, example_input=inputs, batch_size=calibration_seq_length) + + print("running generate") + generate( + model, + encode_tokens(tokenizer, prompt, bos=True, device=device), + max_new_tokens, + batch_size, + interactive=False, + temperature=temperature, + top_k=top_k, + ) + + print("running finalize autoquant") + # do autoquantization + model.finalize_autoquant() + elif "autoquant" in quantization: + from torchao._models._eval import InputRecorder + from torchao._models.llama.model import prepare_inputs_for_model + + calibration_seq_length = 256 + calibration_limit = 1 + inputs = InputRecorder( + tokenizer, + calibration_seq_length, + prepare_inputs_for_model, + False, # pad_calibration_inputs + model.config.vocab_size, + device="cuda" + ).record_inputs( + ["wikitext"], + 1, + ).get_inputs()[0].values[0] + inputs = prepare_inputs_for_model(inputs) + with torch.device("cuda"): + model.setup_caches( + max_batch_size=1, max_seq_length=calibration_seq_length + ) + if "autoquant-int4" == quantization: - model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST) + model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs) elif "autoquant-float8" == quantization: - model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST) + model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs) + if "autoquant-fp" == quantization: + model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, example_input=inputs) + if "autoquant-all" == quantization: + all_qtensor_classes = torchao.quantization.DEFAULT_AUTOQUANT_CLASS_LIST + torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST + torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + if torchao.utils.is_sm_89(): + # this is fp8 related subclasses, should rename + all_qtensor_classes += torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST + model = autoquant(model, manual=True, qtensor_class_list = all_qtensor_classes, example_input=inputs) else: - model = autoquant(model, manual=True) + model = autoquant(model, manual=True, example_input=inputs) generate( model, @@ -327,10 +490,18 @@ def main( # do autoquantization model.finalize_autoquant() + else: if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(model) + # standalone sparsity + elif sparsity: + from torchao.sparsity import semi_sparse_weight, sparsify_ + if "semi" in sparsity: + #TODO there is a bug here, need to fix + sparsify_(model.to(device), semi_sparse_weight(), filter_fn=ffn_only) + model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9 if save: @@ -347,15 +518,27 @@ def main( prefill = torch.compile(prefill, fullgraph=True, dynamic=True) if memory_profile: - torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True) + if device == "cuda": + torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True) + elif device == "xpu": + torch.xpu.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True) + else: + print("Memory profiling only works on CUDA or XPU devices") + aggregate_metrics = { 'tokens_per_sec': [], + 'time': [], + 'decode_tokens_per_sec': [], + 'prefill_time': [], } start = -1 if compile else 0 for i in range(start, num_samples): if i==0: - torch.cuda.reset_peak_memory_stats() + if device == "cuda": + torch.cuda.reset_peak_memory_stats() # MKG + elif device == "xpu": + torch.xpu.reset_peak_memory_stats() # MKG device_sync(device=device) # MKG if i >= 0 and interactive: prompt = input("What is your prompt? ") @@ -381,6 +564,8 @@ def callback(x): else: callback = lambda x : x t0 = time.perf_counter() + prefill_start_event, prefill_end_event = device_timer(device), device_timer(device) + decode_start_event, decode_end_event = device_timer(device), device_timer(device) import contextlib if (i != num_samples - 1 or not profile): prof = contextlib.nullcontext() @@ -400,6 +585,10 @@ def callback(x): kv_cache_quantization=kv_cache_quantization, cache_size=cache_size, linear_causal_mask=linear_causal_mask, + prefill_start_event=prefill_start_event, + prefill_end_event=prefill_end_event, + decode_start_event=decode_start_event, + decode_end_event=decode_end_event, ) if i == -1: print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") @@ -409,7 +598,7 @@ def callback(x): device_sync(device=device) # MKG t = time.perf_counter() - t0 - if not interactive: + if not interactive and prefill_size is None: tok_list = y[0].tolist() # truncate text after end of string token tokens = tok_list if not tokenizer.eos_id() in tok_list else tok_list[:tok_list.index(tokenizer.eos_id())] @@ -419,11 +608,24 @@ def callback(x): tokens_generated = (y.size(-1) - prompt_length) tokens_sec = tokens_generated / t aggregate_metrics['tokens_per_sec'].append(tokens_sec) - print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") + aggregate_metrics['time'].append(t) + decode_time = decode_start_event.elapsed_time(decode_end_event) / 1000 + decode_tokens_sec = tokens_generated / decode_time + aggregate_metrics['decode_tokens_per_sec'].append(decode_tokens_sec) + prefill_time = prefill_start_event.elapsed_time(prefill_end_event) / 1000 + aggregate_metrics['prefill_time'].append(prefill_time) + print(f"Sample {i+1} | overall time {t:.04f} s {tokens_sec:.02f} tokens/sec", + f"| prefill time {prefill_time:.04f} s decode {decode_tokens_sec:.02f} tokens/sec") print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s") if memory_profile and i==0: - snapshot = torch.cuda.memory._snapshot() + if device == "cuda": + snapshot = torch.cuda.memory._snapshot() + elif device == "xpu": + snapshot = torch.xpu.memory._snapshot() + else: + print("Memory profiling only works on CUDA or XPU devices") + with open(f"{memory_profile}.pickle", 'wb') as f: from pickle import dump dump(snapshot, f) @@ -432,12 +634,21 @@ def callback(x): "python pytorch/torch/cuda/_memory_viz.py trace_plot -o .html" ) break - print("==========") + #ignore first sample for warmup tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item() + ttft = torch.mean(torch.tensor(aggregate_metrics['prefill_time'])).item() + decode_tokpersec = torch.mean(torch.tensor(aggregate_metrics['decode_tokens_per_sec'])).item() bandwidth = model_size * tokpersec mem = torch.cuda.max_memory_reserved() /1e9 + print(f"Average overall tokens/sec: {tokpersec:.2f}") + print(f"Average decode tokens/sec: {decode_tokens_sec:.04f} s") + print(f"Average TTFT: {ttft:.04f} s") + if device == "cuda": + mem = torch.cuda.max_memory_reserved() /1e9 + elif device == "xpu": + mem = torch.xpu.max_memory_reserved() /1e9 print(f"Average tokens/sec: {tokpersec:.2f}") if batch_size > 1: print(f"Average tokens/sec including batches {batch_size*tokpersec:.2f}") @@ -445,15 +656,17 @@ def callback(x): print(f"Peak Memory Usage: {mem:.02f} GB") print(f"Model Size: {model_size:.02f} GB") if write_result: - result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB " - result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} " + result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, tok/s_decode={decode_tokpersec:6.2f}, ttft={ttft:5.4f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB " + result_txt += f"quant: {quantization}, sparse: {sparsity}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} " result_txt += f"repro: python generate.py " result_txt += f"--quantization {quantization} " if quantization else "" + result_txt += f"--sparsity {sparsity} " if sparsity else "" result_txt += f"--checkpoint_path {checkpoint_path} " result_txt += f"--device {device} " result_txt += f"--precision {precision} " result_txt += f"--compile " if compile else "" result_txt += f"--compile_prefill " if compile_prefill else "" + result_txt += f"--prefill_size {prefill_size}" if prefill_size else "" result_txt += f"--profile {profile} " if profile else "" result_txt += f"--profile {memory_profile} " if memory_profile else "" result_txt += f"--interactive " if interactive else "" @@ -475,7 +688,7 @@ def callback(x): if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Your CLI description.') - + parser.add_argument('--prefill_size', type=int, default=0, help='Whether to run in ttft mode') parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') @@ -484,15 +697,18 @@ def callback(x): parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') - parser.add_argument('-q', '--quantization', type=str, + parser.add_argument('-q', '--quantization', type=str, help=( 'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--hqq, autoquant, ' +'autoquant-int4, autoquant-float8, uintx--, uintx---hqq, sparse-marlin, spinquant, ' +'embed-int8wo, marlin_qqq' ) ) - parser.add_argument("--calibration_limit", type=int, default=10, help="Number of calibration examples") - parser.add_argument("--calibration_seq_length", type=int, default=256, help="Sequence length for calibration") + parser.add_argument('-s', '--sparsity', type=str, + help=( + 'Which sparsity techniques to apply: semi-structured' + ) + ) parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size') parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)') @@ -507,6 +723,6 @@ def callback(x): args = parser.parse_args() main( - args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k, - args.temperature, args.checkpoint_path, args.quantization, args.calibration_limit, args.calibration_seq_length, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result + args.prefill_size, args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k, + args.temperature, args.checkpoint_path, args.quantization, args.sparsity, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result ) diff --git a/torchao/_models/llama/perf_profile.py b/torchao/_models/llama/perf_profile.py index 1a0d4e36c0..f613982221 100644 --- a/torchao/_models/llama/perf_profile.py +++ b/torchao/_models/llama/perf_profile.py @@ -2,9 +2,9 @@ ## Performance Profiling Example -An minimal version of `gpt-fast generate.py` that demonstrates usage of `torchao.profiler.TransformerPerformanceCounter`. +An minimal version of `gpt-fast generate.py` that demonstrates usage of `torchao.prototype.profiler.TransformerPerformanceCounter`. - Outputs from gpt-fast are prefixed with GPT-Fast -- Outputs from `torchao.profiler.TransformerPerformanceCounter` are prefixed with `TransformerPerfCounter`. +- Outputs from `torchao.prototype.profiler.TransformerPerformanceCounter` are prefixed with `TransformerPerfCounter`. ## Usage ```python @@ -118,7 +118,7 @@ from torchao._models.llama.model import Transformer from torchao._models.llama.tokenizer import get_tokenizer -from torchao.profiler import ( +from torchao.prototype.profiler import ( CUDADeviceSpec, TransformerPerformanceCounter, total_model_params, diff --git a/torchao/_models/sam/README.md b/torchao/_models/sam/README.md index 426d7fe6a8..0039d7f4d6 100644 --- a/torchao/_models/sam/README.md +++ b/torchao/_models/sam/README.md @@ -17,5 +17,7 @@ sh setup.sh Finally, you can run benchmarks with ``` -sh benchmark_sam.sh +sh benchmark.sh ``` + +You can check out the result in results.csv diff --git a/torchao/_models/sam/eval_combo.py b/torchao/_models/sam/eval_combo.py index cb3f1afb9b..9c05d00b26 100644 --- a/torchao/_models/sam/eval_combo.py +++ b/torchao/_models/sam/eval_combo.py @@ -9,7 +9,14 @@ import time import resource -from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, int4_weight_only +import torchao +from torchao.quantization import ( + quantize_, + int8_dynamic_activation_int8_weight, + int4_weight_only, + autoquant, +) +from torchao.prototype.quantization.autoquant_v2 import autoquant_v2 from torchao.sparsity import sparsify_, apply_fake_sparsity, semi_sparse_weight from torchao.dtypes import SemiSparseLayout, MarlinSparseLayout from torchao.utils import unwrap_tensor_subclass @@ -336,6 +343,29 @@ def mlp_only(mod, name): mlp_lin2_only) if not TORCH_VERSION_AT_LEAST_2_5: predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder) + + elif compress is not None and "autoquant_v2" in compress: + example_input = torch.randn(1, 3, 1024, 1024, dtype=torch.bfloat16, device=device) + if "autoquant_v2-int4" == compress: + autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST) + elif "autoquant_v2-float8" == compress: + autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST) + else: + autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True) + + predictor.model.image_encoder(example_input) + predictor.model.image_encoder.finalize_autoquant() + + elif compress is not None and "autoquant" in compress: + example_input = torch.randn(1, 3, 1024, 1024, dtype=torch.bfloat16, device=device) + if "autoquant-int4" == compress: + autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST) + elif "autoquant-float8" == compress: + autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST) + else: + autoquant(predictor.model.image_encoder, example_input=example_input, manual=True) + predictor.model.image_encoder(example_input) + predictor.model.image_encoder.finalize_autoquant() else: assert compress is None, f"Unsupported compress mode {compress}" diff --git a/torchao/_models/sam2/automatic_mask_generator.py b/torchao/_models/sam2/automatic_mask_generator.py index db5a14635d..891a2602ba 100644 --- a/torchao/_models/sam2/automatic_mask_generator.py +++ b/torchao/_models/sam2/automatic_mask_generator.py @@ -36,7 +36,7 @@ ) -class SAM2AutomaticMaskGenerator: +class SAM2AutomaticMaskGenerator(torch.nn.Module): def __init__( self, model: SAM2Base, @@ -105,7 +105,7 @@ def __init__( use_m2m (bool): Whether to add a one step refinement using previous mask predictions. multimask_output (bool): Whether to output multimask at each point of the grid. """ - + super().__init__() assert (points_per_side is None) != ( point_grids is None ), "Exactly one of points_per_side or point_grid must be provided." @@ -381,7 +381,6 @@ def _process_crop_batch( i = 0 batch_features = self.predictor._features all_crop_data = [] - all_all_batch_iterator_data = [] for (cropped_im, crop_box, layer_idx, orig_size) in zip(all_cropped_im, all_crop_box, all_layer_idx, all_orig_size): cropped_im_size = cropped_im.shape[:2] self.predictor.reset_predictor() @@ -425,9 +424,6 @@ def _process_crop_batch( data = self._process_batch_fullgraph(points, im_size, crop_box, crop_box_torch, orig_size, normalize, orig_box_torch) all_batch_iterator_data.append(data) self.predictor.reset_predictor() - all_all_batch_iterator_data.append(all_batch_iterator_data) - - for all_batch_iterator_data in all_all_batch_iterator_data: result_data = None with torch.autograd.profiler.record_function("all mask_to_rle_pytorch_2"): diff --git a/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml b/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml index cbee3cf9b3..42cd897c67 100644 --- a/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +++ b/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml @@ -2,18 +2,18 @@ # Model model: - _target_: sam2.modeling.sam2_base.SAM2Base + _target_: torchao._models.sam2.modeling.sam2_base.SAM2Base image_encoder: - _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + _target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: - _target_: sam2.modeling.backbones.hieradet.Hiera + _target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera embed_dim: 112 num_heads: 2 neck: - _target_: sam2.modeling.backbones.image_encoder.FpnNeck + _target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null @@ -24,17 +24,17 @@ model: fpn_interp_model: nearest memory_attention: - _target_: sam2.modeling.memory_attention.MemoryAttention + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: - _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] embedding_dim: 256 @@ -45,7 +45,7 @@ model: pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] rope_k_repeat: True @@ -57,23 +57,23 @@ model: num_layers: 4 memory_encoder: - _target_: sam2.modeling.memory_encoder.MemoryEncoder + _target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: - _target_: sam2.modeling.memory_encoder.MaskDownSampler + _target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: - _target_: sam2.modeling.memory_encoder.Fuser + _target_: torchao._models.sam2.modeling.memory_encoder.Fuser layer: - _target_: sam2.modeling.memory_encoder.CXBlock + _target_: torchao._models.sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 diff --git a/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_s.yaml b/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_s.yaml index 8e803dfea5..898898b158 100644 --- a/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_s.yaml +++ b/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_s.yaml @@ -2,21 +2,21 @@ # Model model: - _target_: sam2.modeling.sam2_base.SAM2Base + _target_: torchao._models.sam2.modeling.sam2_base.SAM2Base image_encoder: - _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + _target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: - _target_: sam2.modeling.backbones.hieradet.Hiera + _target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera embed_dim: 96 num_heads: 1 stages: [1, 2, 11, 2] global_att_blocks: [7, 10, 13] window_pos_embed_bkg_spatial_size: [7, 7] neck: - _target_: sam2.modeling.backbones.image_encoder.FpnNeck + _target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null @@ -27,17 +27,17 @@ model: fpn_interp_model: nearest memory_attention: - _target_: sam2.modeling.memory_attention.MemoryAttention + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: - _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] embedding_dim: 256 @@ -48,7 +48,7 @@ model: pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] rope_k_repeat: True @@ -60,23 +60,23 @@ model: num_layers: 4 memory_encoder: - _target_: sam2.modeling.memory_encoder.MemoryEncoder + _target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: - _target_: sam2.modeling.memory_encoder.MaskDownSampler + _target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: - _target_: sam2.modeling.memory_encoder.Fuser + _target_: torchao._models.sam2.modeling.memory_encoder.Fuser layer: - _target_: sam2.modeling.memory_encoder.CXBlock + _target_: torchao._models.sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 diff --git a/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_t.yaml b/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_t.yaml index 983c2ea031..c6318f843b 100644 --- a/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_t.yaml +++ b/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_t.yaml @@ -2,21 +2,21 @@ # Model model: - _target_: sam2.modeling.sam2_base.SAM2Base + _target_: torchao._models.sam2.modeling.sam2_base.SAM2Base image_encoder: - _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + _target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: - _target_: sam2.modeling.backbones.hieradet.Hiera + _target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera embed_dim: 96 num_heads: 1 stages: [1, 2, 7, 2] global_att_blocks: [5, 7, 9] window_pos_embed_bkg_spatial_size: [7, 7] neck: - _target_: sam2.modeling.backbones.image_encoder.FpnNeck + _target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null @@ -27,17 +27,17 @@ model: fpn_interp_model: nearest memory_attention: - _target_: sam2.modeling.memory_attention.MemoryAttention + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: - _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] embedding_dim: 256 @@ -48,7 +48,7 @@ model: pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] rope_k_repeat: True @@ -60,23 +60,23 @@ model: num_layers: 4 memory_encoder: - _target_: sam2.modeling.memory_encoder.MemoryEncoder + _target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: - _target_: sam2.modeling.memory_encoder.MaskDownSampler + _target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: - _target_: sam2.modeling.memory_encoder.Fuser + _target_: torchao._models.sam2.modeling.memory_encoder.Fuser layer: - _target_: sam2.modeling.memory_encoder.CXBlock + _target_: torchao._models.sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 diff --git a/torchao/_models/sam2/modeling/sam/mask_decoder.py b/torchao/_models/sam2/modeling/sam/mask_decoder.py index cbd6bb40b5..cba29e414b 100644 --- a/torchao/_models/sam2/modeling/sam/mask_decoder.py +++ b/torchao/_models/sam2/modeling/sam/mask_decoder.py @@ -219,6 +219,10 @@ def predict_masks( # TODO: Not specifying scale kwarg in SDPA will cause NaN here # print("hs.isnan().any(): ", hs.isnan().any().item()) + # TODO: These outputs are being immediately indexed. + # Is there something to remove? + # TODO: The fact that there's a crop box and we try to find stuff at the + # boundary later and there's generally cropping going on smells of padding. iou_token_out = hs[:, s, :] mask_tokens_out = hs[:, s + 1: (s + 1 + self.num_mask_tokens), :] diff --git a/torchao/_models/sam2/sam2_image_predictor.py b/torchao/_models/sam2/sam2_image_predictor.py index e451236624..f404fe00e4 100644 --- a/torchao/_models/sam2/sam2_image_predictor.py +++ b/torchao/_models/sam2/sam2_image_predictor.py @@ -17,7 +17,7 @@ from torchao._models.sam2.utils.transforms import SAM2Transforms -class SAM2ImagePredictor: +class SAM2ImagePredictor(torch.nn.Module): def __init__( self, sam_model: SAM2Base, @@ -66,6 +66,7 @@ def __init__( ] self._image_dtype = torch.float32 + self._transforms_device = "cpu" @classmethod def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor": @@ -110,10 +111,10 @@ def set_image( raise NotImplementedError("Image format not supported") input_image = self._transforms.to_tensor(image) - # TODO: Doing these transforms on the GPU changes the numerics + # NOTE: Doing these transforms on the GPU changes the numerics + input_image = input_image.to(device=self._transforms_device) input_image = self._transforms.transforms(input_image) input_image = input_image.to(device=self.device) - # TODO: Doing this here instead causes masks to not match reference exactly # input_image = self._transforms.transforms(input_image) input_image = input_image[None, ...].to(dtype=self._image_dtype) @@ -167,8 +168,10 @@ def set_image_batch( len(img_batch.shape) == 4 and img_batch.shape[1] == 3 ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}" logging.info("Computing image embeddings for the provided images...") - backbone_out = self.model.forward_image(img_batch) - _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + with torch.autograd.profiler.record_function("forward_image"): + backbone_out = self.model.forward_image(img_batch) + with torch.autograd.profiler.record_function("_prepare_backbone_features"): + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos if self.model.directly_add_no_mem_embed: vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed @@ -462,11 +465,11 @@ def _predict_masks_postprocess(self, low_res_masks, img_idx, return_logits, chan # Upscale the masks to the original image resolution if channel_1: masks = self._transforms.postprocess_masks_1_channel( - low_res_masks, self._orig_hw[img_idx] + low_res_masks, self._orig_hw[img_idx], self._image_dtype ) else: masks = self._transforms.postprocess_masks( - low_res_masks, self._orig_hw[img_idx] + low_res_masks, self._orig_hw[img_idx], self._image_dtype ) low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0) if not return_logits: diff --git a/torchao/_models/sam2/utils/amg.py b/torchao/_models/sam2/utils/amg.py index 640f811f48..cf52cae327 100644 --- a/torchao/_models/sam2/utils/amg.py +++ b/torchao/_models/sam2/utils/amg.py @@ -215,7 +215,8 @@ def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: return mask.transpose() # Put in C order -def _mask_to_rle_pytorch_2_0(tensor: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor): +@torch.compile(fullgraph=True, dynamic=True) +def _mask_to_rle_pytorch_2_0_0(tensor: torch.Tensor) -> (torch.Tensor, torch.Tensor): """ Encodes masks to an uncompressed RLE, in the format expected by pycoco tools. @@ -224,36 +225,57 @@ def _mask_to_rle_pytorch_2_0(tensor: torch.Tensor) -> (torch.Tensor, torch.Tenso b, h, w = tensor.shape tensor = tensor.permute(0, 2, 1).flatten(1) - with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2: change indices"): - # Compute change indices - diff = tensor[:, 1:] ^ tensor[:, :-1] - a = torch.tensor([[True]]) - if diff.is_cuda: - a = a.pin_memory().cuda() - # a = a.to(diff.device) - a = a.expand_as(diff.narrow(1, 0, 1)) - diff = torch.cat([a, diff, a], dim=1) + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + # a = torch.tensor([[True]]) + a = torch.ones((1, 1), dtype=bool, device=diff.device) + # if diff.is_cuda: + # a = a.pin_memory().cuda() + # # a = a.to(diff.device) + a = a.expand_as(diff.narrow(1, 0, 1)) + diff = torch.cat([a, diff, a], dim=1) + return diff + + +@torch.compile(fullgraph=True, dynamic=True) +def _mask_to_rle_pytorch_2_0_1(tensor: torch.Tensor, diff: torch.Tensor, change_indices: torch.Tensor) -> (torch.Tensor, torch.Tensor): + tensor = tensor.permute(0, 2, 1).flatten(1) + + alt_lens = diff.sum(dim=1) + + all_cur_idx = change_indices[:, 1] + if all_cur_idx.numel() == 0: + all_cur_idx_0 = all_cur_idx + all_cur_idx_1 = all_cur_idx + else: + all_cur_idx_0 = all_cur_idx.narrow(0, 1, all_cur_idx.size(0) - 1) + all_cur_idx_1 = all_cur_idx.narrow(0, 0, 1) + all_btw_idx = torch.cat([all_cur_idx_0, all_cur_idx_1]) + all_btw_idx = all_btw_idx - all_cur_idx + + alt_lens_nt = torch.nested.nested_tensor_from_jagged(all_btw_idx, lengths=alt_lens) + # Encode run length + counts_init = (tensor[:, 0] == 0) + return alt_lens_nt, counts_init + + +def _mask_to_rle_pytorch_2_0(tensor: torch.Tensor) -> RLEData: + b, h, w = tensor.shape + with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2: _mask_to_rle_pytorch_2_0_0"): + diff = _mask_to_rle_pytorch_2_0_0(tensor) + with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2: nonzero"): if diff.numel() > 2147483646: num_chunks = (diff.numel() + 2147483646) // 2147483646 change_indices = torch.cat([d.nonzero() for d in diff.chunk(num_chunks)]) else: change_indices = diff.nonzero() - - with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2: all_btw_idx"): - alt_lens = diff.sum(dim=1) - - all_cur_idx = change_indices[:, 1] - all_btw_idx = torch.cat([all_cur_idx[1:], all_cur_idx[:1]]) - all_cur_idx - - with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2: Encode run length"): - alt_lens_nt = torch.nested.nested_tensor_from_jagged(all_btw_idx, lengths=alt_lens) - # Encode run length - counts_init = (tensor[:, 0] == 0) - return RLEData(alt_lens_nt=alt_lens_nt, - counts_init=counts_init, - b=b, - h=h, - w=w) + with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2: _mask_to_rle_pytorch_2_0_1"): + alt_lens_nt, counts_init = _mask_to_rle_pytorch_2_0_1(tensor, diff, change_indices) + return RLEData(alt_lens_nt=alt_lens_nt, + counts_init=counts_init, + b=b, + h=h, + w=w) def _mask_to_rle_pytorch_2_1(rle_data: RLEData): @@ -276,7 +298,8 @@ def _mask_to_rle_pytorch_2_1(rle_data: RLEData): def mask_to_rle_pytorch_2(tensor: torch.Tensor) -> List[Dict[str, Any]]: - return _mask_to_rle_pytorch_2_1(_mask_to_rle_pytorch_2_0(tensor)) + with torch.autograd.profiler.record_function("mask_to_rle_pytorch_2"): + return _mask_to_rle_pytorch_2_1(_mask_to_rle_pytorch_2_0(tensor)) def area_from_rle(rle: Dict[str, Any]) -> int: diff --git a/torchao/_models/sam2/utils/transforms.py b/torchao/_models/sam2/utils/transforms.py index 421b6edb0b..8b90c4477c 100644 --- a/torchao/_models/sam2/utils/transforms.py +++ b/torchao/_models/sam2/utils/transforms.py @@ -73,7 +73,7 @@ def transform_boxes( boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) return boxes - def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: + def postprocess_masks(self, masks: torch.Tensor, orig_hw, output_dtype) -> torch.Tensor: """ Perform PostProcessing on output masks. """ @@ -114,10 +114,11 @@ def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: ) masks = input_masks + masks = masks.to(output_dtype) masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) return masks - def postprocess_masks_1_channel(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: + def postprocess_masks_1_channel(self, masks: torch.Tensor, orig_hw, output_dtype) -> torch.Tensor: """ Perform PostProcessing on output masks. """ @@ -161,5 +162,6 @@ def postprocess_masks_1_channel(self, masks: torch.Tensor, orig_hw) -> torch.Ten ) masks = input_masks + masks = masks.to(output_dtype) masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) return masks diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index b639832648..00305db348 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,13 +1,7 @@ +from . import affine_quantized_tensor_ops from .affine_quantized_tensor import ( AffineQuantizedTensor, - Float8AQTTensorImpl, - Float8Layout, - Layout, - MarlinQQQLayout, - MarlinSparseLayout, - PlainLayout, - SemiSparseLayout, - TensorCoreTiledLayout, + MarlinQQQTensor, to_affine_quantized_floatx, to_affine_quantized_floatx_static, # experimental, will be merged into floatx in the future @@ -16,15 +10,27 @@ to_affine_quantized_intx_static, to_marlinqqq_quantized_intx, ) +from .floatx import ( + Float8Layout, +) from .nf4tensor import NF4Tensor, to_nf4 - -# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor -from .uint4 import UInt4Tensor +from .uintx import ( + BlockSparseLayout, + Int4CPULayout, + MarlinQQQLayout, + MarlinSparseLayout, + SemiSparseLayout, + TensorCoreTiledLayout, + UintxLayout, +) +from .utils import ( + Layout, + PlainLayout, +) __all__ = [ "NF4Tensor", "to_nf4", - "UInt4Tensor", "AffineQuantizedTensor", "to_affine_quantized_intx", "to_affine_quantized_intx_static", @@ -37,7 +43,11 @@ "SemiSparseLayout", "TensorCoreTiledLayout", "Float8Layout", - "Float8AQTTensorImpl", "MarlinSparseLayout", + "affine_quantized_tensor_ops", + "BlockSparseLayout", + "UintxLayout", + "MarlinQQQTensor", "MarlinQQQLayout", + "Int4CPULayout", ] diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 5c36e4e4e0..93d2766d1e 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1,34 +1,18 @@ import logging import math -from dataclasses import dataclass from typing import Optional, Tuple, Union import torch -from torch.utils._python_dispatch import ( - is_traceable_wrapper_subclass, - return_and_correct_aliasing, -) from torchao.dtypes.utils import ( + AQTTensorImpl, Layout, PlainLayout, - get_out_shape, - is_device, -) -from torchao.float8.inference import ( - Float8MMConfig, - _is_rowwise_scaled, - addmm_float8_unwrapped_inference, - preprocess_data, -) -from torchao.kernel import ( - int_scaled_matmul, ) from torchao.quantization.quant_primitives import ( FP8_TYPES, MappingType, ZeroPointDomain, - _get_reduction_params, choose_qparams_affine, choose_qparams_affine_floatx, choose_qparams_and_quantize_affine_hqq, @@ -39,102 +23,30 @@ quantize_affine, quantize_affine_floatx, ) -from torchao.quantization.utils import ( - pack_tinygemm_scales_and_zeros, -) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, - _is_float8_type, - fill_defaults, - find_multiple, ) logger = logging.getLogger(__name__) - - aten = torch.ops.aten - -############################### -# Base Tensor Impl Subclass # -############################### -class AQTTensorImpl(TorchAOBaseTensor): - """ - Base class for the tensor impl for `AffineQuantizedTensor` - - Note: This is not a user facing API, it's used by AffineQuantizedTensor to construct - the underlying implementation of a AQT based on layout - """ - - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Get the plain (unpacked) Tensor for the tensor impl - - Returns data, scale and zero_point - Can be overwritten if other types of AQTTensorImpl has different numbers of plain tensors - """ - pass - - def get_layout(self) -> Layout: - pass - - @classmethod - def from_plain( - cls, - data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - _layout: Layout, - ): - """Construct a TensorImpl from data, scale, zero_point and the _layout""" - pass - - def __repr__(self): - data, scale, zero_point = self.get_plain() - _layout = self.get_layout() - return f"{self.__class__.__name__}(data={str(data)}... , scale={str(scale)}... , zero_point={str(zero_point)}... , _layout={_layout})" +__all__ = [ + "AffineQuantizedTensor", + "MarlinQQQTensor", + "register_layout", + "to_affine_quantized_intx", + "to_affine_quantized_floatx", + "to_affine_quantized_intx_static", + "to_affine_quantized_floatx_static", + "to_affine_quantized_fpx", + "to_marlinqqq_quantized_intx", +] ############################## # Tensor Subclass Definition # ############################## - - -class QuantizedLinearNotImplementedError(NotImplementedError): - """Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table""" - - pass - - -_AQT_QLINEAR_DISPATCH_TABLE = {} - - -def register_aqt_quantized_linear_dispatch(dispatch_condition, impl): - """Register a dispatch for quantized linear op with dispatch_condition function and impl function - both takes three arguments: - input_tensor: dimension is (M1, M2, ..., in_features) - weight_tensor: dimension is (out_features, in_features) - bias: dimension is (out_features,) - so that these can be shared by F.linear, aten.mm, aten.addmm dispatches - - Args: - `dispatch_condition` (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], bool]: the dispatch - condition for a specialized quantized linear implementation, e.g. bfloat16 activation + uint4 weight - `impl` (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: the specialized - quantized linear implementation - """ - _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl - - -def deregister_aqt_quantized_linear_dispatch(dispatch_condition): - if dispatch_condition in _AQT_QLINEAR_DISPATCH_TABLE: - del _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] - else: - logger.warn( - f"Attempting to remove non-existant dispatch condition {dispatch_condition}" - ) - - class AffineQuantizedTensor(TorchAOBaseTensor): """ Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation: @@ -242,6 +154,8 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor self.zero_point_domain, output_dtype=output_dtype, ) + from torchao.dtypes.uintx import TensorCoreTiledLayout + if isinstance(self._layout, TensorCoreTiledLayout): # need to return to original shape if tensor was padded # in preprocessing @@ -251,15 +165,6 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor dq = dq.narrow(dim, 0, dim_size) return dq - @staticmethod - def _quantized_linear_op(input_tensor, weight_tensor, bias): - for dispatch_condition, impl in _AQT_QLINEAR_DISPATCH_TABLE.items(): - if dispatch_condition(input_tensor, weight_tensor, bias): - return impl(input_tensor, weight_tensor, bias) - raise QuantizedLinearNotImplementedError( - "No specialized dispatch found for quantized linear op" - ) - def __tensor_flatten__(self): return ["tensor_impl"], [ self.block_size, @@ -539,7 +444,7 @@ def _apply_fn_to_data(self, fn): strides=self.stride(), ) - # following are the comments for __torch_function__/__torch_dispatch__, we can clean this up + # following are the comments for __torch_function__/__torch_dispatch__, -> this is defined in affine_quantized_tensor_ops.py # a bit later # Note: we only added cpu path here for 8da4w, this is for executorch, in the future # 1. we'll add cpu/cuda version (int4mm etc.) @@ -582,7 +487,7 @@ def from_hp_to_intx( quant_min: Optional[int] = None, quant_max: Optional[int] = None, zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, - _layout: Layout = None, + _layout: Optional[Layout] = None, ): original_shape = input_float.shape input_float = _layout.pre_process(input_float) @@ -611,2068 +516,6 @@ def from_hp_to_intx( register_layout = AffineQuantizedTensor.register_layout get_tensor_impl_constructor = AffineQuantizedTensor.get_tensor_impl_constructor - -@dataclass(frozen=True) -class SemiSparseLayout(Layout): - def pre_process(self, input: torch.Tensor) -> torch.Tensor: - # prune to 2:4 if not already - temp = input.detach() - pruning_inds = temp.abs().view(-1, 4).argsort(dim=1)[:, :2] - temp.view(-1, 4).scatter_(1, pruning_inds, value=0) - return temp - - -@dataclass(frozen=True) -class BlockSparseLayout(Layout): - blocksize: int = 64 - - -@dataclass(frozen=True) -class TensorCoreTiledLayout(Layout): - """ - inner_k_tiles is an internal argument for packing function of tensor core tiled layout - that can affect the performance of the matmul kernel - """ - - inner_k_tiles: int = 8 - - def pre_process(self, input: torch.Tensor) -> torch.Tensor: - orig_out_features, orig_in_features = input.shape - in_features = find_multiple(orig_in_features, 1024) - out_features = find_multiple(orig_out_features, 8) - input = torch.nn.functional.pad( - input, - (0, in_features - orig_in_features, 0, out_features - orig_out_features), - ) - return input - - def pre_process_static( - self, - input: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - block_size: Tuple[int, ...], - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - input = self.pre_process(input) - orig_qparam_shape = scale.shape - new_qparam_shape, reduction_dims = _get_reduction_params( - block_size, input.size() - ) - for dim in reduction_dims: - new_qparam_shape.pop(dim) - change_in_qparam_shape = [ - new_dim_size - orig_dim_size - for new_dim_size, orig_dim_size in zip(new_qparam_shape, orig_qparam_shape) - ] - padding_changes = [] - for dim_change in change_in_qparam_shape: - padding_changes = [0, dim_change] + padding_changes - scale = torch.nn.functional.pad(scale, padding_changes) - zero_point = torch.nn.functional.pad(zero_point, padding_changes) - return input, scale, zero_point - - def post_process(self, input: torch.Tensor) -> torch.Tensor: - orig_out_features, orig_in_features = input.shape - in_features = find_multiple(orig_in_features, 1024) - out_features = find_multiple(orig_out_features, 8) - input = torch.nn.functional.pad( - input, - (0, in_features - orig_in_features, 0, out_features - orig_out_features), - ) - return input - - def extra_repr(self): - return f"inner_k_tiles={self.inner_k_tiles}" - - -@dataclass(frozen=True) -class Float8Layout(Layout): - mm_config: Optional[Float8MMConfig] = None - - -@dataclass(frozen=True) -class MarlinSparseLayout(Layout): - def pre_process(self, input: torch.Tensor) -> torch.Tensor: - """Preprocess the input tensor to be in the correct format for the Marlin sparse kernel. - - 1º: the input tensor is transposed since the linear layer keeps the weights in a transposed format - - 2º: tensor is injected with 2:4 sparsity - - 3º: transposes it again because the quantization process will compute the scales for dim=-1 - - Args: - input (torch.Tensor): the input tensor to preprocess - - Returns: - torch.Tensor: the preprocessed tensor - """ - from torchao.sparsity.marlin import inject_24 # avoid circular import - - input_t = input.t() - w_24, _ = inject_24(input_t, *input_t.shape) - return w_24.t() - - -@dataclass(frozen=True) -class MarlinQQQLayout(Layout): - pass - - -@register_layout(PlainLayout) -class PlainAQTTensorImpl(AQTTensorImpl): - """ - TensorImpl for plain layout for affine quantized tensor, it stores int_data, scale, zero_point - tensors directly as plain tensors. - - fields: - int_data (torch.Tensor): the quantized integer data Tensor - scale (torch.Tensor): the scale Tensor used to map between floating point tensor to quantized tensor - zero_point (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor - """ - - def __new__( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - _layout: Layout, - ): - kwargs = {} - kwargs["device"] = int_data.device - kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout - ) - kwargs["dtype"] = int_data.dtype - kwargs["requires_grad"] = False - shape = int_data.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - _layout: Layout, - ): - self.int_data = int_data - self.scale = scale - self.zero_point = zero_point - self._layout = _layout - - def __tensor_flatten__(self): - return ["int_data", "scale", "zero_point"], [self._layout] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - int_data, scale, zero_point = ( - tensor_data_dict["int_data"], - tensor_data_dict["scale"], - tensor_data_dict["zero_point"], - ) - (_layout,) = tensor_attributes - return cls(int_data, scale, zero_point, _layout) - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - return self.__class__( - self.int_data.to(kwargs["device"]), - self.scale.to(kwargs["device"]), - self.zero_point.to(kwargs["device"]), - self._layout, - ) - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.int_data), - fn(self.scale), - fn(self.zero_point), - self._layout, - ) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - if func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - elif func is aten.t.default: - tensor = args[0] - new = tensor.__class__( - tensor.int_data.t(), tensor.scale, tensor.zero_point, tensor._layout - ) - return return_and_correct_aliasing(func, args, kwargs, new) - - elif func is aten.slice.Tensor: - self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - if dim == 0: - return return_and_correct_aliasing( - func, - args, - kwargs, - args[0]._apply_fn_to_data( - lambda x: aten.slice.Tensor(x, dim, start, end, step) - ), - ) - elif dim == 1: - assert ( - len(self.scale.shape) == 1 - ), f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}" - return PlainAQTTensorImpl( - aten.slice.Tensor(self.int_data, dim, start, end, step), - self.scale.view(-1), - self.zero_point.view(-1), - self._layout, - ) - else: - raise NotImplementedError( - f"PlainAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" - ) - - raise NotImplementedError( - f"PlainAQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - __torch_function__ = torch._C._disabled_torch_function_impl - - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - return self.int_data, self.scale, self.zero_point - - def get_layout(self) -> Layout: - return self._layout - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: Optional[torch.Tensor], - _layout: Layout, - ): - assert isinstance(_layout, PlainLayout) - return cls(int_data, scale, zero_point, _layout) - - -@register_layout(SemiSparseLayout) -class SemiSparseAQTTensorImpl(PlainAQTTensorImpl): - """ - TensorImpl for semi_sparse_cusparselt layout for affine quantized tensor - """ - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - raise NotImplementedError( - f"SparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - def get_plain(self): - # Currently we don't have cuSPARSELt expansion routines, so we matmul by - # the identity matrix to get the original dense matrix. This is slow though. - cols = self.int_data.numel() * 16 // (10 * self.scale.shape[0]) - int_data_expanded = torch._cslt_sparse_mm( - self.int_data, - torch.eye(cols, dtype=self.int_data.dtype, device=self.int_data.device).t(), - ) - return int_data_expanded, self.scale, self.zero_point - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: Optional[torch.Tensor], - _layout: Layout, - ): - assert isinstance(_layout, SemiSparseLayout) - int_data_compressed = torch._cslt_compress(int_data) - return cls(int_data_compressed, scale, zero_point, _layout) - - -@register_layout(BlockSparseLayout) -class BlockSparseAQTTensorImpl(PlainAQTTensorImpl): - bsr_crow_indices: Optional[torch.Tensor] - bsr_col_indices: Optional[torch.Tensor] - bsr_values: Optional[torch.Tensor] - scale: Optional[torch.Tensor] - zero_point: Optional[torch.Tensor] - - __slots__ = [ - "bsr_crow_indices", - "bsr_col_indices", - "bsr_values", - "scale", - "zero_point", - ] - - @staticmethod - def __new__( # noqa: PYI034 - cls, - shape: torch.Size, - bsr_crow_indices: Optional[torch.Tensor], - bsr_col_indices: Optional[torch.Tensor], - bsr_values: Optional[torch.Tensor], - scale: Optional[torch.Tensor], - zero_point: Optional[torch.Tensor], - _layout: Layout, - requires_grad: bool = False, - ): - if bsr_values is None: - raise ValueError("bsr values must be provided!") - else: - previous_tensor = bsr_values - - kwargs = { - "device": previous_tensor.device, - "dtype": previous_tensor.dtype, - "layout": previous_tensor.layout, - "requires_grad": requires_grad, - } - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( # noqa: PYI034 - self, - shape: torch.Size, - bsr_crow_indices: Optional[torch.Tensor], - bsr_col_indices: Optional[torch.Tensor], - bsr_values: Optional[torch.Tensor], - scale: Optional[torch.Tensor], - zero_point: Optional[torch.Tensor], - _layout: Layout, - requires_grad: bool = False, - ): - self.bsr_crow_indices = bsr_crow_indices - self.bsr_col_indices = bsr_col_indices - self.bsr_values = bsr_values - self.scale = scale - self.zero_point = zero_point - self._layout = _layout - - def __tensor_flatten__(self): - inner_tensors = list( - filter(lambda x: getattr(self, x) is not None, self.__slots__) - ) - tensor_meta = (self.shape, self._layout, self.requires_grad) - return inner_tensors, tensor_meta - - @classmethod - def __tensor_unflatten__( - cls, - inner_tensors, - tensor_meta: Tuple[torch.Size, bool], - outer_size, - outer_stride, - ) -> torch.Tensor: - shape, _layout, requires_grad = tensor_meta - return cls( - shape=shape, - bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), - bsr_col_indices=inner_tensors.get("bsr_col_indices", None), - bsr_values=inner_tensors.get("bsr_values", None), - scale=inner_tensors.get("scale", None), - zero_point=inner_tensors.get("zero_point", None), - _layout=_layout, - requires_grad=requires_grad, - ) - - @classmethod - def from_plain(cls, int_data, scale, zero_point, _layout): - bsr_tensor = int_data.to_sparse_bsr(_layout.blocksize) - return cls( - shape=int_data.shape, - bsr_crow_indices=bsr_tensor.crow_indices(), - bsr_col_indices=bsr_tensor.col_indices(), - bsr_values=bsr_tensor.values(), - scale=scale, - zero_point=zero_point, - _layout=_layout, - requires_grad=False, - ) - - def get_plain(self): - int_data_expanded = torch.ops.blocksparse.bsr_to_dense( - self.crow_indices(), - self.col_indices(), - self.values(), - self.shape[0], - self.shape[1], - ) - return int_data_expanded, self.scale, self.zero_point - - def _apply_fn_to_data(self, func): - return self.__class__( - shape=self.shape, - bsr_crow_indices=func(self.bsr_crow_indices), - bsr_col_indices=func(self.bsr_col_indices), - bsr_values=func(self.bsr_values), - scale=self.scale, - zero_point=self.zero_point, - _layout=self._layout, - requires_grad=self.requires_grad, - ) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - if func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - # Need the following for bsr specific functions - if func is aten.crow_indices.default: - return args[0].bsr_crow_indices.detach() - - if func is aten.col_indices.default: - return args[0].bsr_col_indices.detach() - - if func is aten.values.default: - return args[0].bsr_values.detach() - - if func is aten._nnz.default: - return args[0].bsr_values.shape[0] - - raise NotImplementedError( - f"BlockSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - -@register_layout(MarlinSparseLayout) -class MarlinSparseAQTTensorImpl(AQTTensorImpl): - """ - TensorImpl for sparse_marlin_24 layout for affine quantized tensor. - - Can be used with 4 bits and 8 bits quantization. - - Original marlin documentation and information: - https://github.com/IST-DASLab/marlin/tree/master - - Sparse marlin documentation and information: - https://github.com/IST-DASLab/Sparse-Marlin?tab=readme-ov-file - - fields: - original_shape (torch.Size): the original shape of the tensor. used to unpack the tensor to the original shape - group_size (int): the group size used to pack the tensor - num_bits (int): the number of bits used to quantize the tensor - """ - - @staticmethod - def __new__( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - meta: torch.Tensor, - _layout: Layout, - original_shape: torch.Size, - group_size: int, - num_bits: int, - ): - kwargs = {} - kwargs["device"] = int_data.device - kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout - ) - kwargs["dtype"] = int_data.dtype - kwargs["requires_grad"] = False - shape = int_data.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - meta: torch.Tensor, - _layout: Layout, - original_shape: torch.Size, - group_size: int, - num_bits: int, - ): - self.int_data = int_data - self.scale = scale - self.zero_point = zero_point - self.meta = meta - self._layout = _layout - self.original_shape = original_shape - self.group_size = group_size - self.num_bits = num_bits - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - raise NotImplementedError( - f"MarlinSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - def __tensor_flatten__(self): - return ["int_data", "scale", "zero_point", "meta"], [ - self._layout, - self.original_shape, - self.group_size, - self.num_bits, - ] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - int_data = tensor_data_dict["int_data"] - scale = tensor_data_dict["scale"] - zero_point = tensor_data_dict["zero_point"] - meta = tensor_data_dict["meta"] - _layout, original_shape, group_size, num_bits = tensor_attributes - return cls( - int_data, - scale, - zero_point, - meta, - _layout, - original_shape, - group_size, - num_bits, - ) - - def get_plain(self): - from torchao.sparsity.marlin import ( - unpack_from_marlin_24, - ) # avoid circular import - - int_data_expanded, scales_expanded = unpack_from_marlin_24( - self.int_data, - self.scale, - self.meta, - self.original_shape, - self.group_size, - self.num_bits, - ) - int_data_expanded_t = int_data_expanded.t() - scales_expanded_t = scales_expanded.t() - return int_data_expanded_t, scales_expanded_t, self.zero_point - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - _layout: Layout, - ): - from torchao.sparsity.marlin import ( - const, - pack_to_marlin_24, - ) # avoid circular import - - assert isinstance(_layout, MarlinSparseLayout) - - # Linear layers are (in_features, out_features) but the int_data that is reaching this point - # is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code. - q_w_24 = int_data.t() - scale_t = scale.t() - - if not torch.cuda.get_device_capability()[0] >= 8: - raise ValueError( - f"Can not use Sparse Marlin 2:4 int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel." - ) - - if q_w_24.dtype != torch.int32: - raise ValueError("Only `torch.int32` weights are supported.") - - in_features, out_features = q_w_24.shape - if in_features % 128 != 0 or out_features != 256 == 0: - raise ValueError( - "`in_features` must be divisible by 64 and `out_features` by 256." - ) - - # NOTE: The current marlin 2:4 kernel supports both 4 and 8 bits quantization but fp8 - # will require a bit more work to get our current quantization flow to work with it. - # Check the link for a reference: https://github.com/neuralmagic/nm-vllm/tree/main - num_bits = 4 if torch.max(q_w_24) < 16 else -1 - if num_bits not in [4]: - raise ValueError(f"Only {[4]} bits are supported, got {num_bits}.") - - group_size = in_features // scale_t.shape[0] - if group_size == 0: - group_size = in_features - assert ( - group_size <= in_features - ), "Group size must be less than or equal to in_features." - - if group_size not in const.SUPPORTED_GROUP_SIZES: - raise ValueError( - f"Only {const.SUPPORTED_GROUP_SIZES} group sizes are supported, got {group_size}." - ) - - # Compress quantized weight to marlin 2:4 format - marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24( - q_w_24, scale_t, num_bits, group_size - ) - - return cls( - marlin_24_q_w_comp, - marlin_24_s, - zero_point, - meta, - _layout, - q_w_24.shape, - group_size, - num_bits, - ) - - def get_layout(self) -> Layout: - return self._layout - - def _apply_fn_to_data(self, fn): - self.int_data = fn(self.int_data) - self.scale = fn(self.scale) - self.zero_point = fn(self.zero_point) - self.meta = fn(self.meta) - return self - - -@register_layout(MarlinQQQLayout) -class MarlinQQQAQTTensorImpl(AQTTensorImpl): - """ - TensorImpl storage class for sparse_qqq layout for affine quantized tensor. - - Can only be used with 4 bits quantization for now. - - Original marlin documentation and information: - https://github.com/IST-DASLab/marlin/tree/master - - Marlin qqq information: - https://github.com/HandH1998/QQQ/tree/main - https://arxiv.org/pdf/2406.09904 - - fields: - original_shape (torch.Size): the original shape of the tensor. used to unpack the tensor to the original shape - group_size (int): the group size used to pack the tensor - num_bits (int): the number of bits used to quantize the tensor - """ - - @staticmethod - def __new__( - cls, - int_data: torch.Tensor, - s_group: torch.Tensor, - s_channel: torch.Tensor, - _layout: Layout, - original_shape: torch.Size, - group_size: int, - num_bits: int, - ): - kwargs = {} - kwargs["device"] = int_data.device - kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout - ) - kwargs["dtype"] = int_data.dtype - kwargs["requires_grad"] = False - shape = int_data.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - int_data: torch.Tensor, - s_group: torch.Tensor, - s_channel: torch.Tensor, - _layout: Layout, - original_shape: torch.Size, - group_size: int, - num_bits: int, - ): - self.int_data = int_data - self.s_group = s_group - self.s_channel = s_channel - self._layout = _layout - self.original_shape = original_shape - self.group_size = group_size - self.num_bits = num_bits - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - raise NotImplementedError( - f"MarlinQQQAQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - def __tensor_flatten__(self): - return ["int_data", "s_group", "s_channel"], [ - self._layout, - self.original_shape, - self.group_size, - self.num_bits, - ] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - int_data = tensor_data_dict["int_data"] - s_group = tensor_data_dict["s_group"] - s_channel = tensor_data_dict["s_channel"] - _layout, original_shape, group_size, num_bits = tensor_attributes - return cls( - int_data, s_group, s_channel, _layout, original_shape, group_size, num_bits - ) - - def get_plain(self): - from torchao.quantization.marlin_qqq import ( - unpack_from_marlin_qqq, - ) # avoid circular import - - int_data_expanded, s_group_expanded, s_channel_expanded = ( - unpack_from_marlin_qqq( - self.int_data, - self.s_group, - self.s_channel, - self.original_shape, - self.num_bits, - self.group_size, - ) - ) - int_data_expanded_t = int_data_expanded.t() - s_group_expanded_t = s_group_expanded.t() - s_channel_expanded_t = s_channel_expanded.t() - return int_data_expanded_t, s_group_expanded_t, s_channel_expanded_t - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - s_group: torch.Tensor, - s_channel: torch.Tensor, - _layout: Layout, - ): - from torchao.quantization.marlin_qqq import ( - const, - pack_to_marlin_qqq, - ) # avoid circular import - - assert isinstance(_layout, MarlinQQQLayout) - - # Linear layers are (in_features, out_features) but the int_data that is reaching this point - # is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code. - q_w = int_data.t() - s_group_t = s_group.t() - s_channel_t = s_channel.t() - - if not torch.cuda.get_device_capability()[0] >= 8: - raise ValueError( - f"Can not use Marlin QQQ int4*int8 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel." - ) - - if q_w.dtype != torch.int32: - raise ValueError("Only `torch.int32` weights are supported.") - - in_features, out_features = q_w.shape - # (thread_k, thread_n) - thread_config = [(64, 256), (128, 128), (128, 64), (64, 128)] - if not any( - [ - in_features % thread_k == 0 and out_features % thread_n == 0 - for thread_k, thread_n in thread_config - ] - ): - raise ValueError( - "Not supported `in_features`: {} and `out_features`: {}.".format( - in_features, out_features - ) - ) - - num_bits = 4 if torch.max(q_w) - torch.min(q_w) < 16 else -1 - if num_bits not in [4]: - raise ValueError(f"Only {[4]} bits are supported, got {num_bits}.") - - if s_group.numel() == 0: - group_size = -1 - else: - group_size = in_features // s_group_t.shape[0] - assert ( - group_size <= in_features - ), "Group size must be less than or equal to in_features." - - if group_size not in const.SUPPORTED_GROUP_SIZES: - raise ValueError( - f"Only {const.SUPPORTED_GROUP_SIZES} group sizes are supported, got {group_size}." - ) - - # Compress quantized weight to marlin format - marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = pack_to_marlin_qqq( - q_w, s_group_t, s_channel_t, num_bits, group_size - ) - - return cls( - marlin_qqq_q_w, - marlin_qqq_s_group, - marlin_qqq_s_channel, - _layout, - q_w.shape, - group_size, - num_bits, - ) - - def get_layout(self) -> Layout: - return self._layout - - def _apply_fn_to_data(self, fn): - self.int_data = fn(self.int_data) - self.s_group = fn(self.s_group) - self.s_channel = fn(self.s_channel) - return self - - -@register_layout(Float8Layout) -class Float8AQTTensorImpl(AQTTensorImpl): - """ - TensorImpl for float8 layout affine quantized tensor - - Note: technically we should not create a new layout for float8 we should merge this into - plain layout - """ - - float8_data: torch.Tensor - scale: torch.Tensor - transposed: bool - - def __new__( - cls, - float8_data: torch.Tensor, - scale: torch.Tensor, - transposed: bool, - _layout: Layout, - ): - kwargs = {} - kwargs["device"] = float8_data.device - kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else float8_data.layout - ) - kwargs["dtype"] = float8_data.dtype - kwargs["requires_grad"] = False - shape = float8_data.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - float8_data: torch.Tensor, - scale: torch.Tensor, - transposed: bool, - _layout: Layout, - ): - self.float8_data = float8_data - self.scale = scale - self.transposed = transposed - self._layout = _layout - - def _apply_fn_to_data(self, fn): - """Applys a fn to all tensor components stored on this class""" - return self.__class__( - fn(self.float8_data), - fn(self.scale), - self.transposed, - self._layout, - ) - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - return self.__class__( - self.float8_data.to(kwargs["device"]), - self.scale.to(kwargs["device"]), - self.transposed, - self._layout, - ) - - def __tensor_flatten__(self): - return ["float8_data", "scale"], [self.transposed, self._layout] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - float8_data, scale = tensor_data_dict["float8_data"], tensor_data_dict["scale"] - ( - transposed, - _layout, - ) = tensor_attributes - return cls(float8_data, scale, transposed, _layout) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - elif func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - elif func is aten.t.default: - """we don't need to repack the weight and just rely on external - shape being changed and record the status of transpose/no-transpose - """ - args[0].transposed = not args[0].transposed - return return_and_correct_aliasing(func, args, kwargs, args[0]) - elif func is aten.slice.Tensor: - self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - if dim == 0: - # TODO: scale replecation should be dependent on block size - if self.scale.ndim == 1: - return return_and_correct_aliasing( - func, - args, - kwargs, - args[0]._apply_fn_to_data( - lambda x: aten.slice.Tensor(x, dim, start, end, step) - ), - ) - elif self.scale.ndim == 0: - return return_and_correct_aliasing( - func, - args, - kwargs, - Float8AQTTensorImpl( - aten.slice.Tensor(self.float8_data, dim, start, end, step), - self.scale, - None, - self._layout, - ), - ) - else: - raise NotImplementedError( - f"Float8AQTTensorImpl dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported" - ) - elif dim == 1: - return return_and_correct_aliasing( - func, - args, - kwargs, - Float8AQTTensorImpl( - aten.slice.Tensor( - self.float8_data, dim, start, end, step - ).contiguous(), - self.scale, - None, - self._layout, - ), - ) - else: - raise NotImplementedError( - f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" - ) - else: - raise NotImplementedError( - f"Float8AQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - __torch_function__ = torch._C._disabled_torch_function_impl - - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - return self.float8_data, self.scale, None - - def get_layout(self) -> Layout: - return self._layout - - @classmethod - def from_plain( - cls, - data: torch.Tensor, - scale: torch.Tensor, - zero_point: Optional[torch.Tensor], - _layout: Layout, - ): - """Main entrypoint for constructing Float8TensorImpl""" - assert _is_float8_type( - data.dtype - ), f"Float8 TensorImpl must be constructed from float8 dtype but got {data.dtype}" - assert isinstance( - _layout, Float8Layout - ), f"Float8 TensorImpl must be constructed from Float8Layout but got {_layout}" - return cls(data, scale, False, _layout) - - def __repr__(self): - float8_data, scale, _ = self.get_plain() - _layout = self.get_layout() - return ( - f"{self.__class__.__name__}(\n" - f"float8_data={float8_data},\n" - f"scale={scale},\n" - f"transposed={self.transposed}, " - f"_layout={_layout})" - ) - - -@register_layout(TensorCoreTiledLayout) -class TensorCoreTiledAQTTensorImpl(AQTTensorImpl): - """ - TensorImpl for tensor_core_tiled layout for affine quantized tensor, this is for int4 only, - used by tinygemm kernels `_weight_int4pack_mm` - - It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 4-d tensor of - dimension: [n / 8][k / (inner_k_tiles * 16)][32][inner_k_tiles / 2] - (unpacked Tensor shape is n * k) - where inner_k_tiles is an internal argument for packing function of tensor core tiled layout - that can affect the performance of the matmul kernel (defaults to 8) - - Note: we also pack scale and zero point together here for tinygemm kernel - - Note: technically tensor core tiled layout should be the layout for the underlying packed weight - (int Tensor) but since the scale and zero_point are also packed into the same tensor here which is not used - in plain layout, we just created a layout for AQT right now, this could be improved if we split out - int4 aqt into a separate tensor subclass - - fields: - packed_weight (torch.Tensor): the 4-d packed tensor in a tensor_core_tiled layout - scale_and_zero (torch.Tensor): the combined scale Tensor used to map between floating point tensor to quantized tensor and zero_point Tensor - """ - - def __new__( - cls, - packed_weight: torch.Tensor, - scale_and_zero: torch.Tensor, - transposed: bool, - _layout: Layout, - ): - kwargs = {} - kwargs["device"] = packed_weight.device - kwargs["layout"] = ( - kwargs.get("layout") - if kwargs.get("layout", False) - else packed_weight.layout - ) - kwargs["dtype"] = packed_weight.dtype - kwargs["requires_grad"] = False - shape = packed_weight.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - packed_weight: torch.Tensor, - scale_and_zero: torch.Tensor, - transposed: bool, - _layout: Layout, - ): - self.packed_weight = packed_weight - self.scale_and_zero = scale_and_zero - self.transposed = False - self._layout = _layout - - def __tensor_flatten__(self): - return ["packed_weight", "scale_and_zero"], [self.transposed, self._layout] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - packed_weight, scale_and_zero = ( - tensor_data_dict["packed_weight"], - tensor_data_dict["scale_and_zero"], - ) - ( - transposed, - _layout, - ) = tensor_attributes - return cls(packed_weight, scale_and_zero, transposed, _layout) - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: Optional[torch.Tensor], - _layout: Layout, - ): - assert isinstance(_layout, TensorCoreTiledLayout) - - if TORCH_VERSION_AT_LEAST_2_5: - int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) - assert ( - int_data.dtype == torch.uint8 - ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" - else: - assert ( - int_data.dtype == torch.int32 - ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" - packed_weight = torch.ops.aten._convert_weight_to_int4pack( - int_data, _layout.inner_k_tiles - ) - scale = scale.reshape(int_data.shape[0], -1) - zero_point = zero_point.reshape(int_data.shape[0], -1) - - scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) - return cls(packed_weight, scale_and_zero, False, _layout) - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - device = kwargs["device"] - # tensor core tiled layout supports both cpu and cuda but does not support the conversion - # between these two devices, in the future we should not use the same layout for - # cpu and cuda device: https://github.com/pytorch/ao/issues/1117 - if not is_device(torch.device(self.device).type, device): - raise ValueError( - f"TensorCoreTiledAQTTensorImpl does not support conversion from {self.device} to {device}" - ) - return self.__class__( - self.packed_weight.to(device), - self.scale_and_zero.to(device), - self.transposed, - self._layout, - ) - - def _apply_fn_to_data(self, fn): - # self.packed_weight = fn(self.packed_weight) - # self.scale_and_zero = fn(self.scale_and_zero) - # return self - return self.__class__( - fn(self.packed_weight), - fn(self.scale_and_zero), - self.transposed, - self._layout, - ) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - if func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - if func is aten.t.default: - """we don't need to repack the weight and just rely on external - shape being changed and record the status of transpose/no-transpose - """ - transposed = TensorCoreTiledAQTTensorImpl( - args[0].packed_weight, - args[0].scale_and_zero, - not args[0].transposed, - args[0]._layout, - ) - return return_and_correct_aliasing(func, args, kwargs, transposed) - - if func is aten.slice.Tensor: - self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - if dim == 0: - int_data, scale, zero_point = self.get_plain() - int_data = aten.slice.Tensor(int_data, dim, start, end, step) - # this is to handle padding - int_data = self._layout.post_process(int_data) - sliced = self.from_plain(int_data, scale, zero_point, self._layout) - return return_and_correct_aliasing(func, args, kwargs, sliced) - elif dim == 1: - int_data, scale, zero_point = self.get_plain() - assert step == 1, "Only step == 1 is supported in slicing right now" - data_len = int_data.shape[dim] - scale_len = scale.shape[dim] - ratio = data_len / scale_len - start_scale = int(start / ratio) - end_scale = int(end / ratio) - - int_data = aten.slice.Tensor(int_data, dim, start, end, step) - # this is to handle padding - int_data = self._layout.post_process(int_data) - scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) - zero_point = aten.slice.Tensor( - zero_point, dim, start_scale, end_scale, step - ) - sliced = self.from_plain(int_data, scale, zero_point, self._layout) - return sliced - else: - raise NotImplementedError( - f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" - ) - - raise NotImplementedError( - f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - __torch_function__ = torch._C._disabled_torch_function_impl - - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - from torchao.quantization.quant_primitives import ( - ZeroPointDomain, - quantize_affine, - ) - from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros - - scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) - - cur_shape = self.shape - assert len(cur_shape) == 4 - inner_k_tiles = cur_shape[-1] * 2 - original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16)) - eye_shape = original_shape[1] - groupsize = int(original_shape[1] / scale.shape[-2]) - block_size = (1, groupsize) - device = self.device - original_dtype = torch.bfloat16 - target_dtype = torch.int32 - quant_min = 0 - quant_max = 15 - zero_point_domain = ZeroPointDomain.FLOAT - assert len(block_size) == 2 and block_size[0] == 1 - dequantized = torch.ops.aten._weight_int4pack_mm( - torch.eye(eye_shape, device=device, dtype=original_dtype), - self.packed_weight, - groupsize, - self.scale_and_zero, - ) - dequantized = dequantized.t().contiguous() - # TODO: move this to `unpack_tinygemm_scales_and_zeros`? - scale = scale.reshape(scale.shape[:-1]).contiguous() - zero = zero.reshape(zero.shape[:-1]).contiguous() - int_data = quantize_affine( - dequantized, - block_size, - scale, - zero, - target_dtype, - quant_min, - quant_max, - zero_point_domain, - ) - return int_data, scale, zero - - def get_layout(self) -> Layout: - return self._layout - - -##################################################### -# torch functional and aten operator implementation # -##################################################### - - -def _aqt_is_int8(aqt): - """Check if an AffineQuantizedTensor is int8 quantized Tensor""" - return ( - aqt.tensor_impl.dtype == torch.int8 - and (aqt.quant_min is None or aqt.quant_min == -128) - and (aqt.quant_max is None or aqt.quant_max == 127) - ) - - -def _aqt_is_int8_reduced_range(aqt): - return ( - aqt.tensor_impl.dtype == torch.int8 - and aqt.quant_min == -127 - and (aqt.quant_max is None or aqt.quant_max == 127) - ) - - -def _aqt_is_tensor_core_tile_uint4(aqt): - """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" - # TODO: use torch.uint4 - return ( - aqt.tensor_impl.dtype == torch.int32 - and aqt.quant_min == 0 - and aqt.quant_max == 15 - ) - - -implements = AffineQuantizedTensor.implements - -# following are a list of (dispatch_condition, implementation) functions that takes the following args: -# input_tensor: dimension is (M1, M2, ..., in_features) -# weight_tensor: dimension is (out_features, in_features) -# bias: dimension is (out_features,) -# so that these can be shared by F.linear, aten.mm, aten.addmm dispatches - - -def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias): - return ( - isinstance(input_tensor, AffineQuantizedTensor) - and _aqt_is_int8_reduced_range(input_tensor) - and isinstance(weight_tensor, AffineQuantizedTensor) - and input_tensor.dtype == weight_tensor.dtype - and isinstance(input_tensor._layout, PlainLayout) - and isinstance(weight_tensor._layout, PlainLayout) - ) - - -def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): - # - # 1. do the matrix form of dot(X_i, W_j) - # - # - # 2. rescale the output - # - # in cases with large matrices, y_dot_int32 can grow sufficiently - # large that y_dot_int32 * a float16 scale is greater than the maximum - # value of a float 16, (which results in a value of inf even if multiplying - # by the other scale would bring it within the expected range) - - x_vals_int8 = input_tensor.tensor_impl.int_data - x_scales = input_tensor.tensor_impl.scale - w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t() - w_scales = weight_tensor.tensor_impl.scale - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) - x_scales_dtype = x_scales.dtype - # Cast fp16 scale to float to avoid overflow in int_scaled_matmul - intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype - y_dot_scaled = int_scaled_matmul( - tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype) - ) - y_dot_scaled = y_dot_scaled.to(x_scales_dtype) - - y = (y_dot_scaled * w_scales).reshape( - *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] - ) - - # can downcast only at the very end - output_dtype = input_tensor.dtype - y = y.to(output_dtype) - if bias is not None: - y = y + bias - return y - - -def _linear_int8_act_int8_weight_semi_structured_sparse_check( - input_tensor, weight_tensor, bias -): - return ( - isinstance(input_tensor, AffineQuantizedTensor) - and _aqt_is_int8_reduced_range(input_tensor) - and isinstance(weight_tensor, AffineQuantizedTensor) - and weight_tensor.is_cuda - and input_tensor.dtype == weight_tensor.dtype - and isinstance(input_tensor._layout, PlainLayout) - and isinstance(weight_tensor._layout, SemiSparseLayout) - ) - - -def _linear_int8_act_int8_weight_semi_structured_sparse_impl( - input_tensor, weight_tensor, bias -): - x_vals_int8 = input_tensor.tensor_impl.int_data - x_scales = input_tensor.tensor_impl.scale - w_vals_int8 = weight_tensor.tensor_impl.int_data - w_scales = weight_tensor.tensor_impl.scale - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) - # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm - y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( - w_vals_int8, - tmp.t(), - alpha=w_scales.to(torch.float32), - out_dtype=torch.bfloat16, - ).t() - y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( - *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] - ) - output_dtype = input_tensor.dtype - # TODO: waiting for jesse's test/fix - y = y.to(output_dtype).contiguous() - if bias is not None: - y += bias - return y - - -def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, bias): - return ( - isinstance(input_tensor, AffineQuantizedTensor) - and _aqt_is_int8_reduced_range(input_tensor) - and isinstance(weight_tensor, AffineQuantizedTensor) - and weight_tensor.is_cuda - and input_tensor.dtype == weight_tensor.dtype - and isinstance(input_tensor._layout, PlainLayout) - and isinstance(weight_tensor._layout, BlockSparseLayout) - ) - - -def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, bias): - x_vals_int8 = input_tensor.tensor_impl.int_data - x_scales = input_tensor.tensor_impl.scale - w_vals = weight_tensor.tensor_impl - w_scales = weight_tensor.tensor_impl.scale - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) - tmp_t = tmp.t() - - y = torch.ops.blocksparse.int_addmm( - w_vals.crow_indices(), - w_vals.col_indices(), - w_vals.values(), - tmp_t, - w_scales, - x_scales.reshape(-1), - ) - y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1]) - y = y.reshape(*y_shape) - - # can downcast only at the very end - output_dtype = input_tensor.dtype - y = y.to(output_dtype) - if bias is not None: - y += bias - return y - - -def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias): - return ( - # input is native bfloat16 tensor - not is_traceable_wrapper_subclass(input_tensor) - and input_tensor.dtype == torch.bfloat16 - and - # weight is uint4, group quantized tensor_core_tiled tensor impl affine quantized tensor - isinstance(weight_tensor, AffineQuantizedTensor) - and _aqt_is_tensor_core_tile_uint4(weight_tensor) - and weight_tensor.dtype == torch.bfloat16 - and len(weight_tensor.shape) == 2 - and weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT - and isinstance(weight_tensor._layout, TensorCoreTiledLayout) - ) - - -def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): - assert ( - weight_tensor.block_size[0] == 1 - ), f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" - assert input_tensor.shape[-1] == weight_tensor.shape[1], ( - f"need input_tensor shape: {input_tensor.shape} final" - f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " - ) - - # TODO: check groupsize quantization - # avoid circular dep, TODO: move this to a common util.py - act_mat = input_tensor - # weight is packed from padded (out_features, in_features) weight tensor - # (same dimension requirement as F.linear weight) - packed_weight = weight_tensor.tensor_impl.packed_weight - scale_and_zero = weight_tensor.tensor_impl.scale_and_zero - - orig_act_size = act_mat.size() - orig_dtype = act_mat.dtype - - # reshape and pad activation - act_mat = act_mat.reshape(-1, act_mat.shape[-1]).to(torch.bfloat16) - pad_size = find_multiple(act_mat.shape[-1], 1024) - act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1])) - - # groupwise int4 quantization - groupsize = weight_tensor.block_size[1] - y = torch.ops.aten._weight_int4pack_mm( - act_mat.contiguous(), packed_weight, groupsize, scale_and_zero - ) - - # remove out_feature padding - orig_out_features = weight_tensor.shape[-2] - y = y[:, :orig_out_features] - y = y.reshape(*orig_act_size[:-1], orig_out_features) - - if bias is not None: - y += bias - return y.to(orig_dtype) - - -def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias): - return ( - # input is native float tensor - not is_traceable_wrapper_subclass(input_tensor) - and input_tensor.is_floating_point() - and - # weight is int8 per channel quantized affine quantized tensor - isinstance(weight_tensor, AffineQuantizedTensor) - and _aqt_is_int8(weight_tensor) - and len(weight_tensor.shape) == 2 - and len(weight_tensor.block_size) == 2 - and weight_tensor.block_size[0] == 1 - and weight_tensor.block_size[1] == weight_tensor.shape[1] - and weight_tensor.zero_point_domain == ZeroPointDomain.INT - and isinstance(weight_tensor._layout, PlainLayout) - ) - - -def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): - # TODO: enable cpu and mps efficient path - # is_cpu and is_mps only, some issue with is_contiguous() currently - # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_tensor.tensor_impl.scale) - - # per channel int8 weight only quantizated mm - w_vals_int8_t = weight_tensor.tensor_impl.int_data.t() - scale = weight_tensor.tensor_impl.scale - m = torch.mm( - input_tensor.reshape(-1, input_tensor.shape[-1]), - w_vals_int8_t.to(input_tensor.dtype), - ) - y = m * scale.to(m.dtype) - y = y.reshape(*input_tensor.shape[:-1], y.shape[-1]) - if bias is not None: - y += bias.to(m.dtype) - return y - - -def _linear_f16_bf16_act_floatx_weight_check(input_tensor, weight_tensor, bias): - from torchao.dtypes.floatx import FloatxTensorCoreLayout - - return ( - # input is native float32 tensor - not is_traceable_wrapper_subclass(input_tensor) - and input_tensor.is_floating_point() - and input_tensor.dtype in (torch.float16, torch.bfloat16) - and - # weight is floatx Tensor - isinstance(weight_tensor, AffineQuantizedTensor) - and isinstance(weight_tensor._layout, FloatxTensorCoreLayout) - and ( - # weight is using fp6 quantization - (weight_tensor._layout.ebits == 3 and weight_tensor._layout.mbits == 2) - or (weight_tensor._layout.ebits == 2 and weight_tensor._layout.mbits == 3) - or - # weight is using fp5 quantization - (weight_tensor._layout.ebits == 2 and weight_tensor._layout.mbits == 2) - or (weight_tensor._layout.ebits == 3 and weight_tensor._layout.mbits == 1) - ) - ) - - -def _linear_f16_bf16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): - from torchao.dtypes.floatx import _SPLIT_K_MAP - from torchao.ops import quant_llm_linear - - act = input_tensor - weight = weight_tensor - - out_dim, in_dim = weight.shape - act_reshaped = act.view(-1, in_dim) - - # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py - bsize = act_reshaped.shape[0] - splitK = _SPLIT_K_MAP[(bsize - 1) // 64].get(out_dim, 1) if bsize <= 768 else 1 - - out = quant_llm_linear( - weight._layout.ebits, - weight._layout.mbits, - act_reshaped, - weight.tensor_impl.packed_floatx_data, - weight.tensor_impl.scale, - splitK=splitK, - ) - - if bias is not None: - out += bias - - return out.view(*act.shape[:-1], out_dim).to(act.dtype) - - -def _linear_fp8_act_fp8_weight_check( - input_tensor: Union[torch.Tensor, AffineQuantizedTensor], - weight_tensor: Union[torch.Tensor, AffineQuantizedTensor], - bias: Optional[torch.Tensor], -) -> bool: - def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: - return ( - isinstance(aqt, AffineQuantizedTensor) - and isinstance(aqt._layout, Float8Layout) - and aqt.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] - and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt)) - ) - - return check_aqt(input_tensor) and check_aqt(weight_tensor) - - -def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]): - """Ensures input tensor is correctly formated for _scaled_mm""" - input_scale = input_scale.unsqueeze(-1) - - if input_scale.dim() > 2: - input_scale = input_scale.reshape(-1, input_scale.shape[-1]) - - return input_scale - - -def _linear_fp8_act_fp8_weight_impl( - input_tensor: AffineQuantizedTensor, - weight_tensor: AffineQuantizedTensor, - bias: Optional[torch.Tensor], -): - """Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm""" - scaled_mm_config = weight_tensor._layout.mm_config - out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) - - # Weight tensor preprocessing - w_tensor_impl = weight_tensor.tensor_impl - assert not w_tensor_impl.transposed, "Weight tensor must be contiguous" - w_data = w_tensor_impl.float8_data - w_scale = w_tensor_impl.scale - - # Input tensor preprocessing - inpt_data = input_tensor.tensor_impl.float8_data - input_scale = input_tensor.tensor_impl.scale - # Handle case where input tensor is more than 2D - inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1]) - - # Handle rowwise case - if _is_rowwise_scaled(weight_tensor): - assert _is_rowwise_scaled( - input_tensor - ), "Input tensor must be rowwise block size" - w_scale = w_scale.unsqueeze(-1).T - input_scale = preprocess_scale(input_scale, input_tensor.shape) - - # Preprocess data - inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) - - # Perform the computation - return addmm_float8_unwrapped_inference( - inpt_data, - input_scale, - w_data, - w_scale, - output_dtype=input_tensor.dtype, - bias=bias, - use_fast_accum=scaled_mm_config.use_fast_accum, - ).reshape(out_shape) - - -def _linear_fp_act_fp8_weight_check( - input_tensor: Union[torch.Tensor, AffineQuantizedTensor], - weight_tensor: Union[torch.Tensor, AffineQuantizedTensor], - bias: Optional[torch.Tensor], -) -> bool: - return ( - # input is native float tensor - not is_traceable_wrapper_subclass(input_tensor) - and input_tensor.is_floating_point() - and - # weight is float8 quantized affine quantized tensor - isinstance(weight_tensor, AffineQuantizedTensor) - and isinstance(weight_tensor._layout, Float8Layout) - and weight_tensor.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] - and ( - weight_tensor.shape == weight_tensor.block_size - or _is_rowwise_scaled(weight_tensor) - ) - ) - - -def _linear_fp_act_fp8_weight_impl( - input_tensor: torch.Tensor, - weight_tensor: AffineQuantizedTensor, - bias: Optional[torch.Tensor], -): - return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias) - - -def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, bias): - return ( - isinstance(weight_tensor, AffineQuantizedTensor) - and _aqt_is_tensor_core_tile_uint4(weight_tensor) - and input_tensor.dtype == torch.float16 - and len(weight_tensor.shape) == 2 - and weight_tensor.zero_point_domain == ZeroPointDomain.INT - and isinstance(weight_tensor._layout, MarlinSparseLayout) - ) - - -def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, bias): - from torchao.ops import marlin_24_gemm - from torchao.sparsity.marlin import marlin_24_workspace - - assert isinstance(weight_tensor, AffineQuantizedTensor) - - sparse_w_int4 = weight_tensor.tensor_impl.int_data - scale = weight_tensor.tensor_impl.scale - meta = weight_tensor.tensor_impl.meta - original_shape = weight_tensor.tensor_impl.original_shape - num_bits = weight_tensor.tensor_impl.num_bits - - # Folds batch dimension into the first dimension - input_2d = input_tensor.view(-1, input_tensor.shape[-1]) - - size_m = input_2d.shape[0] - size_n = scale.shape[1] - size_k = input_2d.shape[1] - workspace_24 = marlin_24_workspace(original_shape[1]) - - out = marlin_24_gemm( - input_2d, - sparse_w_int4, - meta, - scale, - workspace_24, - num_bits, - size_m, - size_n, - size_k, - ) - - # Unfold the batch dimension - out = out.reshape(input_tensor.shape[:-1] + (scale.shape[1],)) - - if bias is not None: - out += bias.to(out.dtype) - return out - - -def _linear_int8_act_int4_weight_marlin_qqq_check(input_tensor, weight_tensor, bias): - return ( - isinstance(input_tensor, AffineQuantizedTensor) - and _aqt_is_int8_reduced_range(input_tensor) - and input_tensor.dtype == torch.float16 - and input_tensor.tensor_impl.scale.dtype == torch.float32 - and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 - and isinstance(weight_tensor, AffineQuantizedTensor) - and weight_tensor.tensor_impl.dtype == torch.int32 - and len(weight_tensor.shape) == 2 - and isinstance(weight_tensor._layout, MarlinQQQLayout) - ) - - -def _linear_int8_act_int4_weight_marlin_qqq_impl(input_tensor, weight_tensor, bias): - from torchao.ops import marlin_qqq_gemm - from torchao.quantization.marlin_qqq import marlin_qqq_workspace - - assert isinstance(input_tensor, AffineQuantizedTensor) - assert isinstance(weight_tensor, AffineQuantizedTensor) - - input = input_tensor.tensor_impl.int_data - input_scale = input_tensor.tensor_impl.scale - - w_int4 = weight_tensor.tensor_impl.int_data - s_group = weight_tensor.tensor_impl.s_group - s_channel = weight_tensor.tensor_impl.s_channel - original_shape = weight_tensor.tensor_impl.original_shape - - # Folds batch dimension into the first dimension - input_2d = input.view(-1, input.shape[-1]) - input_scale = input_scale.view(1, -1) - - size_m = input_2d.shape[0] - size_n = s_channel.shape[1] - size_k = input_2d.shape[1] - workspace_qqq = marlin_qqq_workspace(original_shape[1]) - - out = marlin_qqq_gemm( - input_2d, - w_int4, - input_scale, - s_channel, - s_group, - workspace_qqq, - size_m, - size_n, - size_k, - ) - - # Unfold the batch dimension - out = out.reshape(input.shape[:-1] + (s_channel.shape[1],)) - - if bias is not None: - out += bias.to(out.dtype) - return out - - -def _register_aqt_quantized_linear_dispatches(): - for dispatch_condition, impl in [ - (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), - ( - _linear_int8_act_int8_weight_semi_structured_sparse_check, - _linear_int8_act_int8_weight_semi_structured_sparse_impl, - ), - ( - _linear_int8_act_int8_weight_block_sparse_check, - _linear_int8_act_int8_weight_block_sparse_impl, - ), - (_linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl), - (_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl), - (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), - (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), - ( - _linear_f16_bf16_act_floatx_weight_check, - _linear_f16_bf16_act_floatx_weight_impl, - ), - ( - _linear_fp_act_int4_weight_sparse_marlin_check, - _linear_fp_act_int4_weight_sparse_marlin_impl, - ), - ( - _linear_int8_act_int4_weight_marlin_qqq_check, - _linear_int8_act_int4_weight_marlin_qqq_impl, - ), - ]: - register_aqt_quantized_linear_dispatch(dispatch_condition, impl) - - -_register_aqt_quantized_linear_dispatches() - - -@implements([torch.nn.functional.linear, aten.linear.default]) -def _(func, types, args, kwargs): - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) - if not input_tensor.is_floating_point(): - raise NotImplementedError( - f"{func} is not implemented for non floating point input" - ) - - # using try/except here so that we can have a general fallback when input_tensor/weight_tensor - # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to - # make the branches easier to understand in `_quantized_linear_op` - try: - return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) - except QuantizedLinearNotImplementedError as e: - # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` - if ( - isinstance(weight_tensor, AffineQuantizedTensor) - and hasattr(weight_tensor._layout, "quantized_linear_impl") - and weight_tensor._layout.quantized_linear_impl is not None - ): - raise e - - if isinstance(input_tensor, AffineQuantizedTensor): - input_tensor = input_tensor.dequantize() - if isinstance(weight_tensor, AffineQuantizedTensor): - weight_tensor = weight_tensor.dequantize() - return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - - -@implements(torch.nn.functional.embedding) -def _(func, types, args, kwargs): - # new_arg1 = args[1].dequantize() - # return torch.nn.embedding(args[0], new_arg1, *args[2:], **kwargs) - assert isinstance( - args[1].tensor_impl, PlainAQTTensorImpl - ), f"embedding only works with PlainAQTTensorImpl but got {type(args[1].tensor_impl)}" - assert ( - kwargs["padding_idx"] is None - and kwargs["max_norm"] is None - and not kwargs["scale_grad_by_freq"] - and not kwargs["sparse"] - and kwargs["norm_type"] == 2.0 - ) - idx = args[0] - int_data, scale, zero_point = args[1].tensor_impl.get_plain() - - sliced_data, sliced_scale, sliced_zero_point = ( - int_data[idx], - scale[idx], - zero_point[idx], - ) - # Block size is expecting 2 dimensions [1, group size] but - # batchsize or other dims gets added to sliced_data, sliced_scale and sliced_zero_point so - # we need to increase block size to correct dim - new_blocks = idx.dim() - 1 - return dequantize_affine( - sliced_data, - new_blocks * [1] + list(args[1].block_size), - sliced_scale, - sliced_zero_point, - sliced_data.dtype, - args[1].quant_min, - args[1].quant_max, - args[1].zero_point_domain, - output_dtype=sliced_scale.dtype, - ) - - -@implements(aten.addmm.default) -def _(func, types, args, kwargs): - input_tensor, weight_tensor, bias = ( - args[1], - args[2], - args[0], - ) - if not input_tensor.is_floating_point(): - raise NotImplementedError( - f"{func} is not implemented for non floating point input" - ) - - # using try/except here so that we can have a general fallback when input_tensor/weight_tensor - # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to - # make the branches easier to understand in `_quantized_linear_op` - try: - weight_tensor = weight_tensor.t() - return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) - except QuantizedLinearNotImplementedError as e: - # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` - if ( - isinstance(weight_tensor, AffineQuantizedTensor) - and hasattr(weight_tensor._layout, "quantized_linear_impl") - and weight_tensor._layout.quantized_linear_impl is not None - ): - raise e - - if isinstance(input_tensor, AffineQuantizedTensor): - input_tensor = input_tensor.dequantize() - if isinstance(weight_tensor, AffineQuantizedTensor): - weight_tensor = weight_tensor.dequantize() - return func(bias, input_tensor, weight_tensor) - - -@implements(aten.mm.default) -def _(func, types, args, kwargs): - input_tensor, weight_tensor, bias = (args[0], args[1], None) - if not input_tensor.is_floating_point(): - raise NotImplementedError( - f"{func} is not implemented for non floating point input" - ) - - try: - weight_tensor = weight_tensor.t() - return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) - except QuantizedLinearNotImplementedError as e: - # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` - if ( - isinstance(weight_tensor, AffineQuantizedTensor) - and hasattr(weight_tensor._layout, "quantized_linear_impl") - and weight_tensor._layout.quantized_linear_impl is not None - ): - raise e - - if isinstance(input_tensor, AffineQuantizedTensor): - input_tensor = input_tensor.dequantize() - if isinstance(weight_tensor, AffineQuantizedTensor): - weight_tensor = weight_tensor.dequantize() - return func(input_tensor, weight_tensor) - - -@implements(aten.detach.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - -@implements(aten.clone.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - -@implements(aten._to_copy.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, - args, - kwargs, - args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), - ) - - -@implements(aten.t.default) -def _(func, types, args, kwargs): - block_size = args[0].block_size - assert len(block_size) == 2 - transposed_block_size = (block_size[1], block_size[0]) - tensor = args[0] - shape = tensor.shape[::-1] - new = tensor.__class__( - tensor.tensor_impl.t(), - transposed_block_size, - shape, - tensor.quant_min, - tensor.quant_max, - tensor.zero_point_domain, - dtype=tensor.dtype, - strides=tensor.stride(), - ) - return return_and_correct_aliasing(func, args, kwargs, new) - - -@implements(aten.slice.Tensor) -def _(func, types, args, kwargs): - self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - assert step == 1 - assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}" - if end >= self.shape[dim]: - end = self.shape[dim] - shape = list(self.shape) - shape[dim] = end - start - block_size = self.block_size - assert ( - len(block_size) == 2 - ), f"Slice only works for 2d block_size right now, got: {block_size}" - # with slice, some shape dimension might be smaller than block_size dimension, so - # we need to make sure there is no overflow - block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1])) - new = self.__class__( - aten.slice.Tensor(self.tensor_impl, dim, start, end, step), - block_size, - shape, - self.quant_min, - self.quant_max, - self.zero_point_domain, - dtype=self.dtype, - strides=self.stride(), - ) - return return_and_correct_aliasing(func, args, kwargs, new) - - -# this is needed for DTensor.from_local() and for flattening tensor -@implements(aten.view.default) -def _(func, types, args, kwargs): - self, shape = args - - if tuple(self.shape) == tuple(shape): - return self.__class__( - self.tensor_impl, - self.block_size, - self.shape, - self.quant_min, - self.quant_max, - self.zero_point_domain, - dtype=self.dtype, - strides=self.stride(), - ) - - if len(shape) == 1 and shape[0] == -1: - assert len(self.block_size) == 2 and self.block_size[0] == 1 - block_size = (self.block_size[1],) - return self.__class__( - self.tensor_impl, - block_size, - (self.numel(),), - self.quant_min, - self.quant_max, - self.zero_point_domain, - dtype=self.dtype, - strides=self.stride(), - ) - - raise ValueError( - f"{self.__class__.__name__} only supports .view() with same shape or shape=[-1]" - ) - - to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py new file mode 100644 index 0000000000..bd7ff7d333 --- /dev/null +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -0,0 +1,386 @@ +import logging + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, +) +from torchao.dtypes.floatx.float8_layout import ( + _linear_fp8_act_fp8_weight_check, + _linear_fp8_act_fp8_weight_impl, + _linear_fp_act_fp8_weight_check, + _linear_fp_act_fp8_weight_impl, +) +from torchao.dtypes.floatx.floatx_tensor_core_layout import ( + _linear_f16_bf16_act_floatx_weight_check, + _linear_f16_bf16_act_floatx_weight_impl, +) +from torchao.dtypes.uintx.block_sparse_layout import ( + _linear_int8_act_int8_weight_block_sparse_check, + _linear_int8_act_int8_weight_block_sparse_impl, +) +from torchao.dtypes.uintx.marlin_qqq_layout import ( + _linear_int8_act_int4_weight_marlin_qqq_check, + _linear_int8_act_int4_weight_marlin_qqq_impl, +) +from torchao.dtypes.uintx.marlin_sparse_layout import ( + _linear_fp_act_int4_weight_sparse_marlin_check, + _linear_fp_act_int4_weight_sparse_marlin_impl, +) +from torchao.dtypes.uintx.plain_layout import ( + PlainAQTTensorImpl, + _linear_fp_act_int8_weight_check, + _linear_fp_act_int8_weight_impl, + _linear_int8_act_int8_weight_check, + _linear_int8_act_int8_weight_impl, +) +from torchao.dtypes.uintx.semi_sparse_layout import ( + _linear_int8_act_int8_weight_semi_structured_sparse_check, + _linear_int8_act_int8_weight_semi_structured_sparse_impl, +) +from torchao.dtypes.uintx.tensor_core_tiled_layout import ( + _linear_bf16_act_uint4_weight_check, + _linear_bf16_act_uint4_weight_impl, +) +from torchao.quantization.quant_primitives import dequantize_affine +from torchao.utils import ( + fill_defaults, +) + +logger = logging.getLogger(__name__) + + +aten = torch.ops.aten + + +_AQT_QLINEAR_DISPATCH_TABLE = {} + + +def register_aqt_quantized_linear_dispatch(dispatch_condition, impl): + """Register a dispatch for quantized linear op with dispatch_condition function and impl function + both takes three arguments: + input_tensor: dimension is (M1, M2, ..., in_features) + weight_tensor: dimension is (out_features, in_features) + bias: dimension is (out_features,) + so that these can be shared by F.linear, aten.mm, aten.addmm dispatches + + Args: + `dispatch_condition` (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], bool]: the dispatch + condition for a specialized quantized linear implementation, e.g. bfloat16 activation + uint4 weight + `impl` (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: the specialized + quantized linear implementation + """ + _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl + + +def deregister_aqt_quantized_linear_dispatch(dispatch_condition): + if dispatch_condition in _AQT_QLINEAR_DISPATCH_TABLE: + del _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] + else: + logger.warn( + f"Attempting to remove non-existant dispatch condition {dispatch_condition}" + ) + + +class QuantizedLinearNotImplementedError(NotImplementedError): + """Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table""" + + pass + + +@staticmethod +def _quantized_linear_op(input_tensor, weight_tensor, bias): + for dispatch_condition, impl in _AQT_QLINEAR_DISPATCH_TABLE.items(): + if dispatch_condition(input_tensor, weight_tensor, bias): + return impl(input_tensor, weight_tensor, bias) + raise QuantizedLinearNotImplementedError( + "No specialized dispatch found for quantized linear op" + ) + + +# Attach the _quantized_linear_op to the AffineQuantizedTensor class +AffineQuantizedTensor._quantized_linear_op = _quantized_linear_op + + +# _register_aqt_quantized_linear_dispatches function has a list of (dispatch_condition, implementation) functions, defined in their dtype layout classes, that takes the following args: +# input_tensor: dimension is (M1, M2, ..., in_features) +# weight_tensor: dimension is (out_features, in_features) +# bias: dimension is (out_features,) +# so that these can be shared by F.linear, aten.mm, aten.addmm dispatches +def _register_aqt_quantized_linear_dispatches(): + for dispatch_condition, impl in [ + (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), + ( + _linear_int8_act_int8_weight_semi_structured_sparse_check, + _linear_int8_act_int8_weight_semi_structured_sparse_impl, + ), + ( + _linear_int8_act_int8_weight_block_sparse_check, + _linear_int8_act_int8_weight_block_sparse_impl, + ), + (_linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl), + (_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl), + (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), + (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), + ( + _linear_f16_bf16_act_floatx_weight_check, + _linear_f16_bf16_act_floatx_weight_impl, + ), + ( + _linear_fp_act_int4_weight_sparse_marlin_check, + _linear_fp_act_int4_weight_sparse_marlin_impl, + ), + ( + _linear_int8_act_int4_weight_marlin_qqq_check, + _linear_int8_act_int4_weight_marlin_qqq_impl, + ), + ]: + register_aqt_quantized_linear_dispatch(dispatch_condition, impl) + + +_register_aqt_quantized_linear_dispatches() + +implements = AffineQuantizedTensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + if not input_tensor.is_floating_point(): + raise NotImplementedError( + f"{func} is not implemented for non floating point input" + ) + + # using try/except here so that we can have a general fallback when input_tensor/weight_tensor + # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to + # make the branches easier to understand in `_quantized_linear_op` + try: + return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) + except QuantizedLinearNotImplementedError as e: + # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` + if ( + isinstance(weight_tensor, AffineQuantizedTensor) + and hasattr(weight_tensor._layout, "quantized_linear_impl") + and weight_tensor._layout.quantized_linear_impl is not None + ): + raise e + + if isinstance(input_tensor, AffineQuantizedTensor): + input_tensor = input_tensor.dequantize() + if isinstance(weight_tensor, AffineQuantizedTensor): + weight_tensor = weight_tensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + + +@implements(torch.nn.functional.embedding) +def _(func, types, args, kwargs): + # new_arg1 = args[1].dequantize() + # return torch.nn.embedding(args[0], new_arg1, *args[2:], **kwargs) + assert isinstance( + args[1].tensor_impl, PlainAQTTensorImpl + ), f"embedding only works with PlainAQTTensorImpl but got {type(args[1].tensor_impl)}" + assert ( + kwargs["padding_idx"] is None + and kwargs["max_norm"] is None + and not kwargs["scale_grad_by_freq"] + and not kwargs["sparse"] + and kwargs["norm_type"] == 2.0 + ) + idx = args[0] + int_data, scale, zero_point = args[1].tensor_impl.get_plain() + + sliced_data, sliced_scale, sliced_zero_point = ( + int_data[idx], + scale[idx], + zero_point[idx], + ) + # Block size is expecting 2 dimensions [1, group size] but + # batchsize or other dims gets added to sliced_data, sliced_scale and sliced_zero_point so + # we need to increase block size to correct dim + new_blocks = idx.dim() - 1 + return dequantize_affine( + sliced_data, + new_blocks * [1] + list(args[1].block_size), + sliced_scale, + sliced_zero_point, + sliced_data.dtype, + args[1].quant_min, + args[1].quant_max, + args[1].zero_point_domain, + output_dtype=sliced_scale.dtype, + ) + + +@implements(aten.addmm.default) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[1], + args[2], + args[0], + ) + if not input_tensor.is_floating_point(): + raise NotImplementedError( + f"{func} is not implemented for non floating point input" + ) + + # using try/except here so that we can have a general fallback when input_tensor/weight_tensor + # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to + # make the branches easier to understand in `_quantized_linear_op` + try: + weight_tensor = weight_tensor.t() + return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) + except QuantizedLinearNotImplementedError as e: + # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` + if ( + isinstance(weight_tensor, AffineQuantizedTensor) + and hasattr(weight_tensor._layout, "quantized_linear_impl") + and weight_tensor._layout.quantized_linear_impl is not None + ): + raise e + + if isinstance(input_tensor, AffineQuantizedTensor): + input_tensor = input_tensor.dequantize() + if isinstance(weight_tensor, AffineQuantizedTensor): + weight_tensor = weight_tensor.dequantize() + return func(bias, input_tensor, weight_tensor) + + +@implements(aten.mm.default) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = (args[0], args[1], None) + if not input_tensor.is_floating_point(): + raise NotImplementedError( + f"{func} is not implemented for non floating point input" + ) + + try: + weight_tensor = weight_tensor.t() + return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) + except QuantizedLinearNotImplementedError as e: + # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` + if ( + isinstance(weight_tensor, AffineQuantizedTensor) + and hasattr(weight_tensor._layout, "quantized_linear_impl") + and weight_tensor._layout.quantized_linear_impl is not None + ): + raise e + + if isinstance(input_tensor, AffineQuantizedTensor): + input_tensor = input_tensor.dequantize() + if isinstance(weight_tensor, AffineQuantizedTensor): + weight_tensor = weight_tensor.dequantize() + return func(input_tensor, weight_tensor) + + +@implements(aten.detach.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + +@implements(aten._to_copy.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + + +@implements(aten.t.default) +def _(func, types, args, kwargs): + block_size = args[0].block_size + assert len(block_size) == 2 + transposed_block_size = (block_size[1], block_size[0]) + tensor = args[0] + shape = tensor.shape[::-1] + new = tensor.__class__( + tensor.tensor_impl.t(), + transposed_block_size, + shape, + tensor.quant_min, + tensor.quant_max, + tensor.zero_point_domain, + dtype=tensor.dtype, + strides=tensor.stride(), + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + assert step == 1 + assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}" + if end >= self.shape[dim]: + end = self.shape[dim] + shape = list(self.shape) + shape[dim] = end - start + block_size = self.block_size + assert ( + len(block_size) == 2 + ), f"Slice only works for 2d block_size right now, got: {block_size}" + # with slice, some shape dimension might be smaller than block_size dimension, so + # we need to make sure there is no overflow + block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1])) + new = self.__class__( + aten.slice.Tensor(self.tensor_impl, dim, start, end, step), + block_size, + shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + dtype=self.dtype, + strides=self.stride(), + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +# this is needed for DTensor.from_local() and for flattening tensor +@implements(aten.view.default) +def _(func, types, args, kwargs): + self, shape = args + + if tuple(self.shape) == tuple(shape): + return self.__class__( + self.tensor_impl, + self.block_size, + self.shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + dtype=self.dtype, + strides=self.stride(), + ) + + if len(shape) == 1 and shape[0] == -1: + assert len(self.block_size) == 2 and self.block_size[0] == 1 + block_size = (self.block_size[1],) + return self.__class__( + self.tensor_impl, + block_size, + (self.numel(),), + self.quant_min, + self.quant_max, + self.zero_point_domain, + dtype=self.dtype, + strides=self.stride(), + ) + + raise ValueError( + f"{self.__class__.__name__} only supports .view() with same shape or shape=[-1]" + ) diff --git a/torchao/dtypes/floatx/__init__.py b/torchao/dtypes/floatx/__init__.py index 6ff0a903d2..3f0a1ccd5c 100644 --- a/torchao/dtypes/floatx/__init__.py +++ b/torchao/dtypes/floatx/__init__.py @@ -1,6 +1,5 @@ -from .floatx import ( - _SPLIT_K_MAP, - FloatxTensorCoreAQTTensorImpl, +from .float8_layout import Float8Layout +from .floatx_tensor_core_layout import ( FloatxTensorCoreLayout, from_scaled_tc_floatx, to_scaled_tc_floatx, @@ -8,8 +7,7 @@ __all__ = [ "FloatxTensorCoreLayout", - "FloatxTensorCoreAQTTensorImpl", "to_scaled_tc_floatx", "from_scaled_tc_floatx", - "_SPLIT_K_MAP", + "Float8Layout", ] diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py new file mode 100644 index 0000000000..dd995fb157 --- /dev/null +++ b/torchao/dtypes/floatx/float8_layout.py @@ -0,0 +1,313 @@ +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.utils import AQTTensorImpl, Layout, get_out_shape +from torchao.float8.inference import ( + Float8MMConfig, + _is_rowwise_scaled, + addmm_float8_unwrapped_inference, + preprocess_data, +) +from torchao.utils import _is_float8_type, fill_defaults + +aten = torch.ops.aten + + +@dataclass(frozen=True) +class Float8Layout(Layout): + mm_config: Optional[Float8MMConfig] = None + + +@register_layout(Float8Layout) +class Float8AQTTensorImpl(AQTTensorImpl): + """ + TensorImpl for float8 layout affine quantized tensor + + Note: technically we should not create a new layout for float8 we should merge this into + plain layout + """ + + float8_data: torch.Tensor + scale: torch.Tensor + transposed: bool + + def __new__( + cls, + float8_data: torch.Tensor, + scale: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = float8_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else float8_data.layout + ) + kwargs["dtype"] = float8_data.dtype + kwargs["requires_grad"] = False + shape = float8_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + float8_data: torch.Tensor, + scale: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + self.float8_data = float8_data + self.scale = scale + self.transposed = transposed + self._layout = _layout + + def _apply_fn_to_data(self, fn): + """Applys a fn to all tensor components stored on this class""" + return self.__class__( + fn(self.float8_data), + fn(self.scale), + self.transposed, + self._layout, + ) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self.__class__( + self.float8_data.to(kwargs["device"]), + self.scale.to(kwargs["device"]), + self.transposed, + self._layout, + ) + + def __tensor_flatten__(self): + return ["float8_data", "scale"], [self.transposed, self._layout] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + float8_data, scale = tensor_data_dict["float8_data"], tensor_data_dict["scale"] + ( + transposed, + _layout, + ) = tensor_attributes + return cls(float8_data, scale, transposed, _layout) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + elif func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + elif func is aten.t.default: + """we don't need to repack the weight and just rely on external + shape being changed and record the status of transpose/no-transpose + """ + args[0].transposed = not args[0].transposed + return return_and_correct_aliasing(func, args, kwargs, args[0]) + elif func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + # TODO: scale replecation should be dependent on block size + if self.scale.ndim == 1: + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data( + lambda x: aten.slice.Tensor(x, dim, start, end, step) + ), + ) + elif self.scale.ndim == 0: + return return_and_correct_aliasing( + func, + args, + kwargs, + Float8AQTTensorImpl( + aten.slice.Tensor(self.float8_data, dim, start, end, step), + self.scale, + None, + self._layout, + ), + ) + else: + raise NotImplementedError( + f"Float8AQTTensorImpl dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported" + ) + elif dim == 1: + return return_and_correct_aliasing( + func, + args, + kwargs, + Float8AQTTensorImpl( + aten.slice.Tensor( + self.float8_data, dim, start, end, step + ).contiguous(), + self.scale, + None, + self._layout, + ), + ) + else: + raise NotImplementedError( + f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) + else: + raise NotImplementedError( + f"Float8AQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + return self.float8_data, self.scale, None + + def get_layout(self) -> Layout: + return self._layout + + @classmethod + def from_plain( + cls, + data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + """Main entrypoint for constructing Float8TensorImpl""" + assert _is_float8_type( + data.dtype + ), f"Float8 TensorImpl must be constructed from float8 dtype but got {data.dtype}" + assert isinstance( + _layout, Float8Layout + ), f"Float8 TensorImpl must be constructed from Float8Layout but got {_layout}" + return cls(data, scale, False, _layout) + + def __repr__(self): + float8_data, scale, _ = self.get_plain() + _layout = self.get_layout() + return ( + f"{self.__class__.__name__}(\n" + f"float8_data={float8_data},\n" + f"scale={scale},\n" + f"transposed={self.transposed}, " + f"_layout={_layout})" + ) + + +########################## +# Float8 Dispatch Kernels +########################## + + +def _linear_fp8_act_fp8_weight_check( + input_tensor: Union[torch.Tensor, "AffineQuantizedTensor"], + weight_tensor: Union[torch.Tensor, "AffineQuantizedTensor"], + bias: Optional[torch.Tensor], +) -> bool: + def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: + return ( + isinstance(aqt, AffineQuantizedTensor) + and isinstance(aqt._layout, Float8Layout) + and aqt.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt)) + ) + + return check_aqt(input_tensor) and check_aqt(weight_tensor) + + +def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]): + """Ensures input tensor is correctly formated for _scaled_mm""" + input_scale = input_scale.unsqueeze(-1) + + if input_scale.dim() > 2: + input_scale = input_scale.reshape(-1, input_scale.shape[-1]) + + return input_scale + + +def _linear_fp8_act_fp8_weight_impl( + input_tensor: "AffineQuantizedTensor", + weight_tensor: "AffineQuantizedTensor", + bias: Optional[torch.Tensor], +): + """Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm""" + scaled_mm_config = weight_tensor._layout.mm_config + out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) + + # Weight tensor preprocessing + w_tensor_impl = weight_tensor.tensor_impl + assert not w_tensor_impl.transposed, "Weight tensor must be contiguous" + w_data = w_tensor_impl.float8_data + w_scale = w_tensor_impl.scale + + # Input tensor preprocessing + inpt_data = input_tensor.tensor_impl.float8_data + input_scale = input_tensor.tensor_impl.scale + # Handle case where input tensor is more than 2D + inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1]) + + # Handle rowwise case + if _is_rowwise_scaled(weight_tensor): + assert _is_rowwise_scaled( + input_tensor + ), "Input tensor must be rowwise block size" + w_scale = w_scale.unsqueeze(-1).T + input_scale = preprocess_scale(input_scale, input_tensor.shape) + + # Preprocess data + inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) + + # Perform the computation + return addmm_float8_unwrapped_inference( + inpt_data, + input_scale, + w_data, + w_scale, + output_dtype=input_tensor.dtype, + bias=bias, + use_fast_accum=scaled_mm_config.use_fast_accum, + ).reshape(out_shape) + + +def _linear_fp_act_fp8_weight_check( + input_tensor: Union[torch.Tensor, "AffineQuantizedTensor"], + weight_tensor: Union[torch.Tensor, "AffineQuantizedTensor"], + bias: Optional[torch.Tensor], +) -> bool: + return ( + # input is native float tensor + not is_traceable_wrapper_subclass(input_tensor) + and input_tensor.is_floating_point() + and + # weight is float8 quantized affine quantized tensor + isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, Float8Layout) + and weight_tensor.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + and ( + weight_tensor.shape == weight_tensor.block_size + or _is_rowwise_scaled(weight_tensor) + ) + ) + + +def _linear_fp_act_fp8_weight_impl( + input_tensor: torch.Tensor, + weight_tensor: "AffineQuantizedTensor", + bias: Optional[torch.Tensor], +): + return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias) diff --git a/torchao/dtypes/floatx/floatx.py b/torchao/dtypes/floatx/floatx_tensor_core_layout.py similarity index 89% rename from torchao/dtypes/floatx/floatx.py rename to torchao/dtypes/floatx/floatx_tensor_core_layout.py index 6f99ab11d0..0f67e9826e 100644 --- a/torchao/dtypes/floatx/floatx.py +++ b/torchao/dtypes/floatx/floatx_tensor_core_layout.py @@ -4,10 +4,17 @@ import torch from torch import Tensor -from torch.utils._python_dispatch import return_and_correct_aliasing +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + return_and_correct_aliasing, +) -from torchao.dtypes.affine_quantized_tensor import AQTTensorImpl, register_layout +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) from torchao.dtypes.utils import ( + AQTTensorImpl, Layout, ) from torchao.prototype.custom_fp_utils import ( @@ -441,8 +448,6 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> # quantization api integrations - - @dataclass(frozen=True) class FloatxTensorCoreLayout(Layout): """Layout type for FloatxTensorCoreAQTTensorImpl""" @@ -600,3 +605,55 @@ def __torch_dispatch__(cls, func, types, args, kwargs): def get_layout(self) -> Layout: return self._layout + + +def _linear_f16_bf16_act_floatx_weight_check(input_tensor, weight_tensor, bias): + from torchao.dtypes.floatx import FloatxTensorCoreLayout + + return ( + # input is native float32 tensor + not is_traceable_wrapper_subclass(input_tensor) + and input_tensor.is_floating_point() + and input_tensor.dtype in (torch.float16, torch.bfloat16) + and + # weight is floatx Tensor + isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, FloatxTensorCoreLayout) + and ( + # weight is using fp6 quantization + (weight_tensor._layout.ebits == 3 and weight_tensor._layout.mbits == 2) + or (weight_tensor._layout.ebits == 2 and weight_tensor._layout.mbits == 3) + or + # weight is using fp5 quantization + (weight_tensor._layout.ebits == 2 and weight_tensor._layout.mbits == 2) + or (weight_tensor._layout.ebits == 3 and weight_tensor._layout.mbits == 1) + ) + ) + + +def _linear_f16_bf16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): + from torchao.ops import quant_llm_linear + + act = input_tensor + weight = weight_tensor + + out_dim, in_dim = weight.shape + act_reshaped = act.view(-1, in_dim) + + # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py + bsize = act_reshaped.shape[0] + splitK = _SPLIT_K_MAP[(bsize - 1) // 64].get(out_dim, 1) if bsize <= 768 else 1 + + out = quant_llm_linear( + weight._layout.ebits, + weight._layout.mbits, + act_reshaped, + weight.tensor_impl.packed_floatx_data, + weight.tensor_impl.scale, + splitK=splitK, + ) + + if bias is not None: + out += bias + + return out.view(*act.shape[:-1], out_dim).to(act.dtype) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 3771f9d4ba..14a8c2d43e 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -980,35 +980,44 @@ def decorator(func): @implements_torch_function(torch.Tensor.to) def function_to_dtype(*args, **kwargs): tensor = args[0] - if isinstance(args[1], torch.dtype): - # Tensor.to(dtype, non_blocking, copy, memory_format) - return tensor.get_original_weight().to(*args[1:], **kwargs) - elif ( - isinstance(args[1], torch.device) - or ( - isinstance(args[1], str) - and (args[1] == "cpu" or args[1].startswith("cuda")) + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( + *args[1:], **kwargs + ) + + # dtype is specified -> dequantize + if dtype is not None: + return tensor.get_original_weight().to( + device, dtype, non_blocking, memory_format=convert_to_format ) - ) and len(args) == 2: - # Tensor.to(device, non_blocking) - device = args[1] - updated_attrs = call_from_inner_tensors(tensor, "to", args[1:], kwargs) - updated_attrs["device"] = device - return NF4Tensor(*construct_nf4_args(tensor, updated_attrs)) - else: - # Tensor.to(device, dtype, non_blocking, copy, memory_format) - # Tensor.to(other, non_blocking, copy) - raise NotImplementedError( - f"NF4Tensor.to({args[1:]}, {kwargs}) is not supported, passing to dispatch" + + # dtype is not specified -> keep NF4 + updated_attrs = dict(device=device) + tensor_attrs, _ = tensor.__tensor_flatten__() + for attr in tensor_attrs: + inner_tensor = getattr(tensor, attr) + updated_attrs[attr] = inner_tensor.to( + device, dtype, non_blocking, memory_format=convert_to_format ) + return NF4Tensor(*construct_nf4_args(tensor, updated_attrs)) @implements_torch_function(torch.Tensor.cpu) def function_cpu(*args, **kwargs): - nf4tensor = args[0] - updated_attrs = call_from_inner_tensors(nf4tensor, "cpu", args[1:], kwargs) - updated_attrs["device"] = "cpu" - return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + # Tensor.cpu(self, memory_format) + return args[0].to("cpu", *args[1:], **kwargs) + + +@implements_torch_function(torch.Tensor.cuda) +def function_cuda(*args, **kwargs): + # Tensor.cuda(self, device, non_blocking, memory_format) + tensor = args[0] + updated_attrs = dict() + tensor_attrs, _ = tensor.__tensor_flatten__() + for attr in tensor_attrs: + inner_tensor = getattr(tensor, attr) + updated_attrs[attr] = inner_tensor.cuda(*args[1:], **kwargs) + updated_attrs["device"] = updated_attrs[tensor_attrs[0]].device + return NF4Tensor(*construct_nf4_args(tensor, updated_attrs)) @implements_torch_function(F.linear) diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index b52b37f5e6..8fba2bb678 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,15 +1,29 @@ -from .uintx import ( - _DTYPE_TO_BIT_WIDTH, - UintxAQTTensorImpl, +from .block_sparse_layout import ( + BlockSparseLayout, +) +from .marlin_qqq_layout import ( + MarlinQQQLayout, +) +from .marlin_sparse_layout import ( + MarlinSparseLayout, +) +from .semi_sparse_layout import ( + SemiSparseLayout, +) +from .tensor_core_tiled_layout import ( + Int4CPULayout, + TensorCoreTiledLayout, +) +from .uintx_layout import ( UintxLayout, - UintxTensor, - to_uintx, ) __all__ = [ - "UintxTensor", "UintxLayout", - "UintxAQTTensorImpl", - "to_uintx", - "_DTYPE_TO_BIT_WIDTH", + "BlockSparseLayout", + "MarlinSparseLayout", + "SemiSparseLayout", + "TensorCoreTiledLayout", + "Int4CPULayout", + "MarlinQQQLayout", ] diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py new file mode 100644 index 0000000000..0670986b13 --- /dev/null +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -0,0 +1,222 @@ +import logging +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.uintx.plain_layout import ( + PlainAQTTensorImpl, + _aqt_is_int8_reduced_range, +) +from torchao.dtypes.utils import ( + Layout, + PlainLayout, +) + +logger = logging.getLogger(__name__) + +aten = torch.ops.aten + + +@dataclass(frozen=True) +class BlockSparseLayout(Layout): + blocksize: int = 64 + + +@register_layout(BlockSparseLayout) +class BlockSparseAQTTensorImpl(PlainAQTTensorImpl): + bsr_crow_indices: Optional[torch.Tensor] + bsr_col_indices: Optional[torch.Tensor] + bsr_values: Optional[torch.Tensor] + scale: Optional[torch.Tensor] + zero_point: Optional[torch.Tensor] + + __slots__ = [ + "bsr_crow_indices", + "bsr_col_indices", + "bsr_values", + "scale", + "zero_point", + ] + + @staticmethod + def __new__( # noqa: PYI034 + cls, + shape: torch.Size, + bsr_crow_indices: Optional[torch.Tensor], + bsr_col_indices: Optional[torch.Tensor], + bsr_values: Optional[torch.Tensor], + scale: Optional[torch.Tensor], + zero_point: Optional[torch.Tensor], + _layout: Layout, + requires_grad: bool = False, + ): + if bsr_values is None: + raise ValueError("bsr values must be provided!") + else: + previous_tensor = bsr_values + + kwargs = { + "device": previous_tensor.device, + "dtype": previous_tensor.dtype, + "layout": previous_tensor.layout, + "requires_grad": requires_grad, + } + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( # noqa: PYI034 + self, + shape: torch.Size, + bsr_crow_indices: Optional[torch.Tensor], + bsr_col_indices: Optional[torch.Tensor], + bsr_values: Optional[torch.Tensor], + scale: Optional[torch.Tensor], + zero_point: Optional[torch.Tensor], + _layout: Layout, + requires_grad: bool = False, + ): + self.bsr_crow_indices = bsr_crow_indices + self.bsr_col_indices = bsr_col_indices + self.bsr_values = bsr_values + self.scale = scale + self.zero_point = zero_point + self._layout = _layout + + def __tensor_flatten__(self): + inner_tensors = list( + filter(lambda x: getattr(self, x) is not None, self.__slots__) + ) + tensor_meta = (self.shape, self._layout, self.requires_grad) + return inner_tensors, tensor_meta + + @classmethod + def __tensor_unflatten__( + cls, + inner_tensors, + tensor_meta: Tuple[torch.Size, bool], + outer_size, + outer_stride, + ) -> torch.Tensor: + shape, _layout, requires_grad = tensor_meta + return cls( + shape=shape, + bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), + bsr_col_indices=inner_tensors.get("bsr_col_indices", None), + bsr_values=inner_tensors.get("bsr_values", None), + scale=inner_tensors.get("scale", None), + zero_point=inner_tensors.get("zero_point", None), + _layout=_layout, + requires_grad=requires_grad, + ) + + @classmethod + def from_plain(cls, int_data, scale, zero_point, _layout): + bsr_tensor = int_data.to_sparse_bsr(_layout.blocksize) + return cls( + shape=int_data.shape, + bsr_crow_indices=bsr_tensor.crow_indices(), + bsr_col_indices=bsr_tensor.col_indices(), + bsr_values=bsr_tensor.values(), + scale=scale, + zero_point=zero_point, + _layout=_layout, + requires_grad=False, + ) + + def get_plain(self): + int_data_expanded = torch.ops.blocksparse.bsr_to_dense( + self.crow_indices(), + self.col_indices(), + self.values(), + self.shape[0], + self.shape[1], + ) + return int_data_expanded, self.scale, self.zero_point + + def _apply_fn_to_data(self, func): + return self.__class__( + shape=self.shape, + bsr_crow_indices=func(self.bsr_crow_indices), + bsr_col_indices=func(self.bsr_col_indices), + bsr_values=func(self.bsr_values), + scale=self.scale, + zero_point=self.zero_point, + _layout=self._layout, + requires_grad=self.requires_grad, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + # Need the following for bsr specific functions + if func is aten.crow_indices.default: + return args[0].bsr_crow_indices.detach() + + if func is aten.col_indices.default: + return args[0].bsr_col_indices.detach() + + if func is aten.values.default: + return args[0].bsr_values.detach() + + if func is aten._nnz.default: + return args[0].bsr_values.shape[0] + + raise NotImplementedError( + f"BlockSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + +def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int8_reduced_range(input_tensor) + and isinstance(weight_tensor, AffineQuantizedTensor) + and weight_tensor.is_cuda + and input_tensor.dtype == weight_tensor.dtype + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor._layout, BlockSparseLayout) + ) + + +def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, bias): + x_vals_int8 = input_tensor.tensor_impl.int_data + x_scales = input_tensor.tensor_impl.scale + w_vals = weight_tensor.tensor_impl + w_scales = weight_tensor.tensor_impl.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + tmp_t = tmp.t() + + y = torch.ops.blocksparse.int_addmm( + w_vals.crow_indices(), + w_vals.col_indices(), + w_vals.values(), + tmp_t, + w_scales, + x_scales.reshape(-1), + ) + y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1]) + y = y.reshape(*y_shape) + + # can downcast only at the very end + output_dtype = input_tensor.dtype + y = y.to(output_dtype) + if bias is not None: + y += bias + return y diff --git a/torchao/dtypes/uintx/marlin_qqq_layout.py b/torchao/dtypes/uintx/marlin_qqq_layout.py new file mode 100644 index 0000000000..c3b2a78394 --- /dev/null +++ b/torchao/dtypes/uintx/marlin_qqq_layout.py @@ -0,0 +1,281 @@ +import logging +from dataclasses import dataclass + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.uintx.plain_layout import ( + _aqt_is_int8_reduced_range, +) +from torchao.dtypes.utils import AQTTensorImpl, Layout + +logger = logging.getLogger(__name__) + +aten = torch.ops.aten + + +@dataclass(frozen=True) +class MarlinQQQLayout(Layout): + pass + + +@register_layout(MarlinQQQLayout) +class MarlinQQQAQTTensorImpl(AQTTensorImpl): + """ + TensorImpl storage class for sparse_qqq layout for affine quantized tensor. + + Can only be used with 4 bits quantization for now. + + Original marlin documentation and information: + https://github.com/IST-DASLab/marlin/tree/master + + Marlin qqq information: + https://github.com/HandH1998/QQQ/tree/main + https://arxiv.org/pdf/2406.09904 + + fields: + original_shape (torch.Size): the original shape of the tensor. used to unpack the tensor to the original shape + group_size (int): the group size used to pack the tensor + num_bits (int): the number of bits used to quantize the tensor + """ + + @staticmethod + def __new__( + cls, + int_data: torch.Tensor, + s_group: torch.Tensor, + s_channel: torch.Tensor, + _layout: Layout, + original_shape: torch.Size, + group_size: int, + num_bits: int, + ): + kwargs = {} + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + kwargs["dtype"] = int_data.dtype + kwargs["requires_grad"] = False + shape = int_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + s_group: torch.Tensor, + s_channel: torch.Tensor, + _layout: Layout, + original_shape: torch.Size, + group_size: int, + num_bits: int, + ): + self.int_data = int_data + self.s_group = s_group + self.s_channel = s_channel + self._layout = _layout + self.original_shape = original_shape + self.group_size = group_size + self.num_bits = num_bits + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + raise NotImplementedError( + f"MarlinQQQAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + def __tensor_flatten__(self): + return ["int_data", "s_group", "s_channel"], [ + self._layout, + self.original_shape, + self.group_size, + self.num_bits, + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + int_data = tensor_data_dict["int_data"] + s_group = tensor_data_dict["s_group"] + s_channel = tensor_data_dict["s_channel"] + _layout, original_shape, group_size, num_bits = tensor_attributes + return cls( + int_data, s_group, s_channel, _layout, original_shape, group_size, num_bits + ) + + def get_plain(self): + from torchao.quantization.marlin_qqq import ( + unpack_from_marlin_qqq, + ) # avoid circular import + + int_data_expanded, s_group_expanded, s_channel_expanded = ( + unpack_from_marlin_qqq( + self.int_data, + self.s_group, + self.s_channel, + self.original_shape, + self.num_bits, + self.group_size, + ) + ) + int_data_expanded_t = int_data_expanded.t() + s_group_expanded_t = s_group_expanded.t() + s_channel_expanded_t = s_channel_expanded.t() + return int_data_expanded_t, s_group_expanded_t, s_channel_expanded_t + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + s_group: torch.Tensor, + s_channel: torch.Tensor, + _layout: Layout, + ): + from torchao.quantization.marlin_qqq import ( + const, + pack_to_marlin_qqq, + ) # avoid circular import + + assert isinstance(_layout, MarlinQQQLayout) + + # Linear layers are (in_features, out_features) but the int_data that is reaching this point + # is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code. + q_w = int_data.t() + s_group_t = s_group.t() + s_channel_t = s_channel.t() + + if not torch.cuda.get_device_capability()[0] >= 8: + raise ValueError( + f"Can not use Marlin QQQ int4*int8 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel." + ) + + if q_w.dtype != torch.int32: + raise ValueError("Only `torch.int32` weights are supported.") + + in_features, out_features = q_w.shape + # (thread_k, thread_n) + thread_config = [(64, 256), (128, 128), (128, 64), (64, 128)] + if not any( + [ + in_features % thread_k == 0 and out_features % thread_n == 0 + for thread_k, thread_n in thread_config + ] + ): + raise ValueError( + "Not supported `in_features`: {} and `out_features`: {}.".format( + in_features, out_features + ) + ) + + num_bits = 4 if torch.max(q_w) - torch.min(q_w) < 16 else -1 + if num_bits not in [4]: + raise ValueError(f"Only {[4]} bits are supported, got {num_bits}.") + + if s_group.numel() == 0: + group_size = -1 + else: + group_size = in_features // s_group_t.shape[0] + assert ( + group_size <= in_features + ), "Group size must be less than or equal to in_features." + + if group_size not in const.SUPPORTED_GROUP_SIZES: + raise ValueError( + f"Only {const.SUPPORTED_GROUP_SIZES} group sizes are supported, got {group_size}." + ) + + # Compress quantized weight to marlin format + marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = pack_to_marlin_qqq( + q_w, s_group_t, s_channel_t, num_bits, group_size + ) + + return cls( + marlin_qqq_q_w, + marlin_qqq_s_group, + marlin_qqq_s_channel, + _layout, + q_w.shape, + group_size, + num_bits, + ) + + def get_layout(self) -> Layout: + return self._layout + + def _apply_fn_to_data(self, fn): + self.int_data = fn(self.int_data) + self.s_group = fn(self.s_group) + self.s_channel = fn(self.s_channel) + return self + + +def _linear_int8_act_int4_weight_marlin_qqq_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int8_reduced_range(input_tensor) + and input_tensor.dtype == torch.float16 + and input_tensor.tensor_impl.scale.dtype == torch.float32 + and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 + and isinstance(weight_tensor, AffineQuantizedTensor) + and weight_tensor.tensor_impl.dtype == torch.int32 + and len(weight_tensor.shape) == 2 + and isinstance(weight_tensor._layout, MarlinQQQLayout) + ) + + +def _linear_int8_act_int4_weight_marlin_qqq_impl(input_tensor, weight_tensor, bias): + from torchao.ops import marlin_qqq_gemm + from torchao.quantization.marlin_qqq import marlin_qqq_workspace + + assert isinstance(input_tensor, AffineQuantizedTensor) + assert isinstance(weight_tensor, AffineQuantizedTensor) + + input = input_tensor.tensor_impl.int_data + input_scale = input_tensor.tensor_impl.scale + + w_int4 = weight_tensor.tensor_impl.int_data + s_group = weight_tensor.tensor_impl.s_group + s_channel = weight_tensor.tensor_impl.s_channel + original_shape = weight_tensor.tensor_impl.original_shape + + # Folds batch dimension into the first dimension + input_2d = input.view(-1, input.shape[-1]) + input_scale = input_scale.view(1, -1) + + size_m = input_2d.shape[0] + size_n = s_channel.shape[1] + size_k = input_2d.shape[1] + workspace_qqq = marlin_qqq_workspace(original_shape[1]) + + out = marlin_qqq_gemm( + input_2d, + w_int4, + input_scale, + s_channel, + s_group, + workspace_qqq, + size_m, + size_n, + size_k, + ) + + # Unfold the batch dimension + out = out.reshape(input.shape[:-1] + (s_channel.shape[1],)) + + if bias is not None: + out += bias.to(out.dtype) + return out diff --git a/torchao/dtypes/uintx/marlin_sparse_layout.py b/torchao/dtypes/uintx/marlin_sparse_layout.py new file mode 100644 index 0000000000..e37623182a --- /dev/null +++ b/torchao/dtypes/uintx/marlin_sparse_layout.py @@ -0,0 +1,289 @@ +from dataclasses import dataclass + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.uintx.tensor_core_tiled_layout import _aqt_is_tensor_core_tile_uint4 +from torchao.dtypes.utils import AQTTensorImpl, Layout +from torchao.quantization.quant_primitives import ( + ZeroPointDomain, +) + +aten = torch.ops.aten + + +def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, bias): + return ( + isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_tensor_core_tile_uint4(weight_tensor) + and input_tensor.dtype == torch.float16 + and len(weight_tensor.shape) == 2 + and weight_tensor.zero_point_domain == ZeroPointDomain.INT + and isinstance(weight_tensor._layout, MarlinSparseLayout) + ) + + +def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, bias): + from torchao.ops import marlin_24_gemm + from torchao.sparsity.marlin import marlin_24_workspace + + assert isinstance(weight_tensor, AffineQuantizedTensor) + + sparse_w_int4 = weight_tensor.tensor_impl.int_data + scale = weight_tensor.tensor_impl.scale + meta = weight_tensor.tensor_impl.meta + original_shape = weight_tensor.tensor_impl.original_shape + num_bits = weight_tensor.tensor_impl.num_bits + + # Folds batch dimension into the first dimension + input_2d = input_tensor.view(-1, input_tensor.shape[-1]) + + size_m = input_2d.shape[0] + size_n = scale.shape[1] + size_k = input_2d.shape[1] + workspace_24 = marlin_24_workspace(original_shape[1]) + + out = marlin_24_gemm( + input_2d, + sparse_w_int4, + meta, + scale, + workspace_24, + num_bits, + size_m, + size_n, + size_k, + ) + + # Unfold the batch dimension + out = out.reshape(input_tensor.shape[:-1] + (scale.shape[1],)) + + if bias is not None: + out += bias.to(out.dtype) + return out + + +@dataclass(frozen=True) +class MarlinSparseLayout(Layout): + def pre_process(self, input: torch.Tensor) -> torch.Tensor: + """Preprocess the input tensor to be in the correct format for the Marlin sparse kernel. + - 1º: the input tensor is transposed since the linear layer keeps the weights in a transposed format + - 2º: tensor is injected with 2:4 sparsity + - 3º: transposes it again because the quantization process will compute the scales for dim=-1 + + Args: + input (torch.Tensor): the input tensor to preprocess + + Returns: + torch.Tensor: the preprocessed tensor + """ + from torchao.sparsity.marlin import inject_24 # avoid circular import + + input_t = input.t() + w_24, _ = inject_24(input_t, *input_t.shape) + return w_24.t() + + +@register_layout(MarlinSparseLayout) +class MarlinSparseAQTTensorImpl(AQTTensorImpl): + """ + TensorImpl for sparse_marlin_24 layout for affine quantized tensor. + + Can be used with 4 bits and 8 bits quantization. + + Original marlin documentation and information: + https://github.com/IST-DASLab/marlin/tree/master + + Sparse marlin documentation and information: + https://github.com/IST-DASLab/Sparse-Marlin?tab=readme-ov-file + + fields: + original_shape (torch.Size): the original shape of the tensor. used to unpack the tensor to the original shape + group_size (int): the group size used to pack the tensor + num_bits (int): the number of bits used to quantize the tensor + """ + + @staticmethod + def __new__( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + meta: torch.Tensor, + _layout: Layout, + original_shape: torch.Size, + group_size: int, + num_bits: int, + ): + kwargs = {} + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + kwargs["dtype"] = int_data.dtype + kwargs["requires_grad"] = False + shape = int_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + meta: torch.Tensor, + _layout: Layout, + original_shape: torch.Size, + group_size: int, + num_bits: int, + ): + self.int_data = int_data + self.scale = scale + self.zero_point = zero_point + self.meta = meta + self._layout = _layout + self.original_shape = original_shape + self.group_size = group_size + self.num_bits = num_bits + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + raise NotImplementedError( + f"MarlinSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + def __tensor_flatten__(self): + return ["int_data", "scale", "zero_point", "meta"], [ + self._layout, + self.original_shape, + self.group_size, + self.num_bits, + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + int_data = tensor_data_dict["int_data"] + scale = tensor_data_dict["scale"] + zero_point = tensor_data_dict["zero_point"] + meta = tensor_data_dict["meta"] + _layout, original_shape, group_size, num_bits = tensor_attributes + return cls( + int_data, + scale, + zero_point, + meta, + _layout, + original_shape, + group_size, + num_bits, + ) + + def get_plain(self): + from torchao.sparsity.marlin import ( + unpack_from_marlin_24, + ) # avoid circular import + + int_data_expanded, scales_expanded = unpack_from_marlin_24( + self.int_data, + self.scale, + self.meta, + self.original_shape, + self.group_size, + self.num_bits, + ) + int_data_expanded_t = int_data_expanded.t() + scales_expanded_t = scales_expanded.t() + return int_data_expanded_t, scales_expanded_t, self.zero_point + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + _layout: Layout, + ): + from torchao.sparsity.marlin import ( + const, + pack_to_marlin_24, + ) # avoid circular import + + assert isinstance(_layout, MarlinSparseLayout) + + # Linear layers are (in_features, out_features) but the int_data that is reaching this point + # is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code. + q_w_24 = int_data.t() + scale_t = scale.t() + + if not torch.cuda.get_device_capability()[0] >= 8: + raise ValueError( + f"Can not use Sparse Marlin 2:4 int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel." + ) + + if q_w_24.dtype != torch.int32: + raise ValueError("Only `torch.int32` weights are supported.") + + in_features, out_features = q_w_24.shape + if in_features % 128 != 0 or out_features != 256 == 0: + raise ValueError( + "`in_features` must be divisible by 64 and `out_features` by 256." + ) + + # NOTE: The current marlin 2:4 kernel supports both 4 and 8 bits quantization but fp8 + # will require a bit more work to get our current quantization flow to work with it. + # Check the link for a reference: https://github.com/neuralmagic/nm-vllm/tree/main + num_bits = 4 if torch.max(q_w_24) < 16 else -1 + if num_bits not in [4]: + raise ValueError(f"Only {[4]} bits are supported, got {num_bits}.") + + group_size = in_features // scale_t.shape[0] + if group_size == 0: + group_size = in_features + assert ( + group_size <= in_features + ), "Group size must be less than or equal to in_features." + + if group_size not in const.SUPPORTED_GROUP_SIZES: + raise ValueError( + f"Only {const.SUPPORTED_GROUP_SIZES} group sizes are supported, got {group_size}." + ) + + # Compress quantized weight to marlin 2:4 format + marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24( + q_w_24, scale_t, num_bits, group_size + ) + + return cls( + marlin_24_q_w_comp, + marlin_24_s, + zero_point, + meta, + _layout, + q_w_24.shape, + group_size, + num_bits, + ) + + def get_layout(self) -> Layout: + return self._layout + + def _apply_fn_to_data(self, fn): + self.int_data = fn(self.int_data) + self.scale = fn(self.scale) + self.zero_point = fn(self.zero_point) + self.meta = fn(self.meta) + return self diff --git a/torchao/dtypes/uintx/plain_layout.py b/torchao/dtypes/uintx/plain_layout.py new file mode 100644 index 0000000000..ed171634cd --- /dev/null +++ b/torchao/dtypes/uintx/plain_layout.py @@ -0,0 +1,268 @@ +from typing import Optional, Tuple + +import torch +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout +from torchao.kernel import ( + int_scaled_matmul, +) +from torchao.quantization.quant_primitives import ( + ZeroPointDomain, +) +from torchao.utils import fill_defaults + +aten = torch.ops.aten + + +@register_layout(PlainLayout) +class PlainAQTTensorImpl(AQTTensorImpl): + """ + TensorImpl for plain layout for affine quantized tensor, it stores int_data, scale, zero_point + tensors directly as plain tensors. + + fields: + int_data (torch.Tensor): the quantized integer data Tensor + scale (torch.Tensor): the scale Tensor used to map between floating point tensor to quantized tensor + zero_point (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor + """ + + def __new__( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + kwargs["dtype"] = int_data.dtype + kwargs["requires_grad"] = False + shape = int_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + _layout: Layout, + ): + self.int_data = int_data + self.scale = scale + self.zero_point = zero_point + self._layout = _layout + + def __tensor_flatten__(self): + return ["int_data", "scale", "zero_point"], [self._layout] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + int_data, scale, zero_point = ( + tensor_data_dict["int_data"], + tensor_data_dict["scale"], + tensor_data_dict["zero_point"], + ) + (_layout,) = tensor_attributes + return cls(int_data, scale, zero_point, _layout) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self.__class__( + self.int_data.to(kwargs["device"]), + self.scale.to(kwargs["device"]), + self.zero_point.to(kwargs["device"]), + self._layout, + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.int_data), + fn(self.scale), + fn(self.zero_point), + self._layout, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + elif func is aten.t.default: + tensor = args[0] + new = tensor.__class__( + tensor.int_data.t(), tensor.scale, tensor.zero_point, tensor._layout + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + elif func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data( + lambda x: aten.slice.Tensor(x, dim, start, end, step) + ), + ) + elif dim == 1: + assert ( + len(self.scale.shape) == 1 + ), f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}" + return PlainAQTTensorImpl( + aten.slice.Tensor(self.int_data, dim, start, end, step), + self.scale.view(-1), + self.zero_point.view(-1), + self._layout, + ) + else: + raise NotImplementedError( + f"PlainAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) + + raise NotImplementedError( + f"PlainAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return self.int_data, self.scale, self.zero_point + + def get_layout(self) -> Layout: + return self._layout + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert isinstance(_layout, PlainLayout) + return cls(int_data, scale, zero_point, _layout) + + +def _aqt_is_int8(aqt): + """Check if an AffineQuantizedTensor is int8 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.int8 + and (aqt.quant_min is None or aqt.quant_min == -128) + and (aqt.quant_max is None or aqt.quant_max == 127) + ) + + +def _aqt_is_int8_reduced_range(aqt): + return ( + aqt.tensor_impl.dtype == torch.int8 + and aqt.quant_min == -127 + and (aqt.quant_max is None or aqt.quant_max == 127) + ) + + +def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias): + return ( + # input is native float tensor + not is_traceable_wrapper_subclass(input_tensor) + and input_tensor.is_floating_point() + and + # weight is int8 per channel quantized affine quantized tensor + isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_int8(weight_tensor) + and len(weight_tensor.shape) == 2 + and len(weight_tensor.block_size) == 2 + and weight_tensor.block_size[0] == 1 + and weight_tensor.block_size[1] == weight_tensor.shape[1] + and weight_tensor.zero_point_domain == ZeroPointDomain.INT + and isinstance(weight_tensor._layout, PlainLayout) + ) + + +def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): + # TODO: enable cpu and mps efficient path + # is_cpu and is_mps only, some issue with is_contiguous() currently + # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_tensor.tensor_impl.scale) + + # per channel int8 weight only quantizated mm + w_vals_int8_t = weight_tensor.tensor_impl.int_data.t() + scale = weight_tensor.tensor_impl.scale + m = torch.mm( + input_tensor.reshape(-1, input_tensor.shape[-1]), + w_vals_int8_t.to(input_tensor.dtype), + ) + y = m * scale.to(m.dtype) + y = y.reshape(*input_tensor.shape[:-1], y.shape[-1]) + if bias is not None: + y += bias.to(m.dtype) + return y + + +def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int8_reduced_range(input_tensor) + and isinstance(weight_tensor, AffineQuantizedTensor) + and input_tensor.dtype == weight_tensor.dtype + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor._layout, PlainLayout) + ) + + +def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): + # + # 1. do the matrix form of dot(X_i, W_j) + # + # + # 2. rescale the output + # + # in cases with large matrices, y_dot_int32 can grow sufficiently + # large that y_dot_int32 * a float16 scale is greater than the maximum + # value of a float 16, (which results in a value of inf even if multiplying + # by the other scale would bring it within the expected range) + + x_vals_int8 = input_tensor.tensor_impl.int_data + x_scales = input_tensor.tensor_impl.scale + w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t() + w_scales = weight_tensor.tensor_impl.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + x_scales_dtype = x_scales.dtype + # Cast fp16 scale to float to avoid overflow in int_scaled_matmul + intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype + y_dot_scaled = int_scaled_matmul( + tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype) + ) + y_dot_scaled = y_dot_scaled.to(x_scales_dtype) + + y = (y_dot_scaled * w_scales).reshape( + *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] + ) + + # can downcast only at the very end + output_dtype = input_tensor.dtype + y = y.to(output_dtype) + if bias is not None: + y += bias + return y diff --git a/torchao/dtypes/uintx/semi_sparse_layout.py b/torchao/dtypes/uintx/semi_sparse_layout.py new file mode 100644 index 0000000000..d832731657 --- /dev/null +++ b/torchao/dtypes/uintx/semi_sparse_layout.py @@ -0,0 +1,115 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.uintx.plain_layout import ( + PlainAQTTensorImpl, + _aqt_is_int8_reduced_range, +) +from torchao.dtypes.utils import Layout, PlainLayout + +aten = torch.ops.aten + + +def _linear_int8_act_int8_weight_semi_structured_sparse_check( + input_tensor, weight_tensor, bias +): + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int8_reduced_range(input_tensor) + and isinstance(weight_tensor, AffineQuantizedTensor) + and weight_tensor.is_cuda + and input_tensor.dtype == weight_tensor.dtype + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor._layout, SemiSparseLayout) + ) + + +def _linear_int8_act_int8_weight_semi_structured_sparse_impl( + input_tensor, weight_tensor, bias +): + x_vals_int8 = input_tensor.tensor_impl.int_data + x_scales = input_tensor.tensor_impl.scale + w_vals_int8 = weight_tensor.tensor_impl.int_data + w_scales = weight_tensor.tensor_impl.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + # must pad + row, col = tmp.shape + from torch.sparse import SparseSemiStructuredTensorCUSPARSELT + tmp_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(tmp) + # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm + y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( + w_vals_int8, + tmp_padded.t(), + alpha=w_scales.to(torch.float32), + out_dtype=torch.bfloat16, + ).t()[:row, :] + y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( + *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] + ) + output_dtype = input_tensor.dtype + # TODO: waiting for jesse's test/fix + y = y.to(output_dtype).contiguous() + if bias is not None: + y += bias + return y + + +@dataclass(frozen=True) +class SemiSparseLayout(Layout): + def pre_process(self, input: torch.Tensor) -> torch.Tensor: + # prune to 2:4 if not already + temp = input.detach() + pruning_inds = temp.abs().view(-1, 4).argsort(dim=1)[:, :2] + temp.view(-1, 4).scatter_(1, pruning_inds, value=0) + return temp + + +@register_layout(SemiSparseLayout) +class SemiSparseAQTTensorImpl(PlainAQTTensorImpl): + """ + TensorImpl for semi_sparse_cusparselt layout for affine quantized tensor + """ + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + raise NotImplementedError( + f"SparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + def get_plain(self): + # Currently we don't have cuSPARSELt expansion routines, so we matmul by + # the identity matrix to get the original dense matrix. This is slow though. + cols = self.int_data.numel() * 16 // (10 * self.scale.shape[0]) + int_data_expanded = torch._cslt_sparse_mm( + self.int_data, + torch.eye(cols, dtype=self.int_data.dtype, device=self.int_data.device).t(), + ) + return int_data_expanded, self.scale, self.zero_point + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert isinstance(_layout, SemiSparseLayout) + int_data_compressed = torch._cslt_compress(int_data) + return cls(int_data_compressed, scale, zero_point, _layout) diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py new file mode 100644 index 0000000000..df79b653e8 --- /dev/null +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -0,0 +1,643 @@ +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.utils import AQTTensorImpl, Layout, is_device +from torchao.quantization.quant_primitives import ZeroPointDomain, _get_reduction_params +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_6, + fill_defaults, + find_multiple, +) + +aten = torch.ops.aten + + +def _aqt_is_tensor_core_tile_uint4(aqt): + """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" + # TODO: use torch.uint4 + return ( + aqt.tensor_impl.dtype == torch.int32 + and aqt.quant_min == 0 + and aqt.quant_max == 15 + ) + + +def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias): + return ( + # input is native bfloat16 tensor + not is_traceable_wrapper_subclass(input_tensor) + and input_tensor.dtype == torch.bfloat16 + and + # weight is uint4, group quantized tensor_core_tiled tensor impl affine quantized tensor + isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_tensor_core_tile_uint4(weight_tensor) + and weight_tensor.dtype == torch.bfloat16 + and len(weight_tensor.shape) == 2 + and weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT + and isinstance(weight_tensor._layout, TensorCoreTiledLayout) + ) + + +def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): + assert ( + weight_tensor.block_size[0] == 1 + ), f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"need input_tensor shape: {input_tensor.shape} final" + f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " + ) + + # TODO: check groupsize quantization + # avoid circular dep, TODO: move this to a common util.py + act_mat = input_tensor + # weight is packed from padded (out_features, in_features) weight tensor + # (same dimension requirement as F.linear weight) + packed_weight = weight_tensor.tensor_impl.packed_weight + scale_and_zero = weight_tensor.tensor_impl.scale_and_zero + + orig_act_size = act_mat.size() + orig_dtype = act_mat.dtype + + # reshape and pad activation + act_mat = act_mat.reshape(-1, act_mat.shape[-1]).to(torch.bfloat16) + pad_size = find_multiple(act_mat.shape[-1], 1024) + act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1])) + + # groupwise int4 quantization + groupsize = weight_tensor.block_size[1] + if is_device(input_tensor.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + y = torch.ops.aten._weight_int4pack_mm_for_cpu( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) + else: + y = torch.ops.aten._weight_int4pack_mm( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) + + # remove out_feature padding + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + if bias is not None: + y += bias + return y.to(orig_dtype) + + +@dataclass(frozen=True) +class TensorCoreTiledLayout(Layout): + """ + inner_k_tiles is an internal argument for packing function of tensor core tiled layout + that can affect the performance of the matmul kernel + """ + + inner_k_tiles: int = 8 + + def pre_process(self, input: torch.Tensor) -> torch.Tensor: + orig_out_features, orig_in_features = input.shape + in_features = find_multiple(orig_in_features, 1024) + out_features = find_multiple(orig_out_features, 8) + input = torch.nn.functional.pad( + input, + (0, in_features - orig_in_features, 0, out_features - orig_out_features), + ) + return input + + def pre_process_static( + self, + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + block_size: Tuple[int, ...], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + input = self.pre_process(input) + orig_qparam_shape = scale.shape + new_qparam_shape, reduction_dims = _get_reduction_params( + block_size, input.size() + ) + for dim in reduction_dims: + new_qparam_shape.pop(dim) + change_in_qparam_shape = [ + new_dim_size - orig_dim_size + for new_dim_size, orig_dim_size in zip(new_qparam_shape, orig_qparam_shape) + ] + padding_changes = [] + for dim_change in change_in_qparam_shape: + padding_changes = [0, dim_change] + padding_changes + scale = torch.nn.functional.pad(scale, padding_changes) + zero_point = torch.nn.functional.pad(zero_point, padding_changes) + return input, scale, zero_point + + def post_process(self, input: torch.Tensor) -> torch.Tensor: + orig_out_features, orig_in_features = input.shape + in_features = find_multiple(orig_in_features, 1024) + out_features = find_multiple(orig_out_features, 8) + input = torch.nn.functional.pad( + input, + (0, in_features - orig_in_features, 0, out_features - orig_out_features), + ) + return input + + def extra_repr(self): + return f"inner_k_tiles={self.inner_k_tiles}" + + +@register_layout(TensorCoreTiledLayout) +class TensorCoreTiledAQTTensorImpl(AQTTensorImpl): + """ + TensorImpl for tensor_core_tiled layout for affine quantized tensor, this is for int4 only, + used by tinygemm kernels `_weight_int4pack_mm` + + It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 4-d tensor of + dimension: [n / 8][k / (inner_k_tiles * 16)][32][inner_k_tiles / 2] + (unpacked Tensor shape is n * k) + where inner_k_tiles is an internal argument for packing function of tensor core tiled layout + that can affect the performance of the matmul kernel (defaults to 8) + + Note: we also pack scale and zero point together here for tinygemm kernel + + Note: technically tensor core tiled layout should be the layout for the underlying packed weight + (int Tensor) but since the scale and zero_point are also packed into the same tensor here which is not used + in plain layout, we just created a layout for AQT right now, this could be improved if we split out + int4 aqt into a separate tensor subclass + + fields: + packed_weight (torch.Tensor): the 4-d packed tensor in a tensor_core_tiled layout + scale_and_zero (torch.Tensor): the combined scale Tensor used to map between floating point tensor to quantized tensor and zero_point Tensor + """ + + def __new__( + cls, + packed_weight: torch.Tensor, + scale_and_zero: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["layout"] = ( + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_weight.layout + ) + kwargs["dtype"] = packed_weight.dtype + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + scale_and_zero: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + self.packed_weight = packed_weight + self.scale_and_zero = scale_and_zero + self.transposed = False + self._layout = _layout + + def __tensor_flatten__(self): + return ["packed_weight", "scale_and_zero"], [self.transposed, self._layout] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight, scale_and_zero = ( + tensor_data_dict["packed_weight"], + tensor_data_dict["scale_and_zero"], + ) + ( + transposed, + _layout, + ) = tensor_attributes + return cls(packed_weight, scale_and_zero, transposed, _layout) + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert isinstance(_layout, TensorCoreTiledLayout) + + if TORCH_VERSION_AT_LEAST_2_5: + int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) + assert ( + int_data.dtype == torch.uint8 + ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" + else: + assert ( + int_data.dtype == torch.int32 + ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" + packed_weight = torch.ops.aten._convert_weight_to_int4pack( + int_data, _layout.inner_k_tiles + ) + scale = scale.reshape(int_data.shape[0], -1) + zero_point = zero_point.reshape(int_data.shape[0], -1) + from torchao.quantization.utils import pack_tinygemm_scales_and_zeros + + scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) + return cls(packed_weight, scale_and_zero, False, _layout) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs["device"] + # tensor core tiled layout supports both cpu and cuda but does not support the conversion + # between these two devices, in the future we should not use the same layout for + # cpu and cuda device: https://github.com/pytorch/ao/issues/1117 + if not is_device(torch.device(self.device).type, device): + raise ValueError( + f"TensorCoreTiledAQTTensorImpl does not support conversion from {self.device} to {device}" + ) + return self.__class__( + self.packed_weight.to(device), + self.scale_and_zero.to(device), + self.transposed, + self._layout, + ) + + def _apply_fn_to_data(self, fn): + # self.packed_weight = fn(self.packed_weight) + # self.scale_and_zero = fn(self.scale_and_zero) + # return self + return self.__class__( + fn(self.packed_weight), + fn(self.scale_and_zero), + self.transposed, + self._layout, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + if func is aten.t.default: + """we don't need to repack the weight and just rely on external + shape being changed and record the status of transpose/no-transpose + """ + transposed = TensorCoreTiledAQTTensorImpl( + args[0].packed_weight, + args[0].scale_and_zero, + not args[0].transposed, + args[0]._layout, + ) + return return_and_correct_aliasing(func, args, kwargs, transposed) + + if func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + int_data, scale, zero_point = self.get_plain() + int_data = aten.slice.Tensor(int_data, dim, start, end, step) + # this is to handle padding + int_data = self._layout.post_process(int_data) + sliced = self.from_plain(int_data, scale, zero_point, self._layout) + return return_and_correct_aliasing(func, args, kwargs, sliced) + elif dim == 1: + int_data, scale, zero_point = self.get_plain() + assert step == 1, "Only step == 1 is supported in slicing right now" + data_len = int_data.shape[dim] + scale_len = scale.shape[dim] + ratio = data_len / scale_len + start_scale = int(start / ratio) + end_scale = int(end / ratio) + + int_data = aten.slice.Tensor(int_data, dim, start, end, step) + # this is to handle padding + int_data = self._layout.post_process(int_data) + scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) + zero_point = aten.slice.Tensor( + zero_point, dim, start_scale, end_scale, step + ) + sliced = self.from_plain(int_data, scale, zero_point, self._layout) + return sliced + else: + raise NotImplementedError( + f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) + + raise NotImplementedError( + f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + from torchao.quantization.quant_primitives import ( + ZeroPointDomain, + quantize_affine, + ) + from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros + + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) + + cur_shape = self.shape + assert len(cur_shape) == 4 + inner_k_tiles = cur_shape[-1] * 2 + original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16)) + eye_shape = original_shape[1] + groupsize = int(original_shape[1] / scale.shape[-2]) + block_size = (1, groupsize) + device = self.device + original_dtype = torch.bfloat16 + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + zero_point_domain = ZeroPointDomain.FLOAT + assert len(block_size) == 2 and block_size[0] == 1 + dequantized = torch.ops.aten._weight_int4pack_mm( + torch.eye(eye_shape, device=device, dtype=original_dtype), + self.packed_weight, + groupsize, + self.scale_and_zero, + ) + dequantized = dequantized.t().contiguous() + # TODO: move this to `unpack_tinygemm_scales_and_zeros`? + scale = scale.reshape(scale.shape[:-1]).contiguous() + zero = zero.reshape(zero.shape[:-1]).contiguous() + int_data = quantize_affine( + dequantized, + block_size, + scale, + zero, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) + return int_data, scale, zero + + def get_layout(self) -> Layout: + return self._layout + + +@dataclass(frozen=True) +class Int4CPULayout(Layout): + """Only for PyTorch version at least 2.6""" + + pass + + +@register_layout(Int4CPULayout) +class Int4CPUAQTTensorImpl(AQTTensorImpl): + """ + TensorImpl for int4 CPU layout for affine quantized tensor, this is for int4 only, + used by tinygemm kernels `_weight_int4pack_mm_for_cpu` + It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of + dimension: [n][k / 2] (uint8 dtype) + (unpacked Tensor shape is n * k) + Note: we also pack scale and zero point together here for tinygemm kernel + Note: technically Int4 CPU layout should be the layout for the underlying packed weight + (int Tensor) but since the scale and zero_point are also packed into the same tensor here which is not used + in plain layout, we just created a layout for AQT right now, this could be improved if we split out + int4 aqt into a separate tensor subclass + fields: + packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 CPU layout + scale_and_zero (torch.Tensor): the combined scale Tensor used to map between floating point tensor to quantized tensor and zero_point Tensor + """ + + def __new__( + cls, + packed_weight: torch.Tensor, + scale_and_zero: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["layout"] = ( + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_weight.layout + ) + kwargs["dtype"] = packed_weight.dtype + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + scale_and_zero: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + self.packed_weight = packed_weight + self.scale_and_zero = scale_and_zero + self.transposed = False + self._layout = _layout + + def __tensor_flatten__(self): + return ["packed_weight", "scale_and_zero"], [self.transposed, self._layout] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight, scale_and_zero = ( + tensor_data_dict["packed_weight"], + tensor_data_dict["scale_and_zero"], + ) + ( + transposed, + _layout, + ) = tensor_attributes + return cls(packed_weight, scale_and_zero, transposed, _layout) + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert isinstance(_layout, Int4CPULayout) + + if TORCH_VERSION_AT_LEAST_2_6: + assert ( + int_data.dtype == torch.int32 + ), "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" + packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + int_data, + 1, # TODO:remove + ) + elif TORCH_VERSION_AT_LEAST_2_5: + int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) + assert ( + int_data.dtype == torch.uint8 + ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" + packed_weight = torch.ops.aten._convert_weight_to_int4pack( + int_data, _layout.inner_k_tiles + ) + else: + assert ( + int_data.dtype == torch.int32 + ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" + packed_weight = torch.ops.aten._convert_weight_to_int4pack( + int_data, _layout.inner_k_tiles + ) + + scale = scale.reshape(int_data.shape[0], -1) + zero_point = zero_point.reshape(int_data.shape[0], -1) + from torchao.quantization.utils import pack_tinygemm_scales_and_zeros + + scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) + return cls(packed_weight, scale_and_zero, False, _layout) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs["device"] + if not is_device(torch.device(self.device).type, device): + raise ValueError( + f"Int4CPUAQTTensorImpl does not support conversion from {self.device} to {device}" + ) + return self.__class__( + self.packed_weight.to(device), + self.scale_and_zero.to(device), + self.transposed, + self._layout, + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.packed_weight), + fn(self.scale_and_zero), + self.transposed, + self._layout, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + if func is aten.t.default: + """we don't need to repack the weight and just rely on external + shape being changed and record the status of transpose/no-transpose + """ + transposed = Int4CPUAQTTensorImpl( + args[0].packed_weight, + args[0].scale_and_zero, + not args[0].transposed, + args[0]._layout, + ) + return return_and_correct_aliasing(func, args, kwargs, transposed) + + if func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + int_data, scale, zero_point = self.get_plain() + int_data = aten.slice.Tensor(int_data, dim, start, end, step) + # this is to handle padding + int_data = self._layout.post_process(int_data) + sliced = self.from_plain(int_data, scale, zero_point, self._layout) + return return_and_correct_aliasing(func, args, kwargs, sliced) + elif dim == 1: + int_data, scale, zero_point = self.get_plain() + assert step == 1, "Only step == 1 is supported in slicing right now" + data_len = int_data.shape[dim] + scale_len = scale.shape[dim] + ratio = data_len / scale_len + start_scale = int(start / ratio) + end_scale = int(end / ratio) + + int_data = aten.slice.Tensor(int_data, dim, start, end, step) + # this is to handle padding + int_data = self._layout.post_process(int_data) + scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) + zero_point = aten.slice.Tensor( + zero_point, dim, start_scale, end_scale, step + ) + sliced = self.from_plain(int_data, scale, zero_point, self._layout) + return sliced + else: + raise NotImplementedError( + f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) + + raise NotImplementedError( + f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + from torchao.quantization.quant_primitives import ( + ZeroPointDomain, + quantize_affine, + ) + from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros + + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) + + cur_shape = self.shape + assert len(cur_shape) == 2 + original_shape = (cur_shape[0], cur_shape[1] * 2) + eye_shape = original_shape[1] + groupsize = int(original_shape[1] / scale.shape[-2]) + block_size = (1, groupsize) + device = self.device + original_dtype = torch.bfloat16 + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + zero_point_domain = ZeroPointDomain.FLOAT + assert len(block_size) == 2 and block_size[0] == 1 + dequantized = torch.ops.aten._weight_int4pack_mm_for_cpu( + torch.eye(eye_shape, device=device, dtype=original_dtype), + self.packed_weight, + groupsize, + self.scale_and_zero, + ) + dequantized = dequantized.t().contiguous() + # TODO: move this to `unpack_tinygemm_scales_and_zeros`? + scale = scale.reshape(scale.shape[:-1]).contiguous() + zero = zero.reshape(zero.shape[:-1]).contiguous() + int_data = quantize_affine( + dequantized, + block_size, + scale, + zero, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) + return int_data, scale, zero + + def get_layout(self) -> Layout: + return self._layout diff --git a/torchao/dtypes/uint4.py b/torchao/dtypes/uintx/uint4_layout.py similarity index 100% rename from torchao/dtypes/uint4.py rename to torchao/dtypes/uintx/uint4_layout.py diff --git a/torchao/dtypes/uintx/uintx.py b/torchao/dtypes/uintx/uintx_layout.py similarity index 98% rename from torchao/dtypes/uintx/uintx.py rename to torchao/dtypes/uintx/uintx_layout.py index 0f27d18eef..29c2ae93fe 100644 --- a/torchao/dtypes/uintx/uintx.py +++ b/torchao/dtypes/uintx/uintx_layout.py @@ -4,7 +4,8 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.dtypes.affine_quantized_tensor import PlainAQTTensorImpl, register_layout +from torchao.dtypes.affine_quantized_tensor import register_layout +from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl from torchao.dtypes.utils import ( Layout, ) diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 6579c1245d..774071f856 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -3,6 +3,8 @@ import torch +from torchao.utils import TorchAOBaseTensor + """ Base class for different layout, following the same design of PyTorch layout https://pytorch.org/docs/stable/tensor_attributes.html#torch-layout, used to represent different @@ -72,3 +74,42 @@ def get_out_shape(input_shape: Tuple[int], weight_shape: Tuple[int]) -> Tuple[in out_dim = weight_shape[0] inpt_dims = input_shape[:-1] return (*inpt_dims, out_dim) + + +############################### +# Base Tensor Impl Subclass # +############################### +class AQTTensorImpl(TorchAOBaseTensor): + """ + Base class for the tensor impl for `AffineQuantizedTensor` + + Note: This is not a user facing API, it's used by AffineQuantizedTensor to construct + the underlying implementation of a AQT based on layout + """ + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Get the plain (unpacked) Tensor for the tensor impl + + Returns data, scale and zero_point + Can be overwritten if other types of AQTTensorImpl has different numbers of plain tensors + """ + pass + + def get_layout(self) -> Layout: + pass + + @classmethod + def from_plain( + cls, + data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + _layout: Layout, + ): + """Construct a TensorImpl from data, scale, zero_point and the _layout""" + pass + + def __repr__(self): + data, scale, zero_point = self.get_plain() + _layout = self.get_layout() + return f"{self.__class__.__name__}(data={str(data)}... , scale={str(scale)}... , zero_point={str(zero_point)}... , _layout={_layout})" diff --git a/torchao/experimental/CMakeLists.txt b/torchao/experimental/CMakeLists.txt index b641c07519..a90cc5884a 100644 --- a/torchao/experimental/CMakeLists.txt +++ b/torchao/experimental/CMakeLists.txt @@ -23,6 +23,10 @@ if(NOT TORCHAO_INCLUDE_DIRS) endif() option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF) +if(TORCHAO_BUILD_KLEIDIAI) + message(STATUS "Building with Arm KleidiAI library") + add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1) +endif() include(CMakePrintHelpers) add_compile_options("-Wall" "-Werror" "-Wno-deprecated") diff --git a/torchao/experimental/_linear_8bit_act_xbit_weight_layout.py b/torchao/experimental/_linear_8bit_act_xbit_weight_layout.py index 9b9b53d5aa..97e6380f92 100644 --- a/torchao/experimental/_linear_8bit_act_xbit_weight_layout.py +++ b/torchao/experimental/_linear_8bit_act_xbit_weight_layout.py @@ -12,10 +12,10 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.dtypes.affine_quantized_tensor import ( - AQTTensorImpl, - register_aqt_quantized_linear_dispatch, register_layout, ) +from torchao.dtypes.utils import AQTTensorImpl +from torchao.dtypes.affine_quantized_tensor_ops import register_aqt_quantized_linear_dispatch from torchao.dtypes.utils import Layout from torchao.quantization.quant_primitives import ( MappingType, @@ -65,7 +65,7 @@ def __init__( group_size: int, target: str, ): - assert nbit <= 7 + assert nbit <= 8 self.nbit = nbit self.group_size = group_size self.target = target_from_str(target) @@ -182,7 +182,7 @@ def from_plain( # Fallback assert layout.target == Target.FALLBACK - packed_weight = int_data.to(torch.int8) + packed_weight = int_data.to(torch.int32) return cls(packed_weight, scale, zero_point, layout) def _apply_fn_to_data(self, fn): diff --git a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt index 6073425183..8751c38c81 100644 --- a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt @@ -4,8 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. - -if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") +if ((CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") OR (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64")) add_library( torchao_kernels_aarch64 ${CMAKE_CURRENT_SOURCE_DIR}/reduction/find_min_and_max.cpp @@ -25,7 +24,7 @@ if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") # Temporarily exposing this to the parent scope until we wire # this up properly from the top level - set(TORCHAO_ENABLE_KLEIDI ON PARENT_SCOPE) + set(TORCHAO_BUILD_KLEIDI ON PARENT_SCOPE) target_link_libraries(torchao_kernels_aarch64 PUBLIC kleidiai) endif() endif() diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h index cdac5829ec..dbda036efd 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h @@ -47,14 +47,14 @@ size_t activation_data_size(int m, int k, int group_size) { } void prepare_activation_data( - void* activation_data, + void* prepared_activation_data, int m, int k, int group_size, const float* activations) { (void)group_size; // unused kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( - get_ukernel(), activation_data, m, k, activations); + get_ukernel(), prepared_activation_data, m, k, activations); } size_t weight_data_size(int n, int k, int group_size) { @@ -63,7 +63,7 @@ size_t weight_data_size(int n, int k, int group_size) { } void prepare_weight_data( - void* weight_data, + void* prepared_weight_data, int n, int k, int group_size, @@ -73,7 +73,7 @@ void prepare_weight_data( const float* bias) { kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( get_ukernel(), - weight_data, + prepared_weight_data, n, k, group_size, diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h index a739dc4c8b..d3d7bd55d9 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h @@ -45,7 +45,7 @@ size_t activation_data_size(int m, int k, int group_size) { } void prepare_activation_data( - void* activation_data, + void* prepared_activation_data, int m, int k, int group_size, @@ -53,7 +53,7 @@ void prepare_activation_data( (void) group_size; // unused kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( get_ukernel(), - activation_data, + prepared_activation_data, m, k, activations); @@ -64,7 +64,7 @@ size_t weight_data_size(int n, int k, int group_size) { } void prepare_weight_data( - void* weight_data, + void* prepared_weight_data, int n, int k, int group_size, @@ -74,7 +74,7 @@ void prepare_weight_data( const float* bias) { kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( get_ukernel(), - weight_data, + prepared_weight_data, n, k, group_size, diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h new file mode 100644 index 0000000000..4ef499d72c --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h @@ -0,0 +1,120 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include + +namespace torchao::kernels::cpu::aarch64::kleidi { +namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { +namespace neon_i8mm_8x4x32 { + +const Ukernel get_ukernel() { + return Ukernel{ + .get_m_step = + kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + .get_n_step = + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + .get_mr = + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + .get_nr = + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + .get_kr = + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + .get_sr = + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + .get_lhs_packed_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + .get_rhs_packed_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + .get_dst_offset = + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + .get_dst_size = + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + .run_matmul = + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm}; +} + +size_t activation_data_size(int m, int k, int group_size) { + (void)group_size; // unused + return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size( + get_ukernel(), m, k); +} + +void prepare_activation_data( + void* prepared_activation_data, + int m, + int k, + int group_size, + const float* activations) { + (void)group_size; // unused + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( + get_ukernel(), prepared_activation_data, m, k, activations); +} + +size_t weight_data_size(int n, int k, int group_size) { + return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size( + get_ukernel(), n, k, group_size); +} + +void prepare_weight_data( + void* prepared_weight_data, + int n, + int k, + int group_size, + const int8_t* weight_qvals, + const float* weight_scales, + const int8_t* weight_zeros, + const float* bias) { + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( + get_ukernel(), + prepared_weight_data, + n, + k, + group_size, + weight_qvals, + weight_scales, + weight_zeros, + bias); +} + +void kernel( + float32_t* output, + int output_m_stride, + int m, + int n, + int k, + int group_size, + const void* weight_data, + const void* activation_data, + float clamp_min, + float clamp_max) { + if (clamp_min == 0 && clamp_max == 0) { + clamp_min = std::numeric_limits::lowest(); + clamp_max = std::numeric_limits::max(); + } + + auto ukernel = get_ukernel(); + ukernel.run_matmul( + m, + n, + k, + group_size, + activation_data, + weight_data, + output, + /*dst_stride_row=*/n * sizeof(float), + /*dst_stride_col=*/sizeof(float), + clamp_min, + clamp_max); +} + +size_t get_preferred_alignement() { + return 16; +} +} // namespace neon_i8mm_8x4x32 +} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p +} // namespace torchao::kernels::cpu::aarch64::kleidi diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h new file mode 100644 index 0000000000..d898cf3e5b --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h @@ -0,0 +1,122 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include + +#include + +namespace torchao::kernels::cpu::aarch64::kleidi { +namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { +namespace neon_i8mm_4x8x32 { + +const Ukernel get_ukernel() { + return Ukernel{ + .get_m_step = + kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + .get_n_step = + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + .get_mr = + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + .get_nr = + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + .get_kr = + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + .get_sr = + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + .get_lhs_packed_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + .get_rhs_packed_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + .get_dst_offset = + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + .get_dst_size = + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + .run_matmul = + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm}; +} + +size_t activation_data_size(int m, int k, int group_size) { + (void)group_size; // unused + return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size( + get_ukernel(), m, k); +} + +void prepare_activation_data( + void* prepared_activation_data, + int m, + int k, + int group_size, + const float* activations) { + (void)group_size; // unused + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( + get_ukernel(), prepared_activation_data, m, k, activations); +} + +size_t weight_data_size(int n, int k, int group_size) { + return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size( + get_ukernel(), n, k, group_size); +} + +void prepare_weight_data( + void* prepared_weight_data, + int n, + int k, + int group_size, + const int8_t* weight_qvals, + const float* weight_scales, + const int8_t* weight_zeros, + const float* bias) { + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( + get_ukernel(), + prepared_weight_data, + n, + k, + group_size, + weight_qvals, + weight_scales, + weight_zeros, + bias); +} + +void kernel( + float32_t* output, + int output_m_stride, + int m, + int n, + int k, + int group_size, + const void* weight_data, + const void* activation_data, + float clamp_min, + float clamp_max) { + if (clamp_min == 0 && clamp_max == 0) { + clamp_min = std::numeric_limits::lowest(); + clamp_max = std::numeric_limits::max(); + } + + auto ukernel = get_ukernel(); + ukernel.run_matmul( + m, + n, + k, + group_size, + activation_data, + weight_data, + output, + /*dst_stride_row=*/n * sizeof(float), + /*dst_stride_col=*/sizeof(float), + clamp_min, + clamp_max); +} + +size_t get_preferred_alignement() { + return 16; +} + +} // namespace neon_i8mm_4x8x32 +} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p +} // namespace torchao::kernels::cpu::aarch64::kleidi diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt index 3712a36250..e4cafdc97a 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt @@ -15,6 +15,11 @@ FetchContent_Declare( ) FetchContent_MakeAvailable(googletest) +if (ANDROID_ABI) + # We are cross compiling, delay test discovery till runtime + set(CMAKE_GTEST_DISCOVER_TESTS_DISCOVERY_MODE PRE_TEST) +endif() + add_compile_options("-Wall" "-Werror") include(CMakePrintHelpers) @@ -35,13 +40,29 @@ endif() add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) -# The TORCHAO_ENABLE_KLEIDI cmake variable should be set by `torchao_kernels_aarch64" -if(TORCHAO_ENABLE_KLEIDI) +# The TORCHAO_BUILD_KLEIDI cmake variable should be set by `torchao_kernels_aarch64" +if(TORCHAO_BUILD_KLEIDI) add_compile_definitions(TORCHAO_ENABLE_KLEIDI) endif() +if(TORCHAO_BUILD_ARM_I8MM) + add_compile_definitions(TORCHAO_ENABLE_ARM_I8MM) +endif() + enable_testing() +if (ANDROID_ABI) + # Given where we are today this is sufficent. But needs to be revisited. + # This is also needed for native builds, but keeping it only for cross builds + # for now given the hacky nature. + file(GLOB DOTPROD_SRC_FILES test*.cpp) + message(SRC_FILES: ${DOTPROD_SRC_FILES}) + set_property(SOURCE + ${DOTPROD_SRC_FILES} + APPEND_STRING PROPERTY + COMPILE_FLAGS " -march=armv8.2-a+dotprod ") +endif() + add_executable(test_quantization test_quantization.cpp) target_link_libraries( test_quantization diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh index 4394b02ece..5c12d7184e 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh +++ b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh @@ -8,22 +8,59 @@ set -eu SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. -export CMAKE_OUT=/tmp/cmake-out/torch_ao/tests +export CMAKE_OUT=/tmp/cmake-out/torch_ao/kernel_tests + +target=${1:-"native"} IS_ARM64=0 +BUILD_ARM_I8MM=0 +EXTRA_ARGS="" +if [[ "${target}" == "android" ]]; then + if [[ -z ${ANDROID_NDK} ]]; then + echo "Need to set ANDROID_NDK env variable to build for Android"; + exit 1; + fi + android_abi=arm64-v8a + android_platform=28 # must be >=28 for aligned_alloc + IS_ARM64=1 + BUILD_ARM_I8MM=1 # Hardcoded for now + CMAKE_OUT=${CMAKE_OUT/cmake-out/cmake-out-android} + toolchain_file="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" + if [[ -z ${toolchain_file} ]]; then + echo "Unable to find toolchain file at ANDROID_NDK location, looking for ${toolchain_file}" + exit 1; + fi + EXTRA_ARGS="\ + -DCMAKE_TOOLCHAIN_FILE=${toolchain_file} \ + -DANDROID_ABI=${android_abi} \ + -DANDROID_PLATFORM=${android_platform} + " + echo "Building tests for Android (${android_abi}) @ ${CMAKE_OUT}" +fi + hash arch; retval=$? if [[ ${retval} -eq 0 && $(arch) == "arm64" ]]; then IS_ARM64=1 fi -cmake -DCMAKE_BUILD_TYPE=Debug \ +cmake \ + ${EXTRA_ARGS} \ + -DCMAKE_BUILD_TYPE=Debug \ -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ -DTORCHAO_BUILD_KLEIDIAI=${IS_ARM64} \ + -DTORCHAO_BUILD_ARM_I8MM=${BUILD_ARM_I8MM} \ -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/tests \ -B ${CMAKE_OUT} cmake --build ${CMAKE_OUT} +echo "Successfully built tests." + +if [[ "${target}" != "native" ]]; then + echo "Skip running tests when cross compiling."; + exit 0; +fi + # Run ${CMAKE_OUT}/test_quantization ${CMAKE_OUT}/test_reduction diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp index b28e3bfdc4..f68106c7e8 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -17,7 +17,11 @@ #ifdef TORCHAO_ENABLE_KLEIDI #include #include -#endif +#ifdef TORCHAO_ENABLE_ARM_I8MM +#include +#include +#endif // TORCHAO_ENABLE_ARM_I8MM +#endif // TORCHAO_ENABLE_KLEIDI float kTol = 0.0001; @@ -587,5 +591,235 @@ TEST( true /*has_clamp*/>( /*m=*/11, /*k=*/128, /*n=*/182, /*group_size=*/128); } + +#ifdef TORCHAO_ENABLE_ARM_I8MM +template +void test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( + int m, + int k, + int n, + int group_size) { + auto test_case = torchao:: + channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( + m, + k, + n, + group_size, + /*weight_nbit=*/4, + /*has_weight_zeros=*/false, + has_bias, + has_clamp, + /*round_weight_scales_to_bf16=*/true); + + using namespace torchao::kernels::cpu::aarch64::kleidi:: + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_8x4x32; + + std::vector activation_data(activation_data_size(m, k, group_size)); + + prepare_activation_data( + (void*)activation_data.data(), + m, + k, + group_size, + test_case.activations.data()); + + std::vector weight_data(weight_data_size(n, k, group_size)); + + prepare_weight_data( + (void*)weight_data.data(), + n, + k, + group_size, + test_case.weight_qvals.data(), + test_case.weight_scales.data(), + /*weight_zeros=*/test_case.weight_zeros.data(), + /*bias=*/test_case.bias.data()); + + std::vector output(m * n); + kernel( + output.data(), + /*output_m_stride=*/n, + m, + n, + k, + group_size, + weight_data.data(), + activation_data.data(), + /*clamp_min=*/test_case.clamp_min, + /*clamp_max=*/test_case.clamp_max); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + k_eq_gs_32) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + large_k_n_gs32) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + even_n_gs32) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + clamp_k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< + false /*has_bias*/, + true /*has_clamp*/>( + /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + m_clamp_k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< + false /*has_bias*/, + true /*has_clamp*/>( + /*m=*/11, /*k=*/128, /*n=*/182, /*group_size=*/128); +} + +template +void test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( + int m, + int k, + int n, + int group_size) { + auto test_case = torchao:: + channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( + m, + k, + n, + group_size, + /*weight_nbit=*/4, + /*has_weight_zeros=*/false, + has_bias, + has_clamp, + /*round_weight_scales_to_bf16=*/true); + + using namespace torchao::kernels::cpu::aarch64::kleidi:: + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_4x8x32; + + std::vector activation_data(activation_data_size(m, k, group_size)); + + prepare_activation_data( + (void*)activation_data.data(), + m, + k, + group_size, + test_case.activations.data()); + + std::vector weight_data(weight_data_size(n, k, group_size)); + + prepare_weight_data( + (void*)weight_data.data(), + n, + k, + group_size, + test_case.weight_qvals.data(), + test_case.weight_scales.data(), + /*weight_zeros=*/test_case.weight_zeros.data(), + /*bias=*/test_case.bias.data()); + + std::vector output(m * n); + kernel( + output.data(), + /*output_m_stride=*/n, + m, + n, + k, + group_size, + weight_data.data(), + activation_data.data(), + /*clamp_min=*/test_case.clamp_min, + /*clamp_max=*/test_case.clamp_max); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + k_eq_gs_32) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + large_k_n_gs32) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + even_n_gs32) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< + false /*has_bias*/, + false /*has_clamp*/>( + /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + clamp_k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< + false /*has_bias*/, + true /*has_clamp*/>( + /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); +} + +TEST( + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + m_clamp_k_eq_gs128) { + test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< + false /*has_bias*/, + true /*has_clamp*/>( + /*m=*/11, /*k=*/128, /*n=*/182, /*group_size=*/128); +} +#endif // TORCHAO_ENABLE_ARM_I8MM #endif // TORCHAO_ENABLE_KLEIDI #endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py b/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py index 56dcc9b730..eea7e42666 100644 --- a/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py +++ b/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py @@ -1,11 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + from typing import Optional import os +import sys import yaml -torchao_root: Optional[str] = os.getenv("TORCHAO_ROOT") -assert torchao_root is not None, "TORCHAO_ROOT is not set" +if len(sys.argv) != 2: + print("Usage: gen_metal_shader_lib.py ") + sys.exit(1) + +# Output file where the generated code will be written +OUTPUT_FILE = sys.argv[1] -MPS_DIR = os.path.join(torchao_root, "torchao", "experimental", "kernels", "mps") +MPS_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) # Path to yaml file containing the list of .metal files to include METAL_YAML = os.path.join(MPS_DIR, "metal.yaml") @@ -21,9 +32,6 @@ # Path to the folder containing the .metal files METAL_DIR = os.path.join(MPS_DIR, "metal") -# Output file where the generated code will be written -OUTPUT_FILE = os.path.join(MPS_DIR, "src", "metal_shader_lib.h") - prefix = """/** * This file is generated by gen_metal_shader_lib.py */ @@ -48,6 +56,7 @@ """ +os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True) with open(OUTPUT_FILE, "w") as outf: outf.write(prefix) for file in metal_files: diff --git a/torchao/experimental/kernels/mps/test/test_lowbit.mm b/torchao/experimental/kernels/mps/test/test_lowbit.mm index 398af237ae..2d86223034 100644 --- a/torchao/experimental/kernels/mps/test/test_lowbit.mm +++ b/torchao/experimental/kernels/mps/test/test_lowbit.mm @@ -101,7 +101,8 @@ void init() { int32_t ceil_K_group_size = (K + qGroupSize - 1) / qGroupSize; for (int idx = 0; idx < N * ceil_K_group_size; ++idx) { s_ptr[idx] = (idx + 1.0) / N; - z_ptr[idx] = int_distrib(generator); + auto zp = int_distrib(generator); + z_ptr[idx] = -zp * s_ptr[idx]; } for (int idx = 0; idx < M * N; ++idx) { c_ptr[idx] = -1.0; diff --git a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp b/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp index dfb61eb928..1b019609a6 100644 --- a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp +++ b/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp @@ -36,6 +36,7 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { DEFINE_OP(5); DEFINE_OP(6); DEFINE_OP(7); + DEFINE_OP(8); } TORCH_LIBRARY_IMPL(torchao, CPU, m) { @@ -46,6 +47,7 @@ TORCH_LIBRARY_IMPL(torchao, CPU, m) { DEFINE_CPU_IMPL(5); DEFINE_CPU_IMPL(6); DEFINE_CPU_IMPL(7); + DEFINE_CPU_IMPL(8); } TORCH_LIBRARY_IMPL(torchao, Meta, m) { @@ -56,4 +58,5 @@ TORCH_LIBRARY_IMPL(torchao, Meta, m) { DEFINE_META_IMPL(5); DEFINE_META_IMPL(6); DEFINE_META_IMPL(7); + DEFINE_META_IMPL(8); } diff --git a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_executorch.cpp b/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_executorch.cpp index 1b79a5e035..f99a575cfe 100644 --- a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_executorch.cpp +++ b/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_executorch.cpp @@ -37,3 +37,4 @@ DEFINE_OP(4); DEFINE_OP(5); DEFINE_OP(6); DEFINE_OP(7); +DEFINE_OP(8); diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp index f69e51e4c9..24d4008969 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp @@ -68,6 +68,7 @@ TORCH_LIBRARY(torchao, m) { DEFINE_OP(5); DEFINE_OP(6); DEFINE_OP(7); + DEFINE_OP(8); } TORCH_LIBRARY_IMPL(torchao, CPU, m) { @@ -78,6 +79,7 @@ TORCH_LIBRARY_IMPL(torchao, CPU, m) { DEFINE_CPU_IMPL(5); DEFINE_CPU_IMPL(6); DEFINE_CPU_IMPL(7); + DEFINE_CPU_IMPL(8); } TORCH_LIBRARY_IMPL(torchao, Meta, m) { @@ -88,4 +90,5 @@ TORCH_LIBRARY_IMPL(torchao, Meta, m) { DEFINE_META_IMPL(5); DEFINE_META_IMPL(6); DEFINE_META_IMPL(7); + DEFINE_META_IMPL(8); } diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w8s.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w8s.cpp new file mode 100644 index 0000000000..5257611d97 --- /dev/null +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w8s.cpp @@ -0,0 +1,29 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// Unlike ATen, ExecuTorch op registration appears to only allow on +// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new +// file is needed for each variant + +#include + +namespace { +Tensor _op_out( + RuntimeContext& ctx, + const Tensor& activations, + const Tensor& packed_weights, + const Tensor& group_size_tensor, + const Tensor& n_tensor, + const Tensor& k_tensor, + Tensor& out) { + (void)ctx; + linear_out_cpu( + activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out); + return out; +} +} // namespace + +EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_8bit0zp_weight.out", _op_out); diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w8sz.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w8sz.cpp new file mode 100644 index 0000000000..e26da69d67 --- /dev/null +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w8sz.cpp @@ -0,0 +1,29 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// Unlike ATen, ExecuTorch op registration appears to only allow on +// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new +// file is needed for each variant + +#include + +namespace { +Tensor _op_out( + RuntimeContext& ctx, + const Tensor& activations, + const Tensor& packed_weights, + const Tensor& group_size_tensor, + const Tensor& n_tensor, + const Tensor& k_tensor, + Tensor& out) { + (void)ctx; + linear_out_cpu( + activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out); + return out; +} +} // namespace + +EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_8bit_weight.out", _op_out); diff --git a/torchao/experimental/ops/mps/.gitignore b/torchao/experimental/ops/mps/.gitignore new file mode 100644 index 0000000000..d48f17d1c5 --- /dev/null +++ b/torchao/experimental/ops/mps/.gitignore @@ -0,0 +1 @@ +cmake-out/ diff --git a/torchao/experimental/ops/mps/CMakeLists.txt b/torchao/experimental/ops/mps/CMakeLists.txt new file mode 100644 index 0000000000..044433ef95 --- /dev/null +++ b/torchao/experimental/ops/mps/CMakeLists.txt @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) + +project(torchao_ops_mps_linear_fp_act_xbit_weight) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED YES) + +if (NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + +if (CMAKE_SYSTEM_NAME STREQUAL "Darwin") + if (NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + message(FATAL_ERROR "Unified Memory requires Apple Silicon architecture") + endif() +else() + message(FATAL_ERROR "Torchao experimental mps ops can only be built on macOS/iOS") +endif() + +find_package(Torch REQUIRED) + +# Generate metal_shader_lib.h by running gen_metal_shader_lib.py +set(GENERATED_METAL_SHADER_LIB ${CMAKE_INSTALL_PREFIX}/include/torchao/experimental/kernels/mps/src/metal_shader_lib.h) +add_custom_command( + OUTPUT ${GENERATED_METAL_SHADER_LIB} + COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py ${GENERATED_METAL_SHADER_LIB} + COMMENT "Generating metal_shader_lib.h using gen_metal_shader_lib.py" +) +add_custom_target(generated_metal_shader_lib ALL DEPENDS ${GENERATED_METAL_SHADER_LIB}) + +if(NOT TORCHAO_INCLUDE_DIRS) + set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) +endif() +message(STATUS "TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}") + +include_directories(${TORCHAO_INCLUDE_DIRS}) +include_directories(${CMAKE_INSTALL_PREFIX}/include) +add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten SHARED aten/register.mm) +add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_aten generated_metal_shader_lib) + +target_include_directories(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_INCLUDE_DIRS}") +target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_LIBRARIES}") +target_compile_definitions(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE USE_ATEN=1) + +# Enable Metal support +find_library(METAL_LIB Metal) +find_library(FOUNDATION_LIB Foundation) +target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE ${METAL_LIB} ${FOUNDATION_LIB}) + +install( + TARGETS torchao_ops_mps_linear_fp_act_xbit_weight_aten + EXPORT _targets + DESTINATION lib +) diff --git a/torchao/experimental/ops/mps/register.mm b/torchao/experimental/ops/mps/aten/register.mm similarity index 97% rename from torchao/experimental/ops/mps/register.mm rename to torchao/experimental/ops/mps/aten/register.mm index a53f55d3d8..92a3ba89f0 100644 --- a/torchao/experimental/ops/mps/register.mm +++ b/torchao/experimental/ops/mps/aten/register.mm @@ -5,7 +5,7 @@ // LICENSE file in the root directory of this source tree. // clang-format off -#include +#include #include #include // clang-format on @@ -58,6 +58,7 @@ void check_linear_mps_args( ": expect S to be 2d tensor with shape [:, ", N, "]"); + TORCH_CHECK(S.is_contiguous(), __func__, " : expect S to be contiguous."); TORCH_CHECK( Z.dim() == 2 && Z.size(1) == N, @@ -65,6 +66,7 @@ void check_linear_mps_args( ": expect Z to be 2d tensor with shape [:, ", N, "]"); + TORCH_CHECK(Z.is_contiguous(), __func__, " : expect Z to be contiguous."); } template @@ -145,9 +147,6 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) { return B; } -// Registers _C as a Python extension module. -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} - TORCH_LIBRARY(torchao, m) { m.def("_pack_weight_1bit(Tensor W) -> Tensor"); m.def("_pack_weight_2bit(Tensor W) -> Tensor"); diff --git a/torchao/experimental/ops/mps/build.sh b/torchao/experimental/ops/mps/build.sh new file mode 100644 index 0000000000..1ea032f8c6 --- /dev/null +++ b/torchao/experimental/ops/mps/build.sh @@ -0,0 +1,19 @@ +#!/bin/bash -eu +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +cd "$(dirname "$BASH_SOURCE")" + +export CMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())') +echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" +export CMAKE_OUT=${PWD}/cmake-out +echo "CMAKE_OUT: ${CMAKE_OUT}" + +cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ + -DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \ + -S . \ + -B ${CMAKE_OUT} +cmake --build ${CMAKE_OUT} -j 16 --target install --config Release diff --git a/torchao/experimental/ops/mps/setup.py b/torchao/experimental/ops/mps/setup.py deleted file mode 100644 index 1205d43d45..0000000000 --- a/torchao/experimental/ops/mps/setup.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import os -from setuptools import setup -from torch.utils.cpp_extension import CppExtension, BuildExtension - -setup( - name="torchao_mps_ops", - version="1.0", - ext_modules=[ - CppExtension( - name="torchao_mps_ops", - sources=["register.mm"], - include_dirs=[os.getenv("TORCHAO_ROOT")], - extra_compile_args=["-DUSE_ATEN=1"], - ), - ], - cmdclass={"build_ext": BuildExtension}, -) diff --git a/torchao/experimental/ops/mps/test/test_lowbit.py b/torchao/experimental/ops/mps/test/test_lowbit.py index f2d9d9c175..f4c460a368 100644 --- a/torchao/experimental/ops/mps/test/test_lowbit.py +++ b/torchao/experimental/ops/mps/test/test_lowbit.py @@ -4,25 +4,38 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import os +import sys import torch -import torchao_mps_ops import unittest +from parameterized import parameterized -def parameterized(test_cases): - def decorator(func): - def wrapper(self): - for case in test_cases: - with self.subTest(case=case): - func(self, *case) +libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib" +libpath = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname) +) - return wrapper - - return decorator +try: + for nbit in range(1, 8): + getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight") + getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit") +except AttributeError: + try: + torch.ops.load_library(libpath) + except: + raise RuntimeError(f"Failed to load library {libpath}") + else: + try: + for nbit in range(1, 8): + getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight") + getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit") + except AttributeError as e: + raise e class TestLowBitQuantWeightsLinear(unittest.TestCase): - cases = [ + CASES = [ (nbit, *param) for nbit in range(1, 8) for param in [ @@ -46,18 +59,18 @@ class TestLowBitQuantWeightsLinear(unittest.TestCase): ] def _init_tensors(self, group_size, M, K, N, nbit, device="mps"): - max_abs = 1 << (nbit - 1) ceil_K_group_size = (K + group_size - 1) // group_size - A = 2 * torch.rand(M, K, dtype=torch.float32, device=device) - 1 - W = torch.randint(0, 2 * max_abs, (N, K), dtype=torch.uint8, device=device) + A = torch.rand(M, K, dtype=torch.float32, device=device) + W = torch.randint(0, 1 << nbit, (N, K), dtype=torch.uint8, device=device) S = torch.rand(ceil_K_group_size, N, dtype=torch.float32, device=device) + 0.01 Z = torch.randint( 0, - 2 * max_abs, + 1 << nbit, (ceil_K_group_size, N), dtype=torch.float32, device=device, ) + Z = -Z * S return A, W, S, Z def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z, nbit): @@ -73,7 +86,7 @@ def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z, nbit): W = scales * W + zeros return torch.mm(A, W.t()) - @parameterized(cases) + @parameterized.expand(CASES) def test_linear(self, nbit, M=1, K=32, N=32, group_size=32): print(f"nbit: {nbit}, M: {M}, K: {K}, N: {N}, group_size: {group_size}") A, W, S, Z = self._init_tensors(group_size, M, K, N, nbit=nbit) diff --git a/torchao/experimental/ops/mps/test/test_quantizer.py b/torchao/experimental/ops/mps/test/test_quantizer.py new file mode 100644 index 0000000000..00c08738c2 --- /dev/null +++ b/torchao/experimental/ops/mps/test/test_quantizer.py @@ -0,0 +1,191 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional +import copy +import itertools +import os +import sys + +import torch +import unittest + +from parameterized import parameterized +from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer +from torchao.experimental.quant_api import _quantize + +libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib" +libpath = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname) +) + +try: + for nbit in range(1, 8): + getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight") + getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit") +except AttributeError: + try: + torch.ops.load_library(libpath) + except: + raise RuntimeError(f"Failed to load library {libpath}") + else: + try: + for nbit in range(1, 8): + getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight") + getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit") + except AttributeError as e: + raise e + + +class TestUIntxWeightOnlyLinearQuantizer(unittest.TestCase): + BITWIDTHS = range(1, 8) + GROUPSIZES = [32, 64, 128, 256] + + # Currently, the quantization code in quant_api.py only supports K values + # multiple of group_size. + # TODO(mcandales): Generalize the code in quant_api.py and add tests to + # cover values of K not multiple of group_size. + def _model_setup(self): + group_size = 32 + k0 = 96 + k1 = 224 + k2 = 160 + n = 47 + layers = [ + torch.nn.Linear(k0, k1, bias=False), + torch.nn.Linear(k1, k2, bias=False), + torch.nn.Linear(k2, n, bias=False), + ] + model = torch.nn.Sequential(*layers) + return model, group_size, k0, n + + def _quantize_model(self, model, precision, nbit, group_size): + quantizer = UIntxWeightOnlyLinearQuantizer( + device="mps", + precision=precision, + bitwidth=nbit, + groupsize=group_size, + ) + quantized_model = copy.deepcopy(model) + quantized_model = quantizer.quantize(quantized_model) + return quantized_model + + @parameterized.expand(BITWIDTHS) + def test_export(self, nbit): + model, group_size, k0, n = self._model_setup() + m = 3 + activations = torch.randn(m, k0, dtype=torch.float32, device="mps") + + quantized_model = self._quantize_model(model, torch.float32, nbit, group_size) + exported = torch.export.export(quantized_model, (activations,)) + + for node in exported.graph.nodes: + if node.op == "call_function": + self.assertTrue( + str(node.target) + == f"torchao._linear_fp_act_{nbit}bit_weight.default" + ) + + @parameterized.expand(BITWIDTHS) + def test_2d_output_device_and_shape(self, nbit): + model, group_size, k0, n = self._model_setup() + m = 3 + activations = torch.randn(m, k0, dtype=torch.float32, device="mps") + + quantized_model = self._quantize_model(model, torch.float32, nbit, group_size) + result = quantized_model(activations) + self.assertTrue(result.is_mps) + self.assertTrue(result.shape == (m, n)) + + @parameterized.expand(BITWIDTHS) + def test_3d_output_device_and_shape(self, nbit): + model, group_size, k0, n = self._model_setup() + leading_shape = (3, 5) + activations = torch.randn(*leading_shape, k0, dtype=torch.float32, device="mps") + + quantized_model = self._quantize_model(model, torch.float32, nbit, group_size) + result = quantized_model(activations) + self.assertTrue(result.is_mps) + self.assertTrue(result.shape == (*leading_shape, n)) + + @parameterized.expand(itertools.product(BITWIDTHS, GROUPSIZES)) + def test_valid_groupsizes(self, nbit, group_size): + k0 = 3 * group_size + k1 = 7 * group_size + n = 47 + layers = [ + torch.nn.Linear(k0, k1, bias=False), + torch.nn.Linear(k1, n, bias=False), + ] + model = torch.nn.Sequential(*layers) + m = 5 + activations = torch.randn(m, k0, dtype=torch.float32, device="mps") + + quantized_model = self._quantize_model(model, torch.float32, nbit, group_size) + result = quantized_model(activations) + self.assertTrue(result.is_mps) + self.assertTrue(result.shape == (m, n)) + + @parameterized.expand(BITWIDTHS) + def test_invalid_groupsizes(self, nbit): + group_size = 16 + k0 = 3 * group_size + k1 = 7 * group_size + n = 47 + layers = [ + torch.nn.Linear(k0, k1, bias=False), + torch.nn.Linear(k1, n, bias=False), + ] + model = torch.nn.Sequential(*layers) + + with self.assertRaises(ValueError): + self._quantize_model(model, torch.float32, nbit, group_size) + + # TODO(mcandales): Consolidate with the reference impl in test_lowbit.py + def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z): + N = W.shape[0] + K = W.shape[1] + W = W.to(torch.float32) + scales = S.t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] + zeros = Z.t().unsqueeze(2).repeat(1, 1, group_size).view(N, -1)[:, :K] + W = scales * W + zeros + return torch.mm(A, W.t()) + + @parameterized.expand(BITWIDTHS) + def test_accuracy(self, nbit): + group_size = 32 + m = 3 + n = 7 + k = 64 + with torch.no_grad(): + activations = torch.rand(m, k, dtype=torch.float32, device="mps") + model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) + quantized_model = self._quantize_model( + model, torch.float32, nbit, group_size + ) + result = quantized_model(activations) + + # Compute expected result + weight_cpu = model[0].weight.data + weight_qvals_cpu, weight_scales_cpu, weight_zeros_cpu = _quantize( + weight_cpu, group_size, nbit, True, torch.uint8 + ) + weight_scales_cpu = weight_scales_cpu.t() + weight_zeros_cpu = -weight_zeros_cpu.t() * weight_scales_cpu + expected = self._reference_linear_lowbit_quant_weights( + activations.cpu(), + weight_qvals_cpu, + group_size, + weight_scales_cpu, + weight_zeros_cpu, + ) + + # Compare results + torch.testing.assert_close(result.cpu(), expected, rtol=0.001, atol=0.001) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index 1c04305d31..be72a59aab 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -25,10 +25,14 @@ logger.addHandler(handler) -def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros: bool): +def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros: bool, signed=True): assert nbit >= 1 and nbit <= 8 - qmin = -(1 << (nbit - 1)) - qmax = (1 << (nbit - 1)) - 1 + if signed: + qmin = -(1 << (nbit - 1)) + qmax = (1 << (nbit - 1)) - 1 + else: + qmin = 0 + qmax = (1 << nbit) - 1 n, k = vals.shape vals = vals.reshape(-1, group_size) @@ -51,7 +55,7 @@ def _quantize(vals: torch.Tensor, group_size: int, nbit: int, has_weight_zeros: zero_points=group_zeros, quant_min=qmin, quant_max=qmax, - dtype=torch.int8, + dtype=torch.int8 if signed else torch.uint8, group_size=group_size, ) @@ -198,7 +202,7 @@ def forward(self, x): def _maybe_get_quantized_linear_native(nbit, has_weight_zeros): try: - if nbit in [1, 2, 3, 4, 5, 6, 7]: + if nbit in [1, 2, 3, 4, 5, 6, 7, 8]: wzp_suffix = "" if has_weight_zeros else "0zp" return _Int8DynActIntxWeightQuantizedLinearNative( pack_weight_op=getattr( @@ -230,7 +234,7 @@ def _replace_linear_with_quantized_linear(module: nn.Module, kwargs={}): has_weight_zeros = kwargs["has_weight_zeros"] assert not isinstance(module, nn.Linear) - assert nbit >= 1 and nbit <= 7 + assert nbit >= 1 and nbit <= 8 for name, child in module.named_children(): if not isinstance(child, nn.Linear): @@ -366,9 +370,9 @@ def quantize_and_pack_weights(self, weights, group_size): weight_qvals, weight_scales, weight_zeros = _quantize( weights, self.group_size, self.nbit, has_weight_zeros=True ) - self.weight_qvals = weight_qvals.to(torch.int8) + self.weight_qvals = weight_qvals.to(torch.int32) self.weight_scales = weight_scales - self.weight_zeros = weight_zeros.to(torch.int8) + self.weight_zeros = weight_zeros.to(torch.int32) def forward(self, x): shape = x.shape @@ -394,7 +398,7 @@ def _replace_embedding_with_quantized_embedding(module: nn.Module, kwargs={}): nbit = kwargs["nbit"] assert not isinstance(module, nn.Embedding) - assert nbit >= 1 and nbit <= 7 + assert nbit >= 1 and nbit <= 8 for name, child in module.named_children(): if not isinstance(child, nn.Embedding): @@ -516,3 +520,121 @@ def apply(weight): ) return _get_linear_subclass_inserter(apply) + + +class UIntxWeightOnlyQuantizedLinear(nn.Module): + def __init__( + self, + pack_weight_op, + linear_op, + ): + super().__init__() + self._pack_weights_op = pack_weight_op + self._linear_op = linear_op + + def quantize_and_pack_weights(self, weights, nbit, group_size): + self.nbit = nbit + self.group_size = group_size + + weight_qvals, weight_scales, weight_zeros = _quantize( + weights, self.group_size, self.nbit, has_weight_zeros=True, signed=False + ) + weight_scales = torch.transpose_copy(weight_scales, 1, 0) + weight_zeros = torch.transpose_copy(weight_zeros, 1, 0) + self.weight_scales = weight_scales + self.weight_zeros = -weight_zeros * weight_scales + + self.packed_weights = self._pack_weights_op(weight_qvals.cpu()).to(device="mps") + + def forward(self, x): + assert x.dim() >= 2 + if x.dim() == 2: + return self._linear_op( + x, self.packed_weights, self.group_size, self.weight_scales, self.weight_zeros + ) + + lead_shape = x.shape[0:-1] + k = x.shape[-1] + n = self.weight_scales.shape[1] + return self._linear_op( + x.reshape(-1, k), self.packed_weights, self.group_size, self.weight_scales, self.weight_zeros + ).reshape(*lead_shape, n) + +# TODO(mcandales): Consolidate with _replace_linear_with_quantized_linear +def _replace_linear_with_quantized_linear_mps(module: nn.Module, kwargs={}): + group_size = kwargs["group_size"] + nbit = kwargs["nbit"] + + assert not isinstance(module, nn.Linear) + assert nbit >= 1 and nbit <= 7 + + for name, child in module.named_children(): + if not isinstance(child, nn.Linear): + _replace_linear_with_quantized_linear_mps(child, kwargs) + else: + assert child.bias is None + qlinear = UIntxWeightOnlyQuantizedLinear( + pack_weight_op=getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit"), + linear_op=getattr( + torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight" + ), + ) + setattr(module, name, qlinear) + qlinear.quantize_and_pack_weights( + child.weight, nbit, group_size + ) + + +class UIntxWeightOnlyLinearQuantizer: + def __init__( + self, + device, + precision, + *, + bitwidth: Optional[int] = None, + groupsize: Optional[int] = None, + ): + if device != "mps": + raise NotImplementedError( + "Only device=mps is currently supported in UIntxWeightOnlyLinearQuantizer" + ) + else: + self.device = device + + if precision not in [torch.float32, torch.float16, torch.bfloat16]: + raise ValueError( + "Only precisions float32, float16 & bfloat16 are supported in UIntxWeightOnlyLinearQuantizer" + ) + else: + self.precision = precision + + if bitwidth is None: + bitwidth = 4 + logger.warning(f"bitwidth not specified, defaulting to {bitwidth}.") + if bitwidth not in range(1, 8): + raise ValueError( + "Only bitwidts 1 to 7 are supported in UIntxWeightOnlyLinearQuantizer" + ) + else: + self.bitwidth = bitwidth + + if groupsize is None: + groupsize = 128 + logger.warning(f"groupsize not specified, defaulting to {groupsize}.") + if groupsize not in [32, 64, 128, 256]: + raise ValueError( + "Only groupsizes 32, 64, 128 & 256 are supported in UIntxWeightOnlyLinearQuantizer" + ) + else: + self.groupsize = groupsize + + def quantize(self, model: nn.Module) -> nn.Module: + model = model.to(self.device).to(self.precision) + _replace_linear_with_quantized_linear_mps( + model, + kwargs={ + "group_size": self.groupsize, + "nbit": self.bitwidth, + }, + ) + return model diff --git a/torchao/experimental/temp_build.py b/torchao/experimental/temp_build.py new file mode 100644 index 0000000000..fb9d413037 --- /dev/null +++ b/torchao/experimental/temp_build.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import glob +import subprocess +import tempfile +import torch + +def cmake_build_torchao_ops(cmake_lists_path, temp_build_dir): + from distutils.sysconfig import get_python_lib + print("Building torchao ops for ATen target") + cmake_prefix_path = get_python_lib() + subprocess.run( + [ + "cmake", + "-DCMAKE_PREFIX_PATH=" + cmake_prefix_path, + "-DCMAKE_INSTALL_PREFIX=" + temp_build_dir.name, + "-S " + cmake_lists_path, + "-B " + temp_build_dir.name, + ] + ) + subprocess.run( + [ + "cmake", + "--build", + temp_build_dir.name, + "-j 16", + "--target install", + "--config Release", + ] + ) + +def temp_build_and_load_torchao_ops(cmake_lists_path): + temp_build_dir = tempfile.TemporaryDirectory() + cmake_build_torchao_ops(cmake_lists_path, temp_build_dir) + libs = glob.glob(f"{temp_build_dir.name}/lib/libtorchao_ops_aten.*") + libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) + assert len(libs) == 1 + torch.ops.load_library(libs[0]) + print(f"TorchAO ops are loaded from {libs[0]}") diff --git a/torchao/experimental/tests/test_embedding_xbit_quantizer.py b/torchao/experimental/tests/test_embedding_xbit_quantizer.py index a6ba04b439..0eccf33cdb 100644 --- a/torchao/experimental/tests/test_embedding_xbit_quantizer.py +++ b/torchao/experimental/tests/test_embedding_xbit_quantizer.py @@ -65,7 +65,7 @@ def test_accuracy(self): model = torch.nn.Sequential(*[torch.nn.Embedding(num_embeddings, embedding_dim)]) indices = torch.randint(0, num_embeddings, (7,), dtype=torch.int32) - for nbit in [1, 2, 3, 4, 5, 6, 7]: + for nbit in [1, 2, 3, 4, 5, 6, 7, 8]: print(f"Testing nbit={nbit}") quantized_model = copy.deepcopy(model) quantizer = IntxWeightEmbeddingQuantizer( diff --git a/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py b/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py index 5d2828d9bc..aeb19555d7 100644 --- a/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py +++ b/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py @@ -67,7 +67,7 @@ def test_accuracy(self): activations = torch.randn(2, 3, m, k, dtype=torch.float32) model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) - for nbit in [1, 2, 3, 4, 5, 6, 7]: + for nbit in [1, 2, 3, 4, 5, 6, 7, 8]: for has_weight_zeros in [True, False]: print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}") quantized_model = copy.deepcopy(model) diff --git a/torchao/experimental/tests/test_linear_int8_dynamic_activation_intx_weight_subclass.py b/torchao/experimental/tests/test_linear_int8_dynamic_activation_intx_weight_subclass.py index 44e63386ce..d9035cbe3f 100644 --- a/torchao/experimental/tests/test_linear_int8_dynamic_activation_intx_weight_subclass.py +++ b/torchao/experimental/tests/test_linear_int8_dynamic_activation_intx_weight_subclass.py @@ -70,7 +70,7 @@ def test_accuracy(self): activations = torch.randn(m, k, dtype=torch.float32) model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) - for nbit in [1, 2, 3, 4, 5, 6, 7]: + for nbit in [1, 2, 3, 4, 5, 6, 7, 8]: for has_weight_zeros in [True, False]: print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}") quantized_model = copy.deepcopy(model) diff --git a/torchao/float8/README.md b/torchao/float8/README.md index b6bce5dbed..1a87770899 100644 --- a/torchao/float8/README.md +++ b/torchao/float8/README.md @@ -95,8 +95,6 @@ config = Float8LinearConfig( cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), - # enable_amax_init=False, # only needed for autocast + compile + FSDP + float8 delayed - # enable_pre_and_post_forward=False # only needed for autocast + compile + FSDP + float8 delayed ) # convert all `torch.nn.Linear` modules to `Float8Linear`, specifying custom scaling behavior @@ -111,8 +109,11 @@ for _ in range(10): y = m(x) y.sum().backward() - # specific to float8 with delayed scaling: separate step to sync scales/amaxes - # in the future, this may move to a context manager + # Specific to delayed scaling: separate step to sync scales/amaxes. + # On the first call, this function also sets the `is_amax_initialized` flag to + # mark the amax and scale buffers as initialized. + # Make sure you run this after every model forward+backward pass. + # In the future, this may move to a context manager. sync_float8_amax_and_scale_history(m) optimizer.step() diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 1011e93524..de57655e88 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -189,15 +189,12 @@ class Float8LinearConfig: # Per-linear configuration # - # If True, on the first iteration of Float8Linear the amaxes will be - # initialized with the incoming data. As of 2023-12-30, this doesn't work - # with autocast + torch.compile + FSDP. Enabling this option is nice for - # testing, but this is not necessary for real training jobs. + # This configuration option is deprecated and no longer has an effect. It may + # be removed in a future release. enable_amax_init: bool = True - # If True, pre-forward and post-forward functions are run. As of 2023-12-30, - # this doesn't work with autocast + torch.compile + FSDP. Enabling this - # option is useful for safety, but not strictly necessary. + # This configuration option is deprecated and no longer has an effect. It may + # be removed in a future release. enable_pre_and_post_forward: bool = True # If True, then uses a tensor subclass for the float8 linear module's weight that diff --git a/torchao/float8/distributed_utils.py b/torchao/float8/distributed_utils.py index 4c0b36c35d..cd1560fabd 100644 --- a/torchao/float8/distributed_utils.py +++ b/torchao/float8/distributed_utils.py @@ -3,110 +3,25 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Any import torch -from fairscale.nn.model_parallel.initialize import get_model_parallel_group +import torch.distributed._functional_collectives as funcol +from torch.distributed._tensor import DTensor -# from float8_tensor import Float8Tensor from torchao.float8.float8_tensor import Float8Tensor -# additional differentiable distributed primitives for SP which are not in -# the Fairscale codebase - -def _gather_along_first_dim(input_: torch.Tensor): - # same as https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/model_parallel/mappings.py#L67, - # but gather along first dim instead of last dim - group = get_model_parallel_group() - - # Bypass the function if we are using only 1 GPU. - if torch.distributed.get_world_size(group=group) == 1: - return input_ - - # Size and dimension. - first_dim = 0 - rank = torch.distributed.get_rank(group=group) - world_size = torch.distributed.get_world_size(group=group) - - # If the input is a float8 tensor, we need to do the transformation on the - # inner tensor and then return a new wrapper. - def _transform(t): - # tensors must be contiguous for all_gather to work - input_contig = t.contiguous() - - tensor_list = [torch.empty_like(input_contig) for _ in range(world_size)] - tensor_list[rank] = input_contig - torch.distributed.all_gather(tensor_list, input_contig, group=group) - - # Note: torch.cat already creates a contiguous tensor. - output = torch.cat(tensor_list, dim=first_dim).contiguous() - return output - - if isinstance(input_, Float8Tensor): - new_data = input_._data - new_data = new_data.view(torch.int8) - new_data = _transform(new_data) - new_data = new_data.view(input_._data.dtype) - output = Float8Tensor(new_data, input_._scale, input_._orig_dtype) - else: - output = _transform(input_) - - return output - - -def _reduce_scatter(ctx: Any, input_: torch.Tensor): - group = get_model_parallel_group() - world_size = torch.distributed.get_world_size(group) - - assert input_.shape[0] % world_size == 0 - output_shape = (input_.shape[0] // world_size, *input_.shape[1:]) - output = torch.empty(*output_shape, device=input_.device, dtype=input_.dtype) - - torch.distributed.reduce_scatter_tensor(output, input_, group=group) - return output - - -def _split_along_first_dim(input_: torch.Tensor): - # this is needed for testing - - # like fairscale.nn.model_parallel.mappings._split, but - # along the first dim instead of last dim - - group = get_model_parallel_group() - local_rank = torch.distributed.get_rank(group) - world_size = torch.distributed.get_world_size(group) - - assert input_.shape[0] % world_size == 0 - input_list = torch.split(input_, input_.shape[0] // world_size) - return input_list[local_rank] - - -class _AllGatherFloat8FwReduceScatterBw(torch.autograd.Function): - @staticmethod - def forward(ctx, input_): - return _gather_along_first_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - return _reduce_scatter(ctx, grad_output) - - -class _ReduceScatterFwAllGatherFloat8Bw(torch.autograd.Function): - @staticmethod - def forward(ctx, input_): - return _reduce_scatter(ctx, input_) - - @staticmethod - def backward(ctx, grad_output): - return _gather_along_first_dim(grad_output) - - -class _AllGatherFwSplitBw(torch.autograd.Function): - @staticmethod - def forward(ctx, input_): - return _gather_along_first_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - return _split_along_first_dim(grad_output) +def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: + """ + Check if the tensor is already casted to fp8, works if the local + tensor is wrapped in DTensor. + """ + if isinstance(tensor, Float8Tensor): + return True + elif isinstance(tensor, DTensor): + # TODO: shall we stick to public API and directly use tensor.to_local() here? + return tensor_already_casted_to_fp8(tensor._local_tensor) + elif isinstance(tensor, funcol.AsyncCollectiveTensor): + return tensor_already_casted_to_fp8(tensor.elem) + + return False diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 76b0fb9e7d..c34c5be670 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -13,6 +13,7 @@ import torch.utils.checkpoint as checkpoint from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType +from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 from torchao.float8.float8_scaling_utils import ( NoopFwToFloat8BwDelayed, NoopFwToFloat8BwDynamic, @@ -332,14 +333,6 @@ def __init__(self, *args, **kwargs): # TODO(future PR): add serialization for this flag self.is_amax_initialized = not self.config.enable_amax_init - # Syncing of amaxes and scales happens outside of this function. This - # flag is here to enforce that the user does not forget to do this. - self.amax_and_scale_synced = not self.config.enable_amax_init - - # This is needed to properly handle autocast in the amax/scale - # update function for torch.float16 - self.last_seen_input_dtype = None - # pre_forward and post_forward are currently broken with FSDP # and torch.compile, this option can disable them # Note that when using `self.config.enable_pre_and_post_forward = False`, @@ -472,7 +465,7 @@ def cast_input_to_float8( return input_fp8 def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]: - if isinstance(weight, Float8Tensor): + if tensor_already_casted_to_fp8(weight): return None if self.scaling_type_weight is ScalingType.DELAYED: scale_fn_name = self.config.delayed_scaling_config.scale_fn_name @@ -500,7 +493,7 @@ def cast_weight_to_float8_t( is_amax_initialized: bool, weight_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if isinstance(weight, Float8Tensor): + if tensor_already_casted_to_fp8(weight): return weight.t() weight_fp8 = hp_tensor_and_scale_to_float8( weight, @@ -547,25 +540,16 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: return output def float8_pre_forward(self, input): + # TODO(future PR): deprecate these functions and the corresponding + # config setting if not self.enable_pre_and_post_forward: return - if ( - self.is_amax_initialized - and (not self.amax_and_scale_synced) - and torch.is_grad_enabled() - ): - raise AssertionError( - "amaxes and scales not synced, please call `sync_float8_amax_and_scale_history` before forward" - ) - self.last_seen_input_dtype = input.dtype def float8_post_forward(self): + # TODO(future PR): deprecate these functions and the corresponding + # config setting if not self.enable_pre_and_post_forward: return - # Ensure that calling forward again will fail until the user syncs - # amaxes and scales - self.is_amax_initialized = True - self.amax_and_scale_synced = False def forward_fp8_matmul(self, input: torch.Tensor) -> torch.Tensor: has_any_axiswise_scaling = any( diff --git a/torchao/float8/float8_linear_utils.py b/torchao/float8/float8_linear_utils.py index 96a9cb90b4..37453d8cfe 100644 --- a/torchao/float8/float8_linear_utils.py +++ b/torchao/float8/float8_linear_utils.py @@ -191,6 +191,9 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) and we loop over all fp8_layers to sync and update amax scale histories. Users can use get_float8_layers to get all fp8 layers. """ + # TODO(future): consider adding a flag to control setting the `is_amax_initialized` + # flag only on the first iteration. + if fp8_layers is None: fp8_layers = get_float8_layers(model) @@ -225,7 +228,6 @@ def inner_func(): input_dtypes = set() weight_dtypes = set() grad_output_dtypes = set() - x_dtypes = set() scale_fn_recipes = set() for idx, child in enumerate(fp8_layers): @@ -240,20 +242,12 @@ def inner_func(): input_dtypes.add(child.config.cast_config_input.dtype) weight_dtypes.add(child.config.cast_config_weight.dtype) grad_output_dtypes.add(child.config.cast_config_grad_output.dtype) - x_dtypes.add(child.last_seen_input_dtype) scale_fn_recipes.add(child.config.delayed_scaling_config.scale_fn_name) (input_dtype,) = input_dtypes (weight_dtype,) = weight_dtypes (grad_output_dtype,) = grad_output_dtypes - # TODO This way to get the activation dtype is not ideal - if len(x_dtypes) != 1: - raise ValueError( - f"All layers must have the same last seen input_dtype, got {x_dtypes}" - ) - x_dtype = next(iter(x_dtypes)) - if len(scale_fn_recipes) != 1: raise ValueError( f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}" @@ -311,16 +305,13 @@ def inner_func(): # Calculate the new scales from the updated history stacks new_input_scales = amax_history_to_scale_stack( - fp8_input_amax_history_stack, input_dtype, x_dtype, scale_fn_recipe + fp8_input_amax_history_stack, input_dtype, scale_fn_recipe ) new_weight_scales = amax_history_to_scale_stack( - fp8_weight_amax_history_stack, weight_dtype, x_dtype, scale_fn_recipe + fp8_weight_amax_history_stack, weight_dtype, scale_fn_recipe ) new_grad_output_scales = amax_history_to_scale_stack( - fp8_grad_output_amax_history_stack, - grad_output_dtype, - x_dtype, - scale_fn_recipe, + fp8_grad_output_amax_history_stack, grad_output_dtype, scale_fn_recipe, ) # Iterate through the layers and update the scales @@ -329,10 +320,10 @@ def inner_func(): child.fp8_scale_weight.copy_(new_weight_scales[idx]) child.fp8_scale_grad_output.copy_(new_grad_output_scales[idx]) - # This allows for the compile to succede on the inner func and fail on the graph breaks + # This allows for the compile to succeed on the inner func and fail on the graph breaks # at the beginning and and of syncing inner_func() for child in fp8_layers: - # Set a flag to signal amaxes/scales are ready - child.amax_and_scale_synced = True + # Set a flag to signal that initialization is done + child.is_amax_initialized = True diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index 921d50e093..2af4160de4 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -85,7 +85,10 @@ def float8_desugar_data_and_scale_op(aten_op, args, kwargs=None): ) def float8_transpose(aten_op, args, kwargs=None): new_data = aten_op(args[0]._data, *args[1:], **kwargs) - new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) + if args[0]._scale.ndim > 1: + new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) + else: + new_scale = args[0]._scale if aten_op == aten.transpose.int: _assert_tensorwise_scale(aten_op, args[0]._scale) diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index 2da7c6028b..dec03f1ebb 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -13,12 +13,12 @@ import torch from torchao.float8.config import ScalingGranularity +from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, LinearMMConfig, hp_tensor_and_scale_to_float8, - tensor_already_casted_to_fp8, ) from torchao.float8.float8_utils import ( amax_history_to_scale, @@ -176,9 +176,7 @@ def _maybe_initialize_amaxes_scales_for_float8_cast( new_amax = tensor_to_amax(x, reduce_amax=reduce_amax) cur_amax.fill_(new_amax) amax_history[0] = new_amax - new_scale = amax_history_to_scale( - amax_history, float8_dtype, x.dtype, scale_fn_name - ) + new_scale = amax_history_to_scale(amax_history, float8_dtype, scale_fn_name) scale.copy_(new_scale) diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py index 1aed6cebdc..fe2498e2b0 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_tensor.py @@ -7,7 +7,6 @@ from typing import Dict, NamedTuple, Optional import torch -import torch.distributed._functional_collectives as funcol from torch.distributed._tensor import DTensor from torchao.float8.float8_utils import ( @@ -120,21 +119,6 @@ def choose_scaled_mm_config( raise AssertionError(f"unexpected a_role {a_role} and b_role {b_role}") -def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: - """ - Check if the tensor is already casted to fp8 - """ - if isinstance(tensor, Float8Tensor): - return True - elif isinstance(tensor, DTensor): - # TODO: shall we stick to public API and directly use tensor.to_local() here? - return tensor_already_casted_to_fp8(tensor._local_tensor) - elif isinstance(tensor, funcol.AsyncCollectiveTensor): - return tensor_already_casted_to_fp8(tensor.elem) - - return False - - @torch._dynamo.allow_in_graph class _ToFloat8ConstrFunc(torch.autograd.Function): """ diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 06735c30d4..90927659f8 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -8,6 +8,7 @@ import torch import torch.distributed as dist +from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce from torchao.float8.config import ScalingGranularity @@ -28,14 +29,11 @@ @torch.no_grad() -def amax_to_scale( - amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype -): +def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype): """Converts the amax value of a tensor to the fp8 scale. Args: amax: The amax value of the tensor. float8_dtype: The float8 dtype. - orig_dtype: The original dtype of the tensor. """ # torch.compile and eager show different numerics for 1.0 / float32, # upcast to float64 to ensure same numeric between compile and eager @@ -45,11 +43,6 @@ def amax_to_scale( else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") - # Ensure that the scale is representable in float16, - # this helps when amax is small. We are assuming that we don't need - # to care about this for float32/bfloat16. - if orig_dtype is torch.float16: - res = torch.clamp(res, max=torch.finfo(torch.float16).max) return res.to(torch.float32) @@ -57,19 +50,17 @@ def amax_to_scale( def amax_history_to_scale( amax_history: torch.Tensor, float8_dtype: torch.Tensor, - orig_dtype: torch.dtype, history_to_scale_fn_type: Literal["max"], ): """Takes in a history of amax values and returns a scale tensor. Args: amax_history: A tensor containing the history of amax values. float8_dtype: The float8 dtype. - orig_dtype: The original dtype of the tensor. history_to_scale_fn_type: The type of function to use to convert the history to a scale. """ if history_to_scale_fn_type == "max": amax = torch.max(amax_history) - return amax_to_scale(amax, float8_dtype, orig_dtype) + return amax_to_scale(amax, float8_dtype) raise NotImplementedError() @@ -77,19 +68,17 @@ def amax_history_to_scale( def amax_history_to_scale_stack( amax_history: torch.Tensor, float8_dtype: torch.dtype, - orig_dtype: torch.dtype, history_to_scale_fn_type: Literal["max"], ) -> torch.Tensor: """Takes in a stack of amax_history tensors and returns a scale tensor. Args: amax_history: A 2D tensor containing a stack of amax histories. float8_dtype: The float8 dtype. - orig_dtype: The original dtype of the tensor. history_to_scale_fn_type: The type of function to use to convert the history to a scale. """ if history_to_scale_fn_type == "max": amax_stack = torch.max(amax_history, dim=1).values - return amax_to_scale(amax_stack, float8_dtype, orig_dtype) + return amax_to_scale(amax_stack, float8_dtype) raise NotImplementedError( f"Invalid history_to_scale_fn_type, only 'max' is supported. Got: {history_to_scale_fn_type}" ) @@ -115,7 +104,11 @@ def tensor_to_amax( # happen elsewhere. if reduce_amax and dist.is_initialized(): pg = device_mesh.get_group() if device_mesh is not None else None - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=pg) + # dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=pg) + group = list(range(dist.get_world_size())) if pg is None else pg + amax = all_reduce(amax, "MAX", group) + if isinstance(amax, AsyncCollectiveTensor): + amax = amax.wait() return amax @@ -136,7 +129,7 @@ def tensor_to_scale( scaling_granularity, axiswise_dim, ) - return amax_to_scale(amax, float8_dtype, x.dtype) + return amax_to_scale(amax, float8_dtype) def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype): diff --git a/torchao/ops.py b/torchao/ops.py index 9713f68eb2..2774deb08a 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -3,13 +3,22 @@ from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 - lib = torch.library.Library("torchao", "FRAGMENT") -lib.define("quant_llm_linear(int EXPONENT, int MANTISSA, Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor") -lib.define("unpack_tensor_core_tiled_layout(Tensor packed_w, int inner_k_tiles) -> Tensor") -lib.define("dequantize_tensor_core_tiled_layout(Tensor packed_w, Tensor scales_and_zeros, int group_size, int inner_k_tiles) -> Tensor") -lib.define("marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor") -lib.define("marlin_qqq_gemm(Tensor x, Tensor weight_marlin, Tensor s_tok, Tensor s_ch, Tensor s_group, Tensor workspace, int size_m, int size_n, int size_k) -> Tensor") +lib.define( + "quant_llm_linear(int EXPONENT, int MANTISSA, Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor" +) +lib.define( + "unpack_tensor_core_tiled_layout(Tensor packed_w, int inner_k_tiles) -> Tensor" +) +lib.define( + "dequantize_tensor_core_tiled_layout(Tensor packed_w, Tensor scales_and_zeros, int group_size, int inner_k_tiles) -> Tensor" +) +lib.define( + "marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor" +) +lib.define( + "marlin_qqq_gemm(Tensor x, Tensor weight_marlin, Tensor s_tok, Tensor s_ch, Tensor s_group, Tensor workspace, int size_m, int size_n, int size_k) -> Tensor" +) def register_custom_op(name): @@ -18,6 +27,7 @@ def decorator(func): return torch.library.register_fake(f"{name}")(func) else: return torch.library.impl_abstract(f"{name}")(func) + return decorator @@ -43,7 +53,9 @@ def quant_llm_linear( Returns output of linear layer """ - return torch.ops.torchao.quant_llm_linear.default(EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK) + return torch.ops.torchao.quant_llm_linear.default( + EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK + ) @register_custom_op("torchao::quant_llm_linear") @@ -55,12 +67,29 @@ def _( _scales: Tensor, splitK: int = 1, ) -> Tensor: - torch._check(_in_feats.dim() == 2, lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D") - torch._check(_in_feats.dtype in (torch.float16, torch.bfloat16), lambda: f"weight must be FP16 or BF16, got {_in_feats.dtype}") - torch._check(_weights.dim() == 2, lambda: f"weight should be a 2d tensor, got {_weights.dim()}D") - torch._check(_weights.dtype is torch.uint8, lambda: f"weight must be UINT8, got {_weights.dtype}") - torch._check(_scales.dim() == 1, lambda: f"scale should be a 2d tensor, got {_scales.dim()}D") - torch._check(_scales.dtype in (torch.float16, torch.bfloat16), lambda: f"scale must be FP16 or BF16, got {_scales.dtype}") + torch._check( + _in_feats.dim() == 2, + lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D", + ) + torch._check( + _in_feats.dtype in (torch.float16, torch.bfloat16), + lambda: f"weight must be FP16 or BF16, got {_in_feats.dtype}", + ) + torch._check( + _weights.dim() == 2, + lambda: f"weight should be a 2d tensor, got {_weights.dim()}D", + ) + torch._check( + _weights.dtype is torch.uint8, + lambda: f"weight must be UINT8, got {_weights.dtype}", + ) + torch._check( + _scales.dim() == 1, lambda: f"scale should be a 2d tensor, got {_scales.dim()}D" + ) + torch._check( + _scales.dtype in (torch.float16, torch.bfloat16), + lambda: f"scale must be FP16 or BF16, got {_scales.dtype}", + ) BS, IC = _in_feats.shape OC, _ = _weights.shape @@ -71,7 +100,6 @@ def _( return _in_feats.new_empty((BS, OC)) - def unpack_tensor_core_tiled_layout(packed_w: Tensor, inner_k_tiles: int) -> Tensor: """ Unpacks weights that were packed with `torch.ops.aten._convert_weight_to_int4pack` to original tensor of shape `N x K`. @@ -115,7 +143,10 @@ def _(packed_w: Tensor, inner_k_tiles: int) -> Tensor: return torch.empty((N, K), dtype=torch.int32, device=packed_w.device) -def dequantize_tensor_core_tiled_layout(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles: int) -> Tensor: + +def dequantize_tensor_core_tiled_layout( + packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles: int +) -> Tensor: """ Dequantizes by: - Unpacking weights that were packed with `torch.ops.aten._convert_weight_to_int4pack` to original tensor of shape `N x K` @@ -143,7 +174,9 @@ def dequantize_tensor_core_tiled_layout(packed_w: Tensor, scales_and_zeros: Tens @register_custom_op("torchao::dequantize_tensor_core_tiled_layout") -def _(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles: int) -> Tensor: +def _( + packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles: int +) -> Tensor: # packed_w preconditions torch._check( packed_w.dim() == 4, @@ -166,12 +199,28 @@ def _(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles K = packed_w.size(1) * inner_k_tiles * 16 # scales_and_zeros preconditions - torch._check(scales_and_zeros.dtype is torch.bfloat16, lambda: "scales_and_zeros must be bfloat16") - torch._check(scales_and_zeros.dim() == 3, lambda: "scales_and_zeros must be 3D, got {scales_and_zeros.dim()}") - torch._check(group_size == 32 or group_size == 64 or group_size == 128 or group_size == 256, lambda: "qGroupSize must be 32, 64, 128, or 256") - torch._check(scales_and_zeros.size(0) == K // group_size, lambda: "scales_and_zeros must have K // qGroupSize at dim 0") - torch._check(scales_and_zeros.size(1) == N, lambda: "scales_and_zeros must have N at dim 1") - torch._check(scales_and_zeros.size(2) == 2, lambda: "scales_and_zeros must have 2 at dim 2") + torch._check( + scales_and_zeros.dtype is torch.bfloat16, + lambda: "scales_and_zeros must be bfloat16", + ) + torch._check( + scales_and_zeros.dim() == 3, + lambda: "scales_and_zeros must be 3D, got {scales_and_zeros.dim()}", + ) + torch._check( + group_size == 32 or group_size == 64 or group_size == 128 or group_size == 256, + lambda: "qGroupSize must be 32, 64, 128, or 256", + ) + torch._check( + scales_and_zeros.size(0) == K // group_size, + lambda: "scales_and_zeros must have K // qGroupSize at dim 0", + ) + torch._check( + scales_and_zeros.size(1) == N, lambda: "scales_and_zeros must have N at dim 1" + ) + torch._check( + scales_and_zeros.size(2) == 2, lambda: "scales_and_zeros must have 2 at dim 2" + ) return torch.empty((N, K), dtype=torch.bfloat16, device=packed_w.device) @@ -224,27 +273,55 @@ def _( MAX_PARALLELISM = 64 # Verify num_bits - torch._check(bits == 4 or bits == 8, lambda: f"num_bits must be 4 or 8. Got = {bits}") + torch._check( + bits == 4 or bits == 8, lambda: f"num_bits must be 4 or 8. Got = {bits}" + ) pack_factor = 32 // bits # Verify M - torch._check(size_m == x.size(0), lambda: f"Shape mismatch: x.size(0) = {x.size(0)}, size_m = {size_m}") + torch._check( + size_m == x.size(0), + lambda: f"Shape mismatch: x.size(0) = {x.size(0)}, size_m = {size_m}", + ) # Verify K - torch._check(size_k == x.size(1), lambda: f"Shape mismatch: x.size(1) = {x.size(1)}, size_k = {size_k}") - torch._check(size_k % TILE_SIZE == 0, lambda: f"size_k = {size_k} is not divisible by tile_size = {TILE_SIZE}") - torch._check((size_k // TILE_SIZE // 2) == weight_marlin.size(0), lambda: f"Shape mismatch: weight_marlin.size(0) = {weight_marlin.size(0)}, size_k = {size_k}, tile_size = {TILE_SIZE}") + torch._check( + size_k == x.size(1), + lambda: f"Shape mismatch: x.size(1) = {x.size(1)}, size_k = {size_k}", + ) + torch._check( + size_k % TILE_SIZE == 0, + lambda: f"size_k = {size_k} is not divisible by tile_size = {TILE_SIZE}", + ) + torch._check( + (size_k // TILE_SIZE // 2) == weight_marlin.size(0), + lambda: f"Shape mismatch: weight_marlin.size(0) = {weight_marlin.size(0)}, size_k = {size_k}, tile_size = {TILE_SIZE}", + ) # Verify N - torch._check(s.size(1) == size_n, lambda: f"s.size(1) = {s.size(1)}, size_n = {size_n}") - torch._check(weight_marlin.size(1) % TILE_SIZE == 0, lambda: f"weight_marlin.size(1) = {weight_marlin.size(1)} is not divisible by tile_size = {TILE_SIZE}") + torch._check( + s.size(1) == size_n, lambda: f"s.size(1) = {s.size(1)}, size_n = {size_n}" + ) + torch._check( + weight_marlin.size(1) % TILE_SIZE == 0, + lambda: f"weight_marlin.size(1) = {weight_marlin.size(1)} is not divisible by tile_size = {TILE_SIZE}", + ) actual_size_n = (weight_marlin.size(1) // TILE_SIZE) * pack_factor - torch._check(size_n == actual_size_n, lambda: f"size_n = {size_n}, actual_size_n = {actual_size_n}") + torch._check( + size_n == actual_size_n, + lambda: f"size_n = {size_n}, actual_size_n = {actual_size_n}", + ) # Verify meta - torch._check(meta.size(0) == size_k // 8 // 2 // 2, lambda: f"meta.size(0) = {meta.size(0)} is not size_k / 8 / 2 / 2 = {size_k // 8 // 2 // 2}") - torch._check(meta.size(1) == size_n * 2, lambda: f"meta.size(1) = {meta.size(1)} is not size_n * 2 = {size_n * 2}") + torch._check( + meta.size(0) == size_k // 8 // 2 // 2, + lambda: f"meta.size(0) = {meta.size(0)} is not size_k / 8 / 2 / 2 = {size_k // 8 // 2 // 2}", + ) + torch._check( + meta.size(1) == size_n * 2, + lambda: f"meta.size(1) = {meta.size(1)} is not size_n * 2 = {size_n * 2}", + ) # Verify A device and strides torch._check(x.is_cuda, lambda: "x is not on GPU") @@ -252,7 +329,9 @@ def _( # Verify B device and strides torch._check(weight_marlin.is_cuda, lambda: "weight_marlin is not on GPU") - torch._check(weight_marlin.is_contiguous(), lambda: "weight_marlin is not contiguous") + torch._check( + weight_marlin.is_contiguous(), lambda: "weight_marlin is not contiguous" + ) # Verify meta device and strides torch._check(meta.is_cuda, lambda: "meta is not on GPU") @@ -265,15 +344,27 @@ def _( # Verify groupsize groupsize = -1 if s.size(0) > 1: - torch._check(size_k % s.size(0) == 0, lambda: f"size_k = {size_k} is not divisible by s.size(0) = {s.size(0)}") + torch._check( + size_k % s.size(0) == 0, + lambda: f"size_k = {size_k} is not divisible by s.size(0) = {s.size(0)}", + ) groupsize = size_k // s.size(0) groupsize //= 2 # Because of 24 - torch._check(groupsize == -1 or groupsize == 64, lambda: f"Unexpected groupsize = {groupsize}") + torch._check( + groupsize == -1 or groupsize == 64, + lambda: f"Unexpected groupsize = {groupsize}", + ) # Verify workspace size - torch._check(size_n % MIN_THREAD_N == 0, lambda: f"size_n = {size_n} is not divisible by min_thread_n = {MIN_THREAD_N}") + torch._check( + size_n % MIN_THREAD_N == 0, + lambda: f"size_n = {size_n} is not divisible by min_thread_n = {MIN_THREAD_N}", + ) min_workspace_size = (size_n // MIN_THREAD_N) * MAX_PARALLELISM - torch._check(workspace.numel() >= min_workspace_size, lambda: f"workspace.numel = {workspace.numel()} is below min_workspace_size = {min_workspace_size}") + torch._check( + workspace.numel() >= min_workspace_size, + lambda: f"workspace.numel = {workspace.numel()} is below min_workspace_size = {min_workspace_size}", + ) return torch.empty((x.size(0), s.size(1)), dtype=x.dtype, device=x.device) diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 0a26ab98d3..d0f3ebc0d6 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -9,7 +9,7 @@ ) from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter -from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayout +from torchao.dtypes.uintx.uintx_layout import _DTYPE_TO_BIT_WIDTH, UintxLayout from torchao.dtypes import( to_affine_quantized_intx, TensorCoreTiledLayout, diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index 034d73639e..89f615e9ea 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayout +from torchao.dtypes.uintx.uintx_layout import _DTYPE_TO_BIT_WIDTH, UintxLayout from torchao.dtypes import to_affine_quantized_intx from torchao.quantization.granularity import Granularity from torchao.quantization.quant_primitives import ( diff --git a/torchao/prototype/hqq/README.md b/torchao/prototype/hqq/README.md index 8bf1d34260..1bdbcd96e1 100644 --- a/torchao/prototype/hqq/README.md +++ b/torchao/prototype/hqq/README.md @@ -83,7 +83,7 @@ Initial benchmarking (on `A6000`) demonstrates promising results, scaling well f - Times are in `ms`, see `benchmarks/benchmark_hqq.py`. - `hqq_ref` is the base `HQQ_Linear` [module](https://github.com/mobiusml/hqq/blob/6d50eee4bcdd99cc10716f1297c5b2803d2b6da4/hqq/core/quantize.py#L349) that is unfused (dequantization followed by call to torch.matmul). -- `tinygemm` calls `torch.ops.aten._weight_int4pack_mm`. Implementation is a custom HQQLinear layer that wraps the preprocessing necessary for this kernel, adapted from a benchmark script posted by @mobicham from `CUDA-mode` Discord discussions. +- `tinygemm` calls `torch.ops.aten._weight_int4pack_mm` or `torch.ops.aten._weight_int4pack_mm_for_cpu`. Implementation is a custom HQQLinear layer that wraps the preprocessing necessary for this kernel, adapted from a benchmark script posted by @mobicham from `CUDA-mode` Discord discussions. GPU details: diff --git a/torchao/prototype/hqq/example.py b/torchao/prototype/hqq/example.py index 07d5dea205..eb12b2b45e 100644 --- a/torchao/prototype/hqq/example.py +++ b/torchao/prototype/hqq/example.py @@ -2,13 +2,12 @@ from torchao.prototype.hqq.core import HQQQuantizer from torchao.dtypes.affine_quantized_tensor import ( to_affine_quantized_intx, - ZeroPointDomain, - PlainAQTTensorImpl, - PlainLayout, - TensorCoreTiledAQTTensorImpl, - TensorCoreTiledLayout, - MappingType, ) +from torchao.quantization import ( + ZeroPointDomain, + MappingType, +) +from torchao.dtypes import TensorCoreTiledLayout, PlainLayout #Parameters device, compute_dtype = "cuda:0", torch.bfloat16 diff --git a/torchao/prototype/hqq/hqq_tinygemm_linear.py b/torchao/prototype/hqq/hqq_tinygemm_linear.py index 8abdad039a..743c6128a7 100644 --- a/torchao/prototype/hqq/hqq_tinygemm_linear.py +++ b/torchao/prototype/hqq/hqq_tinygemm_linear.py @@ -12,7 +12,8 @@ from hqq.core.utils import * import torch.nn.functional as F -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6 +from torchao.dtypes.utils import is_device class HQQLinearTorchWeightOnlyInt4(torch.nn.Module): @@ -162,9 +163,14 @@ def process_hqq_quants(self, W_q, meta): W_q_torch, scales_torch, zeros_torch = self.hqq_quants_to_torch_quants( W_q=W_q, scales=scales, zeros=zeros, shape=shape, nbits=self.nbits ) - self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( - W_q_torch, self.inner_k_tiles - ) + if is_device(W_q.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + W_q_torch, self.inner_k_tiles + ) + else: + self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + W_q_torch, self.inner_k_tiles + ) self.scales_and_zeros = self.pack_scales_and_zeros(scales_torch, zeros_torch) del W_q_torch, scales_torch, zeros_torch @@ -200,7 +206,8 @@ def hqq_quants_to_torch_quants( .contiguous() ) if TORCH_VERSION_AT_LEAST_2_5: - W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8) + if not is_device(W_q.device.type, "cpu"): + W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8) # group_dequantize_tensor_from_qparams # W_r = W_q*scales + min_val @@ -232,9 +239,14 @@ def pack_scales_and_zeros(self, scales, zeros): def matmul(self, x): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) - c = torch.ops.aten._weight_int4pack_mm( - x, self.weight_int4pack, self.groupsize, self.scales_and_zeros - ) + if is_device(x.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + c = torch.ops.aten._weight_int4pack_mm_for_cpu( + x, self.weight_int4pack, self.groupsize, self.scales_and_zeros + ) + else: + c = torch.ops.aten._weight_int4pack_mm( + x, self.weight_int4pack, self.groupsize, self.scales_and_zeros + ) new_shape = origin_x_size[:-1] + (self.out_features,) c = c.reshape(new_shape) return c diff --git a/torchao/prototype/low_bit_optim/README.md b/torchao/prototype/low_bit_optim/README.md index bd66262609..6358574e45 100644 --- a/torchao/prototype/low_bit_optim/README.md +++ b/torchao/prototype/low_bit_optim/README.md @@ -80,7 +80,7 @@ All of our low-bit optimizers mentioned above also support `bf16_stochastic_roun ## Optimizer CPU offload -This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. Only CUDA is supported. For multi-GPU training, you can use FSDP's built-in CPU offload. +This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. Only CUDA and XPU is supported. For multi-GPU training, you can use FSDP's built-in CPU offload. ```python import torch @@ -97,7 +97,7 @@ optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, offload_gradi This will reduce GPU memory usage by optimizer state size, and additionally gradient size if `offload_gradients=True`. `CPUOffloadOptimizer` can wrap any base optimizer. -For saving and loading `CPUOffloadOptimizer`, it is important that you load model's weights BEFORE creating the optimizer, since we create a CPU copy of the parameters inside `CPUOffloadOptimizer.__init__()`. (TODO: we might want to have a method to synchronize CUDA and CPU params in either direction CPU->CUDA and CUDA->CPU, in case they are out of sync.) +For saving and loading `CPUOffloadOptimizer`, it is important that you load model's weights BEFORE creating the optimizer, since we create a CPU copy of the parameters inside `CPUOffloadOptimizer.__init__()`. (TODO: we might want to have a method to synchronize GPU and CPU params in either direction CPU->GPU and GPU->CPU, in case they are out of sync.) ```python ckpt = torch.load("checkpoint.pth") diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index 1c3718972b..9cad9777bf 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -55,23 +55,29 @@ def __setstate__(self, state): def _subclass_zeros(p: Tensor, signed: bool, block_size: int): raise NotImplementedError - # follow bitsandbytes, only quantize tensors >= 4096 values - # also wrap subclass in DTensor when needed def _new_buffer(self, p: Tensor, signed: bool): - if p.numel() >= 4096 and p.numel() % self.block_size == 0: - if isinstance(p, DTensor): - out = DTensor.from_local( - local_tensor=self._subclass_zeros( - p.to_local(), signed, self.block_size - ), - device_mesh=p.device_mesh, - placements=p.placements, - run_check=False, - ) - else: - out = self._subclass_zeros(p, signed, self.block_size) + local_p = p.to_local() if isinstance(p, DTensor) else p + + # follow bitsandbytes, only quantize tensors >= 4096 values + if local_p.numel() >= 4096 and local_p.numel() % self.block_size == 0: + out = self._subclass_zeros(local_p, signed, self.block_size) else: - out = torch.zeros_like(p) + out = torch.zeros_like(local_p) + + # wrap subclass in DTensor as needed + # NOTE: local tensor may have different shapes across ranks. + # this happens when the 1st dim is not divisible by WORLD_SIZE. + # thus, we must supply shape (and stride) to DTensor.from_local() + if isinstance(p, DTensor): + out = DTensor.from_local( + local_tensor=out, + device_mesh=p.device_mesh, + placements=p.placements, + run_check=False, + shape=p.shape, + stride=p.stride(), + ) + return out @torch.no_grad() @@ -111,8 +117,12 @@ def step(self, closure=None): "optim.param_groups[0]['lr'].fill_(new_lr)" ) + # without calling p.detach(), torch.compile() will have issues with FSDP2 in some cases + # https://github.com/pytorch/ao/issues/652#issuecomment-2285040894 + # thus, by calling p.detach(), DTensor won't have .grad anymore, which is ok since we + # are passing grad separately anyway. torch.compile(single_param_adam, fullgraph=True, dynamic=False)( - p, + p.detach(), grad, state["step"], state["exp_avg"], diff --git a/torchao/prototype/low_bit_optim/cpu_offload.py b/torchao/prototype/low_bit_optim/cpu_offload.py index c69932aa4c..90008f67fe 100644 --- a/torchao/prototype/low_bit_optim/cpu_offload.py +++ b/torchao/prototype/low_bit_optim/cpu_offload.py @@ -3,7 +3,7 @@ import torch from torch.optim.optimizer import Optimizer, ParamsT -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, get_available_devices class CPUOffloadOptimizer: @@ -38,51 +38,56 @@ def __init__( if not isinstance(param_groups[0], dict): param_groups = [{"params": param_groups}] - self.param_cuda2cpu_map = dict() + self.param_d2h_map = dict() self.optim_dict = dict() - self.stream = torch.cuda.Stream() + self.device = get_available_devices()[-1] + assert self.device in [ + "cuda", + "xpu", + ], "CPU Offload currently only supports CUDA & XPU" + self.stream = getattr(torch, self.device).Stream() # the queue maintains the order which param we should do optim step on first. self.queue = dict() - def backward_hook(p_cuda): - if p_cuda.grad is not None: - p_cpu = self.param_cuda2cpu_map[p_cuda] + def backward_hook(p_device): + if p_device.grad is not None: + p_host = self.param_d2h_map[p_device] # make sure backward for this param finishes - self.stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(self.stream): - p_cpu.grad.copy_(p_cuda.grad, non_blocking=True) + self.stream.wait_stream(getattr(torch, self.device).current_stream()) + with getattr(torch, self.device).stream(self.stream): + p_host.grad.copy_(p_device.grad, non_blocking=True) # we rely on CPython implementation of dictionary, which preserves insertion order. # if a param is added again (e.g. due to gradient accumulation), it is moved to the # end of the queue by removing and inserting it again. - if p_cuda in self.queue: - del self.queue[p_cuda] - self.queue[p_cuda] = self.stream.record_event() + if p_device in self.queue: + del self.queue[p_device] + self.queue[p_device] = self.stream.record_event() - # deallocate CUDA gradients once D2H transfer finishes. + # deallocate DEVICE gradients once D2H transfer finishes. if offload_gradients: - p_cuda.grad.record_stream(self.stream) - p_cuda.grad = None + p_device.grad.record_stream(self.stream) + p_device.grad = None for param_group in param_groups: params = param_group.pop("params") - for p_cuda in params: - if not p_cuda.requires_grad: + for p_device in params: + if not p_device.requires_grad: continue # pre-allocate CPU params and grads - p_cpu = torch.empty_like(p_cuda, device="cpu", pin_memory=True) - p_cpu.grad = torch.empty_like(p_cpu, pin_memory=True) + p_host = torch.empty_like(p_device, device="cpu", pin_memory=True) + p_host.grad = torch.empty_like(p_host, pin_memory=True) - p_cpu.copy_(p_cuda.detach(), non_blocking=True) - self.param_cuda2cpu_map[p_cuda] = p_cpu + p_host.copy_(p_device.detach(), non_blocking=True) + self.param_d2h_map[p_device] = p_host - p_cuda.register_post_accumulate_grad_hook(backward_hook) - self.optim_dict[p_cuda] = optimizer_class( - [{"params": p_cpu, **param_group}], **kwargs + p_device.register_post_accumulate_grad_hook(backward_hook) + self.optim_dict[p_device] = optimizer_class( + [{"params": p_host, **param_group}], **kwargs ) @torch.no_grad() @@ -91,16 +96,16 @@ def step(self, closure=None): if closure is not None: loss = closure() - for p_cuda, grad_d2h_event in self.queue.items(): + for p_device, grad_d2h_event in self.queue.items(): grad_d2h_event.synchronize() - self.optim_dict[p_cuda].step() + self.optim_dict[p_device].step() # submit more job to self.stream. it guarantees that we only start # moving param H2D once all backwards finish, since self.stream # will wait for current_stream when moving grad D2H. - p_cpu = self.param_cuda2cpu_map[p_cuda] - with torch.cuda.stream(self.stream): - p_cuda.copy_(p_cpu, non_blocking=True) + p_host = self.param_d2h_map[p_device] + with getattr(torch, self.device).stream(self.stream): + p_device.copy_(p_host, non_blocking=True) self.queue.clear() return loss @@ -108,9 +113,9 @@ def step(self, closure=None): def zero_grad(self, set_to_none=True): assert set_to_none - # only clear CUDA grad. CPU grad will always be overwritten by CUDA grad. - for p_cuda in self.param_cuda2cpu_map.keys(): - p_cuda.grad = None + # only clear DEVICE grad. CPU grad will always be overwritten by DEVICE grad. + for p_device in self.param_d2h_map.keys(): + p_device.grad = None @property def param_groups(self): diff --git a/torchao/profiler/__init__.py b/torchao/prototype/profiler/__init__.py similarity index 99% rename from torchao/profiler/__init__.py rename to torchao/prototype/profiler/__init__.py index e748438e87..976d4e3a05 100644 --- a/torchao/profiler/__init__.py +++ b/torchao/prototype/profiler/__init__.py @@ -1,4 +1,3 @@ - # Re-exports from .device_spec import CUDADeviceSpec, DeviceSpec from .performance_counter import ( @@ -20,4 +19,3 @@ "DeviceSpec", "total_model_params", ] - diff --git a/torchao/profiler/device_spec.py b/torchao/prototype/profiler/device_spec.py similarity index 100% rename from torchao/profiler/device_spec.py rename to torchao/prototype/profiler/device_spec.py diff --git a/torchao/profiler/performance_counter.py b/torchao/prototype/profiler/performance_counter.py similarity index 100% rename from torchao/profiler/performance_counter.py rename to torchao/prototype/profiler/performance_counter.py diff --git a/torchao/profiler/utils.py b/torchao/prototype/profiler/utils.py similarity index 100% rename from torchao/profiler/utils.py rename to torchao/prototype/profiler/utils.py diff --git a/torchao/prototype/quantization/autoquant_v2.py b/torchao/prototype/quantization/autoquant_v2.py new file mode 100644 index 0000000000..bf6dbb2a46 --- /dev/null +++ b/torchao/prototype/quantization/autoquant_v2.py @@ -0,0 +1,1424 @@ +import copy +import csv +import logging +import os +import re +from itertools import chain + +import torch +import torch.nn.functional as F +from torch.utils._python_dispatch import return_and_correct_aliasing +from torch.utils._pytree import tree_map + +import torchao +from torchao.dtypes import ( + AffineQuantizedTensor, + Float8Layout, + PlainLayout, + TensorCoreTiledLayout, +) +from torchao.float8.inference import Float8MMConfig +from torchao.kernel import safe_int_mm +from torchao.quantization.linear_activation_quantized_tensor import ( + LinearActivationQuantizedTensor, +) +from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, +) +from torchao.quantization.utils import quantize_activation_per_token_absmax +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_5, + TorchAOBaseTensor, +) + +from torchao.quantization.granularity import ( + PerRow, + PerTensor, +) +from torchao.quantization.subclass import ( # noqa + Int8DynamicallyQuantizedLinearWeight, + Int8WeightOnlyQuantizedLinearWeight, + QuantizedLinearWeightBase, +) +from .subgraph_utils.extract_subgraphs import ( + debug_linears_for_float8, + prepare_target_folder, +) +from torchao.quantization.subclass import QuantizedLinearWeightBase +from torchao.quantization.autoquant import AutoQuantizableLinearWeight as AutoQuantizableLinearWeightV1 +from torchao.dtypes import AffineQuantizedTensor +from torchao.quantization import LinearActivationQuantizedTensor + +logging.basicConfig(level=logging.ERROR) # Set the root logger level to ERROR + + +target_folder = "/home/jerryzh/local/tmp/20241104_dynamo_test" + +__all__ = [ + "AutoQuantizableLinearWeight", + "autoquant_v2", + "DEFAULT_AUTOQUANT_CLASS_LIST", + "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", + "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST", + "OTHER_AUTOQUANT_CLASS_LIST", + "_is_linear", +] + +def _is_linear(mod, *args): + # avoid circular dependencies + from torchao.quantization.qat.affine_fake_quantized_tensor import ( + AffineFakeQuantizedTensor, + ) + + # adding weight tensor subclass isinstance check to make sure the weight is only quantized once + # when it is shared by multiple linear modules + return ( + isinstance(mod, torch.nn.Linear) + and hasattr(mod, "weight") + and not isinstance(mod.weight, QuantizedLinearWeightBase) + and not isinstance(mod.weight, AutoQuantizableLinearWeightV1) + and not isinstance(mod.weight, AffineQuantizedTensor) + and not isinstance(mod.weight, LinearActivationQuantizedTensor) + and not isinstance(mod.weight, AffineFakeQuantizedTensor) + and not isinstance(mod, torch.nn.modules.linear.NonDynamicallyQuantizableLinear) + ) + + +# TODO: use SubgraphMatcher +def _graph_equals(g1, g2): + if len(g1.nodes) != len(g2.nodes): + return False + + for n1, n2 in zip(g1.nodes, g2.nodes): + if n1.op != n2.op: + return False + + if n1.op in ["call_function", "call_method"] and n1.target != n2.target: + return False + + if len(n1.args) != len(n2.args): + return False + return True + + +aten = torch.ops.aten + +AUTOQUANT_CACHE = {} + +# This is a flag to control whether we do some rewrite for graph +# to account for different batch sizes, it's a temporary solution for llama model +# we'll need to think about how to support this more generally +LLAMA = True + +def check_cache(gm, cls, shapes_and_dtype): + for gm_, cls_, shapes_and_dtype_ in AUTOQUANT_CACHE.keys(): + graph_equals = _graph_equals(gm_.graph, gm.graph) + if graph_equals and cls_ is cls and shapes_and_dtype_ == shapes_and_dtype: + return AUTOQUANT_CACHE[(gm_, cls_, shapes_and_dtype_)] + return None + + +def update_cache(gm, cls, shapes_and_dtype, res): + AUTOQUANT_CACHE[(gm, cls, shapes_and_dtype)] = res + + +# adjust each input's bsz to target_bsz +# enable grad +# a hacky solution but should work in the use cases we are testing now +# we went through the list of sizes and swap the dimension that matches extracted_bsz to target_bsz +def resize_input(t, extracted_bsz, target_bsz): + if len(t.shape) > 1: + new_shape = [] + for i in range(len(t.size())): + if t.size(i) == extracted_bsz: + new_shape.append(target_bsz) + else: + new_shape.append(t.size(i)) + t = torch.randn(*new_shape, dtype=t.dtype, device=t.device) + return t + + +# a hacky solution but should work in the use cases we are testing now +# we went through the list of sizes and swap the dimension that matches extracted_bsz to target_bsz +def maybe_adjust_model_bsz(m, extracted_bsz, target_bsz): + """ + Makes guesses on how to adjust the model graph to account for the + fact that we changed the batch size. Note: this is very brittle + """ + for n in m.graph.nodes: + if n.op == "call_method" and n.target == "view": + new_args = [] + for arg in n.args: + if arg == extracted_bsz: + new_args.append(target_bsz) + else: + new_args.append(arg) + n.args = tuple(new_args) + + m.recompile() + + +# TODO: Document the methods +class AutoQuantizableLinearWeight(torch.Tensor): + """ + A subclass of torch.Tensor that, when run, finds the best type of quantization for itself and swaps + its data with the quantized version. + + Args: + weight (torch.Tensor): The initial weight tensor. + qtensor_class_list (list): A list of tensor classes to be considered for quantization. + *args: Additional positional arguments. + mode (list, optional): A list containing mode settings for quantization. The first element is the mode type + (e.g., "relu"), and the second element is the mode value (e.g., None). Defaults to ["relu", None]. + **kwargs: Additional keyword arguments. + """ + + @staticmethod + def __new__( + cls, + weight, + qtensor_class_list, + *args, + mode=["relu", None], + model=None, + fqn=None, + example_inputs=None, + fqn_to_submodule=None, + batch_size=None, + **kwargs, + ): + kwargs["device"] = weight.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else weight.layout + ) + kwargs["dtype"] = ( + kwargs.get("dtype") if kwargs.get("dtype", False) else weight.dtype + ) + kwargs["requires_grad"] = False + shape = kwargs.pop("shape", weight.shape) + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + weight, + qtensor_class_list, + *args, + mode=["relu", None], + model=None, + fqn=None, + example_inputs=None, + fqn_to_submodule=None, + batch_size=None, + **kwargs, + ): + self.weight = weight + self.qtensor_class_list = qtensor_class_list + self.logged_data = {} + self.mode = mode + self.model = model + self.fqn = fqn + self.example_inputs = example_inputs + self.fqn_to_submodule = fqn_to_submodule + self.batch_size = batch_size + + def __repr__(self): + return ( + f"{self.__class__.__name__}(data={self.weight}, shape={self.shape}, " + f"device={self.device}, dtype={self.dtype}, qtensor_class_list={self.qtensor_class_list})" + ) + + @staticmethod + def log_shape(act_mat, w_autoquant, bias): + act_mat = act_mat.reshape(-1, act_mat.shape[-1]) + logged_dtype = act_mat.dtype + logged_shapes = ( + act_mat.shape, + w_autoquant.shape, + None if bias is None else bias.shape, + ) + shapes_and_dtype = logged_shapes + (logged_dtype,) + w_autoquant.logged_data[shapes_and_dtype] = 1 + w_autoquant.logged_data.get( + shapes_and_dtype, 0 + ) + + def tune_autoquant2( + self, fqn, m, batch_size, inputs, q_cls, shapes_and_dtype, time_for_best_shape + ): + act_shape, w_shape, bias_shape, act_dtype = shapes_and_dtype + + with torch.no_grad(): + try: + m_copy = copy.deepcopy(m) + for name, module in m_copy.named_modules(): + if isinstance(module, torch.nn.Linear): + linear_module = module + weight = q_cls.from_float(linear_module.weight) + linear_module.weight = torch.nn.Parameter(weight, requires_grad=False) + if batch_size is not None: + extracted_bsz = batch_size + target_bsz = act_shape[0] + inputs = tree_map( + lambda t: resize_input(t, extracted_bsz, target_bsz), inputs + ) + maybe_adjust_model_bsz(m_copy, extracted_bsz, target_bsz) + + m_copy = torch.compile(m_copy, mode="max-autotune-no-cudagraphs") + + if isinstance(inputs, (list, tuple)): + cur_time = do_autoquant_bench(m_copy, *inputs, warmup=25, rep=100) + else: + cur_time = do_autoquant_bench(m_copy, **inputs, warmup=25, rep=100) + print( + f">>time: {cur_time:0.3f}ms for {q_cls}, to_beat: {time_for_best_shape}" + ) + if cur_time < time_for_best_shape: + update_cache(m, q_cls, shapes_and_dtype, cur_time) + res = cur_time + return res + except Exception as e: + print(f"warning: failed to autoquant {q_cls.__name__} due to {e}") + return None + + @torch.no_grad() + def to_quantized(self, error_on_unseen, **kwargs): + if error_on_unseen and self.logged_data == {}: + raise RuntimeError( + "must run module normally to get shape, dtype info for autoquant" + ) + elif (self.logged_data == {}) and not error_on_unseen: + # default back to non-quantized weight if not seen + self = AQDefaultLinearWeight.from_float(self.weight) + return self + + # only want to print shape (at start) and final result (at end) + # once per shape+quantization subclass combination. + ran_new_benchmarks = False + print_shape_once = True + + def count_shapes(self, do_print=True): + differe_shape_count = 0 + for shapes_and_dtype, times_seen in self.logged_data.items(): + differe_shape_count += 1 + if do_print: + act_shape, weight_shape, bias_shape, dtype = shapes_and_dtype + print(f"activation_shapes: {act_shape}, times_seen: {times_seen}") + if do_print: + print( + f"weight_shape: {weight_shape}, dtype: {dtype}, bias_shape: {bias_shape}" + ) + return differe_shape_count + + # check each class + best_time = torch.inf + best_cls = None + fqn = self.fqn + print(f"autoquant for {fqn}") + for q_cls in self.qtensor_class_list: + # for each logged shape+dtype, benchmark + cur_time = 0 + total_seen = 0 + shape_count = count_shapes(self, do_print=False) + # copied from https://github.com/pytorch/pytorch/blob/75eeefbfab3862abe887e1d85a0b1b18c227d9f3/torch/_dynamo/variables/builder.py#L963 + modified_fqn = "L__self___" + re.sub(r"[^a-zA-Z0-9]+", "_", fqn) + m, inputs = self.fqn_to_submodule[modified_fqn] + for shapes_and_dtype, times_seen in self.logged_data.items(): + if check_cache(m, q_cls, shapes_and_dtype) is None: + # only print shapes once + if print_shape_once is True: + print_shape_once = False + count_shapes(self, do_print=True) + + time_for_best_shape = check_cache(m, q_cls, shapes_and_dtype) + time_for_best_shape = ( + torch.inf + if time_for_best_shape is None + else time_for_best_shape + ) + self.tune_autoquant2( + fqn, m, self.batch_size, inputs, q_cls, shapes_and_dtype, time_for_best_shape + ) + ran_new_benchmarks = True + torch._dynamo.reset() + if check_cache(m, q_cls, shapes_and_dtype) is not None: + cur_time += check_cache(m, q_cls, shapes_and_dtype) * times_seen + total_seen += times_seen + + if total_seen != 0: + cur_time = cur_time / total_seen + + # print aggregated time if there were multiple shapes to aggregate and some new benchmarking was done + if shape_count is not None and shape_count > 1 and ran_new_benchmarks: + print( + f">time (all shapes): {cur_time:0.4f}ms for {q_cls}, prev_best: {best_time:0.4f}ms" + ) + if best_time >= cur_time: + best_time = cur_time + best_cls = q_cls + # if no new benchmarking was done, don't print the final result, it will be the same as for another layer + if ran_new_benchmarks: + print(f"best_cls={best_cls}\n") + # TODO handle random cls args/kwargs? or should they be curried? + if best_cls is None: + best_cls = AQDefaultLinearWeight + + self = best_cls.from_float(self.weight) + return self + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.weight), + self.qtensor_class_list, + dtype=self.dtype, + mode=self.mode, + model=self.model, + fqn=self.fqn, + example_inputs=self.example_inputs, + fqn_to_submodule=self.fqn_to_submodule, + batch_size=self.batch_size, + ) + + def __tensor_flatten__(self): + return ["weight"], [ + self.qtensor_class_list, + self.mode, + self.model, + self.fqn, + self.example_inputs, + self.fqn_to_submodule, + self.batch_size, + self.dtype, + self.shape, + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None + ): + weight = tensor_data_dict["weight"] + ( + qtensor_class_list, + mode, + model, + fqn, + example_inputs, + fqn_to_submodule, + batch_size, + dtype, + shape, + ) = tensor_attributes + return cls( + weight, + qtensor_class_list, + mode, + model=model, + fqn=fqn, + example_inputs=example_inputs, + fqn_to_submodule=fqn_to_submodule, + batch_size=batch_size, + shape=shape if outer_size is None else outer_size, + dtype=dtype, + strides=outer_stride, + ) + + @classmethod + def from_float(cls, weight, qtensor_class_list, **kwargs): + return cls(weight, qtensor_class_list, **kwargs) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + kwargs = {} if kwargs is None else kwargs + + if func is torch.nn.functional.linear: + mat1, w_autoquant, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + cls.log_shape(mat1, w_autoquant, bias) + return func(mat1, w_autoquant.weight, bias) + try: + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + except Exception: + print(f"ERR: subclass doesn't implement {func}") + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@torch.no_grad() +def do_autoquant_bench(op, *args, **kwargs): + """ + runs benchmark op(*args, **kwargs) avoiding torch.compile overhead + """ + rep = kwargs.pop("rep", 100) + warmup = kwargs.pop("warmup", 25) + with torch.no_grad(): + torch.cuda.synchronize() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + op(*args, **kwargs) + stream.synchronize() + torch.cuda.current_stream().wait_stream(stream) + torch.cuda.synchronize() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + op(*args, **kwargs) + if TORCH_VERSION_AT_LEAST_2_5: + from torch._inductor.runtime.benchmarking import benchmarker + + res = benchmarker.benchmark_gpu( + lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" + ) + elif TORCH_VERSION_AT_LEAST_2_3: + from torch._inductor.runtime.runtime_utils import do_bench_gpu + + res = do_bench_gpu( + lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" + ) + else: + from torch._inductor.utils import do_bench + + res = do_bench( + lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" + ) + return res + + +def _is_interpolate_mode(mode): + if ( + isinstance(mode, list) + and mode[0] == "interpolate" + and len(mode) == 2 + and isinstance(mode[1], float) + ): + return True + return False + + +class AQMixin: + """ + Tests and benchmarks the autoquantization process for the given activation matrix, weight, and bias. + + Args: + act_mat (torch.Tensor): The activation matrix. + weight (torch.Tensor): The weight tensor. + bias (torch.Tensor or None): The bias tensor. + best_time (float): The best time to beat for the quantization process. + mode (list, optional): A list containing mode settings for quantization. The first element is the mode type + (e.g., "relu"), and the second element is the mode value (e.g., None). Defaults to ["relu", None]. + + Returns: + float: The benchmarked time for the autoquantization process. + """ + + @classmethod + def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): + w_qtensor = cls.from_float(weight) + if _is_interpolate_mode(mode): + q_c_op = torch.compile( + cls._quantized_linear_op, mode="max-autotune-no-cudagraphs" + ) + else: + func = lambda a, b, c: F.relu(cls._quantized_linear_op(F.relu(a), b, c)) + q_c_op = torch.compile(func, mode="max-autotune-no-cudagraphs") + res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=100) + if res < best_time * 1.1: + res2 = do_autoquant_bench( + q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=900 + ) + res = res2 * 0.9 + res * 0.1 + print(f">>time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ") + return res + + +class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor): + """ + AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight + """ + + @classmethod + def from_float(cls, weight): + # TODO test if this is valid + # in_features = weight.shape[1] + # int8 dynamic quantization only has benefit when in_feature > 16 + # if in_features <= 16: + # return weight + + # avoid circular dep + from torchao.dtypes import to_affine_quantized_intx + + # weight settings + mapping_type = MappingType.SYMMETRIC + + def get_weight_block_size(x): + return (1, x.shape[1]) + + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + + # input settings + def get_per_token_block_size(x): + block_size = list(x.shape) + for i in range(len(block_size) - 1): + block_size[i] = 1 + return block_size + + input_mapping_type = MappingType.SYMMETRIC + input_target_dtype = torch.int8 + input_eps = 1e-5 + input_quant_min = -127 + input_quant_max = 127 + _layout = PlainLayout() + input_quant_func = lambda x: to_affine_quantized_intx( + x, + input_mapping_type, + get_per_token_block_size(x), + input_target_dtype, + eps=input_eps, + quant_min=input_quant_min, + quant_max=input_quant_max, + scale_dtype=torch.float32 if x.dtype == torch.float16 else None, + ) + + block_size = get_weight_block_size(weight) + weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype, + _layout=_layout, + ) + weight = super(AQInt8DynamicallyQuantizedLinearWeight, cls).from_float( + weight, input_quant_func + ) + return weight + + @classmethod + def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): + """ + Tests and benchmarks the autoquantization process with special handling for interpolate mode. + + Args: + act_mat (torch.Tensor): The activation matrix. + weight (torch.Tensor): The weight tensor. + bias (torch.Tensor or None): The bias tensor. + best_time (float): The best time to beat for the quantization process. + mode (list, optional): A list containing mode settings for quantization. The first element is the mode type + (e.g., "relu"), and the second element is the mode value (e.g., None). Defaults to ["relu", None]. + + Returns: + float: The benchmarked time for the autoquantization process. + """ + if not _is_interpolate_mode(mode): + return super()._autoquant_test(act_mat, weight, bias, best_time, mode) + + # SAM best is between .8 and 1, SDXL also performs best in this range + INTERPOLATION_CONSTANT = mode[1] + w_qtensor = cls.from_float(weight) + x_vals_int8, x_scales = quantize_activation_per_token_absmax( + act_mat.reshape(-1, act_mat.shape[-1]) + ) + quantized_matmul = ( + lambda x_vals_int8, x_scales, w_vals_int8: safe_int_mm( + x_vals_int8, w_vals_int8 + ) + * x_scales + ) + q_c_matmul = torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs") + with torch.no_grad(): + w_vals_int8 = ( + w_qtensor.original_weight_tensor.tensor_impl.int_data.contiguous().t() + ) + res_matmul = do_autoquant_bench( + q_c_matmul, x_vals_int8, x_scales.reshape(-1, 1), w_vals_int8 + ) + print( + f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms" + ) + + # if the (much faster) matmul kernel is already beat, don't bother benchmarking full op + if res_matmul >= best_time: + return res_matmul + + # calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT + to_beat = best_time + INTERPOLATION_CONSTANT / (1 - INTERPOLATION_CONSTANT) * ( + best_time - res_matmul + ) + res = super()._autoquant_test(act_mat, weight, bias, to_beat) + max_int_const_win = (best_time - res_matmul) / (res - res_matmul) + res_f = INTERPOLATION_CONSTANT * res + (1 - INTERPOLATION_CONSTANT) * res_matmul + print( + f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}" + ) + return res_f + + +class AQInt8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): + """ + AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight + """ + + @classmethod + def from_float(cls, weight): + mapping_type = MappingType.SYMMETRIC + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + block_size = (1, weight.shape[1]) + return super(AQInt8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_intx( + weight, + mapping_type, + block_size, + target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype, + ) + + +class AQInt8WeightOnlyQuantizedLinearWeight2( + AQInt8WeightOnlyQuantizedLinearWeight, AQMixin +): + """ + AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that + uses a different kernel + """ + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + """ + Performs the quantized linear operations + + Args: + act_mat (torch.Tensor): The activation matrix. + w_qtensor (torch.Tensor): The quantized weight tensor. + bias (torch.Tensor or None): The bias tensor. + + Returns: + torch.Tensor: The result of the quantized operation. + """ + orig_dtype = act_mat.dtype + orig_shape = act_mat.shape + act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1) + y = (act_mat * w_qtensor.tensor_impl.int_data.t().unsqueeze(0)).sum(dim=-2) + y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.tensor_impl.scale + if bias is not None: + y += bias + return y.to(orig_dtype) + + @classmethod + def _autoquant_test(cls, act_mat, *args): + # if act_mat has batchsize>2 don't use this kernel + if act_mat.reshape(-1, act_mat.shape[-1]).shape[0] > 32: + return torch.inf + return super()._autoquant_test(act_mat, *args) + + +class AQInt8WeightOnlyQuantizedLinearWeight3( + AQInt8WeightOnlyQuantizedLinearWeight, AQMixin +): + """ + AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that + uses a different kernel + """ + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + orig_shape = act_mat.shape + y = torch.mm( + act_mat.reshape(-1, orig_shape[-1]), + w_qtensor.tensor_impl.int_data.t() * w_qtensor.tensor_impl.scale, + ) + y = y.reshape(*orig_shape[:-1], y.shape[-1]) + if bias is not None: + y += bias + return y + + +class AQInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): + """ + AutoQuantizable version of Int4WeightOnlyQuantizedLinearWeight + """ + + group_size: int = 32 + + @classmethod + def from_float(cls, weight): + group_size = cls.group_size + _layout = TensorCoreTiledLayout(inner_k_tiles=8) + + if weight.shape[-1] % group_size != 0: + return weight + use_hqq = True + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + preserve_zero = False + zero_point_dtype = torch.bfloat16 + zero_point_domain = ZeroPointDomain.FLOAT + return super(AQInt4G32WeightOnlyQuantizedLinearWeight, cls).from_hp_to_intx( + weight, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + _layout=_layout, + use_hqq=use_hqq, + ) + + +class AQInt4G64WeightOnlyQuantizedLinearWeight( + AQInt4G32WeightOnlyQuantizedLinearWeight +): + group_size: int = 64 + + +class AQInt4G128WeightOnlyQuantizedLinearWeight( + AQInt4G32WeightOnlyQuantizedLinearWeight +): + group_size: int = 128 + + +class AQInt4G256WeightOnlyQuantizedLinearWeight( + AQInt4G32WeightOnlyQuantizedLinearWeight +): + group_size: int = 256 + + +class AQDefaultLinearWeight(torch.Tensor, AQMixin): + """ + A class to be used in concert with AutoQuantizableLinearWeight to provide a + default/non-quantized option. Only implements the bare minimum needed to work with the + AutoQuantizableLinearWeight class using the same interfaces that would normally be + used by QTensor subclasses but for a default linear op instead. Result of from_float + is not a tensor subclass, but rather the float tensor. + """ + + def __init__(self): + super().__init__() + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + return torch.nn.functional.linear(act_mat, w_qtensor, bias) + + @classmethod + def from_float(cls, weight): + return weight + + +class Float32Tensor(TorchAOBaseTensor): + """ Tensor subclass tensor for fp32 dtype + """ + def __init__(self, weight): + self.weight = weight.to(torch.float32) + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.float32 + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, + ).to(dtype=orig_dtype) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.weight), + ) + + @classmethod + def from_float(cls, weight): + return cls(weight) + +@Float32Tensor.implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) + +@Float32Tensor.implements(aten.detach.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@Float32Tensor.implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + +@Float32Tensor.implements(aten._to_copy.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + + +class BFloat16Tensor(Float32Tensor): + def __init__(self, weight): + self.weight = weight.to(torch.bfloat16) + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.bfloat16 + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, + ).to(dtype=orig_dtype) + + +class Float16Tensor(Float32Tensor): + def __init__(self, weight): + self.weight = weight.to(torch.float16) + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.float16 + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, + ).to(dtype=orig_dtype) + + +class AQFloat32LinearWeight(Float32Tensor, AQMixin): + """ + AutoQuantizable version for float32 precision weight + + (also converts input activation and bias to float32, and restores the original precision after + linear) + """ + @classmethod + def from_float(cls, weight): + return super(AQFloat32LinearWeight, cls).from_float(weight) + + +class AQBFloat16LinearWeight(BFloat16Tensor, AQMixin): + """ + AutoQuantizable version for bfloat16 precision weight + + (also converts input activation and bias to bfloat16, and restores the original precision after + linear) + """ + @classmethod + def from_float(cls, weight): + return super(AQBFloat16LinearWeight, cls).from_float(weight) + + +class AQFloat16LinearWeight(Float16Tensor, AQMixin): + """ + AutoQuantizable version for float16 precision weight + + (also converts input activation and bias to float16, and restores the original precision after + linear) + """ + @classmethod + def from_float(cls, weight): + return super(AQFloat16LinearWeight, cls).from_float(weight) + + +class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): + """ + AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn + """ + + target_dtype: torch.dtype = torch.float8_e4m3fn + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + return torch.nn.functional.linear(act_mat, w_qtensor.dequantize(), bias) + + @classmethod + def from_float(cls, weight): + block_size = (1, weight.shape[1]) + return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx( + weight, block_size, target_dtype=cls.target_dtype, _layout=Float8Layout() + ) + + +class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight( + AQMixin, LinearActivationQuantizedTensor +): + """ + AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight using per row scaling + """ + + activation_granularity = PerRow() + + @classmethod + def from_float(cls, weight): + # avoid circular dep + from torchao.dtypes import to_affine_quantized_floatx + from torchao.quantization.quant_api import _input_activation_quant_func_fp8 + + # weight settings + def get_weight_block_size(x): + return (1, x.shape[1]) + + target_dtype = torch.float8_e4m3fn + + # input settings + def get_per_token_block_size(x): + block_size = list(x.shape) + for i in range(len(block_size) - 1): + block_size[i] = 1 + return block_size + + input_target_dtype = torch.float8_e4m3fn + _layout = Float8Layout(mm_config=Float8MMConfig(use_fast_accum=True)) + input_quant_func = lambda x: _input_activation_quant_func_fp8( + x=x, + activation_granularity=cls.activation_granularity, + activation_dtype=input_target_dtype, + ) + block_size = get_weight_block_size(weight) + weight = to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=target_dtype, + _layout=_layout, + scale_dtype=torch.float32, + ) + weight = super( + AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, cls + ).from_float(weight, input_quant_func) + return weight + + +class AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight( + AQMixin, LinearActivationQuantizedTensor +): + """ + AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight using per tensor scaling + """ + + activation_granularity = PerTensor() + + @classmethod + def from_float(cls, weight): + # avoid circular dep + from torchao.dtypes import to_affine_quantized_floatx + from torchao.quantization.quant_api import _input_activation_quant_func_fp8 + + # weight settings + def get_weight_block_size(x): + assert x.ndim == 2, "Only works for 2D tensors" + return x.shape + + target_dtype = torch.float8_e4m3fn + + input_target_dtype = torch.float8_e4m3fn + _layout = Float8Layout(mm_config=Float8MMConfig(use_fast_accum=True)) + input_quant_func = lambda x: _input_activation_quant_func_fp8( + x=x, + activation_granularity=cls.activation_granularity, + activation_dtype=input_target_dtype, + ) + block_size = get_weight_block_size(weight) + weight = to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=target_dtype, + _layout=_layout, + scale_dtype=torch.float32, + ) + weight = super( + AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, cls + ).from_float(weight, input_quant_func) + return weight + + +# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison +DEFAULT_AUTOQUANT_CLASS_LIST = [ + AQDefaultLinearWeight, + AQInt8WeightOnlyQuantizedLinearWeight, + AQInt8WeightOnlyQuantizedLinearWeight2, + # AQInt8WeightOnlyQuantizedLinearWeight3, + # TODO this gets picked in places where it makes perf worse, why? + AQInt8DynamicallyQuantizedLinearWeight, +] + +DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [ + AQDefaultLinearWeight, + AQInt8DynamicallyQuantizedLinearWeight, + AQInt4G64WeightOnlyQuantizedLinearWeight, +] + +DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST = [ + AQFloat32LinearWeight, + AQBFloat16LinearWeight, + AQFloat16LinearWeight, +] + +OTHER_AUTOQUANT_CLASS_LIST = [ + AQFloat8WeightOnlyQuantizedLinearWeight, + AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, + AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, +] + + +def _replace_with_custom_fn_if_matches_filter( + model, + replacement_fn, + filter_fn, + cur_fqn="", + device=None, +) -> None: + """ + Recursively replaces each child module in `model` with the result of `replacement_fn(child)` + if `filter_fn(child)` returns `True`. + Args: + model (torch.nn.Module): The model containing modules to be replaced. + replacement_fn (Callable[[torch.nn.Module], torch.nn.Module]): The function to replace matching modules. + filter_fn (Callable[[torch.nn.Module], bool]): The filter function to determine which modules to replace. + cur_fqn (str, optional): The current fully qualified name of the module being processed. Defaults to "". + device (device, optional): Device to move the model to before applying `filter_fn`. Defaults to None. + Returns: + None + """ + if filter_fn(model, cur_fqn[:-1]): + if device is not None: + model.to(device=device) # move to device before quantization + model = replacement_fn(model, cur_fqn[:-1]) + return model + else: + for name, child in model.named_children(): + new_child = _replace_with_custom_fn_if_matches_filter( + child, replacement_fn, filter_fn, f"{cur_fqn}{name}.", device + ) + if new_child is not child: + setattr(model, name, new_child) + if device is not None: + model.to(device=device) # move parent module to device + return model + + +def dict_union(*args): + return dict(chain.from_iterable(d.items() for d in args)) + + +def _change_linears_to_autoquantizable( + model, example_input, fqn_to_submodule, batch_size, **kwargs +): + """ + Converts all linear weight tensors to the + AutoQuantizableLinearWeight tensor subclass. Expectation is that this is followed + by running the model and then calling _change_autoquantizable_to_quantized + """ + # from torchao.quantization.quant_api import _is_linear + + filter_fn = kwargs.pop("filter_fn", _is_linear) + _ = kwargs.pop( + "error_on_unseen", True + ) # same kwargs used for this and to_quantized + kwargs["qtensor_class_list"] = kwargs.get( + "qtensor_class_list", DEFAULT_AUTOQUANT_CLASS_LIST + ) + kwargs["mode"] = kwargs.get("mode", ["relu", None]) + kwargs["model"] = model + kwargs["example_inputs"] = example_input + kwargs["fqn_to_submodule"] = fqn_to_submodule + kwargs["batch_size"] = batch_size + from torchao.quantization.quant_api import _get_subclass_inserter + + _replace_with_custom_fn_if_matches_filter( + model, + lambda model, fqn: _get_subclass_inserter( + AutoQuantizableLinearWeight, **dict_union(kwargs, {"fqn": fqn}) + )(model), + filter_fn if filter_fn is not None else _is_linear, + ) + + +def _change_autoquantizable_to_quantized( + model, supress_autoquant_errors=True, **kwargs +): + """ + Converts AutoQuantizableLinearWeight tensor subclasses + to various quantized/non-quantized tensor subclasses depending + on benchmark results. Expectation is that these modules are + torch.compiled afterwards. + """ + hold_automatic_dynamic_shapes = torch._dynamo.config.automatic_dynamic_shapes + torch._dynamo.config.automatic_dynamic_shapes = False + + if supress_autoquant_errors: + hold_supress_errors = torch._dynamo.config.suppress_errors + torch._dynamo.config.suppress_errors = True + import logging + + torch._logging.set_logs(inductor=logging.CRITICAL, dynamo=logging.CRITICAL) + filter_fn = kwargs.pop( + "filter_fn", + lambda mod, *args: hasattr(mod, "weight") + and isinstance(mod.weight, AutoQuantizableLinearWeight), + ) + error_on_unseen = kwargs.pop("error_on_unseen", True) + from torchao.quantization.quant_api import ( + _get_subclass_inserter, + _replace_with_custom_fn_if_matches_filter, + ) + + _replace_with_custom_fn_if_matches_filter( + model, + _get_subclass_inserter( + AutoQuantizableLinearWeight, + method="to_quantized", + error_on_unseen=error_on_unseen, + **kwargs, + ), + filter_fn, + ) + # undo dynamic shape change + torch._dynamo.config.automatic_dynamic_shapes = hold_automatic_dynamic_shapes + + # undo error supression + if supress_autoquant_errors: + torch._dynamo.config.suppress_errors = hold_supress_errors + torch._logging.set_logs() + torch._dynamo.reset() + + +# TODO: example_input seems weird to include in the API +# TODO: Document all the modes +# TODO: Mode being a list is weird, should be a string or some object +@torch.no_grad() +def autoquant_v2( + model, + example_input=None, + qtensor_class_list=DEFAULT_AUTOQUANT_CLASS_LIST, + filter_fn=None, + mode=["interpolate", 0.85], + manual=False, + set_inductor_config=True, + supress_autoquant_errors=True, + batch_size=None, + **aq_kwargs, +): + """ + Autoquantization is a process which identifies the fastest way to quantize each layer of a model over some set of potential + qtensor subclasses. + + Autoquantization happens in three steps: + + 1-Prepare Model: the model is searched for Linear layers whose weights are exchanged for AutoQuantizableLinearWeight. + 2-Shape Calibration: the user runs the model on one or more inputs, the details of the activation shape/dtype seen by + the AutoQuantizableLinearWeight are recorded so we know what shapes/dtypes to use in order to optimize the quantized op in step 3 + 3-Finalize Autoquantization: for each AutoQuantizableLinearWeight, benchmarks are run for each shape/dtype on each member of the qtensor_class_list. + the fastest option is picked, resulting in a highly performant model + + This autoquant function performs step 1. Steps 2 and 3 can be completed by simply running the model. + If `example_input` is provided, this function also runs the model (which completes steps 2 and 3). + This autoquant api can handle models which have already had torch.compile applied to them, in which case, once the model is run and quantized, + the torch.compile process normally proceeds as well. + + To optimize over a combination of input shapes/dtypes, the user can set manual=True, run the model with all desired shapes/dtypes, then + call model.finalize_autoquant to finalize the quantization once the desired set of inputs have been logged. + + Args: + model (torch.nn.Module): The model to be autoquantized. + example_input (Any, optional): An example input for the model. If provided, the function performs a forward pass + on this input (which fully autoquantizes the model unless manual=True). Defaults to None. + qtensor_class_list (list, optional): A list of tensor classes to be used for quantization. Defaults to DEFAULT_AUTOQUANT_CLASS_LIST. + filter_fn (callable, optional): A filter function to apply to the model parameters. Defaults to None. + mode (list, optional): A list containing mode settings for quantization. The first element is the mode type (e.g., "interpolate"), + and the second element is the mode value (e.g., 0.85). Defaults to ["interpolate", .85]. + manual (bool, optional): Whether to stop shape calibration and do autoquant after a single run (default, False) or to wait for + the user to call model.finalize_autoquant (True) so inputs with several shapes/dtypes can be logged. + set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True) + supress_autoquant_errors (bool, optional): Whether to suppress errors during autoquantization. (defaults to True) + **aq_kwargs: Additional keyword arguments for the autoquantization process. + + Returns: + torch.nn.Module: The autoquantized and wrapped model. If `example_input` is provided, the function performs a forward pass + on the input and returns the result of the forward pass. + + Example usage: + torchao.autoquant(torch.compile(model)) + model(*example_input) + + # multiple input shapes + torchao.autoquant(model, manual=True) + model(*example_input1) + model(*example_input2) + model.finalize_autoquant() + """ + if set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + if qtensor_class_list is OTHER_AUTOQUANT_CLASS_LIST: + assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= ( + 8, + 9, + ), "float8 requires CUDA arch >= 8.9" + + assert example_input is not None + + prepare_target_folder(target_folder) + torch._dynamo.reset() + # TODO: explore using node.meta to retrieve the subgraph and fqn information + # disable nn module inlining, our subgraph extraction logic depends on this + torch._dynamo.config.inline_inbuilt_nn_modules = False + torch._inductor.config.pre_grad_custom_pass = lambda g: debug_linears_for_float8( + g, target_folder + ) + model = torch.compile(model) + if isinstance(example_input, torch.Tensor): + example_input = [example_input] + if isinstance(example_input, (list, tuple)): + model(*example_input) + elif isinstance(example_input, dict): + model(**example_input) + else: + raise Exception("Unexpected example_input:", example_input) + + torch._inductor.config.pre_grad_custom_pass = None + + # verify debug logs and summary got saved + assert os.path.isfile( + os.path.join(target_folder, "debug_logs_0.txt") + ), "No debug log saved, autoquant_v2 can't work for this model right now" + assert os.path.isfile( + os.path.join(target_folder, "summary_0.csv") + ), "No debug log saved, autoquant_v2 can't work for this model right now" + + # first, find how many torch.compile'd regions we have + extraction_idxs = [] + for f in os.listdir(target_folder): + match = re.match(r"summary_([0-9]+).csv", f) + if match: + extraction_idxs.append(int(match.group(1))) + extraction_idxs.sort() + + fqn_to_submodule = {} + + for extraction_idx in extraction_idxs: + summary_filename = os.path.join(target_folder, f"summary_{extraction_idx}.csv") + summary_rows = [] + with open(summary_filename, "r") as f: + reader = csv.reader(f) + for row in reader: + summary_rows.append(row) + + # [1:] to skip header row + for row_idx, row in enumerate(summary_rows[1:]): + subgraph_idx = row[2] + fqn = row[-1] + subgraph_fname = f"subgraph_with_inputs_{extraction_idx}_{subgraph_idx}.pt" + print(f"loading {subgraph_fname} fqn {fqn}") + subgraph_fname = os.path.join(target_folder, subgraph_fname) + m, inputs = torch.load(subgraph_fname, weights_only=False) + + # for now, force cast to bf16 + # TODO(future): configure this + m = m.to(torch.bfloat16) + inputs = tree_map(lambda x: x.to(torch.bfloat16), inputs) + + m = m.to(torch.bfloat16) + inputs = tree_map(lambda x: x.to(torch.bfloat16), inputs) + + fqn_to_submodule[fqn] = m, inputs + + model = model._orig_mod + + # perform initial swap from linear weights + # to AutoQuantizableLinearWeight + _change_linears_to_autoquantizable( + model, + example_input, + fqn_to_submodule, + batch_size, + filter_fn=filter_fn, + qtensor_class_list=qtensor_class_list, + mode=mode, + **aq_kwargs, + ) + + # access actual model of torch.compile wrapper if needed + is_compiled = isinstance(model, torch._dynamo.eval_frame.OptimizedModule) + if is_compiled: + real_model = model._orig_mod + else: + real_model = model + + if manual: + # we don't want model.forward to trigger + # torch.compilation + if is_compiled: + real_model.old_forward = model.forward + model.forward = real_model.forward + + # we want to automatically do autoquant after a single model run + # and have it occur before torch.compilation if applicable + else: + # the hook we will use to intercept the model forward and perform + # autoquantization + def autoquant_prehook(module, args, kwargs): + real_model.forward(*args, **kwargs) + module.finalize_autoquant() + return args, kwargs + + # the autoquant_prehook intercepts the forward call, performs logging then + # does autoquantization. if model is a torch.compile wrapper, it then + # does the tracing/compile since the prehook is naturally followed by the normal. + # model run. + handle = model.register_forward_pre_hook(autoquant_prehook, with_kwargs=True) + + # note the torch.compile wrapper (eval_frame) moves the assignment of any assigned + # attributes to the inner model that didn't exist before, so we have to call delattr on the inner model + def finalize_autoquant(): + _change_autoquantizable_to_quantized( + real_model, + supress_autoquant_errors, + **aq_kwargs, + ) + if hasattr(real_model, "old_forward"): + model.forward = real_model.old_forward + delattr(real_model, "old_forward") + if hasattr(real_model, "finalize_autoquant"): + delattr(real_model, "finalize_autoquant") + if not manual: + handle.remove() + + real_model.finalize_autoquant = finalize_autoquant + + # if example input was provided, check it and run it + if isinstance(example_input, torch.Tensor): + example_input = [example_input] + if isinstance(example_input, (tuple, list)): + model(*example_input) + elif isinstance(example_input, dict): + model(**example_input) + + return model diff --git a/torchao/prototype/quantization/subgraph_utils/extract_subgraphs.py b/torchao/prototype/quantization/subgraph_utils/extract_subgraphs.py new file mode 100644 index 0000000000..4f22ea84c2 --- /dev/null +++ b/torchao/prototype/quantization/subgraph_utils/extract_subgraphs.py @@ -0,0 +1,680 @@ +import csv +import os +import traceback +from typing import Callable, Optional + +import torch +import torch.nn.functional as F +from torch.utils._pytree import tree_map + +graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph") + +# TODO(future): might be nice to have input shapes here, but needs a refactor since +# they are easiest to get from the subgraph extraction function, but currently summary +# is generated from the subgraph debug function. +summary_headers = [ + "extraction_idx", + "orig_node_name", + "subgraph_idx", + "lin1_shape", + "lin2_shape", + "subgraph_summary", +] + + +# A model can have multiple regions torch.compile'd separately. Because we currently +# depend on torch._inductor.config.pre_grad_custom_pass, that means we will run +# the extraction logic once per torch.compile'd region. The variable below tracks +# how many times we have called the top level subgraph extractor, so we can +# save results from each run without overwriting data. +DEBUG_LINEARS_CALL_COUNTER = 0 + + +def maybe_short_name(torch_fn): + """ + Tries to format things like + + '' + + as + + 'torch.cat' + """ + if hasattr(torch_fn, "__name__"): + # torch.cat -> cat + if hasattr(torch, torch_fn.__name__): + if getattr(torch, torch_fn.__name__) == torch_fn: + return torch_fn.__name__ + + # F.layer_norm -> layer_norm + if hasattr(F, torch_fn.__name__): + if getattr(F, torch_fn.__name__) == torch_fn: + return torch_fn.__name__ + + # builtin function mul + # note: there is definitely a more generic way to do this + if torch_fn.__name__ == "mul": + return "mul" + if torch_fn.__name__ == "add": + return "add" + + # activation modules + # note: there is definitely a more generic way to do this + if "torch.nn.modules.activation" in str(torch_fn): + return torch_fn.__name__ + + return torch_fn + + +def get_meta_val(n: torch.fx.Node): + # from https://github.com/pytorch/pytorch/blob/8d708090c0eb306facfd8f85d58c578a8cbbe689/torch/fx/graph.py#L644-L647 + meta_val = n.meta.get( + "val", n.meta.get("tensor_meta", n.meta.get("example_value", None)) + ) + return meta_val + + +def get_stack_summary(n: torch.fx.Node): + # from https://github.com/pytorch/pytorch/blob/8d708090c0eb306facfd8f85d58c578a8cbbe689/torch/fx/graph.py#L609 + if n.stack_trace: + parsed_stack_trace = torch.fx.graph._parse_stack_trace(n.stack_trace) + summary = parsed_stack_trace.get_summary_str() + return summary + return None + + +def is_first_node_of_dual_linear(gm: torch.fx.GraphModule, n: torch.fx.Node): + first_user = list(n.users.items())[0][0] + if first_user.op == "call_module": + first_user_mod = getattr(gm, first_user.target) + if type(first_user_mod) is torch.nn.Linear: + return True + elif first_user.op == "call_function": + if first_user.target is torch._C._nn.linear: + return True + return False + + +def debug_single_linear( + gm: torch.fx.GraphModule, + linear_node: torch.fx.Node, + linear_mod: torch.nn.Module, + debug_logs_filename: str, + subgraph_idx: int, +): + def printme(s): + # write both to stdout and log file + # print(s) + with open(debug_logs_filename, "a") as f: + f.write(s + "\n") + + printme(f"\ndebugging linear {linear_node.target} {subgraph_idx}") + printme("\ndebugging details\n") + + prev_input_shape = None + prev_node_type = None + if linear_mod is not None: + cur_linear_size = linear_mod.in_features, linear_mod.out_features + else: + cur_linear_weight = linear_node.args[1] + weight_shape = get_meta_val(cur_linear_weight).shape + cur_linear_size = weight_shape[1], weight_shape[0] + cur_linear_2_size = None, None + next_node_types = [] + + # look at the preceding activation + for prev_n in linear_node.all_input_nodes: + if prev_n.op == "placeholder": + continue + # to get the shape of the input, we need to look at the previous node + for prev_prev_n in prev_n.all_input_nodes: + prev_prev_meta = get_meta_val(prev_prev_n) + if isinstance(prev_prev_meta, tuple): + prev_input_shape = ",".join(str(x.shape) for x in prev_prev_meta) + else: + prev_input_shape = prev_prev_meta.shape + printme(f"prev input shape: {prev_input_shape}") + printme(f"prev node: {prev_n.format_node()}") + printme(f"prev stack_summary: {get_stack_summary(prev_n)}") + if prev_n.op == "call_module": + mod = getattr(gm, prev_n.target) + prev_node_type = type(mod) + printme(f"prev mod: {mod}") + else: + prev_node_type = prev_n.target + + # print info about current linear + printme(f"cur_linear node: {linear_node.format_node()}") + printme(f"cur_linear mod: {linear_mod}") + printme(f"cur_linear stack_summary: {get_stack_summary(linear_node)}") + + # if there is a dual linear, print that too + linear_node_to_use = linear_node + dual_linear = False + if is_first_node_of_dual_linear(gm, linear_node): + dual_linear = True + linear_node_2 = list(linear_node.users.items())[0][0] + linear_mod_2 = None + if linear_mod is not None: + linear_mod_2 = getattr(gm, linear_node_2.target) + cur_linear_2_size = linear_mod_2.in_features, linear_mod_2.out_features + else: + cur_linear_2_weight = linear_node_2.args[1] + weight_shape = get_meta_val(cur_linear_2_weight).shape + cur_linear_2_size = weight_shape[1], weight_shape[0] + + printme(f"cur_linear 2 node: {linear_node_2.format_node()}") + printme(f"cur_linear 2 mod: {linear_mod_2}") + printme(f"cur_linear 2 stack_summary: {get_stack_summary(linear_node_2)}") + linear_node_to_use = linear_node_2 + + # look at the subsequent ops + # note: sometimes this is a view, so might need to look farther + printme(f"num users: {len(linear_node_to_use.users)}") + for next_n, _ in linear_node_to_use.users.items(): + for next_n_input in next_n.all_input_nodes: + printme(f"next input shape: {get_meta_val(next_n_input).shape}") + printme(f"next node: {next_n.format_node()}") + printme(f"next stack_summary: {get_stack_summary(next_n)}") + if next_n.op == "call_module": + mod = getattr(gm, next_n.target) + printme(f"next mod: {mod}") + next_node_types.append(type(mod)) + else: + next_node_types.append(next_n.target) + + printme("\ndebugging summary\n") + if not dual_linear: + linear_shape_str = f"{cur_linear_size}" + linear_str = "Linear" + else: + linear_shape_str = f"{cur_linear_size} {cur_linear_2_size}" + linear_str = "Linear -> Linear" + printme(f"input_shape {prev_input_shape}, (K, N) {linear_shape_str}") + subgraph_summary = f"{maybe_short_name(prev_node_type)} -> {linear_str} -> {[maybe_short_name(t) for t in next_node_types]}" + printme(subgraph_summary) + printme("\n") + + summary_result = [ + DEBUG_LINEARS_CALL_COUNTER, # extraction_idx + linear_node.target, # orig_node_name + subgraph_idx, + cur_linear_size, + cur_linear_2_size, + subgraph_summary, + ] + return summary_result + + +def extract_linear_subgraph( + old_gm: torch.fx.GraphModule, + old_linear_node: torch.fx.Node, + old_linear_mod: torch.nn.Module, + subgraph_save_filename: str, +) -> None: + """ + Input: a GraphModule with a `linear_node` calling `linear_mod`. + + This function does the following: + * find the subgraph prev_op -> linear_node -> [*next_ops] + * create a new GraphModule containing this subgraph + * save it to disk for further debugging + """ + + # to start, create a new module which just calls the linear + if old_linear_mod is not None: + new_m = torch.nn.Sequential(old_linear_mod) + else: + weight_val = get_meta_val(old_linear_node.args[1]) + + # handle dual linear for inlined here + # TODO(future): merge with code below for dual linear for non-inlined + + old_shape = weight_val.shape + new_shape = old_shape[1], old_shape[0] + + if not is_first_node_of_dual_linear(old_gm, old_linear_node): + new_m = torch.nn.Sequential( + torch.nn.Linear(*new_shape, dtype=weight_val.dtype), + ).cuda() + + else: + old_2nd_linear_node = list(old_linear_node.users.items())[0][0] + weight2_val = get_meta_val(old_2nd_linear_node.args[1]) + # TODO handle no bias and kwargs + get_meta_val(old_2nd_linear_node.args[2]) + + old_shape2 = weight2_val.shape + new_shape2 = old_shape2[1], old_shape2[0] + new_m = torch.nn.Sequential( + torch.nn.Linear(*new_shape, dtype=weight_val.dtype), + torch.nn.Linear(*new_shape2, dtype=weight2_val.dtype), + ).cuda() + + new_gm = torch.fx.symbolic_trace(new_m) + new_g = new_gm.graph + new_linear_node = list(new_gm.graph.nodes)[1] + # print(f'new_gm: {new_gm}') + # print(f'new_linear_node: {new_linear_node}') + + # copy the linear metadata over + new_linear_node.meta = old_linear_node.meta + new_linear_node.args[0].meta = old_linear_node.args[0].meta + + # + # step 1: add the preceding activation node + # + # before: input -> linear + # after: input_args -> prev_op -> linear + + # add the node inputs as placeholders, and copy the non-node inputs as is + prev_old_arg_to_new_arg = {} + + def prev_node_map_arg(old_arg): + if isinstance(old_arg, torch.fx.Node): + if old_arg in prev_old_arg_to_new_arg: + return prev_old_arg_to_new_arg[old_arg] + + with new_g.inserting_before(new_linear_node): + new_arg = new_g.placeholder(old_arg.name) + # copy the metadata over + new_arg.meta = old_arg.meta + prev_old_arg_to_new_arg[old_arg] = new_arg + return new_arg + return old_arg + + old_prev_node = old_linear_node.all_input_nodes[0] + + if old_prev_node.op == "call_module": + prev_mod = getattr(old_gm, old_prev_node.target) + new_name = "prev_mod" + setattr(new_gm, new_name, prev_mod) + new_args = tree_map(prev_node_map_arg, old_prev_node.args) + new_kwargs = tree_map(prev_node_map_arg, old_prev_node.kwargs) + with new_g.inserting_before(new_linear_node): + new_prev_node = new_g.call_module(new_name, new_args, new_kwargs) + + elif old_prev_node.op == "call_function": + new_args = tree_map(prev_node_map_arg, old_prev_node.args) + new_kwargs = tree_map(prev_node_map_arg, old_prev_node.kwargs) + with new_g.inserting_before(new_linear_node): + new_prev_node = new_g.call_function( + old_prev_node.target, new_args, new_kwargs + ) + + elif old_prev_node.op == "call_method": + new_args = tree_map(prev_node_map_arg, old_prev_node.args) + new_kwargs = tree_map(prev_node_map_arg, old_prev_node.kwargs) + with new_g.inserting_before(new_linear_node): + new_prev_node = new_g.call_method( + old_prev_node.target, new_args, new_kwargs + ) + + elif old_prev_node.op == "placeholder": + new_prev_node = new_linear_node.args[0] + + else: + raise AssertionError(f"old_prev_node.op: {old_prev_node.op} is unsupported") + + # only erase placeholder if there is a previous op + if old_prev_node.op != "placeholder": + prev_placeholder = new_linear_node.args[0] + new_linear_node.args = (new_prev_node, *new_linear_node.args[1:]) + new_g.erase_node(prev_placeholder) + + new_prev_node.meta = old_prev_node.meta + new_gm.recompile() + + # + # step 2 (optional): if there is a dual linear and dynamo is not inlining, + # handle it in a single subgraph. Note: the inlined case is handled above. + # + # before: input_args -> prev_op -> linear + # after: input_args -> prev_op -> linear -> linear2 + # + # then, in step 3, next_ops will be after linear2 + + # we still need to refer to the first linear in some places, + # save it + old_old_linear_node = old_linear_node + first_new_linear_node = new_linear_node + + if is_first_node_of_dual_linear(old_gm, old_linear_node): + print("DUAL LINEAR") + if old_linear_mod is not None: + old_first_user = list(old_linear_node.users.items())[0][0] + old_first_user_mod = getattr(old_gm, old_first_user.target) + dual_linear_name = "1" + setattr(new_gm, dual_linear_name, old_first_user_mod) + new_args, new_kwargs = (new_linear_node,), {} + + with new_g.inserting_after(new_linear_node): + new_dual_linear_node = new_g.call_module( + dual_linear_name, new_args, new_kwargs + ) + new_dual_linear_node.meta = old_first_user.meta + + # make the following code treat the second linear as the root + new_linear_node = new_dual_linear_node + old_linear_node = old_first_user + else: + # make the following code treat the second linear as the root + old_first_user = list(old_linear_node.users.items())[0][0] + new_linear_node = list(new_linear_node.users.items())[0][0] + old_linear_node = old_first_user + + # + # step 3: add the subsequent nodes (can be multiple users) + # + # before: input_args -> prev_op -> linear + # after: input_args -> prev_op -> linear -> next_op_1 + # -> ... + # -> next_op_n + + # create last_node to ensure graph order matches the original if there + # are multiple users of linear's output + new_last_node = new_linear_node + new_output_nodes = [] + next_old_arg_to_new_arg = {} + + def next_node_map_arg(old_arg): + if isinstance(old_arg, torch.fx.Node): + if old_arg in next_old_arg_to_new_arg: + # handle the same arg being used multiple times + return next_old_arg_to_new_arg[old_arg] + if old_arg == old_linear_node: + next_old_arg_to_new_arg[old_arg] = new_linear_node + return new_linear_node + elif old_arg == old_old_linear_node: + next_old_arg_to_new_arg[old_arg] = first_new_linear_node + return first_new_linear_node + elif old_arg == old_prev_node: + next_old_arg_to_new_arg[old_arg] = new_prev_node + return new_prev_node + elif old_arg in prev_old_arg_to_new_arg: + return prev_old_arg_to_new_arg[old_arg] + else: + # this is something else, make it a graph input + with new_g.inserting_before(new_linear_node): + new_arg = new_g.placeholder(old_arg.name) + # copy the metadata over + new_arg.meta = old_arg.meta + next_old_arg_to_new_arg[old_arg] = new_arg + return new_arg + return old_arg + + next_node_is_output = False + for counter, (old_next_n, _) in enumerate(old_linear_node.users.items()): + if old_next_n.op == "output": + # nothing to do + next_node_is_output = True + break + + new_args = tree_map(next_node_map_arg, old_next_n.args) + new_kwargs = tree_map(next_node_map_arg, old_next_n.kwargs) + if old_next_n.op == "call_function": + with new_g.inserting_after(new_last_node): + new_next_n = new_g.call_function( + old_next_n.target, + new_args, + new_kwargs, + ) + new_output_nodes.append(new_next_n) + new_last_node = new_next_n + elif old_next_n.op == "call_method": + with new_g.inserting_after(new_last_node): + new_next_n = new_g.call_method( + old_next_n.target, + new_args, + new_kwargs, + ) + new_output_nodes.append(new_next_n) + new_last_node = new_next_n + elif old_next_n.op == "call_module": + prev_mod = getattr(old_gm, old_next_n.target) + new_name = f"next_mod_{counter}" + setattr(new_gm, new_name, prev_mod) + with new_g.inserting_after(new_last_node): + new_next_n = new_g.call_module(new_name, new_args, new_kwargs) + new_output_nodes.append(new_next_n) + new_last_node = new_next_n + else: + assert False, f"unsupported old_next_n.op, {old_next_n.op}" + new_next_n.meta = old_next_n.meta + new_gm.recompile() + # print(f'after adding next_node, {new_gm}') + + if not next_node_is_output: + # reroute graph outputs from `linear` to `new_output_nodes` + cur_output_node = list(new_g.nodes)[-1] + # print(f'cur_output_node: {cur_output_node.format_node()}') + + if len(new_output_nodes) == 1: + new_g.output(new_output_nodes[0]) + else: + new_g.output(tuple(new_output_nodes)) + # print(f'new_output_node: {cur_output_node.format_node()}') + new_g.erase_node(cur_output_node) + new_gm.recompile() + # print(f'after new output, {new_gm}') + + # ensure every node has metas + for n in new_g.nodes: + if n.op == "output": + continue + assert n.meta is not None and n.meta != {}, f"{n}.meta is {n.meta}!" + + test_inputs = [] + for node in new_g.nodes: + if node.op != "placeholder": + continue + meta = get_meta_val(node) + new_inputs = None + if isinstance(meta, tuple): + new_inputs = tuple( + torch.randn(*inner_meta.shape, dtype=inner_meta.dtype, device="cuda") + for inner_meta in meta + ) + else: + new_inputs = torch.randn(*meta.shape, dtype=meta.dtype, device="cuda") + test_inputs.append(new_inputs) + + # save subgraph and inputs + torch.save((new_gm, test_inputs), subgraph_save_filename) + + # test fwd + new_gm(*test_inputs) + + # Note: cannot verify runnable after save/load here, because loading + # from disk in this file seems to try to use device meta as we are inside + # of dynamo tracing. Need to load from a separate process to properly test. + + +def print_and_append_to_logs(logger, filename, s): + logger.debug(s) + with open(filename, "a") as f: + f.write(s + "\n") + + +def prepare_target_folder(target_folder: str): + # ensure target folder exists + if not os.path.isdir(target_folder): + os.makedirs(target_folder) + + # ensure target folder only has file extensions we could have written + for root, dirs, files in os.walk(target_folder): + for file in files: + if not ( + file.endswith(".txt") + or file.endswith(".pt") + or file.endswith(".swp") + or file.endswith(".csv") + or file.endswith(".json") + ): + raise AssertionError(f"unknown file in target_dir: {file}") + + # delete any existing files from previous run for this target_folder + for root, dirs, files in os.walk(target_folder): + for file in files: + os.unlink(os.path.join(root, file)) + + global DEBUG_LINEARS_CALL_COUNTER + DEBUG_LINEARS_CALL_COUNTER = 0 + + +def debug_linears_for_float8( + g: torch.fx.Graph, + target_folder: str, + linear_mod_filter_fn: Optional[Callable] = None, + linear_node_filter_fn: Optional[Callable] = None, +) -> None: + """ + This function: + 1. looks for subgraphs containing `torch.nn.Linear` modules, including the preceding + and subsequent ops + 2. for each found subgraph + - extracts metadata about the subgraph (ops, shapes, modeling code location) and saves it to disk + - extracts it into a new `torch.fx.Graphmodule` instance and saves it to disk, this + can then be loaded elsewhere to run microbenchmarks + + Inputs: + - `g` - the graph to debug, assumed to come from dynamo's pre-dispatch trace and have torch IR + - `target_folder` - the folder to save metadata and microbenchmarks to, note that all folder + content is overwritten every time the script is run. The contents of this folder will be: + target_folder/ + debug_logs_0.txt + skip_logs_0.txt + summary_0.csv + subgraph_with_inputs_0_0.pt + ... + subgraph_with_inputs_0_(n-1).pt + - `linear_mod_filter_fn`: optional filtering function on linear modules, if it returns false then subgraph + extraction is skipped for that linear + - `linear_node_filter_fn`: optional filtering function on linear nodes, if it returns false then subgraph + extraction is skipped for that linear + + Format of summary_0.csv (column: example_value): + extraction_idx: 0 + orig_node_name: fn_1 + subgraph_idx: 0 + lin1_shape: (2, 3) + lin2_shape: (3, 4) # only applies to dual linear subgraphs + subgraph_summary: ReLU -> Linear -> ["cat"] + + Format of subgraph_with_inputs_0_0.pt: Tuple[nn.Module, Tuple[torch.tensor]] + """ + global DEBUG_LINEARS_CALL_COUNTER + debug_logs_filename = os.path.join( + target_folder, f"debug_logs_{DEBUG_LINEARS_CALL_COUNTER}.txt" + ) + skip_logs_filename = os.path.join( + target_folder, f"skip_logs_{DEBUG_LINEARS_CALL_COUNTER}.txt" + ) + summary_filename = os.path.join( + target_folder, f"summary_{DEBUG_LINEARS_CALL_COUNTER}.csv" + ) + summary_results = [summary_headers] + + gm = g.owning_module + assert gm is not None, "unsupported, gm needs to be specified" + graph_tabular_log.debug("\nstarting linear debug\n") + + def log_skip_linear(n, mod, reason): + print_and_append_to_logs( + graph_tabular_log, skip_logs_filename, f"SKIP: {reason}" + ) + print_and_append_to_logs( + graph_tabular_log, skip_logs_filename, f"node: {n.format_node()}" + ) + print_and_append_to_logs( + graph_tabular_log, skip_logs_filename, f"node.meta: {get_meta_val(n)}" + ) + print_and_append_to_logs( + graph_tabular_log, skip_logs_filename, f"node.stack: {get_stack_summary(n)}" + ) + print_and_append_to_logs(graph_tabular_log, skip_logs_filename, f"mod: {mod}") + print_and_append_to_logs(graph_tabular_log, skip_logs_filename, "\n") + + subgraph_idx = 0 + module_fqn = None + for n in gm.graph.nodes: + if n.op == "call_module": + # check for linear + module_fqn = n.target + module_instance = getattr(gm, n.target) + if type(module_instance) is not torch.nn.Linear: + continue + + if linear_mod_filter_fn is not None and not linear_mod_filter_fn( + module_instance + ): + log_skip_linear(n, module_instance, "failed filter function") + continue + + # Note: we special case for linear -> linear, + # so if we are at the second linear then skip debug/extract to avoid duplication + is_second_linear_of_dual_linear = ( + n.args[0].op == "call_module" + and type(getattr(gm, n.args[0].target)) is torch.nn.Linear + and len(n.args[0].users) == 1 + ) + if is_second_linear_of_dual_linear: + log_skip_linear(n, module_instance, "second of dual linear") + continue + + elif n.op == "call_function": + if n.target != torch._C._nn.linear: + continue + + if linear_node_filter_fn is not None and not linear_node_filter_fn(n): + log_skip_linear(n, None, "failed filter function") + continue + + # Note: we special case for linear -> linear, + # so if we are at the second linear then skip debug/extract to avoid duplication + is_second_linear_of_dual_linear = ( + n.args[0].op == "call_function" + and n.args[0].target is torch._C._nn.linear + and len(n.args[0].users) == 1 + ) + if is_second_linear_of_dual_linear: + log_skip_linear(n, module_instance, "second of dual linear") + continue + + module_instance = None + + else: + continue + + # for now, the case where the linear's input is a graph input is not supported + if False: + is_input_placeholder = n.args[0].op == "placeholder" + if is_input_placeholder: + log_skip_linear(n, module_instance, "input is placeholder") + continue + + try: + summary_result = debug_single_linear( + gm, n, module_instance, debug_logs_filename, subgraph_idx + ) + subgraph_save_filename = os.path.join( + target_folder, + f"subgraph_with_inputs_{DEBUG_LINEARS_CALL_COUNTER}_{subgraph_idx}.pt", + ) + extract_linear_subgraph(gm, n, module_instance, subgraph_save_filename) + summary_results.append(summary_result + [module_fqn]) + except Exception as e: + print(e) + log_skip_linear( + n, + module_instance, + f"{subgraph_idx}, {str(e)}, {traceback.format_exc()}", + ) + subgraph_idx += 1 + + with open(summary_filename, "w") as f: + csv.writer(f).writerows(summary_results) + + graph_tabular_log.debug("\nending linear debug\n") + + DEBUG_LINEARS_CALL_COUNTER += 1 diff --git a/torchao/prototype/sparsity/superblock/utils.py b/torchao/prototype/sparsity/superblock/utils.py index 9ed38e50d3..e2b546db24 100644 --- a/torchao/prototype/sparsity/superblock/utils.py +++ b/torchao/prototype/sparsity/superblock/utils.py @@ -387,7 +387,7 @@ def accelerate_with_sparsity(model, args): if args.sparsity == "bsr": apply_sparsity(model) if args.quantization: - from torchao.dtypes.affine_quantized_tensor import BlockSparseLayout + from torchao.dtypes import BlockSparseLayout quantize_( model, @@ -401,7 +401,7 @@ def accelerate_with_sparsity(model, args): sparsify_(model, block_sparse_weight(blocksize=args.bsr), superblock_only) elif args.sparsity == "semi_structured": if args.quantization: - from torchao.dtypes.affine_quantized_tensor import SemiSparseLayout + from torchao.dtypes import SemiSparseLayout quantize_( model, diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 470e71ae36..cb7c8d0481 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -17,15 +17,16 @@ import torch.nn.functional as F from torch.utils._pytree import tree_flatten, tree_unflatten +from torchao.dtypes.utils import is_device from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_6, find_multiple, ) from .quant_primitives import MappingType from .unified import Quantizer from .utils import ( - _lm_eval_available, _MultiInput, get_group_qparams_symmetric, get_groupwise_affine_qparams, @@ -39,10 +40,6 @@ aten = torch.ops.aten - -if not _lm_eval_available: - logging.info("lm_eval is not installed, GPTQ may not be usable") - add_ons = [] if TORCH_VERSION_AT_LEAST_2_3: @@ -542,12 +539,20 @@ def linear_forward_int4( ): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) - c = torch.ops.aten._weight_int4pack_mm( - x.to(precision), - weight_int4pack, - groupsize, - scales_and_zeros.to(scales_precision), - ).to(dtype=x.dtype) + if is_device(x.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + c = torch.ops.aten._weight_int4pack_mm_for_cpu( + x.to(precision), + weight_int4pack, + groupsize, + scales_and_zeros.to(scales_precision), + ).to(dtype=x.dtype) + else: + c = torch.ops.aten._weight_int4pack_mm( + x.to(precision), + weight_int4pack, + groupsize, + scales_and_zeros.to(scales_precision), + ).to(dtype=x.dtype) new_shape = origin_x_size[:-1] + (out_features,) c = c.reshape(new_shape) return c @@ -575,8 +580,6 @@ def __init__( super().__init__() self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles) if self.padding: - from .utils import find_multiple - self.origin_in_features = in_features in_features = find_multiple(in_features, 1024) @@ -596,19 +599,32 @@ def __init__( assert ( in_features % (inner_k_tiles * 16) == 0 ), "require in_features % (innerKTiles * 16) == 0" - self.register_buffer( - "weight", - torch.zeros( - ( - out_features // 8, - in_features // (inner_k_tiles * 16), - 32, - inner_k_tiles // 2, + if is_device(device.type, "cpu"): + self.register_buffer( + "weight", + torch.zeros( + ( + out_features, + in_features // 2, + ), + dtype=torch.uint8, + device=device, ), - dtype=torch.int32, - device=device, - ), - ) + ) + else: + self.register_buffer( + "weight", + torch.zeros( + ( + out_features // 8, + in_features // (inner_k_tiles * 16), + 32, + inner_k_tiles // 2, + ), + dtype=torch.int32, + device=device, + ), + ) self.dtype = dtype self.register_buffer( "scales_and_zeros", @@ -743,8 +759,6 @@ def _create_quantized_state_dict( if self.padding_allowed: import torch.nn.functional as F - from .utils import find_multiple - logging.warn( f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" ) @@ -765,9 +779,19 @@ def _create_quantized_state_dict( self.precision, # dtype for scales_and_zeros ) # TODO: just get the device from mod.weight.device? - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( - w_int4x8.to(self.device), self.inner_k_tiles - ) + if ( + is_device(w_int4x8.device.type, "cpu") + and TORCH_VERSION_AT_LEAST_2_6 + ): + weight_int4pack = ( + torch.ops.aten._convert_weight_to_int4pack_for_cpu( + w_int4x8.to(self.device), self.inner_k_tiles + ) + ) + else: + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + w_int4x8.to(self.device), self.inner_k_tiles + ) cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device) cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to( self.device @@ -851,9 +875,14 @@ def make_names_and_values_dict_func(q, qparams): # how much we need to pad the weight delta_k = int((new_k - k) / 2) q = q.to(self.device) - final_q = torch.ops.aten._convert_weight_to_int4pack( - F.pad(q, pad=(0, delta_k)), inner_k_tiles - ) + if is_device(self.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + final_q = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + F.pad(q, pad=(0, delta_k)), inner_k_tiles + ) + else: + final_q = torch.ops.aten._convert_weight_to_int4pack( + F.pad(q, pad=(0, delta_k)), inner_k_tiles + ) scales = qparams[0].to(torch.bfloat16).to(self.device) zeros = qparams[1].to(torch.bfloat16).to(self.device) scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) @@ -1118,8 +1147,6 @@ def _create_quantized_state_dict( if self.padding_allowed: import torch.nn.functional as F - from .utils import find_multiple - logging.warn( f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" ) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 90e898debd..3fc2cb5ef0 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -3,7 +3,7 @@ Typically quantization algorithms will have different schemes for how the activa ## Benchmarks Benchmarks and evaluation are run on a machine with a single NVIDIA-A100-80GB GPU using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data. The models used were meta-llama/Llama-2-7b-chat-hf and meta-llama/Meta-Llama-3-8B. - +### CUDA backend | Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | | ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | | Llama-2-7B | Base (bfloat16) | 12.212 | 107.38 | 1418.93 | 13.88 | 13.21 | @@ -20,9 +20,16 @@ Benchmarks and evaluation are run on a machine with a single NVIDIA-A100-80GB GP | | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 | | | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 | | | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 | +### XPU backend +| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | +| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | +| Llama-2-7B | Base (bfloat16) | NA | 42.20 | 557.71 | 13.89 | 13.21 | +| | int8dq | NA | 9.87 | 65.35 | 14.60 | 6.62 | +| | int8wo | NA | 66.24 | 438.61 | 14.60 | 6.62 -Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are run on a machine with a single NVIDIA-H100 GPU using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data. + +### CUDA backend | Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | | ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | | Llama-3.1-8B | Base (bfloat16) | 7.54 | 126.90 | 1904.75 | 16.75 | 15.01 | @@ -31,6 +38,15 @@ Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are run on a ma | | float8wo | 7.60 | 178.46 | 1339.93 | 12.09 | 7.51 | | | float8dq (PerTensor) | 7.62 | 116.40 | 873.58 | 11.14 | 7.51 | | | float8dq (Per Row) | 7.61 | 154.63 | 1161.47 | 11.14 | 7.51 | +### XPU backend +| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | +| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | +| Llama-3-8.1B | Base (bfloat16) | 7.441 | 40.36 | 605.77 | 16.35 | 15.01 | +| | int8dq | 7.581 | 13.60 | 102.28 | 18.69 | 7.52 | +| | int8wo | 7.447 | 59.49 | 447.27 | 18.60 | 7.52 + + +Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are run on a machine with a single NVIDIA-H100 GPU or Intel-Max1100 using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data. note: Int8 dynamic quantization works best on compute bound models like [SAM](https://github.com/pytorch-labs/segment-anything-fast) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance. @@ -77,7 +93,7 @@ import torchao.quantization # After the first forward pass (when quantization was done) from torchao.quantization.autoquant import AUTOQUANT_CACHE with open("quantization-cache.pkl", "wb") as f: - pickle.dump(AUTOQUANT_CACHE) + pickle.dump(AUTOQUANT_CACHE, f) # On load from torchao.quantization.autoquant import AUTOQUANT_CACHE @@ -96,7 +112,7 @@ be applied individually. While there are a large variety of quantization apis, t from torchao.quantization import quantize_, int4_weight_only group_size = 32 -# you can enable [hqq](https://ithub.com/mobiusml/hqq/tree/master) quantization which is expected to improves accuracy through +# you can enable [hqq](https://github.com/mobiusml/hqq/tree/master) quantization which is expected to improves accuracy through # use_hqq flag for `int4_weight_only` quantization use_hqq = False quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq)) @@ -333,7 +349,16 @@ We're trying to develop kernels for low bit quantization for intx quantization f You try can out these apis with the `quantize_` api as above alongside the constructor `uintx_weight_only` an example can be found in in `torchao/_models/llama/generate.py`. +### int8_dynamic_activation_intx_weight Quantization +We have kernels that do 8-bit dynamic quantization of activations and uintx groupwise quantization of weights. These kernels are experimental and can only be run on a device with an ARM CPU (e.g., a Mac computers with Apple silicon). The benchmarks below were run on an M1 Mac Pro, with 8 perf cores, and 2 efficiency cores, and 32GB of RAM. In all cases, torch.compile was used. + +| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | +| ------------- | -------------------------------------------------| --------------| ------------------------| ---------------- | ----------------| +| Llama-3.1-8B | Base (bfloat16) | 1.24 | 18.62 | NA | 15.01 | +| | int8_dynamic_activation_intx_weight-4-256-false | 16.03 | 65.81 | NA | 4.11 | +| | int8_dynamic_activation_intx_weight-3-256-false | 18.94 | 59.97 | NA | 3.17 | +You try can out these apis with the `quantize_` api as above alongside the constructor `int8_dynamic_activation_intx_weight`. An example can be found in `torchao/_models/llama/generate.py`. ### Automatic Inductor Configuration The `quantize_` and `autoquant` apis now automatically use our recommended inductor configuration setings. You can mimic the same configuration settings for your own experiments by using the `torchao.quantization.utils.recommended_inductor_config_setter` to replicate our recommended configuration settings. Alternatively if you wish to disable these recommended settings, you can use the key word argument `set_inductor_config` and set it to false in the `quantize_` or `autoquant` apis to prevent assignment of those configuration settings. You can also overwrite these configuration settings after they are assigned if you so desire, as long as they are overwritten before passing any inputs to the torch.compiled model. This means that previous flows which referenced a variety of inductor configurations that needed to be set are now outdated, though continuing to manually set those same inductor configurations is unlikely to cause any issues. diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index ff66e23cc9..344bdeea41 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -11,6 +11,7 @@ from .autoquant import ( DEFAULT_AUTOQUANT_CLASS_LIST, + DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, DEFAULT_INT4_AUTOQUANT_CLASS_LIST, OTHER_AUTOQUANT_CLASS_LIST, autoquant, @@ -89,6 +90,7 @@ "autoquant", "DEFAULT_AUTOQUANT_CLASS_LIST", "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", + "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST", "OTHER_AUTOQUANT_CLASS_LIST", # top level API - manual "quantize_", diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index ee6bf98852..b486683290 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -18,8 +18,15 @@ MappingType, ZeroPointDomain, ) -from torchao.quantization.utils import quantize_activation_per_token_absmax -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5 +from torchao.quantization.utils import ( + compute_error, + quantize_activation_per_token_absmax, +) +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_5, + TorchAOBaseTensor, +) from .granularity import ( PerRow, @@ -36,6 +43,7 @@ "autoquant", "DEFAULT_AUTOQUANT_CLASS_LIST", "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", + "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST", "OTHER_AUTOQUANT_CLASS_LIST", ] @@ -69,7 +77,15 @@ class AutoQuantizableLinearWeight(torch.Tensor): """ @staticmethod - def __new__(cls, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs): + def __new__( + cls, + weight, + qtensor_class_list, + *args, + mode=["relu", None], + min_sqnr=None, + **kwargs, + ): kwargs["device"] = weight.device kwargs["layout"] = ( kwargs.get("layout") if kwargs.get("layout", False) else weight.layout @@ -82,12 +98,19 @@ def __new__(cls, weight, qtensor_class_list, *args, mode=["relu", None], **kwarg return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] def __init__( - self, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs + self, + weight, + qtensor_class_list, + *args, + mode=["relu", None], + min_sqnr=None, + **kwargs, ): self.weight = weight self.qtensor_class_list = qtensor_class_list self.logged_data = {} self.mode = mode + self.min_sqnr = min_sqnr def __repr__(self): return ( @@ -123,9 +146,25 @@ def tune_autoquant(self, q_cls, shapes_and_dtype, best_time): else torch.randn(bias_shape, dtype=act_dtype, device=self.device) ) try: - res = q_cls._autoquant_test( - act_mat, self.weight, bias, best_time, self.mode + ref_output = AQDefaultLinearWeight._quantized_linear_op( + act_mat, self.weight, bias ) + q_output = q_cls._quantized_linear_op( + act_mat, q_cls.from_float(self.weight), bias + ) + if ( + self.min_sqnr is not None + and (sqnr := compute_error(q_output, ref_output)) + < self.min_sqnr + ): + print( + f"skipping q_cls: {q_cls} because the sqnr is too small, minimum expected sqnr: {self.min_sqnr}, got {sqnr}" + ) + res = torch.inf + else: + res = q_cls._autoquant_test( + act_mat, self.weight, bias, best_time, self.mode + ) except Exception as e: print( f"warning: failed to autoquant {q_cls.__name__} for shape: {shapes_and_dtype} due to {e}" @@ -141,7 +180,7 @@ def to_quantized(self, error_on_unseen, **kwargs): ) elif (self.logged_data == {}) and not error_on_unseen: # default back to non-quantized weight if not seen - self = AQFloatLinearWeight.from_float(self.weight) + self = AQDefaultLinearWeight.from_float(self.weight) return self # only want to print shape (at start) and final result (at end) @@ -194,34 +233,49 @@ def count_shapes(self, do_print=True): print( f">time (all shapes): {cur_time:0.4f}ms for {q_cls}, prev_best: {best_time:0.4f}ms" ) - if best_time >= cur_time: + if cur_time != torch.inf and best_time >= cur_time: best_time = cur_time best_cls = q_cls # if no new benchmarking was done, don't print the final result, it will be the same as for another layer if ran_new_benchmarks: print(f"best_cls={best_cls}\n") + + if best_cls is None: + best_cls = AQDefaultLinearWeight + # TODO handle random cls args/kwargs? or should they be curried? self = best_cls.from_float(self.weight) return self def _apply_fn_to_data(self, fn): return self.__class__( - fn(self.weight), self.qtensor_class_list, dtype=self.dtype, mode=self.mode + fn(self.weight), + self.qtensor_class_list, + dtype=self.dtype, + mode=self.mode, + min_sqnr=self.min_sqnr, ) def __tensor_flatten__(self): - return ["weight"], [self.qtensor_class_list, self.mode, self.dtype, self.shape] + return ["weight"], [ + self.qtensor_class_list, + self.mode, + self.min_sqnr, + self.dtype, + self.shape, + ] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None ): weight = tensor_data_dict["weight"] - qtensor_class_list, mode, dtype, shape = tensor_attributes[0] + qtensor_class_list, mode, min_sqnr, dtype, shape = tensor_attributes return cls( weight, qtensor_class_list, - mode, + mode=mode, + min_sqnr=min_sqnr, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride, @@ -608,7 +662,7 @@ class AQInt4G256WeightOnlyQuantizedLinearWeight( group_size: int = 256 -class AQFloatLinearWeight(torch.Tensor, AQMixin): +class AQDefaultLinearWeight(torch.Tensor, AQMixin): """ A class to be used in concert with AutoQuantizableLinearWeight to provide a default/non-quantized option. Only implements the bare minimum needed to work with the @@ -629,6 +683,135 @@ def from_float(cls, weight): return weight +class Float32Tensor(TorchAOBaseTensor): + """Tensor subclass tensor for fp32 dtype""" + + def __init__(self, weight): + self.weight = weight.to(torch.float32) + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.float32 + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, + ).to(dtype=orig_dtype) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.weight), + ) + + @classmethod + def from_float(cls, weight): + return cls(weight) + + +@Float32Tensor.implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) + + +@Float32Tensor.implements(aten.detach.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@Float32Tensor.implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + +@Float32Tensor.implements(aten._to_copy.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + + +class BFloat16Tensor(Float32Tensor): + def __init__(self, weight): + self.weight = weight.to(torch.bfloat16) + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.bfloat16 + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, + ).to(dtype=orig_dtype) + + +class Float16Tensor(Float32Tensor): + def __init__(self, weight): + self.weight = weight.to(torch.float16) + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.float16 + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, + ).to(dtype=orig_dtype) + + +class AQFloat32LinearWeight(Float32Tensor, AQMixin): + """ + AutoQuantizable version for float32 precision weight + + (also converts input activation and bias to float32, and restores the original precision after + linear) + """ + + @classmethod + def from_float(cls, weight): + return super(AQFloat32LinearWeight, cls).from_float(weight) + + +class AQBFloat16LinearWeight(BFloat16Tensor, AQMixin): + """ + AutoQuantizable version for bfloat16 precision weight + + (also converts input activation and bias to bfloat16, and restores the original precision after + linear) + """ + + @classmethod + def from_float(cls, weight): + return super(AQBFloat16LinearWeight, cls).from_float(weight) + + +class AQFloat16LinearWeight(Float16Tensor, AQMixin): + """ + AutoQuantizable version for float16 precision weight + + (also converts input activation and bias to float16, and restores the original precision after + linear) + """ + + @classmethod + def from_float(cls, weight): + return super(AQFloat16LinearWeight, cls).from_float(weight) + + class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): """ AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn @@ -742,7 +925,7 @@ def get_weight_block_size(x): # here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison DEFAULT_AUTOQUANT_CLASS_LIST = [ - AQFloatLinearWeight, + AQDefaultLinearWeight, AQInt8WeightOnlyQuantizedLinearWeight, AQInt8WeightOnlyQuantizedLinearWeight2, # AQInt8WeightOnlyQuantizedLinearWeight3, @@ -751,11 +934,17 @@ def get_weight_block_size(x): ] DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [ - AQFloatLinearWeight, + AQDefaultLinearWeight, AQInt8DynamicallyQuantizedLinearWeight, AQInt4G64WeightOnlyQuantizedLinearWeight, ] +DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST = [ + AQFloat32LinearWeight, + AQBFloat16LinearWeight, + AQFloat16LinearWeight, +] + OTHER_AUTOQUANT_CLASS_LIST = [ AQFloat8WeightOnlyQuantizedLinearWeight, AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, @@ -779,6 +968,7 @@ def _change_linears_to_autoquantizable(model, **kwargs): "qtensor_class_list", DEFAULT_AUTOQUANT_CLASS_LIST ) kwargs["mode"] = kwargs.get("mode", ["relu", None]) + kwargs["min_sqnr"] = kwargs.get("min_sqnr", None) from torchao.quantization.quant_api import ( _get_subclass_inserter, _replace_with_custom_fn_if_matches_filter, @@ -853,6 +1043,7 @@ def autoquant( manual=False, set_inductor_config=True, supress_autoquant_errors=True, + min_sqnr=None, **aq_kwargs, ): """ @@ -887,6 +1078,9 @@ def autoquant( the user to call model.finalize_autoquant (True) so inputs with several shapes/dtypes can be logged. set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True) supress_autoquant_errors (bool, optional): Whether to suppress errors during autoquantization. (defaults to True) + min_sqnr (float, optional): minimum acceptable signal to quantization noise ration (https://en.wikipedia.org/wiki/Signal-to-quantization-noise_ratio) for output of quantized layer v.s. non-quantized layer, this is used to filter + out quantization methods that causes too large numerical impact, user can start with a resaonable + number like 40 and adjust depending on the result **aq_kwargs: Additional keyword arguments for the autoquantization process. Returns: @@ -919,6 +1113,7 @@ def autoquant( filter_fn=filter_fn, qtensor_class_list=qtensor_class_list, mode=mode, + min_sqnr=min_sqnr, **aq_kwargs, ) diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index cbe6296407..d5f2dca5b4 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -9,6 +9,7 @@ import torch import torch.nn.functional as F +from torchao.dtypes.utils import is_device from torchao.quantization.GPTQ import ( Int8DynActInt4WeightLinear, WeightOnlyInt4Linear, @@ -23,6 +24,7 @@ ) from torchao.quantization.unified import TwoStepQuantizer from torchao.quantization.utils import get_group_qparams_symmetric +from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 from .api import FakeQuantizeConfig from .fake_quantizer import FakeQuantizer @@ -363,6 +365,7 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module): inner_k_tiles=inner_k_tiles, precision=child.weight.dtype, scales_precision=config.scale_precision, + device=next(child.parameters()).device, ) setattr(module, name, quantized_linear) @@ -373,10 +376,19 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module): n_bit, config.group_size, ) - q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to(child.weight.device), - child.inner_k_tiles, - ) + if ( + is_device(q_weight.device.type, "cpu") + and TORCH_VERSION_AT_LEAST_2_6 + ): + q_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + q_weight.to(child.weight.device), + child.inner_k_tiles, + ) + else: + q_weight = torch.ops.aten._convert_weight_to_int4pack( + q_weight.to(child.weight.device), + child.inner_k_tiles, + ) quantized_linear.weight = q_weight quantized_linear.scales_and_zeros = scales_and_zeros else: diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index f96a1198a1..96ccb1889c 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -33,12 +33,13 @@ PlainLayout, SemiSparseLayout, TensorCoreTiledLayout, + UintxLayout, to_affine_quantized_floatx, to_affine_quantized_floatx_static, to_affine_quantized_intx, to_marlinqqq_quantized_intx, ) -from torchao.dtypes.uintx.uintx import UintxLayout +from torchao.float8.float8_linear import Float8Linear from torchao.float8.inference import Float8MMConfig from torchao.quantization.linear_activation_weight_observed_tensor import ( LinearActivationWeightObservedTensor, @@ -51,6 +52,9 @@ TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + is_MI300, + is_sm_at_least_89, + is_sm_at_least_90, ) from .autoquant import AutoQuantizableLinearWeight, autoquant @@ -219,6 +223,12 @@ def _replace_with_custom_fn_if_matches_filter( Returns: None """ + if isinstance(model, Float8Linear): + with torch.device("meta"): + new_module = nn.Linear(model.in_features, model.out_features) + new_module.weight = model.weight + new_module.bias = model.bias + model = new_module if filter_fn(model, cur_fqn[:-1]): if device is not None: model.to(device=device) # move to device before quantization @@ -627,7 +637,8 @@ def int4_weight_only( "tensor_core_tiled" layout for speedup with tinygemm kernel Note: - This is targeting `tinygemm` int4mm kernel (`torch.ops.aten._weight_int4pack_mm`), the main difference + This is targeting `tinygemm` int4mm kernel (`torch.ops.aten._weight_int4pack_mm` + and `torch.ops.aten._weight_int4pack_mm_for_cpu`), the main difference of quantization algorithm compared to the more traditional type of integer quantization is the following: 1). zero_point is in floating point domain instead of integer domain (`zero_point_domain`=`ZeroPointDomain.FLOAT`) 2). floating point zero does not have to be exactly representable (`preserve_zero`=False in `choose_qparams_affine`) @@ -827,10 +838,11 @@ def _normalize_granularity( Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]] ], ) -> Tuple[_fp8_granularities, _fp8_granularities]: + processed_granularity = None if granularity is None: - return (PerTensor(), PerTensor()) + processed_granularity = (PerTensor(), PerTensor()) elif isinstance(granularity, (PerTensor, PerRow)): - return (granularity, granularity) + processed_granularity = (granularity, granularity) elif isinstance(granularity, tuple) and len(granularity) == 2: if not ( isinstance(granularity[0], (PerTensor, PerRow)) @@ -843,11 +855,25 @@ def _normalize_granularity( raise ValueError( f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported." ) - return granularity + processed_granularity = granularity else: raise ValueError( f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported." ) + # Validate granularity with supported Hardware + for _granularity in processed_granularity: + if isinstance(_granularity, PerTensor): + assert ( + is_sm_at_least_89() or is_MI300() + ), "PerTensor quantization only works for CUDA>=8.9 and MI300+" + elif isinstance(_granularity, PerRow): + assert ( + is_sm_at_least_90() or is_MI300() + ), "PerRow quantization only works for CUDA>=9.0 and MI300+" + else: + raise ValueError(f"Invalid granularity type: {_granularity}") + + return processed_granularity def _input_activation_quant_func_fp8( @@ -939,6 +965,9 @@ def float8_dynamic_activation_float8_weight( mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. """ + assert ( + is_sm_at_least_89() or is_MI300() + ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" if mm_config is None: mm_config = Float8MMConfig(use_fast_accum=True) @@ -993,6 +1022,9 @@ def float8_static_activation_float8_weight( weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. """ + assert ( + is_sm_at_least_89() or is_MI300() + ), "Float8 static activation quantization is only supported on CUDA 8.9 and above" if mm_config is None: mm_config = Float8MMConfig(use_fast_accum=True) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 036109bc8d..9715d99e08 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -8,6 +8,7 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.dtypes.utils import is_device from torchao.quantization.utils import ( dequantize_per_channel, dynamically_quantize_per_channel, @@ -15,7 +16,7 @@ quant_int8_dynamic_per_token_linear, unpack_tinygemm_scales_and_zeros, ) -from torchao.utils import find_multiple +from torchao.utils import TORCH_VERSION_AT_LEAST_2_6, find_multiple __all__ = [ "Int8DynamicallyQuantizedLinearWeight", @@ -458,12 +459,20 @@ def _quantized_op(act_mat, w_qtensor, bias): act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1])) # matmul - y = aten._weight_int4pack_mm( - act_mat.contiguous(), - w_qtensor.int_data, - w_qtensor.groupsize, - w_qtensor.scales_and_zeros, - ) + if is_device(act_mat.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + y = aten._weight_int4pack_mm_for_cpu( + act_mat.contiguous(), + w_qtensor.int_data, + w_qtensor.groupsize, + w_qtensor.scales_and_zeros, + ) + else: + y = aten._weight_int4pack_mm( + act_mat.contiguous(), + w_qtensor.int_data, + w_qtensor.groupsize, + w_qtensor.scales_and_zeros, + ) # remove out_feature padding orig_out_features = ( @@ -609,5 +618,10 @@ def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8): input_int4x8, scales_and_zeros = groupwise_affine_quantize_tensor( input_float, 4, groupsize, dtype=input_float.dtype ) - int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles) + if is_device(input_float.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + int_data = aten._convert_weight_to_int4pack_for_cpu( + input_int4x8, inner_k_tiles + ) + else: + int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles) return int_data, scales_and_zeros, False, groupsize, inner_k_tiles diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 9083dd7621..e1cf98b549 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -9,6 +9,7 @@ import torch from torch.utils._python_dispatch import TorchDispatchMode +from torchao.dtypes.utils import is_device from torchao.kernel import ( int_scaled_matmul, ) @@ -19,7 +20,7 @@ dequantize_affine, quantize_affine, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6 __all__ = [ "compute_error", @@ -402,13 +403,8 @@ def groupwise_affine_quantize_tensor_from_qparams( zero_point_domain=ZeroPointDomain.FLOAT, ) if TORCH_VERSION_AT_LEAST_2_5 and w.shape[-1] > 1: - int_data_device_type = int_data.device.type - # Move to cpu, until issue with MPS memory management of temporary tensors is resolved - if int_data_device_type == "mps": - int_data = int_data.cpu() - int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) - if int_data_device_type == "mps": - int_data = int_data.to(device="mps") + if not (is_device(int_data.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6): + int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) return int_data @@ -422,8 +418,10 @@ def groupwise_affine_dequantize_tensor_from_qparams( assert groupsize > 1 assert w_int4x8.dim() == 2 # need to handle single column case so check for dtype/size from groupwise_affine_quantize_tensor_from_qparams path - if TORCH_VERSION_AT_LEAST_2_5 and ( - w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1 + if ( + TORCH_VERSION_AT_LEAST_2_5 + and (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1) + and not (is_device(w_int4x8.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6) ): data = w_int4x8.to(torch.int32) high_bits = data >> 4 diff --git a/torchao/testing/float8/dtensor_utils.py b/torchao/testing/float8/dtensor_utils.py new file mode 100644 index 0000000000..84e4095263 --- /dev/null +++ b/torchao/testing/float8/dtensor_utils.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +import torch.nn.functional as F + + +class FeedForward(nn.Module): + """MLP based model""" + + def __init__(self): + super(FeedForward, self).__init__() + self.w1 = nn.Linear(16, 32, bias=False) + self.w2 = nn.Linear(16, 32, bias=False) + self.out_proj = nn.Linear(32, 16, bias=False) + + def forward(self, x): + return self.out_proj(F.silu(self.w1(x)) * self.w2(x)) + + +class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.ffn = FeedForward() + + def forward(self, x): + return self.ffn(x) diff --git a/torchao/testing/float8/fsdp2_utils.py b/torchao/testing/float8/fsdp2_utils.py index 7744ae4e92..af46b7fa71 100644 --- a/torchao/testing/float8/fsdp2_utils.py +++ b/torchao/testing/float8/fsdp2_utils.py @@ -1,16 +1,13 @@ -import contextlib -from typing import List, Optional +from typing import List import torch import torch.distributed as dist import torch.nn as nn -import torchao.float8.config as config from torchao.float8.config import ( Float8LinearConfig, ScalingType, ) - from torchao.float8.float8_linear_utils import ( linear_requires_sync, sync_float8_amax_and_scale_history, @@ -52,7 +49,11 @@ def check_parity_no_mp( ): precompute_float8_dynamic_scale_for_fsdp(model) - test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}") + test_cls.assertEqual( + losses[0], + losses[1], + msg=f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}", + ) def check_parity_bf16_mp( @@ -87,7 +88,11 @@ def check_parity_bf16_mp( ref_model.parameters(), ref_model_bf16.parameters() ): param_bf16.detach().copy_(param_fp32) - test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}") + test_cls.assertEqual( + losses[0], + losses[1], + msg=f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}", + ) def check_parity_fp8_comm_only( @@ -104,7 +109,6 @@ def check_parity_fp8_comm_only( for iter_idx in range(10): losses: List[torch.Tensor] = [] for model, optim in ((ref_model, ref_optim), (fsdp_model, fsdp_optim)): - optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) losses.append(model(local_inp).sum()) losses[-1].backward() @@ -123,9 +127,15 @@ def check_parity_fp8_comm_only( and config.cast_config_weight.scaling_type is ScalingType.DYNAMIC ): precompute_float8_dynamic_scale_for_fsdp(model) - + if compile: # When compile, the ref loss and fsdp loss are not exactly the same, only check the loss values are valid for now. - assert (torch.isfinite(losses[0]).any() and torch.isfinite(losses[1]).any()), f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}" + assert ( + torch.isfinite(losses[0]).any() and torch.isfinite(losses[1]).any() + ), f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}" else: - test_cls.assertEqual(losses[0], losses[1], f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}") + test_cls.assertEqual( + losses[0], + losses[1], + f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}", + ) diff --git a/torchao/testing/float8/test_utils.py b/torchao/testing/float8/test_utils.py index 7f37c3f30a..7b8ac121b6 100644 --- a/torchao/testing/float8/test_utils.py +++ b/torchao/testing/float8/test_utils.py @@ -1,9 +1,9 @@ import torch + from torchao.float8.config import ( - ScalingGranularity, - ScalingType, - CastConfig, + CastConfig, Float8LinearConfig, + ScalingType, ) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 39edc50085..d88241783f 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -1,15 +1,19 @@ -import unittest -import functools import copy -import torch -import torchao -import os +import functools +import unittest +import torch +from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard from torch.testing._internal import common_utils -from torchao.dtypes import AffineQuantizedTensor -from torchao.dtypes import to_affine_quantized_intx +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) + +import torchao +from torchao.dtypes import AffineQuantizedTensor, to_affine_quantized_intx +from torchao.quantization import int8_weight_only, quantize_ from torchao.quantization.quant_primitives import MappingType -from torchao.quantization import quantize_, int8_weight_only from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 """ @@ -36,10 +40,9 @@ class MyTestCase(TorchAOBasicTestCase): unittest.main() """ + # copied from https://github.com/pytorch/pytorch/blob/941d094dd1b507dacf06ddc6ed3485a9537e09b7/test/inductor/test_torchinductor.py#L11389 -def copy_tests( - my_cls, other_cls, suffix, test_failures=None, xfail_prop=None -): # noqa: B902 +def copy_tests(my_cls, other_cls, suffix, test_failures=None, xfail_prop=None): # noqa: B902 for name, value in my_cls.__dict__.items(): if name.startswith("test_"): # You cannot copy functions in Python, so we use closures here to @@ -70,7 +73,6 @@ def new_test(self, value=value): setattr(other_cls, f"{name}_{suffix}", new_test) - class TorchAOBasicTestCase(common_utils.TestCase): COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] @@ -90,17 +92,21 @@ def test_flatten_unflatten(self): hp_tensor = torch.randn(4, 128) lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() - tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict} + tensor_data_dict = { + name: getattr(lp_tensor, name) for name in tensor_data_name_dict + } outer_size = lp_tensor.size() outer_stride = lp_tensor.stride() - reconstructed = self.TENSOR_SUBCLASS.__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride) + reconstructed = self.TENSOR_SUBCLASS.__tensor_unflatten__( + tensor_data_dict, tensor_attributes, outer_size, outer_stride + ) self.assertEqual(lp_tensor.dequantize(), reconstructed.dequantize()) @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) def test_hp_tensor_device_dtype(self, device, dtype): hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) - lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + self.FACTORY_FN(hp_tensor, **self.kwargs) @common_utils.parametrize("device1", COMMON_DEVICES) @common_utils.parametrize("device2", COMMON_DEVICES) @@ -141,7 +147,10 @@ def test_linear(self, device, dtype): hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype) hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor) lp_res = torch.nn.functional.linear(hp_act_tensor, lp_tensor) - self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) + self.assertGreater( + torchao.quantization.utils.compute_error(hp_res, lp_res), + self.LINEAR_MIN_SQNR, + ) class TorchAOCompileTestCase(common_utils.TestCase): @@ -165,6 +174,7 @@ class TorchAOCompileTestCase(common_utils.TestCase): def test_input_output_tensor_subclass(self, device, dtype): hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + def f(tensor): return tensor @@ -179,6 +189,7 @@ def f(tensor): def test_input_tensor_subclass(self, device, dtype): hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + def f(tensor): return tensor.dequantize() @@ -192,6 +203,7 @@ def f(tensor): @common_utils.parametrize("dtype", COMMON_DTYPES) def test_output_tensor_subclass(self, device, dtype): hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) + def f(hp_tensor): return self.FACTORY_FN(hp_tensor, **self.kwargs) @@ -201,7 +213,12 @@ def f(hp_tensor): self.assertTrue(isinstance(f(hp_tensor), self.TENSOR_SUBCLASS)) # bfloat16 seems to result in much larger numerical differences if dtype != torch.bfloat16: - self.assertGreater(torchao.quantization.utils.compute_error(ref.dequantize(), compiled.dequantize()), self.COMPILE_MIN_SQNR) + self.assertGreater( + torchao.quantization.utils.compute_error( + ref.dequantize(), compiled.dequantize() + ), + self.COMPILE_MIN_SQNR, + ) @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) @@ -211,22 +228,18 @@ def test_linear_compile(self, device, dtype): hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype) hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor) - l = torch.nn.Linear(128, 4, bias=False, device=device, dtype=dtype) - l.weight = torch.nn.Parameter(lp_tensor) - lp_res = torch.compile(l)(hp_act_tensor) - self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) + linear = torch.nn.Linear(128, 4, bias=False, device=device, dtype=dtype) + linear.weight = torch.nn.Parameter(lp_tensor) + lp_res = torch.compile(linear)(hp_act_tensor) + self.assertGreater( + torchao.quantization.utils.compute_error(hp_res, lp_res), + self.LINEAR_MIN_SQNR, + ) -import torch.distributed as dist -from torch.distributed._tensor import DTensor, Replicate, Shard, DeviceMesh -from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, - with_comms, - NUM_DEVICES, -) class TorchAOTensorParallelTestCase(DTensorTestBase): - """Basic test case for tensor subclasses - """ + """Basic test case for tensor subclasses""" + COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] TENSOR_SUBCLASS = AffineQuantizedTensor @@ -247,9 +260,7 @@ def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: # Construct DTensor from local shard dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)]) # Replace parameter in module - m.linear.weight = torch.nn.Parameter( - dtensor, requires_grad=False - ) + m.linear.weight = torch.nn.Parameter(dtensor, requires_grad=False) return m @staticmethod @@ -266,9 +277,7 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: # Construct DTensor from local shard dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)]) # Replace parameter in module - m.linear.weight = torch.nn.Parameter( - dtensor, requires_grad=False - ) + m.linear.weight = torch.nn.Parameter(dtensor, requires_grad=False) return m def quantize(self, m: torch.nn.Module) -> torch.nn.Module: @@ -289,7 +298,9 @@ def test_tp(self, dtype): class M(torch.nn.Module): def __init__(self, in_features, out_features, **kwargs) -> None: super().__init__(**kwargs) - self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda") + self.linear = torch.nn.Linear( + in_features, out_features, bias=False, device="cuda" + ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) @@ -301,12 +312,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: proj_up = M(1024, 2048).to(device).to(dtype) proj_dn = M(2048, 1024).to(device).to(dtype) example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype) - y = proj_dn(proj_up(example_input)) + proj_dn(proj_up(example_input)) # Quantize the model up_quant = self.quantize(proj_up) dn_quant = self.quantize(proj_dn) - y_q = dn_quant(up_quant(example_input)) + dn_quant(up_quant(example_input)) mesh = self.build_device_mesh() mesh.device_type = "cuda" @@ -316,11 +327,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dn_dist = self.rowwise_shard(dn_quant, mesh) # We need to turn inputs into DTensor form as well -- just a format change - input_dtensor = DTensor.from_local( - example_input, mesh, [Replicate()] - ) + input_dtensor = DTensor.from_local(example_input, mesh, [Replicate()]) - y_d = dn_dist(up_dist(input_dtensor)) + dn_dist(up_dist(input_dtensor)) if not TORCH_VERSION_AT_LEAST_2_6: # Need torch 2.6 to support compiled tensor parallelism @@ -329,7 +338,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: up_compiled = torch.compile(up_dist) y_up = up_compiled(input_dtensor) dn_compiled = torch.compile(dn_dist) - y_dn = dn_compiled(y_up) + dn_compiled(y_up) + common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase) common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase) diff --git a/torchao/utils.py b/torchao/utils.py index e474824135..d56191ed6b 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -13,6 +13,7 @@ __all__ = [ "benchmark_model", "profiler_runner", + "get_available_devices", "get_compute_capability", "skip_if_compute_capability_less_than", "benchmark_torch_function_in_microseconds", @@ -31,6 +32,9 @@ "TORCH_VERSION_AFTER_2_3", "TORCH_VERSION_AFTER_2_4", "TORCH_VERSION_AFTER_2_5", + "is_MI300", + "is_sm_at_least_89", + "is_sm_at_least_90", ] @@ -121,6 +125,18 @@ def profiler_runner(path, fn, *args, **kwargs): return result +def get_available_devices(): + devices = ["cpu"] + if torch.cuda.is_available(): + devices.append("cuda") + elif torch.xpu.is_available(): + devices.append("xpu") + if TORCH_VERSION_AT_LEAST_2_5: + if torch.mps.is_available(): + devices.append("mps") + return devices + + def get_compute_capability(): if torch.cuda.is_available(): capability = torch.cuda.get_device_capability() @@ -586,6 +602,32 @@ def _torch_version_at_least(min_version): return is_fbcode() or version("torch") >= min_version +def is_MI300(): + if torch.cuda.is_available() and torch.version.hip: + mxArchName = ["gfx940", "gfx941", "gfx942"] + archName = torch.cuda.get_device_properties().gcnArchName + for arch in mxArchName: + if arch in archName: + return True + return False + + +def is_sm_at_least_89(): + return ( + torch.cuda.is_available() + and torch.version.cuda + and torch.cuda.get_device_capability() >= (8, 9) + ) + + +def is_sm_at_least_90(): + return ( + torch.cuda.is_available() + and torch.version.cuda + and torch.cuda.get_device_capability() >= (9, 0) + ) + + TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev") TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev") TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev") diff --git a/version.txt b/version.txt index faef31a435..a3df0a6959 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.7.0 +0.8.0