Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update to pick up mlx v0.16.0 #115

Merged
merged 1 commit into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,9 @@ endif()
FetchContent_Declare(
mlx-c
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
GIT_TAG "v0.0.8")
GIT_TAG "v0.0.9")
FetchContent_MakeAvailable(mlx-c)

# TEMPORARY OVERRIDE -- 0.0.8 depends on v0.14.0 but we need v0.15.2 for iOS /
# float16 issues
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may have to add this back as mlx-c is for v0.15.2 but we need to build vs v0.16.0 for the Sequoia assert crashes

FetchContent_Declare(
mlx
GIT_REPOSITORY "https://github.com/ml-explore/mlx.git"
GIT_TAG v0.15.2)
FetchContent_MakeAvailable(mlx)

# swift-numerics
set(swift_numerics_patch git apply
${CMAKE_CURRENT_SOURCE_DIR}/cmake/swift-numerics.patch)
Expand Down
4 changes: 4 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ let package = Package(
"mlx/mlx/distributed/mpi",
"mlx/mlx/distributed/ops.cpp",
"mlx/mlx/distributed/primitives.cpp",

// the mlx-c side of distributed
"include/mlx/c/distributed.cpp",
"include/mlx/c/distributed_group.cpp",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now exposed on the mlx-c layer -- need to discuss how we expose this in swift and specifically how the build will work.

],

