Skip to content

Commit

Permalink
Faster Metal unary and binary for general case (#1431)
Browse files Browse the repository at this point in the history
* faster unary and binary for general case

* update ternary + jit fix

* fix jit

* unary work per thread
  • Loading branch information
awni authored Sep 25, 2024
1 parent afc9c0e commit 4f9f9eb
Show file tree
Hide file tree
Showing 12 changed files with 183 additions and 93 deletions.
29 changes: 18 additions & 11 deletions mlx/backend/metal/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@

namespace mlx::core {

constexpr int MAX_BINARY_SPECIALIZED_DIMS = 3;

std::string get_kernel_name(
BinaryOpType bopt,
const std::string& op,
const array& a,
bool use_2d,
int ndim) {
int ndim,
int work_per_thread) {
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
Expand All @@ -43,14 +42,17 @@ std::string get_kernel_name(
break;
case BinaryOpType::General:
kname << "g";
if (ndim <= MAX_BINARY_SPECIALIZED_DIMS) {
if (ndim <= 3) {
kname << ndim;
} else {
kname << "n";
if (work_per_thread > 1) {
kname << work_per_thread;
}
}
break;
}
kname << op << type_to_name(a);
kname << "_" << op << type_to_name(a);
return kname.str();
}

Expand Down Expand Up @@ -85,7 +87,11 @@ void binary_op_gpu_inplace(
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();

bool use_2d = out.data_size() > UINT32_MAX;
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
auto ndim = shape.size();
int work_per_thread =
(bopt == BinaryOpType::General && shape[ndim - 1] > 4) ? 4 : 1;
std::string kernel_name =
get_kernel_name(bopt, op, a, use_2d, shape.size(), work_per_thread);
auto& d = metal::device(s.device);

auto kernel = outputs.size() == 2
Expand All @@ -110,14 +116,19 @@ void binary_op_gpu_inplace(
}

if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1);

if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), arg_idx++);
compute_encoder->setBytes(
strides_a.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(&ndim, sizeof(int), arg_idx++);
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(
Expand All @@ -126,10 +137,6 @@ void binary_op_gpu_inplace(
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
}

// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
Expand Down
34 changes: 22 additions & 12 deletions mlx/backend/metal/jit_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,19 @@ MTL::ComputePipelineState* get_unary_kernel(
const std::string& kernel_name,
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(1);
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
auto u_def = get_template_definition(
"v" + lib_name, "unary_v", get_type_string(out_type), op);
auto u2_def = get_template_definition(
"v2" + lib_name, "unary_v2", get_type_string(out_type), op);
auto g_def = get_template_definition(
"g" + lib_name, "unary_g", get_type_string(out_type), op);
kernel_source << metal::utils() << metal::unary_ops() << metal::unary()
<< u_def << u2_def << g_def;
kernel_source << metal::utils() << metal::unary_ops() << metal::unary();
kernel_source << get_template_definition(
"v_" + lib_name, "unary_v", get_type_string(out_type), op);
kernel_source << get_template_definition(
"v2_" + lib_name, "unary_v2", get_type_string(out_type), op);
kernel_source << get_template_definition(
"g_" + lib_name, "unary_g", get_type_string(out_type), op);
kernel_source << get_template_definition(
"gn4_" + lib_name, "unary_g", get_type_string(out_type), op, 4);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
Expand Down Expand Up @@ -81,13 +82,20 @@ void add_binary_kernels(
for (auto& [name, func] : kernel_types) {
std::string template_def;
template_def = get_template_definition(
name + lib_name,
name + "_" + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op);
kernel_source << template_def;
}
kernel_source << get_template_definition(
"gn4_" + lib_name,
"binary_g",
get_type_string(in_type),
get_type_string(out_type),
op,
4);
}

MTL::ComputePipelineState* get_binary_kernel(
Expand All @@ -96,7 +104,7 @@ MTL::ComputePipelineState* get_binary_kernel(
Dtype in_type,
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(2);
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
Expand All @@ -113,7 +121,7 @@ MTL::ComputePipelineState* get_binary_two_kernel(
Dtype in_type,
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(2);
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
Expand Down Expand Up @@ -149,6 +157,8 @@ MTL::ComputePipelineState* get_ternary_kernel(
name + "_" + lib_name, func, get_type_string(type), op);
kernel_source << template_def;
}
kernel_source << get_template_definition(
"gn4_" + lib_name, "ternary_g", get_type_string(type), op, 4);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
Expand Down
16 changes: 12 additions & 4 deletions mlx/backend/metal/kernels/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ template <typename T, typename U, typename Op>
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}

template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = 1>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
Expand All @@ -124,8 +124,16 @@ template <typename T, typename U, typename Op>
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
auto idx = elem_to_loc_2_nd(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1];
size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
c[out_idx++] = Op()(a[idx.x], b[idx.y]);
idx.x += a_xstride;
idx.y += b_xstride;
}
}
25 changes: 13 additions & 12 deletions mlx/backend/metal/kernels/binary.metal
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,19 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary.h"

#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn_" #op #tname, binary_g, itype, otype, op) \
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \

#define instantiate_binary_integer(op) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
Expand Down
20 changes: 14 additions & 6 deletions mlx/backend/metal/kernels/binary_two.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ template <typename T, typename U, typename Op>
d[out_idx] = out[1];
}

template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = 1>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
Expand All @@ -155,10 +155,18 @@ template <typename T, typename U, typename Op>
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
auto idx = elem_to_loc_2_nd(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1];
size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];
d[out_idx] = out[1];
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];
d[out_idx++] = out[1];
idx.x += a_xstride;
idx.y += b_xstride;
}
}
25 changes: 13 additions & 12 deletions mlx/backend/metal/kernels/binary_two.metal
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary_two.h"

