Skip to content

Commit

Permalink
cmake torchao_ops_mps_linear_fp_act_xbit_weight (#1304)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1304

Move from setup.py to cmake for building custom torchao mps ops

Differential Revision: D66120124
  • Loading branch information
manuelcandales authored and facebook-github-bot committed Nov 19, 2024
1 parent b714026 commit a3bb86d
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 45 deletions.
20 changes: 14 additions & 6 deletions torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional
import os
import sys
import yaml

torchao_root: Optional[str] = os.getenv("TORCHAO_ROOT")
assert torchao_root is not None, "TORCHAO_ROOT is not set"
if len(sys.argv) != 2:
print("Usage: gen_metal_shader_lib.py <output_file>")
sys.exit(1)

# Output file where the generated code will be written
OUTPUT_FILE = sys.argv[1]

MPS_DIR = os.path.join(torchao_root, "torchao", "experimental", "kernels", "mps")
MPS_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))

# Path to yaml file containing the list of .metal files to include
METAL_YAML = os.path.join(MPS_DIR, "metal.yaml")
Expand All @@ -21,9 +32,6 @@
# Path to the folder containing the .metal files
METAL_DIR = os.path.join(MPS_DIR, "metal")

# Output file where the generated code will be written
OUTPUT_FILE = os.path.join(MPS_DIR, "src", "metal_shader_lib.h")

prefix = """/**
* This file is generated by gen_metal_shader_lib.py
*/
Expand Down
50 changes: 50 additions & 0 deletions torchao/experimental/ops/mps/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

cmake_minimum_required(VERSION 3.19)

project(torchao_ops_mps_linear_fp_act_xbit_weight)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED YES)

if (NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif()

if (CMAKE_SYSTEM_NAME STREQUAL "Darwin")
if (NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
message(FATAL_ERROR "Unified Memory requires Apple Silicon arquitecture")
endif()
else()
message(FATAL_ERROR "Torchao experimental mps ops can only be built on macOS/iOS")
endif()

find_package(Torch REQUIRED)

if(NOT TORCHAO_INCLUDE_DIRS)
set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../..)
endif()
message(STATUS "TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")

include_directories(${TORCHAO_INCLUDE_DIRS})
include_directories(${CMAKE_INSTALL_PREFIX}/include)
add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten SHARED aten/register.mm)

target_include_directories(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_INCLUDE_DIRS}")
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_LIBRARIES}")
target_compile_definitions(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE USE_ATEN=1)

# Enable Metal support
find_library(METAL_LIB Metal)
find_library(FOUNDATION_LIB Foundation)
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE ${METAL_LIB} ${FOUNDATION_LIB})

install(
TARGETS torchao_ops_mps_linear_fp_act_xbit_weight_aten
EXPORT _targets
DESTINATION lib
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// LICENSE file in the root directory of this source tree.

// clang-format off
#include <torch/extension.h>
#include <torch/library.h>
#include <ATen/native/mps/OperationUtils.h>
#include <torchao/experimental/kernels/mps/src/lowbit.h>
// clang-format on
Expand Down Expand Up @@ -147,9 +147,6 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) {
return B;
}

// Registers _C as a Python extension module.
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}

TORCH_LIBRARY(torchao, m) {
m.def("_pack_weight_1bit(Tensor W) -> Tensor");
m.def("_pack_weight_2bit(Tensor W) -> Tensor");
Expand Down
25 changes: 25 additions & 0 deletions torchao/experimental/ops/mps/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash -eu
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

cd "$(dirname "$BASH_SOURCE")"

export CMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')
echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}"
export CMAKE_OUT=$(python -c "import sys; print(sys.prefix)")/torchao_mps/cmake-out
echo "CMAKE_OUT: ${CMAKE_OUT}"

export INCLUDE_PATH=${CMAKE_OUT}/include
mkdir -p ${INCLUDE_PATH}/torchao/experimental/kernels/mps/src/
export GENERATED_METAL_SHADER_LIB=${INCLUDE_PATH}/torchao/experimental/kernels/mps/src/metal_shader_lib.h
python ../../kernels/mps/codegen/gen_metal_shader_lib.py ${GENERATED_METAL_SHADER_LIB}
echo "GENERATED_METAL_SHADER_LIB: ${GENERATED_METAL_SHADER_LIB}"

cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \
-DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \
-S . \
-B ${CMAKE_OUT}
cmake --build ${CMAKE_OUT} -j 16 --target install --config Release
23 changes: 0 additions & 23 deletions torchao/experimental/ops/mps/setup.py

This file was deleted.

28 changes: 17 additions & 11 deletions torchao/experimental/ops/mps/test/test_lowbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,31 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import sys
import torch
import torchao_mps_ops
import unittest

from parameterized import parameterized

def parameterized(test_cases):
def decorator(func):
def wrapper(self):
for case in test_cases:
with self.subTest(case=case):
func(self, *case)
libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib"
libpath = os.path.join(sys.prefix, "torchao_mps/cmake-out/lib/", libname)

return wrapper
try:
torch.ops.load_library(libpath)
except:
print(f"Failed to load library {libpath}")
raise

return decorator
for nbit in range(1, 8):
op = getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
assert op is not None
op = getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
assert op is not None


class TestLowBitQuantWeightsLinear(unittest.TestCase):
cases = [
CASES = [
(nbit, *param)
for nbit in range(1, 8)
for param in [
Expand Down Expand Up @@ -73,7 +79,7 @@ def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z, nbit):
W = scales * W + zeros
return torch.mm(A, W.t())

@parameterized(cases)
@parameterized.expand(CASES)
def test_linear(self, nbit, M=1, K=32, N=32, group_size=32):
print(f"nbit: {nbit}, M: {M}, K: {K}, N: {N}, group_size: {group_size}")
A, W, S, Z = self._init_tensors(group_size, M, K, N, nbit=nbit)
Expand Down
16 changes: 15 additions & 1 deletion torchao/experimental/ops/mps/test/test_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,27 @@
import sys

import torch
import torchao_mps_ops
import unittest

from parameterized import parameterized
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer
from torchao.experimental.quant_api import _quantize

libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib"
libpath = os.path.join(sys.prefix, "torchao_mps/cmake-out/lib/", libname)

try:
torch.ops.load_library(libpath)
except:
print(f"Failed to load library {libpath}")
raise

for nbit in range(1, 8):
op = getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
assert op is not None
op = getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
assert op is not None


class TestUIntxWeightOnlyLinearQuantizer(unittest.TestCase):
BITWIDTHS = range(1, 8)
Expand Down

0 comments on commit a3bb86d

Please sign in to comment.