cSettings: [
Expand Down
2 changes: 1 addition & 1 deletion Source/Cmlx/mlx
128 changes: 128 additions & 0 deletions Source/Cmlx/mlx-generated/hadamard.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
namespace mlx::core::metal {

const char* hadamard() {
return R"preamble(

using namespace metal;
template <short R>
METAL_FUNC void radix_func(thread float* x) {
constexpr short logR = __builtin_ctz(R);
short h = 1;
#pragma clang loop unroll(full)
for (short s = 0; s < logR; s++) {
#pragma clang loop unroll(full)
for (short i = 0; i < R / 2; i++) {
short k = i & (h - 1);
short j = ((i - k) << 1) + k;
float a = x[j];
float b = x[j + h];
x[j] = a + b;
x[j + h] = a - b;
}
h <<= 1;
}
}
template <typename T, int N, int max_radix, int read_width>
[[kernel]] void hadamard_n(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const float& scale,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
constexpr short num_threads = N / max_radix;
constexpr short logN = __builtin_ctz(N);
constexpr short logR = __builtin_ctz(max_radix);
constexpr short num_steps = logN / logR;
constexpr short logFinal = logN % logR;
constexpr short final_radix = 1 << (logFinal);
int batch_idx = elem.x * N;
short i = elem.y;
threadgroup T buf[N];
#pragma clang loop unroll(full)
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
#pragma clang loop unroll(full)
for (short r = 0; r < read_width; r++) {
buf[index + r] = in[batch_idx + index + r];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float x[max_radix];
short h = 1;
#pragma clang loop unroll(full)
for (short s = 0; s < num_steps; s++) {
short k = i & (h - 1);
short j = ((i - k) << logR) + k;
#pragma clang loop unroll(full)
for (short r = 0; r < max_radix; r++) {
x[r] = buf[j + h * r];
}
radix_func<max_radix>(x);
#pragma clang loop unroll(full)
for (short r = 0; r < max_radix; r++) {
buf[j + h * r] = x[r];
}
h <<= logR;
threadgroup_barrier(mem_flags::mem_threadgroup);
}
if (final_radix > 1) {
#pragma clang loop unroll(full)
for (int t = 0; t < max_radix / final_radix; t++) {
short index = i + t * num_threads;
short k = index & (h - 1);
short j = ((index - k) << logFinal) + k;
#pragma clang loop unroll(full)
for (short r = 0; r < final_radix; r++) {
x[r] = buf[j + h * r];
}
radix_func<final_radix>(x);
#pragma clang loop unroll(full)
for (short r = 0; r < final_radix; r++) {
buf[j + h * r] = x[r];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
#pragma clang loop unroll(full)
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
#pragma clang loop unroll(full)
for (short r = 0; r < read_width; r++) {
out[batch_idx + index + r] = buf[index + r] * scale;
}
}
}
template <typename T, int N, int M, int read_width>
[[kernel]] void hadamard_m(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const float& scale,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
int index = elem.x * grid.y + elem.y;
short i = index % (N / read_width);
int batch_idx = index / (N / read_width) * M * N;
float x[read_width][M];
#pragma clang loop unroll(full)
for (short c = 0; c < M; c++) {
#pragma clang loop unroll(full)
for (short r = 0; r < read_width; r++) {
x[r][c] = in[batch_idx + c * N + i * read_width + r];
}
}
#pragma clang loop unroll(full)
for (short r = 0; r < read_width; r++) {
hadamard_radix_m(x[r]);
}
#pragma clang loop unroll(full)
for (short c = 0; c < M; c++) {
#pragma clang loop unroll(full)
for (short r = 0; r < read_width; r++) {
out[batch_idx + c * N + i * read_width + r] = x[r][c] * scale;
}
}
}
)preamble";
}

} // namespace mlx::core::metal
1 change: 1 addition & 0 deletions Source/MLX/Documentation.docc/free-functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,4 @@ operations as methods for convenience.

- ``diag(_:k:stream:)``
- ``diagonal(_:offset:axis1:axis2:stream:)``
- ``view(_:dtype:stream:)``
16 changes: 16 additions & 0 deletions Source/MLX/MLXArray+Ops.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2658,4 +2658,20 @@ extension MLXArray {
MLXArray(mlx_var_all(ctx, keepDims, ddof.int32, stream.ctx))
}

/// View the array as a different type.
///
/// The output array will change along the last axis if the input array's
/// type and the output array's type do not have the same size.
///
/// Note: the view op does not imply that the input and output arrays share
/// their underlying data. The view only gaurantees that the binary
/// representation of each element (or group of elements) is the same.
///
/// - Parameters:
/// - dtype: type to change to
/// - stream: stream or device to evaluate on
/// - Returns: array with the new type
public func view(dtype: DType, stream: StreamOrDevice = .default) -> MLXArray {
MLXArray(mlx_view(ctx, dtype.cmlxDtype, stream.ctx))
}
}
19 changes: 19 additions & 0 deletions Source/MLX/Ops+Array.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1715,3 +1715,22 @@ public func variance(
) -> MLXArray {
MLXArray(mlx_var_all(array.ctx, keepDims, ddof.int32, stream.ctx))
}

/// View the array as a different type.
///
/// The output array will change along the last axis if the input array's
/// type and the output array's type do not have the same size.
///
/// Note: the view op does not imply that the input and output arrays share
/// their underlying data. The view only gaurantees that the binary
/// representation of each element (or group of elements) is the same.
///
/// - Parameters:
/// - dtype: type to change to
/// - stream: stream or device to evaluate on
///
/// ### See Also
///- ``MLXArray/view(dtype:stream:)``
public func view(_ array: MLXArray, dtype: DType, stream: StreamOrDevice = .default) -> MLXArray {
MLXArray(mlx_view(array.ctx, dtype.cmlxDtype, stream.ctx))
}
4 changes: 3 additions & 1 deletion tools/update-mlx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ cmake ../Source/Cmlx/mlx -DMLX_METAL_JIT=ON -DMACOS_VERSION=14.0

# NOTE:
# until mlx supports overriding the METAL_VERSION you will need to edit
# Source/Cmlx/mlx/mlx/backend/metal/CMakeLists.txt and manually set the METAL_VERSION.
# Source/Cmlx/mlx/mlx/backend/metal/CMakeLists.txt and manually set the METAL_VERSION
# to "3.0"
#
# Also Plugins/PrepareMetalShaders/main.swift kernels needs to be in sync.

Expand All @@ -34,6 +35,7 @@ make \
fft \
gather \
gemm \
hadamard \
quantized \
reduce \
reduce_utils \
Expand Down