diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index 70afd79c9..248fb526c 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -52,7 +52,7 @@ std::string get_kernel_name( } break; } - kname << op << type_to_name(a); + kname << "_" << op << type_to_name(a); return kname.str(); } diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index b579ad71d..32daed2fd 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -81,7 +81,7 @@ void add_binary_kernels( for (auto& [name, func] : kernel_types) { std::string template_def; template_def = get_template_definition( - name + lib_name, + name + "_" + lib_name, func, get_type_string(in_type), get_type_string(out_type), @@ -89,7 +89,7 @@ void add_binary_kernels( kernel_source << template_def; } kernel_source << get_template_definition( - "gn4" + lib_name, + "gn4_" + lib_name, "binary_g", get_type_string(in_type), get_type_string(out_type), @@ -103,7 +103,7 @@ MTL::ComputePipelineState* get_binary_kernel( Dtype in_type, Dtype out_type, const std::string op) { - std::string lib_name = kernel_name.substr(2); + std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name); if (lib == nullptr) { std::ostringstream kernel_source; @@ -120,7 +120,7 @@ MTL::ComputePipelineState* get_binary_two_kernel( Dtype in_type, Dtype out_type, const std::string op) { - std::string lib_name = kernel_name.substr(2); + std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); auto lib = d.get_library(lib_name); if (lib == nullptr) { std::ostringstream kernel_source; diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index 19f4a807f..5c437bd2a 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -9,19 +9,19 @@ #include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary.h" -#define instantiate_binary_all(op, tname, itype, otype) \ - instantiate_kernel("ss" #op #tname, binary_ss, itype, otype, op) \ - instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \ - instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \ - instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \ - instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \ - instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \ - instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \ - instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \ - instantiate_kernel("gn4" #op #tname, binary_g, itype, otype, op, 4) \ - instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \ - instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \ - instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \ +#define instantiate_binary_all(op, tname, itype, otype) \ + instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ + instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ + instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ + instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ + instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ + instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ + instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ + instantiate_kernel("gn_" #op #tname, binary_g, itype, otype, op) \ + instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \ + instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \ + instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \ + instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \ #define instantiate_binary_integer(op) \ instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ diff --git a/mlx/backend/metal/kernels/binary_two.metal b/mlx/backend/metal/kernels/binary_two.metal index 114713496..f062439ec 100644 --- a/mlx/backend/metal/kernels/binary_two.metal +++ b/mlx/backend/metal/kernels/binary_two.metal @@ -7,19 +7,19 @@ #include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary_two.h" -#define instantiate_binary_all(op, tname, itype, otype) \ - instantiate_kernel("ss" #op #tname, binary_ss, itype, otype, op) \ - instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \ - instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \ - instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \ - instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \ - instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \ - instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \ - instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \ - instantiate_kernel("gn4" #op #tname, binary_g, itype, otype, op, 4) \ - instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \ - instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \ - instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \ +#define instantiate_binary_all(op, tname, itype, otype) \ + instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ + instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ + instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ + instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ + instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ + instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ + instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ + instantiate_kernel("gn_" #op #tname, binary_g, itype, otype, op) \ + instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \ + instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \ + instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \ + instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \ #define instantiate_binary_float(op) \ instantiate_binary_all(op, float16, half, half) \