#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn_" #op #tname, binary_g, itype, otype, op) \
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \

#define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \
Expand Down
24 changes: 19 additions & 5 deletions mlx/backend/metal/kernels/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ template <typename T, typename Op>
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}

template <typename T, typename Op>
template <typename T, typename Op, int N = 1>
[[kernel]] void ternary_g(
device const bool* a,
device const T* b,
Expand All @@ -88,9 +88,23 @@ template <typename T, typename Op>
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx =
elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
auto idx = elem_to_loc_3_nd(
{N * index.x, index.y, index.z},
shape,
a_strides,
b_strides,
c_strides,
ndim);
auto xshape = shape[ndim - 1];
size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
auto c_xstride = c_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
d[out_idx++] = Op()(a[idx.x], b[idx.y], c[idx.z]);
idx.x += a_xstride;
idx.y += b_xstride;
idx.z += c_xstride;
}
}
1 change: 1 addition & 0 deletions mlx/backend/metal/kernels/ternary.metal
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("g_" #op #tname, ternary_g, type, op) \
instantiate_kernel("gn4_" #op #tname, ternary_g, type, op, 4) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op) \
Expand Down
17 changes: 13 additions & 4 deletions mlx/backend/metal/kernels/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,23 @@ template <typename T, typename Op>
out[offset] = Op()(in[offset]);
}

template <typename T, typename Op>
template <typename T, typename Op, int N = 1>
[[kernel]] void unary_g(
device const T* in,
device T* out,
constant const int* in_shape,
constant const size_t* in_strides,
device const int& ndim,
uint index [[thread_position_in_grid]]) {
auto idx = elem_to_loc(index, in_shape, in_strides, ndim);
out[index] = Op()(in[idx]);
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx =
elem_to_loc({N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
auto xshape = in_shape[ndim - 1];
auto xstride = in_strides[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
out[out_idx++] = Op()(in[idx]);
idx += xstride;
}
}
9 changes: 5 additions & 4 deletions mlx/backend/metal/kernels/unary.metal
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
#include "mlx/backend/metal/kernels/unary_ops.h"
#include "mlx/backend/metal/kernels/unary.h"

#define instantiate_unary_all(op, tname, type) \
instantiate_kernel("v" #op #tname, unary_v, type, op) \
instantiate_kernel("v2" #op #tname, unary_v2, type, op) \
instantiate_kernel("g" #op #tname, unary_g, type, op)
#define instantiate_unary_all(op, tname, type) \
instantiate_kernel("v_" #op #tname, unary_v, type, op) \
instantiate_kernel("v2_" #op #tname, unary_v2, type, op) \
instantiate_kernel("gn4_" #op #tname, unary_g, type, op, 4) \
instantiate_kernel("g_" #op #tname, unary_g, type, op)

#define instantiate_unary_float(op) \
instantiate_unary_all(op, float16, half) \
Expand Down
Loading

0 comments on commit 4f9f9eb

Please sign in to comment.