diff --git a/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py b/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py index 56dcc9b73..7176be5bc 100644 --- a/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py +++ b/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py @@ -1,11 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + from typing import Optional import os +import sys import yaml -torchao_root: Optional[str] = os.getenv("TORCHAO_ROOT") -assert torchao_root is not None, "TORCHAO_ROOT is not set" +if len(sys.argv) != 2: + print("Usage: gen_metal_shader_lib.py ") + sys.exit(1) + +# Output file where the generated code will be written +OUTPUT_FILE = sys.argv[1] -MPS_DIR = os.path.join(torchao_root, "torchao", "experimental", "kernels", "mps") +MPS_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) # Path to yaml file containing the list of .metal files to include METAL_YAML = os.path.join(MPS_DIR, "metal.yaml") @@ -21,9 +32,6 @@ # Path to the folder containing the .metal files METAL_DIR = os.path.join(MPS_DIR, "metal") -# Output file where the generated code will be written -OUTPUT_FILE = os.path.join(MPS_DIR, "src", "metal_shader_lib.h") - prefix = """/** * This file is generated by gen_metal_shader_lib.py */ diff --git a/torchao/experimental/ops/mps/CMakeLists.txt b/torchao/experimental/ops/mps/CMakeLists.txt new file mode 100644 index 000000000..2e2576dbe --- /dev/null +++ b/torchao/experimental/ops/mps/CMakeLists.txt @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) + +project(torchao_ops_mps_linear_fp_act_xbit_weight) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED YES) + +if (NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + +if (CMAKE_SYSTEM_NAME STREQUAL "Darwin") + if (NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + message(FATAL_ERROR "Unified Memory requires Apple Silicon arquitecture") + endif() +else() + message(FATAL_ERROR "Torchao experimental mps ops can only be built on macOS/iOS") +endif() + +find_package(Torch REQUIRED) + +if(NOT TORCHAO_INCLUDE_DIRS) + set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) +endif() +message(STATUS "TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}") + +include_directories(${TORCHAO_INCLUDE_DIRS}) +include_directories(${CMAKE_INSTALL_PREFIX}/include) +add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten SHARED aten/register.mm) + +target_include_directories(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_INCLUDE_DIRS}") +target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_LIBRARIES}") +target_compile_definitions(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE USE_ATEN=1) + +# Enable Metal support +find_library(METAL_LIB Metal) +find_library(FOUNDATION_LIB Foundation) +target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE ${METAL_LIB} ${FOUNDATION_LIB}) + +install( + TARGETS torchao_ops_mps_linear_fp_act_xbit_weight_aten + EXPORT _targets + DESTINATION lib +) diff --git a/torchao/experimental/ops/mps/register.mm b/torchao/experimental/ops/mps/aten/register.mm similarity index 98% rename from torchao/experimental/ops/mps/register.mm rename to torchao/experimental/ops/mps/aten/register.mm index 44946a30f..92a3ba89f 100644 --- a/torchao/experimental/ops/mps/register.mm +++ b/torchao/experimental/ops/mps/aten/register.mm @@ -5,7 +5,7 @@ // LICENSE file in the root directory of this source tree. // clang-format off -#include +#include #include #include // clang-format on @@ -147,9 +147,6 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) { return B; } -// Registers _C as a Python extension module. -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} - TORCH_LIBRARY(torchao, m) { m.def("_pack_weight_1bit(Tensor W) -> Tensor"); m.def("_pack_weight_2bit(Tensor W) -> Tensor"); diff --git a/torchao/experimental/ops/mps/build.sh b/torchao/experimental/ops/mps/build.sh new file mode 100644 index 000000000..124e58e64 --- /dev/null +++ b/torchao/experimental/ops/mps/build.sh @@ -0,0 +1,25 @@ +#!/bin/bash -eu +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +cd "$(dirname "$BASH_SOURCE")" + +export CMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())') +echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" +export CMAKE_OUT=$(python -c "import sys; print(sys.prefix)")/torchao_mps/cmake-out +echo "CMAKE_OUT: ${CMAKE_OUT}" + +export INCLUDE_PATH=${CMAKE_OUT}/include +mkdir -p ${INCLUDE_PATH}/torchao/experimental/kernels/mps/src/ +export GENERATED_METAL_SHADER_LIB=${INCLUDE_PATH}/torchao/experimental/kernels/mps/src/metal_shader_lib.h +python ../../kernels/mps/codegen/gen_metal_shader_lib.py ${GENERATED_METAL_SHADER_LIB} +echo "GENERATED_METAL_SHADER_LIB: ${GENERATED_METAL_SHADER_LIB}" + +cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ + -DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \ + -S . \ + -B ${CMAKE_OUT} +cmake --build ${CMAKE_OUT} -j 16 --target install --config Release diff --git a/torchao/experimental/ops/mps/setup.py b/torchao/experimental/ops/mps/setup.py deleted file mode 100644 index 1205d43d4..000000000 --- a/torchao/experimental/ops/mps/setup.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import os -from setuptools import setup -from torch.utils.cpp_extension import CppExtension, BuildExtension - -setup( - name="torchao_mps_ops", - version="1.0", - ext_modules=[ - CppExtension( - name="torchao_mps_ops", - sources=["register.mm"], - include_dirs=[os.getenv("TORCHAO_ROOT")], - extra_compile_args=["-DUSE_ATEN=1"], - ), - ], - cmdclass={"build_ext": BuildExtension}, -) diff --git a/torchao/experimental/ops/mps/test/test_lowbit.py b/torchao/experimental/ops/mps/test/test_lowbit.py index 797c5dac2..61eb41aa0 100644 --- a/torchao/experimental/ops/mps/test/test_lowbit.py +++ b/torchao/experimental/ops/mps/test/test_lowbit.py @@ -4,25 +4,31 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import os +import sys import torch -import torchao_mps_ops import unittest +from parameterized import parameterized -def parameterized(test_cases): - def decorator(func): - def wrapper(self): - for case in test_cases: - with self.subTest(case=case): - func(self, *case) +libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib" +libpath = os.path.join(sys.prefix, "torchao_mps/cmake-out/lib/", libname) - return wrapper +try: + torch.ops.load_library(libpath) +except: + print(f"Failed to load library {libpath}") + raise - return decorator +for nbit in range(1, 8): + op = getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight") + assert op is not None + op = getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit") + assert op is not None class TestLowBitQuantWeightsLinear(unittest.TestCase): - cases = [ + CASES = [ (nbit, *param) for nbit in range(1, 8) for param in [ @@ -73,7 +79,7 @@ def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z, nbit): W = scales * W + zeros return torch.mm(A, W.t()) - @parameterized(cases) + @parameterized.expand(CASES) def test_linear(self, nbit, M=1, K=32, N=32, group_size=32): print(f"nbit: {nbit}, M: {M}, K: {K}, N: {N}, group_size: {group_size}") A, W, S, Z = self._init_tensors(group_size, M, K, N, nbit=nbit) diff --git a/torchao/experimental/ops/mps/test/test_quantizer.py b/torchao/experimental/ops/mps/test/test_quantizer.py index 87f67c545..92cc90e8e 100644 --- a/torchao/experimental/ops/mps/test/test_quantizer.py +++ b/torchao/experimental/ops/mps/test/test_quantizer.py @@ -11,13 +11,27 @@ import sys import torch -import torchao_mps_ops import unittest from parameterized import parameterized from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer from torchao.experimental.quant_api import _quantize +libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib" +libpath = os.path.join(sys.prefix, "torchao_mps/cmake-out/lib/", libname) + +try: + torch.ops.load_library(libpath) +except: + print(f"Failed to load library {libpath}") + raise + +for nbit in range(1, 8): + op = getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight") + assert op is not None + op = getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit") + assert op is not None + class TestUIntxWeightOnlyLinearQuantizer(unittest.TestCase): BITWIDTHS = range(1, 8)