diff --git a/Cargo.toml b/Cargo.toml index 0fea042329..ba09b1d4cc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ members = [ exclude = [ "candle-flash-attn", "candle-kernels", + "candle-metal-kernels", "candle-onnx", ] resolver = "2" @@ -60,7 +61,7 @@ tracing-subscriber = "0.3.7" wav = "1.0.0" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "0.6.6", default-features = false } -metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +metal = { version = "0.27.1", features = ["mps"], package="candle-metal" } [profile.release-with-debug] inherits = "release" diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 5d5e70a3bd..e7d3ab6a59 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -13,6 +13,7 @@ readme = "README.md" accelerate-src = { workspace = true, optional = true } byteorder = { workspace = true } candle-kernels = { path = "../candle-kernels", version = "0.3.1", optional = true } +candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.1", optional = true } metal = { workspace = true, optional = true} cudarc = { workspace = true, optional = true } gemm = { workspace = true } @@ -40,4 +41,4 @@ cuda = ["cudarc", "dep:candle-kernels"] cudnn = ["cuda", "cudarc/cudnn"] mkl = ["dep:libc", "dep:intel-mkl-src"] accelerate = ["dep:libc", "dep:accelerate-src"] -metal = ["dep:metal"] +metal = ["dep:metal", "dep:candle-metal-kernels"] diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index de57c03ac6..3eb7f8b7fa 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -8,7 +8,7 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType}; pub enum DeviceLocation { Cpu, Cuda { gpu_id: usize }, - Metal, + Metal { gpu_id: usize }, } #[derive(Debug, Clone)] @@ -146,6 +146,7 @@ impl Device { match (self, rhs) { (Self::Cpu, Self::Cpu) => true, (Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs), + (Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs), _ => false, } } diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index 215c28f64f..4f5a390e80 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -14,7 +14,9 @@ impl Tensor { crate::DeviceLocation::Cuda { gpu_id } => { format!(", cuda:{}", gpu_id) } - _ => todo!(), + crate::DeviceLocation::Metal { gpu_id } => { + format!(", metal:{}", gpu_id) + } }; write!(f, "Tensor[")?; @@ -477,7 +479,9 @@ impl std::fmt::Display for Tensor { crate::DeviceLocation::Cuda { gpu_id } => { format!(", cuda:{}", gpu_id) } - crate::DeviceLocation::Metal => todo!(), + crate::DeviceLocation::Metal { gpu_id } => { + format!(", metal:{}", gpu_id) + } }; write!( diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index da61bdb574..36f5f6b17a 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -53,6 +53,8 @@ mod dummy_metal_backend; pub mod error; mod indexer; pub mod layout; +#[cfg(feature = "metal")] +pub mod metal_backend; #[cfg(feature = "mkl")] mod mkl; pub mod npy; diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs new file mode 100644 index 0000000000..52cde1b709 --- /dev/null +++ b/candle-core/src/metal_backend.rs @@ -0,0 +1,827 @@ +use crate::backend::{BackendDevice, BackendStorage}; +use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D}; +use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; +use crate::{CpuStorage, DType, Layout, Result, Shape}; +use candle_metal_kernels; +use candle_metal_kernels::Kernels; +use core::mem; +use half::{bf16, f16}; +use metal; +use metal::{Buffer, CommandQueue, MTLResourceOptions, NSUInteger}; +use std::sync::Arc; + +/// Metal related errors +#[derive(thiserror::Error, Debug)] +pub enum MetalError { + #[error("{0}")] + Message(String), + #[error(transparent)] + KernelError(#[from] candle_metal_kernels::MetalKernelError), + + #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] + MatMulNonContiguous { + lhs_stride: Vec, + rhs_stride: Vec, + mnk: (usize, usize, usize), + }, +} + +impl From for MetalError { + fn from(e: String) -> Self { + MetalError::Message(e) + } +} + +#[derive(Clone)] +pub struct MetalDevice { + device: metal::Device, + command_queue: metal::CommandQueue, + kernels: Arc, +} + +impl std::fmt::Debug for MetalDevice { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "MetalDevice({:?})", self.device.registry_id()) + } +} + +impl std::ops::Deref for MetalDevice { + type Target = metal::DeviceRef; + + fn deref(&self) -> &Self::Target { + &self.device + } +} + +impl MetalDevice { + pub fn id(&self) -> NSUInteger { + self.registry_id() + } + + pub fn command_queue(&self) -> &CommandQueue { + &self.command_queue + } + + pub fn kernels(&self) -> &Kernels { + &self.kernels + } + + pub fn device(&self) -> &metal::Device { + &self.device + } + + pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { + let size = (element_count * dtype.size_in_bytes()) as NSUInteger; + self.device + .new_buffer(size, MTLResourceOptions::StorageModeManaged) + } +} + +#[derive(Debug, Clone)] +pub struct MetalStorage { + buffer: metal::Buffer, + device: MetalDevice, + dtype: DType, +} + +impl BackendStorage for MetalStorage { + type Device = MetalDevice; + + fn try_clone(&self, _: &Layout) -> Result { + Ok(self.clone()) + } + + fn dtype(&self) -> DType { + self.dtype + } + + fn device(&self) -> &Self::Device { + &self.device + } + + fn to_cpu_storage(&self) -> Result { + let length = self.buffer.length() as usize; + let size = self.dtype.size_in_bytes(); + if length % size != 0 { + crate::bail!( + "The Metal buffer length is not aligned with dtype {:?}", + self.dtype + ); + } + match self.dtype { + DType::U8 => Ok(CpuStorage::U8(self.buffer.read_to_vec(length / size))), + DType::U32 => Ok(CpuStorage::U32(self.buffer.read_to_vec(length / size))), + DType::I64 => Ok(CpuStorage::I64(self.buffer.read_to_vec(length / size))), + DType::F16 => Ok(CpuStorage::F16(self.buffer.read_to_vec(length / size))), + DType::BF16 => Ok(CpuStorage::BF16(self.buffer.read_to_vec(length / size))), + DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(length / size))), + DType::F64 => Ok(CpuStorage::F64(self.buffer.read_to_vec(length / size))), + } + } + + fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { + let device = self.device().clone(); + + let shape = layout.shape(); + let el = shape.elem_count(); + let dtype = self.dtype; + + if layout.is_contiguous() || layout.start_offset() != 0 || dtype != DType::F32 { + crate::bail!("Not contiguous, non-f32 affine is not implemented yet."); + } + + let mut buffer = device.new_buffer(el, self.dtype); + let command_buffer = self.device.command_queue.new_command_buffer(); + candle_metal_kernels::call_affine( + &device.device, + &command_buffer, + &device.kernels, + el, + &self.buffer, + &mut buffer, + mul as f32, + add as f32, + ) + .map_err(MetalError::from)?; + command_buffer.commit(); + command_buffer.wait_until_completed(); + return Ok(Self { + buffer, + device: device.clone(), + dtype, + }); + } + + fn powf(&self, _: &Layout, _: f64) -> Result { + crate::bail!("powf metal") + } + + fn elu(&self, _: &Layout, _: f64) -> Result { + crate::bail!("elu metal") + } + + fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { + if !(sum_dims.len() == 1 + && sum_dims[0] == layout.shape().rank() - 1 + && layout.is_contiguous() + && layout.start_offset() == 0) + { + crate::bail!("Non contiguous reduce op not supported yet"); + } + let device = self.device.clone(); + let src_stride = layout.stride(); + let src_dims = layout.shape().dims(); + let src_el: usize = src_dims.iter().product(); + // Source dims and strides with the sum dims at the end. + let mut dims = vec![]; + let mut stride = vec![]; + let mut dst_el: usize = 1; + for (dim_idx, &d) in src_dims.iter().enumerate() { + if !sum_dims.contains(&dim_idx) { + dst_el *= d; + dims.push(d); + stride.push(src_stride[dim_idx]); + } + } + for &dim_idx in sum_dims.iter() { + dims.push(src_dims[dim_idx]); + stride.push(src_stride[dim_idx]); + } + + // The reduction loop requires the shared array to be properly initialized and for + // this we want the number of threads to be a power of two. + let (name, check_empty, return_index) = match (op, self.dtype) { + (ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false), + (ReduceOp::Min, DType::F32) => ("fast_min_float", true, false), + (ReduceOp::Max, DType::F32) => ("fast_max_float", true, false), + (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true), + (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true), + _ => crate::bail!("Reduce op for non float"), + }; + if check_empty && layout.shape().elem_count() == 0 { + Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? + } + let dtype = if return_index { DType::U32 } else { self.dtype }; + let mut buffer = device.new_buffer(dst_el, dtype); + let command_buffer = self.device.command_queue.new_command_buffer(); + candle_metal_kernels::call_reduce_contiguous( + &device.device, + &command_buffer, + &device.kernels, + name, + src_el, + dst_el, + &self.buffer, + &mut buffer, + ) + .map_err(MetalError::from)?; + command_buffer.commit(); + command_buffer.wait_until_completed(); + + Ok(Self { + buffer, + device, + dtype, + }) + } + + fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result { + crate::bail!("cmp metal") + } + + fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { + let device = self.device(); + let shape = layout.shape(); + let el_count = shape.elem_count(); + let mut buffer = device.new_buffer(el_count, dtype); + let command_buffer = device.command_queue.new_command_buffer(); + if layout.is_contiguous() { + let kernel_name = match (self.dtype, dtype) { + (DType::U32, DType::F32) => "cast_u32_f32", + (left, right) => crate::bail!("to dtype {left:?} - {right:?}"), + }; + candle_metal_kernels::call_cast_contiguous( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + &self.buffer, + &mut buffer, + ) + .map_err(MetalError::from)?; + } else { + crate::bail!( + "TODO Implement the kernel calling cast {:?}-{:?}", + self.dtype, + dtype + ); + } + + command_buffer.commit(); + command_buffer.wait_until_completed(); + Ok(Self { + buffer, + device: device.clone(), + dtype, + }) + } + + fn unary_impl(&self, layout: &Layout) -> Result { + let device = self.device(); + let dtype = self.dtype; + let shape = layout.shape(); + let el_count = shape.elem_count(); + let mut buffer = device.new_buffer(el_count, dtype); + let command_buffer = device.command_queue.new_command_buffer(); + if layout.is_contiguous() && layout.start_offset() == 0 { + use candle_metal_kernels::unary::contiguous; + + let kernel_name = match (B::KERNEL, dtype) { + ("ucos", DType::F32) => contiguous::cos::FLOAT, + ("usin", DType::F32) => contiguous::sin::FLOAT, + ("usqr", DType::F32) => contiguous::sqr::FLOAT, + ("usqrt", DType::F32) => contiguous::sqrt::FLOAT, + ("uneg", DType::F32) => contiguous::neg::FLOAT, + ("uexp", DType::F32) => contiguous::exp::FLOAT, + ("ulog", DType::F32) => contiguous::log::FLOAT, + (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), + }; + candle_metal_kernels::call_unary_contiguous( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + &self.buffer, + &mut buffer, + ) + .map_err(MetalError::from)?; + } else { + crate::bail!("TODO Implement the kernel calling {}", B::KERNEL); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + Ok(Self { + buffer, + device: device.clone(), + dtype, + }) + } + + fn binary_impl( + &self, + rhs: &Self, + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result { + let device = self.device(); + let dtype = self.dtype; + let shape = lhs_l.shape(); + let el_count = shape.elem_count(); + let mut buffer = device.new_buffer(el_count, dtype); + let command_buffer = device.command_queue.new_command_buffer(); + if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) + && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) + { + use candle_metal_kernels::binary::contiguous; + + let kernel_name = match (B::KERNEL, dtype) { + ("add", DType::F32) => contiguous::add::FLOAT, + ("badd", DType::F32) => contiguous::add::FLOAT, + ("sub", DType::F32) => contiguous::sub::FLOAT, + ("bsub", DType::F32) => contiguous::sub::FLOAT, + ("mul", DType::F32) => contiguous::mul::FLOAT, + ("bmul", DType::F32) => contiguous::mul::FLOAT, + ("div", DType::F32) => contiguous::div::FLOAT, + ("bdiv", DType::F32) => contiguous::div::FLOAT, + (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), + }; + candle_metal_kernels::call_binary_contiguous( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + &self.buffer, + &rhs.buffer, + &mut buffer, + ) + .map_err(MetalError::from)?; + } else { + use candle_metal_kernels::binary::strided; + + let kernel_name = match (B::KERNEL, dtype) { + ("badd", DType::F32) => strided::add::FLOAT, + ("bsub", DType::F32) => strided::sub::FLOAT, + ("bmul", DType::F32) => strided::mul::FLOAT, + ("bdiv", DType::F32) => strided::div::FLOAT, + (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), + }; + candle_metal_kernels::call_binary_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + lhs_l.dims(), + &self.buffer, + &lhs_l.stride(), + lhs_l.start_offset() * self.dtype.size_in_bytes(), + &rhs.buffer, + &rhs_l.stride(), + rhs_l.start_offset() * rhs.dtype.size_in_bytes(), + &mut buffer, + ) + .map_err(MetalError::from)?; + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + Ok(Self { + buffer, + device: device.clone(), + dtype, + }) + } + + fn where_cond( + &self, + layout: &Layout, + t: &Self, + t_l: &Layout, + f: &Self, + f_l: &Layout, + ) -> Result { + let device = self.device.clone(); + let shape = t_l.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + let dtype = t.dtype; + let mut buffer = self.device.new_buffer(el, dtype); + let command_buffer = self.device.command_queue.new_command_buffer(); + candle_metal_kernels::call_where_cond_strided( + &device.device, + &command_buffer, + &device.kernels, + "where_u8_f32", + &dims, + &self.buffer, + ( + layout.stride(), + layout.start_offset() * self.dtype.size_in_bytes(), + ), + &t.buffer, + (&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()), + &f.buffer, + (&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()), + &mut buffer, + ) + .map_err(MetalError::from)?; + command_buffer.commit(); + command_buffer.wait_until_completed(); + Ok(Self { + buffer, + device, + dtype, + }) + } + + fn conv1d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &ParamsConv1D, + ) -> Result { + crate::bail!("conv1d metal") + } + + fn conv_transpose1d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &ParamsConvTranspose1D, + ) -> Result { + crate::bail!("conv_transpose1d metal") + } + + fn conv2d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &ParamsConv2D, + ) -> Result { + crate::bail!("conv2d metal") + } + + fn conv_transpose2d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &ParamsConvTranspose2D, + ) -> Result { + crate::bail!("conv_tranpose2d metal") + } + + fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + crate::bail!("avg_pool2d metal") + } + + fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + crate::bail!("max_pool2d metal") + } + + fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result { + crate::bail!("upsample_nearest1d metal") + } + + fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result { + crate::bail!("upsample_nearest2d metal") + } + + fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result { + crate::bail!("gather metal") + } + + fn scatter_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result { + crate::bail!("scatter_add metal") + } + + fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result { + if !(src_l.is_contiguous() + && src_l.start_offset() == 0 + && ids_l.is_contiguous() + && ids_l.start_offset() == 0) + { + crate::bail!("Non contiguous index select not implemented"); + } + let left_size: usize = src_l.dims()[..dim].iter().product(); + let right_size: usize = src_l.dims()[dim + 1..].iter().product(); + let ids_el = ids_l.shape().elem_count(); + let dst_el = ids_el * left_size * right_size; + let dtype = self.dtype; + let device = self.device(); + let mut buffer = device.new_buffer(dst_el, dtype); + let out = self.to_cpu_storage()?; + let name = match (ids.dtype, self.dtype) { + (DType::U32, DType::F32) => "is_u32_f32", + (left, right) => crate::bail!("index select metal {left:?} {right:?}"), + }; + let command_buffer = self.device.command_queue.new_command_buffer(); + candle_metal_kernels::call_index_select( + &device.device, + &command_buffer, + &self.device.kernels, + name, + src_l.dims(), + ids_el, + dim, + &self.buffer, + &ids.buffer, + &mut buffer, + ) + .map_err(MetalError::from)?; + command_buffer.commit(); + command_buffer.wait_until_completed(); + Ok(Self { + buffer, + device: device.clone(), + dtype, + }) + } + + fn index_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result { + crate::bail!("index_add metal") + } + + fn matmul( + &self, + rhs: &Self, + (b, m, n, k): (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result { + // Create descriptors + use metal::mps::matrix::*; + let type_id = metal::mps::MPS_FLOATBIT_ENCODING | 32; + let size = core::mem::size_of::() as NSUInteger; + + let elem_count = b * m * n; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + // The a tensor has dims batching, k, n (rhs) + let transpose_left = if lhs_m1 == 1 && lhs_m2 == k { + false + } else if lhs_m1 == m && lhs_m2 == 1 { + true + } else { + Err(MetalError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })? + }; + let transpose_right = if rhs_m1 == 1 && rhs_m2 == n { + false + } else if rhs_m1 == k && rhs_m2 == 1 { + true + } else { + Err(MetalError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })? + }; + + let b = b as NSUInteger; + let m = m as NSUInteger; + let n = n as NSUInteger; + let k = k as NSUInteger; + + let left_descriptor = if transpose_left { + MatrixDescriptor::init_single(k, m, m * size, type_id) + } else { + MatrixDescriptor::init_single(m, k, k * size, type_id) + }; + let right_descriptor = if transpose_right { + MatrixDescriptor::init_single(n, k, k * size, type_id) + } else { + MatrixDescriptor::init_single(k, n, n * size, type_id) + }; + let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id); + + // Create matrix objects + let left_matrix = Matrix::init_with_buffer_descriptor(&self.buffer, &left_descriptor) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + let right_matrix = Matrix::init_with_buffer_descriptor(&rhs.buffer, &right_descriptor) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + + let out_buffer = self.device.new_buffer(elem_count, self.dtype); + let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, &result_descriptor) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + + let alpha = 1.0f64; + let beta = 0.0f64; + // Create kernel + let matrix_multiplication = MatrixMultiplication::init( + &self.device, + transpose_left, + transpose_right, + m, + n, + k, + alpha, + beta, + ) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + + matrix_multiplication.set_batch_size(b); + + // Encode kernel to command buffer + let command_buffer = self.device.command_queue.new_command_buffer(); + matrix_multiplication.encode_to_command_buffer( + command_buffer, + &left_matrix, + &right_matrix, + &result_matrix, + ); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + Ok(Self { + buffer: out_buffer, + device: self.device.clone(), + dtype: self.dtype(), + }) + } + + fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { + let src_shape = src_l.shape(); + let el_count = src_shape.elem_count(); + if el_count == 0 { + return Ok(()); + } + let command_buffer = self.device.command_queue.new_command_buffer(); + let kernel_name = match self.dtype { + DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, + DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, + DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, + dtype => crate::bail!("copy_strided not implemented for {dtype:?}"), + }; + candle_metal_kernels::call_unary_strided( + &self.device.device, + &command_buffer, + &self.device.kernels, + kernel_name, + src_l.dims(), + &self.buffer, + &src_l.stride(), + src_l.start_offset() * self.dtype.size_in_bytes(), + &mut dst.buffer, + dst_offset, + ) + .map_err(MetalError::from)?; + command_buffer.commit(); + command_buffer.wait_until_completed(); + Ok(()) + } +} + +impl MetalStorage { + pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self { + Self { + buffer, + device, + dtype, + } + } + + pub fn buffer(&self) -> &Buffer { + &self.buffer + } +} + +impl BackendDevice for MetalDevice { + type Storage = MetalStorage; + + fn new(ordinal: usize) -> Result { + let device = metal::Device::all().swap_remove(ordinal); + + let command_queue = device.new_command_queue(); + let kernels = Arc::new(Kernels::new()); + Ok(Self { + device, + command_queue, + kernels, + }) + } + + fn set_seed(&self, _seed: u64) -> Result<()> { + crate::bail!("set_seed") + } + + fn location(&self) -> crate::DeviceLocation { + crate::DeviceLocation::Metal { + gpu_id: self.registry_id() as usize, + } + } + + fn same_device(&self, rhs: &Self) -> bool { + self.device.registry_id() == rhs.device.registry_id() + } + + fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { + // TODO Is there a faster way ? + let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?; + self.storage_from_cpu_storage(&cpu_storage) + } + + fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { + // TODO Is there a faster way ? + let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?; + self.storage_from_cpu_storage(&cpu_storage) + } + + fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result { + let option = metal::MTLResourceOptions::StorageModeManaged; + let buffer = match storage { + CpuStorage::U8(storage) => self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as NSUInteger, + option, + ), + CpuStorage::U32(storage) => self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as NSUInteger, + option, + ), + CpuStorage::I64(storage) => self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as NSUInteger, + option, + ), + CpuStorage::BF16(storage) => self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as NSUInteger, + option, + ), + CpuStorage::F16(storage) => self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as NSUInteger, + option, + ), + CpuStorage::F32(storage) => self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as NSUInteger, + option, + ), + CpuStorage::F64(storage) => self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as NSUInteger, + option, + ), + }; + Ok(Self::Storage { + buffer, + device: self.clone(), + dtype: storage.dtype(), + }) + } + + fn rand_uniform( + &self, + shape: &Shape, + dtype: DType, + mean: f64, + stddev: f64, + ) -> Result { + // TODO is there a better way ? + let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?; + self.storage_from_cpu_storage(&cpu_storage) + } + + fn rand_normal( + &self, + shape: &Shape, + dtype: DType, + mean: f64, + stddev: f64, + ) -> Result { + // TODO is there a better way ? + let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?; + self.storage_from_cpu_storage(&cpu_storage) + } +} diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 2a0924b633..ce5858fa81 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1859,7 +1859,14 @@ impl Tensor { (Storage::Cpu(storage), Device::Cuda(cuda)) => { Storage::Cuda(cuda.storage_from_cpu_storage(storage)?) } + (Storage::Cpu(storage), Device::Metal(metal)) => { + Storage::Metal(metal.storage_from_cpu_storage(storage)?) + } (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), + (Storage::Metal(storage), Device::Cpu) => { + println!("{storage:?} - {:?}", storage.to_cpu_storage()?); + Storage::Cpu(storage.to_cpu_storage()?) + } (Storage::Cuda(storage), Device::Cuda(cuda)) => { // TODO: Avoid passing through the cpu storage here, especially if the gpu ids // are the same. diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml new file mode 100644 index 0000000000..186f320906 --- /dev/null +++ b/candle-metal-kernels/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "candle-metal-kernels" +version = "0.3.1" +edition = "2021" + +description = "Metal kernels for Candle" +repository = "https://github.com/huggingface/candle" +keywords = ["blas", "tensor", "machine-learning"] +categories = ["science"] +license = "MIT OR Apache-2.0" + +[dependencies] +metal = { version = "0.27.1", features = ["mps"], package="candle-metal" } +once_cell = "1.18.0" +thiserror = "1" +tracing = "0.1.37" + +[dev-dependencies] +half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } +rand = "0.8.5" diff --git a/candle-metal-kernels/README.md b/candle-metal-kernels/README.md new file mode 100644 index 0000000000..ec923e9a2a --- /dev/null +++ b/candle-metal-kernels/README.md @@ -0,0 +1,3 @@ +# candle-metal-kernels + +This crate contains Metal kernels used from candle. \ No newline at end of file diff --git a/candle-metal-kernels/examples/affine.rs b/candle-metal-kernels/examples/affine.rs new file mode 100644 index 0000000000..b8005dc0ae --- /dev/null +++ b/candle-metal-kernels/examples/affine.rs @@ -0,0 +1,75 @@ +use candle_metal_kernels::{call_affine, Kernels}; +use metal::objc::rc::autoreleasepool; +use metal::{Device, MTLResourceOptions}; +use rand; +use std::any::type_name; +use std::time::Instant; + +fn main() { + let device = Device::system_default().unwrap(); + let kernels = Kernels::new(); + + let f32_1k = (0..1000).map(|_| rand::random::()).collect::>(); + let f32_10k = (0..10000) + .map(|_| rand::random::()) + .collect::>(); + let f32_100k = (0..100000) + .map(|_| rand::random::()) + .collect::>(); + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}", + "dtype", "kernel", "size", "runs", "total time", "avg time" + ); + + // f32 + run_affine_bench(&device, &kernels, &f32_1k); + run_affine_bench(&device, &kernels, &f32_10k); + run_affine_bench(&device, &kernels, &f32_100k); +} + +fn run_affine_bench(device: &Device, kernels: &Kernels, v: &[T]) { + let command_queue = device.new_command_queue(); + let options = MTLResourceOptions::StorageModeManaged; + + let iterations = 10000; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + core::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options); + + let mul: f32 = 1.2345; + let add: f32 = 2.3456; + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_affine( + &device, + command_buffer, + &kernels, + v.len(), + &input, + &mut output, + mul, + add, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + "affine", + v.len(), + iterations, + total_time, + total_time / iterations + ); +} diff --git a/candle-metal-kernels/examples/binary.rs b/candle-metal-kernels/examples/binary.rs new file mode 100644 index 0000000000..af5a8bdc62 --- /dev/null +++ b/candle-metal-kernels/examples/binary.rs @@ -0,0 +1,182 @@ +use candle_metal_kernels::{binary, call_binary_contiguous, call_binary_strided, Kernels}; +use half::{bf16, f16}; +use metal::objc::rc::autoreleasepool; +use metal::{Device, MTLResourceOptions}; +use rand; +use std::any::type_name; +use std::time::Instant; + +fn main() { + let device = Device::system_default().unwrap(); + let kernels = Kernels::new(); + + let f32_1k = (0..1000).map(|_| rand::random::()).collect::>(); + let f32_10k = (0..10000) + .map(|_| rand::random::()) + .collect::>(); + let f32_100k = (0..100000) + .map(|_| rand::random::()) + .collect::>(); + + let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::>(); + let f16_1k = f16_map(&f32_1k); + let f16_10k = f16_map(&f32_10k); + let f16_100k = f16_map(&f32_100k); + + let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::>(); + let bf16_1k = bf16_map(&f32_1k); + let bf16_10k = bf16_map(&f32_10k); + let bf16_100k = bf16_map(&f32_100k); + + let f32_ckernels = [ + binary::contiguous::add::FLOAT, + binary::contiguous::sub::FLOAT, + binary::contiguous::mul::FLOAT, + binary::contiguous::div::FLOAT, + ]; + let f32_skernels = [ + binary::strided::add::FLOAT, + binary::strided::sub::FLOAT, + binary::strided::mul::FLOAT, + binary::strided::div::FLOAT, + ]; + let f16_ckernels = [ + binary::contiguous::add::HALF, + binary::contiguous::sub::HALF, + binary::contiguous::mul::HALF, + binary::contiguous::div::HALF, + ]; + let f16_skernels = [ + binary::strided::add::HALF, + binary::strided::sub::HALF, + binary::strided::mul::HALF, + binary::strided::div::HALF, + ]; + let bf16_ckernels = [ + binary::contiguous::add::BFLOAT, + binary::contiguous::sub::BFLOAT, + binary::contiguous::mul::BFLOAT, + binary::contiguous::div::BFLOAT, + ]; + let bf16_skernels = [ + binary::strided::add::BFLOAT, + binary::strided::sub::BFLOAT, + binary::strided::mul::BFLOAT, + binary::strided::div::BFLOAT, + ]; + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}", + "dtype", "kernel", "size", "runs", "total time", "avg time" + ); + + // f32 + run_binary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels); + run_binary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels); + run_binary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels); + + // f16 + run_binary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels); + run_binary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels); + run_binary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels); + + // bf16 + run_binary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels); + run_binary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels); + run_binary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels); +} + +fn run_binary_bench( + device: &Device, + kernels: &Kernels, + v: &[T], + contiguous: [binary::contiguous::Kernel; 4], + strided: [binary::strided::Kernel; 4], +) { + let command_queue = device.new_command_queue(); + let options = MTLResourceOptions::StorageModeManaged; + + let iterations = 1000; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + core::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options); + + // Contiguous + for kernel_name in contiguous { + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_binary_contiguous( + device, + &command_buffer, + kernels, + kernel_name, + v.len(), + &input, + &input, + &mut output, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + kernel_name.to_string(), + v.len(), + iterations, + total_time, + total_time / iterations + ); + } + + // Strided + let shape = vec![2, 5_000]; + let strides = vec![2, 1]; + let offset = 0; + for kernel_name in strided { + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_binary_strided( + device, + command_buffer, + &kernels, + kernel_name, + &shape, + &input, + &strides, + offset, + &input, + &strides, + offset, + &mut output, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + kernel_name.to_string(), + v.len(), + iterations, + total_time, + total_time / iterations + ); + } +} diff --git a/candle-metal-kernels/examples/cast.rs b/candle-metal-kernels/examples/cast.rs new file mode 100644 index 0000000000..090f510d16 --- /dev/null +++ b/candle-metal-kernels/examples/cast.rs @@ -0,0 +1,84 @@ +use candle_metal_kernels::{call_cast_contiguous, Kernels}; +use metal::objc::rc::autoreleasepool; +use metal::{Device, MTLResourceOptions}; +use rand; +use std::any::type_name; +use std::time::Instant; + +fn main() { + let device = Device::system_default().unwrap(); + let kernels = Kernels::new(); + + let f32_1k = (0..1000).map(|_| rand::random::()).collect::>(); + let f32_10k = (0..10000) + .map(|_| rand::random::()) + .collect::>(); + let f32_100k = (0..100000) + .map(|_| rand::random::()) + .collect::>(); + + let contiguous_kernels = ["cast_u32_f32"]; + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}", + "dtype", "kernel", "size", "runs", "total time", "avg time" + ); + + // f32 + run_cast_bench(&device, &kernels, &f32_1k, &contiguous_kernels); + run_cast_bench(&device, &kernels, &f32_10k, &contiguous_kernels); + run_cast_bench(&device, &kernels, &f32_100k, &contiguous_kernels); +} + +fn run_cast_bench( + device: &Device, + kernels: &Kernels, + v: &[T], + contiguous: &[&'static str], +) { + let command_queue = device.new_command_queue(); + let options = MTLResourceOptions::StorageModeManaged; + + let iterations = 1000; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + core::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options); + + // Contiguous + for kernel_name in contiguous { + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_cast_contiguous( + device, + &command_buffer, + kernels, + kernel_name, + v.len(), + &input, + &mut output, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + kernel_name.to_string(), + v.len(), + iterations, + total_time, + total_time / iterations + ); + } + + // Strided? +} diff --git a/candle-metal-kernels/examples/unary.rs b/candle-metal-kernels/examples/unary.rs new file mode 100644 index 0000000000..7039c0985a --- /dev/null +++ b/candle-metal-kernels/examples/unary.rs @@ -0,0 +1,197 @@ +use candle_metal_kernels::{call_unary_contiguous, call_unary_strided, unary, Kernels}; +use half::{bf16, f16}; +use metal::objc::rc::autoreleasepool; +use metal::{Device, MTLResourceOptions}; +use rand; +use std::any::type_name; +use std::time::Instant; + +fn main() { + let device = Device::system_default().unwrap(); + let kernels = Kernels::new(); + + let f32_1k = (0..1000).map(|_| rand::random::()).collect::>(); + let f32_10k = (0..10000) + .map(|_| rand::random::()) + .collect::>(); + let f32_100k = (0..100000) + .map(|_| rand::random::()) + .collect::>(); + + let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::>(); + let f16_1k = f16_map(&f32_1k); + let f16_10k = f16_map(&f32_10k); + let f16_100k = f16_map(&f32_100k); + + let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::>(); + let bf16_1k = bf16_map(&f32_1k); + let bf16_10k = bf16_map(&f32_10k); + let bf16_100k = bf16_map(&f32_100k); + + let f32_ckernels = [ + unary::contiguous::sin::FLOAT, + unary::contiguous::cos::FLOAT, + unary::contiguous::exp::FLOAT, + unary::contiguous::sqr::FLOAT, + unary::contiguous::sqrt::FLOAT, + unary::contiguous::neg::FLOAT, + unary::contiguous::copy::FLOAT, + ]; + let f32_skernels = [ + unary::strided::sin::FLOAT, + unary::strided::cos::FLOAT, + unary::strided::exp::FLOAT, + unary::strided::sqr::FLOAT, + unary::strided::sqrt::FLOAT, + unary::strided::neg::FLOAT, + unary::strided::copy::FLOAT, + ]; + let f16_ckernels = [ + unary::contiguous::sin::HALF, + unary::contiguous::cos::HALF, + unary::contiguous::exp::HALF, + unary::contiguous::sqr::HALF, + unary::contiguous::sqrt::HALF, + unary::contiguous::neg::HALF, + unary::contiguous::copy::HALF, + ]; + let f16_skernels = [ + unary::strided::sin::HALF, + unary::strided::cos::HALF, + unary::strided::exp::HALF, + unary::strided::sqr::HALF, + unary::strided::sqrt::HALF, + unary::strided::neg::HALF, + unary::strided::copy::HALF, + ]; + let bf16_ckernels = [ + unary::contiguous::sin::BFLOAT, + unary::contiguous::cos::BFLOAT, + unary::contiguous::exp::BFLOAT, + unary::contiguous::sqr::BFLOAT, + unary::contiguous::sqrt::BFLOAT, + unary::contiguous::neg::BFLOAT, + unary::contiguous::copy::BFLOAT, + ]; + let bf16_skernels = [ + unary::strided::sin::BFLOAT, + unary::strided::cos::BFLOAT, + unary::strided::exp::BFLOAT, + unary::strided::sqr::BFLOAT, + unary::strided::sqrt::BFLOAT, + unary::strided::neg::BFLOAT, + unary::strided::copy::BFLOAT, + ]; + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}", + "dtype", "kernel", "size", "runs", "total time", "avg time" + ); + + // f32 + run_unary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels); + run_unary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels); + run_unary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels); + + // f16 + run_unary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels); + run_unary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels); + run_unary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels); + + // bf16 + run_unary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels); + run_unary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels); + run_unary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels); +} + +fn run_unary_bench( + device: &Device, + kernels: &Kernels, + v: &[T], + contiguous: [unary::contiguous::Kernel; 7], + strided: [unary::strided::Kernel; 7], +) { + let command_queue = device.new_command_queue(); + let options = MTLResourceOptions::StorageModeManaged; + + let iterations = 10000; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + core::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options); + + // Contiguous + for kernel_name in contiguous { + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_unary_contiguous( + device, + &command_buffer, + kernels, + kernel_name, + v.len(), + &input, + &mut output, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + kernel_name.to_string(), + v.len(), + iterations, + total_time, + total_time / iterations + ); + } + + // Strided + let shape = vec![2, 5_000]; + let strides = vec![2, 1]; + let offset = 0; + for kernel_name in strided { + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_unary_strided( + device, + command_buffer, + &kernels, + kernel_name, + &shape, + &input, + &strides, + offset, + &mut output, + 0, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + kernel_name.to_string(), + v.len(), + iterations, + total_time, + total_time / iterations + ); + } +} diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal new file mode 100644 index 0000000000..e5f0a841e0 --- /dev/null +++ b/candle-metal-kernels/src/affine.metal @@ -0,0 +1,43 @@ +#include + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +using namespace metal; + +#define AFFINE(FN_NAME, TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + constant float &mul, \ + constant float &add, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + const TYPENAME m = TYPENAME(mul); \ + const TYPENAME a = TYPENAME(add); \ + output[id] = input[id] * m + a; \ +} \ + +AFFINE(affine_float, float) +AFFINE(affine_half, half) + + +#if __METAL_VERSION__ >= 310 +AFFINE(affine_bfloat, bfloat); +#endif diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal new file mode 100644 index 0000000000..f18cdbb0f8 --- /dev/null +++ b/candle-metal-kernels/src/binary.metal @@ -0,0 +1,72 @@ +#include + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +using namespace metal; + +#define BINARY(FN, TYPENAME, OUT_TYPENAME, FN_NAME, FN_NAME_STRIDED) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + device const TYPENAME *left, \ + device const TYPENAME *right, \ + device TYPENAME *output, \ + uint thread_position_in_grid [[ thread_position_in_grid ]] \ +) { \ + if (thread_position_in_grid >= dim) { \ + return; \ + } \ + TYPENAME x = left[thread_position_in_grid]; \ + TYPENAME y = right[thread_position_in_grid]; \ + output[thread_position_in_grid] = OUT_TYPENAME(FN); \ +}\ +kernel void FN_NAME_STRIDED( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *left_strides, \ + constant size_t *right_strides, \ + device const TYPENAME *left, \ + device const TYPENAME *right, \ + device TYPENAME *output, \ + uint thread_position_in_grid [[ thread_position_in_grid ]] \ +) { \ + if (thread_position_in_grid >= dim) { \ + return; \ + } \ + TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \ + TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \ + output[thread_position_in_grid] = OUT_TYPENAME(FN); \ +} + +#define BINARY_OP(FN, NAME) \ +BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \ +BINARY(FN, half, half, NAME##_half, NAME##_half_strided); + +#define BFLOAT_BINARY_OP(FN, NAME) \ +BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided); + + +BINARY_OP(x + y, add) +BINARY_OP(x - y, sub) +BINARY_OP(x * y, mul) +BINARY_OP(x / y, div) + +#if __METAL_VERSION__ >= 310 +BFLOAT_BINARY_OP(x + y, add) +BFLOAT_BINARY_OP(x - y, sub) +BFLOAT_BINARY_OP(x * y, mul) +BFLOAT_BINARY_OP(x / y, div) +#endif diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal new file mode 100644 index 0000000000..d178825317 --- /dev/null +++ b/candle-metal-kernels/src/cast.metal @@ -0,0 +1,51 @@ +#include + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + + +using namespace metal; + +#define CAST(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + device const LEFT_TYPENAME *input, \ + device RIGHT_TYPENAME *output, \ + uint thread_position_in_grid [[ thread_position_in_grid ]] \ +) { \ + if (thread_position_in_grid >= dim) { \ + return; \ + } \ + output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \ +} \ +kernel void FN_NAME_STRIDED( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + device const LEFT_TYPENAME *input, \ + device RIGHT_TYPENAME *output, \ + uint i [[ thread_position_in_grid ]] \ +) { \ + if (i >= dim) { \ + return; \ + } \ + output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \ +} \ + +CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float) + +#if __METAL_VERSION__ >= 310 +#endif diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal new file mode 100644 index 0000000000..444fa322fe --- /dev/null +++ b/candle-metal-kernels/src/indexing.metal @@ -0,0 +1,102 @@ +#include +using namespace metal; + +# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \ +kernel void NAME( \ + constant size_t &dst_size, \ + constant size_t &left_size, \ + constant size_t &src_dim_size, \ + constant size_t &right_size, \ + constant size_t &ids_size, \ + const device TYPENAME *input, \ + const device INDEX_TYPENAME *input_ids, \ + device TYPENAME *output, \ + uint gid [[ thread_position_in_grid ]] \ +) { \ + if (gid >= dst_size) { \ + return; \ + } \ + const size_t id_i = gid / right_size / left_size; \ + const size_t right_rank_i = gid % right_size; \ + const size_t left_rank_i = gid % left_size; \ + /* \ + // Force prevent out of bounds indexing \ + // since there doesn't seem to be a good way to force crash \ + // No need to check for zero we're only allowing unsized. \ + */ \ + const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \ + const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \ + output[gid] = input[src_i]; \ +} + + + +template +void index_add( + device I *ids [[buffer(0)]], + device T *inp [[buffer(1)]], + device T *out [[buffer(2)]], + + constant uint &ids_dim_size, + constant uint &left_size, + constant uint &dst_dim_size, + constant uint &right_size, + + uint gid [[ thread_position_in_grid ]] \ +) { + + if (gid >= left_size * right_size) { + return; + } + + const uint i = gid; + const uint pre = i / right_size; + const uint post = i % right_size; + + for (uint j = 0; j < ids_dim_size; j++) { + const uint idx = ids[j]; + const uint src_i = (pre * ids_dim_size + j) * right_size + post; + const uint dst_i = (pre * dst_dim_size + idx) * right_size + post; + out[dst_i] += inp[src_i]; + } +} + +#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +kernel void FN_NAME( \ + device INDEX_TYPENAME *ids [[buffer(0)]], \ + device TYPENAME *inp [[buffer(1)]], \ + device TYPENAME *out [[buffer(2)]], \ + constant uint &ids_dim_size, \ + constant uint &left_size, \ + constant uint &dst_dim_size, \ + constant uint &right_size, \ + uint gid [[ thread_position_in_grid ]] \ +) { index_add(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \ + + +INDEX_OP(is_u32_f32, uint, float) + + +#if __METAL_VERSION__ >= 310 +IA_OP(bfloat, int64_t, ia_i64_bf16) +IA_OP(bfloat, uint32_t, ia_u32_bf16) +IA_OP(bfloat, uint8_t, ia_u8_bf16) +#endif + +IA_OP(half, uint32_t, ia_u32_f16) +IA_OP(half, uint8_t, ia_u8_f16) + +IA_OP(float, int64_t, ia_i64_f32) +IA_OP(uint8_t, int64_t, ia_i64_u8) +IA_OP(int64_t, int64_t, ia_i64_i64) +IA_OP(uint32_t, int64_t, ia_i64_u32) + +IA_OP(float, uint32_t, ia_u32_f32) +IA_OP(uint8_t, uint32_t, ia_u32_u8) +IA_OP(int64_t, uint32_t, ia_u32_i64) +IA_OP(uint32_t, uint32_t, ia_u32_u32) + +IA_OP(float, uint8_t, ia_u8_f32) +IA_OP(uint8_t, uint8_t, ia_u8_u8) +IA_OP(uint32_t, uint8_t, ia_u8_u32) +IA_OP(int64_t, uint8_t, ia_u8_i64) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs new file mode 100644 index 0000000000..5a6bd41bde --- /dev/null +++ b/candle-metal-kernels/src/lib.rs @@ -0,0 +1,675 @@ +use metal::{ + Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor, + ComputePipelineState, Device, Function, Library, MTLSize, +}; +use std::collections::HashMap; +use std::ffi::c_void; +use std::sync::RwLock; + +const AFFINE: &str = include_str!("affine.metal"); +const INDEXING: &str = include_str!("indexing.metal"); +const UNARY: &str = include_str!("unary.metal"); +const BINARY: &str = include_str!("binary.metal"); +const TERNARY: &str = include_str!("ternary.metal"); +const CAST: &str = include_str!("cast.metal"); +const REDUCE: &str = include_str!("reduce.metal"); + +fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) { + let size = length as u64; + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size); + let count = (size + width - 1) / width; + let thread_group_count = MTLSize { + width: count, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + (thread_group_count, thread_group_size) +} + +fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: P) { +

::set_param(encoder, position, data) +} +trait EncoderParam { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); +} +macro_rules! primitive { + ($type:ty) => { + impl EncoderParam for $type { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_bytes( + position, + core::mem::size_of::<$type>() as u64, + &data as *const $type as *const c_void, + ); + } + } + }; +} +primitive!(usize); +primitive!(u32); +primitive!(f32); + +impl EncoderParam for &[T] { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_bytes( + position, + (core::mem::size_of::() * data.len()) as u64, + data.as_ptr() as *const T as *const c_void, + ); + } +} + +impl EncoderParam for &Buffer { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data), 0); + } +} +impl EncoderParam for (&Buffer, usize) { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data.0), data.1 as u64); + } +} +impl EncoderParam for &mut Buffer { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data), 0); + } +} +impl EncoderParam for (&mut Buffer, usize) { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data.0), data.1 as u64); + } +} + +macro_rules! set_params { + ($encoder:ident, ($($param:expr),+)) => ( + let mut _index = 0; + $( + set_param($encoder, _index, $param); + _index += 1; + )* + ); +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Source { + Affine, + Indexing, + Unary, + Binary, + Ternary, + Cast, + Reduce, +} + +macro_rules! ops{ + ($($name:ident),+) => { + + pub mod contiguous { + #[derive(Clone, Copy)] + pub struct Kernel(pub(crate) &'static str); + impl std::fmt::Display for Kernel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } + } + $( + pub mod $name { + use super::Kernel; + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float")); + pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half")); + pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat")); + } + )+ + } + + pub mod strided { + #[derive(Clone, Copy)] + pub struct Kernel(pub(crate) &'static str); + impl std::fmt::Display for Kernel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } + } + $( + pub mod $name { + use super::Kernel; + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float_strided")); + pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half_strided")); + pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided")); + } + )+ + } + }; +} + +pub mod unary { + ops!(cos, sin, exp, sqr, sqrt, neg, copy, log); +} +pub mod binary { + ops!(add, sub, mul, div); +} + +#[derive(thiserror::Error, Debug)] +pub enum MetalKernelError { + #[error("Could not lock kernel map: {0}")] + LockError(String), + #[error("Error while loading library: {0}")] + LoadLibraryError(String), + #[error("Error while loading function: {0}")] + LoadFunctionError(String), +} + +impl From> for MetalKernelError { + fn from(e: std::sync::PoisonError) -> Self { + Self::LockError(e.to_string()) + } +} + +type KernelMap = HashMap<&'static str, T>; +type Libraries = HashMap; +type Functions = KernelMap; + +#[derive(Debug, Default)] +pub struct Kernels { + libraries: RwLock, + funcs: RwLock, +} + +impl Kernels { + pub fn new() -> Self { + let libraries = RwLock::new(Libraries::new()); + let funcs = RwLock::new(Functions::new()); + Self { libraries, funcs } + } + + fn get_library_source(&self, source: Source) -> &'static str { + match source { + Source::Affine => AFFINE, + Source::Unary => UNARY, + Source::Binary => BINARY, + Source::Ternary => TERNARY, + Source::Indexing => INDEXING, + Source::Cast => CAST, + Source::Reduce => REDUCE, + } + } + + pub fn load_library( + &self, + device: &Device, + source: Source, + ) -> Result { + let mut libraries = self.libraries.write()?; + if let Some(lib) = libraries.get(&source) { + Ok(lib.clone()) + } else { + let source_content = self.get_library_source(source); + let lib = device + .new_library_with_source(source_content, &CompileOptions::new()) + .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?; + libraries.insert(source, lib.clone()); + Ok(lib) + } + } + + pub fn load_function( + &self, + device: &Device, + source: Source, + name: &'static str, + ) -> Result { + let mut funcs = self.funcs.write()?; + if let Some(func) = funcs.get(name) { + Ok(func.clone()) + } else { + let func = self + .load_library(device, source)? + .get_function(name, None) + .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; + funcs.insert(name, func.clone()); + Ok(func) + } + } +} + +#[allow(clippy::too_many_arguments)] +pub fn call_unary_contiguous( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: unary::contiguous::Kernel, + length: usize, + input: &Buffer, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Unary, kernel_name.0)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, input, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_unary_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: unary::strided::Kernel, + shape: &[usize], + input: &Buffer, + strides: &[usize], + offset: usize, + output: &mut Buffer, + output_offset: usize, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Unary, name.0)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let num_dims: usize = shape.len(); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + let length: usize = shape.iter().product(); + set_params!( + encoder, + ( + length, + num_dims, + shape, + strides, + (input, offset), + (output, output_offset) + ) + ); + + let width: usize = shape.iter().product(); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_binary_contiguous( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: binary::contiguous::Kernel, + length: usize, + left: &Buffer, + right: &Buffer, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Binary, kernel_name.0)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, left, right, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_binary_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: binary::strided::Kernel, + shape: &[usize], + left_input: &Buffer, + left_strides: &[usize], + left_offset: usize, + right_input: &Buffer, + right_strides: &[usize], + right_offset: usize, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Binary, name.0)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let num_dims: usize = shape.len(); + let encoder = command_buffer.new_compute_command_encoder(); + let width: usize = shape.iter().product(); + encoder.set_compute_pipeline_state(&pipeline); + + let length: usize = shape.iter().product(); + + set_params!( + encoder, + ( + length, + num_dims, + shape, + left_strides, + right_strides, + (left_input, left_offset), + (right_input, right_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_cast_contiguous( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + length: usize, + input: &Buffer, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Cast, kernel_name)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, input, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_reduce_contiguous( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + length: usize, + out_length: usize, + input: &Buffer, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Reduce, kernel_name)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let elements_to_sum = length / out_length; + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, elements_to_sum, input, output)); + + let thread_group_count = MTLSize { + width: out_length as u64, + height: 1, + depth: 1, + }; + + let width = std::cmp::min( + pipeline.max_total_threads_per_threadgroup(), + (elements_to_sum as u64 + 2 - 1) / 2, + ) + .next_power_of_two(); + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_last_softmax( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + length: usize, + elements_to_sum: usize, + input: &Buffer, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Reduce, kernel_name)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, elements_to_sum, input, output)); + + let out_length = length / elements_to_sum; + + let thread_group_count = MTLSize { + width: out_length as u64, + height: 1, + depth: 1, + }; + + let width = std::cmp::min( + pipeline.max_total_threads_per_threadgroup(), + elements_to_sum as u64, + ) + .next_power_of_two(); + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_affine( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + size: usize, + input: &Buffer, + output: &mut Buffer, + mul: f32, + add: f32, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Affine, "affine_float")?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (size, mul, add, input, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_where_cond_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + cond: &Buffer, + (cond_stride, cond_offset): (&[usize], usize), + left: &Buffer, + (left_stride, left_offset): (&[usize], usize), + right: &Buffer, + (right_stride, right_offset): (&[usize], usize), + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Ternary, name)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + let size: usize = shape.iter().product(); + let rank = shape.len(); + + set_params!( + encoder, + ( + size, + rank, + shape, + cond_stride, + left_stride, + right_stride, + (cond, cond_offset), + (left, left_offset), + (right, right_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_index_select( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + ids_size: usize, + dim: usize, + input: &Buffer, + ids: &Buffer, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + let left_size: usize = shape[..dim].iter().product(); + let right_size: usize = shape[dim + 1..].iter().product(); + let src_dim_size = shape[dim]; + let dst_el = ids_size * left_size * right_size; + + let func = kernels.load_function(device, Source::Indexing, name)?; + let pipeline = device + .new_compute_pipeline_state_with_function(&func) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + dst_el, + left_size, + src_dim_size, + right_size, + ids_size, + input, + ids, + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[cfg(test)] +mod tests; diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal new file mode 100644 index 0000000000..c6984474b2 --- /dev/null +++ b/candle-metal-kernels/src/reduce.metal @@ -0,0 +1,139 @@ +#include +using namespace metal; + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +constant int THREADGROUP_SIZE = 256; + +# define REDUCE(FN, NAME, TYPENAME) \ +kernel void NAME( \ + constant size_t &src_numel, \ + constant size_t &el_to_sum_per_block, \ + device const TYPENAME *src, \ + device TYPENAME *dst, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint blockDim [[ threads_per_threadgroup ]] \ +) { \ + \ + threadgroup float shared_memory[THREADGROUP_SIZE]; \ + \ + shared_memory[tid] = 0; \ + /* \ + // Elements summed in this block range from dst_id * el_to_sum_per_block \ + // to (dst_id + 1) * el_to_sum_per_block. \ + */ \ + size_t start_idx = dst_id * el_to_sum_per_block; \ + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \ + size_t idx = start_idx + tid; \ + while (idx < stop_idx) { \ + /* \ + // TODO: Fast version for the contiguous case. \ + // size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ + */ \ + TYPENAME x = shared_memory[tid]; \ + TYPENAME y = src[idx]; \ + shared_memory[tid] = FN; \ + idx += blockDim; \ + } \ + \ + threadgroup_barrier(mem_flags::mem_none); \ + \ + /* \ + // reduction in shared memory \ + */ \ + for (uint s = blockDim / 2; s > 0; s >>= 1) { \ + if (tid < s) { \ + TYPENAME x = shared_memory[tid]; \ + TYPENAME y = shared_memory[tid + s]; \ + shared_memory[tid] = FN; \ + } \ + threadgroup_barrier(mem_flags::mem_none); \ + } \ + \ + dst[dst_id] = shared_memory[0]; \ +} \ + +kernel void softmax_float( + constant size_t &src_numel, + constant size_t &el_to_sum_per_block, + device const float *src, + device float *dst, + uint id [[ thread_position_in_grid ]], + uint tid [[ thread_index_in_threadgroup ]], + uint dst_id [[ threadgroup_position_in_grid ]], + uint blockDim [[ threads_per_threadgroup ]] +) { + + threadgroup float shared_memory[THREADGROUP_SIZE]; + + shared_memory[tid] = -INFINITY; + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + shared_memory[tid] = max(shared_memory[tid], src[idx]); + idx += blockDim; + } + + threadgroup_barrier(mem_flags::mem_none); + + // reduction in shared memory + for (uint s = blockDim / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]); + } + threadgroup_barrier(mem_flags::mem_none); + } + + float max = shared_memory[0]; + + shared_memory[tid] = 0; + + // Restart + idx = start_idx + tid; + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + const float val = exp(src[idx] - max); + dst[idx] = val; + shared_memory[tid] += val; + idx += blockDim; + } + // reduction in shared memory + for (uint s = blockDim / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_memory[tid] += shared_memory[tid + s]; + } + threadgroup_barrier(mem_flags::mem_none); + } + + const float inv_acc = 1/shared_memory[0]; + idx = start_idx + tid; + while (idx < stop_idx) { + dst[idx] *= inv_acc; + idx += blockDim; + } +} + + +REDUCE(x + y, fast_sum_float, float) +REDUCE(x * y, fast_mul_float, float) +REDUCE(max(x, y), fast_max_float, float) diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal new file mode 100644 index 0000000000..0945b355cf --- /dev/null +++ b/candle-metal-kernels/src/ternary.metal @@ -0,0 +1,57 @@ +#include +# +using namespace metal; + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + + +#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &numel, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t *strides_t, \ + constant size_t *strides_f, \ + device const ID_TYPENAME *ids, \ + device const TYPENAME *t, \ + device const TYPENAME *f, \ + device TYPENAME *out ,\ + uint i [[ thread_position_in_grid ]] \ +) { \ + uint strided_i = get_strided_index(i, num_dims, dims, strides); \ + uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \ + uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \ + out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \ +} \ + +// WHERE_OP(float, int64_t, where_i64_f32) +// WHERE_OP(double, int64_t, where_i64_f64) +// WHERE_OP(uint8_t, int64_t, where_i64_u8) +// WHERE_OP(uint32_t, int64_t, where_i64_u32) +// WHERE_OP(int64_t, int64_t, where_i64_i64) +// +// WHERE_OP(float, uint32_t, where_u32_f32) +// WHERE_OP(double, uint32_t, where_u32_f64) +// WHERE_OP(uint8_t, uint32_t, where_u32_u8) +// WHERE_OP(uint32_t, uint32_t, where_u32_u32) +// WHERE_OP(int64_t, uint32_t, where_u32_i64) + +WHERE_OP(float, uint8_t, where_u8_f32) +// WHERE_OP(double, uint8_t, where_u8_f64) +// WHERE_OP(uint8_t, uint8_t, where_u8_u8) +// WHERE_OP(uint32_t, uint8_t, where_u8_u32) +// WHERE_OP(int64_t, uint8_t, where_u8_i64) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs new file mode 100644 index 0000000000..2330d48d26 --- /dev/null +++ b/candle-metal-kernels/src/tests.rs @@ -0,0 +1,616 @@ +use super::*; +use half::f16; +use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger}; + +fn new_buffer(device: &Device, data: &[T]) -> Buffer { + let options = MTLResourceOptions::StorageModeManaged; + let ptr = data.as_ptr() as *const core::ffi::c_void; + let size = (data.len() * std::mem::size_of::()) as u64; + device.new_buffer_with_data(ptr, size, options) +} + +fn device() -> Device { + Device::system_default().unwrap() +} + +fn approx(v: Vec, digits: i32) -> Vec { + let b = 10f32.powi(digits); + v.iter().map(|t| f32::round(t * b) / b).collect() +} + +fn approx_f16(v: Vec, digits: i32) -> Vec { + let b = 10f32.powi(digits); + v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect() +} + +fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); + call_unary_contiguous( + &device, + command_buffer, + &kernels, + name, + v.len(), + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + output.read_to_vec::(v.len()) +} + +fn run_binary(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + let left = new_buffer(&device, x); + let right = new_buffer(&device, y); + let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options); + call_binary_contiguous( + &device, + command_buffer, + &kernels, + name, + x.len(), + &left, + &right, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + output.read_to_vec::(x.len()) +} + +fn run_strided( + v: &[T], + kernel: unary::strided::Kernel, + shape: &[usize], + strides: &[usize], + offset: usize, +) -> Vec { + let device = device(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); + let kernels = Kernels::new(); + call_unary_strided( + &device, + command_buffer, + &kernels, + kernel, + shape, + &input, + strides, + offset, + &mut output, + 0, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + output.read_to_vec::(v.len()) +} + +#[test] +fn cos_f32() { + let v = vec![1.0f32, 2.0, 3.0]; + let results = run(&v, unary::contiguous::cos::FLOAT); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]); + assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]); + + let v = vec![1.0f32; 10_000]; + let results = run(&v, unary::contiguous::cos::FLOAT); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403; 10_000]); + assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); +} + +#[test] +fn cos_f32_strided() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let shape = vec![6]; + let strides = vec![1]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!( + approx(results, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + assert_eq!( + approx(expected, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + + // Contiguous + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let shape = vec![3, 2]; + let strides = vec![2, 1]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!( + approx(results, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + assert_eq!( + approx(expected, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + + // Transposed + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let shape = vec![3, 2]; + let strides = vec![1, 3]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!( + approx(results, 4), + vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602] + ); + assert_eq!( + approx(expected, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + + // Very large + let v = vec![1.0f32; 10_000]; + let shape = vec![2, 5_000]; + let strides = vec![2, 1]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403; 10_000]); + assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); +} + +#[test] +fn cos_strided_random() { + let v: Vec<_> = (0..10_000).map(|_| rand::random::()).collect(); + let shape = vec![5_000, 2]; + let strides = vec![1, 5_000]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(vec![results[0]], 4), approx(vec![expected[0]], 4)); + assert_eq!( + approx(vec![results[1]], 4), + approx(vec![expected[5_000]], 4) + ); + assert_eq!(approx(vec![results[2]], 4), approx(vec![expected[1]], 4)); + assert_eq!( + approx(vec![results[3]], 4), + approx(vec![expected[5_001]], 4) + ); + assert_eq!( + approx(vec![results[5_000]], 4), + approx(vec![expected[2_500]], 4) + ); +} + +#[test] +fn binary_add_f32() { + let left = vec![1.0f32, 2.0, 3.0]; + let right = vec![2.0f32, 3.1, 4.2]; + let results = run_binary(&left, &right, binary::contiguous::add::FLOAT); + let expected: Vec<_> = left + .iter() + .zip(right.iter()) + .map(|(&x, &y)| x + y) + .collect(); + assert_eq!(approx(results, 4), vec![3.0f32, 5.1, 7.2]); + assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]); +} + +fn cast(v: &[T], name: &'static str) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); + + call_cast_contiguous( + &device, + command_buffer, + &kernels, + name, + v.len(), + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + output.read_to_vec::(v.len()) +} + +#[test] +fn cast_u32_f32() { + let v = vec![1u32, 2, 3]; + let results = cast(&v, "cast_u32_f32"); + let expected: Vec<_> = v.iter().map(|&v| v as f32).collect(); + assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]); + assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]); + + let v = vec![1.0f32; 10_000]; + let results = run(&v, unary::contiguous::cos::FLOAT); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403; 10_000]); + assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); +} + +fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); + + let size = v.len(); + + call_affine( + &device, + command_buffer, + &kernels, + size, + &input, + &mut output, + mul as f32, + add as f32, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::(v.len()) +} + +#[test] +fn affine() { + let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let mul = 1.5; + let add = 1.1; + let result = run_affine(&input, mul, add); + assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]); + + let input = [1.0f32; 40_000]; + let mul = 1.5; + let add = 1.1; + let result = run_affine(&input, mul, add); + assert_eq!(result, vec![2.6; 40_000]); +} + +#[test] +fn index_select() { + let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + let shape = [5, 2]; + let ids = [0u32, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim); + assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]); + + let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + let shape = [2, 5]; + let ids = [0u32, 1, 0]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim); + assert_eq!( + result, + vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] + ); + + let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + let shape = [5, 2]; + let ids = [0u32, 1, 0]; + let dim = 1; + let result = run_index_select(&embedding, &shape, &ids, dim); + assert_eq!( + result, + vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] + ); +} + +fn run_index_select( + embeddings: &[T], + shape: &[usize], + ids: &[I], + dim: usize, +) -> Vec { + let device = Device::system_default().expect("no device found"); + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let embeddings_buffer = new_buffer(&device, &embeddings); + let ids_buffer = new_buffer(&device, &ids); + + let left_size: usize = shape[..dim].iter().product(); + let right_size: usize = shape[dim + 1..].iter().product(); + let dst_el = ids.len() * left_size * right_size; + let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); + + let kernels = Kernels::new(); + call_index_select( + &device, + &command_buffer, + &kernels, + "is_u32_f32", + shape, + ids.len(), + dim, + &embeddings_buffer, + &ids_buffer, + &mut dst_buffer, + ) + .unwrap(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + dst_buffer.read_to_vec::(dst_el) +} + +#[test] +fn index_add() { + let device = Device::system_default().expect("no device found"); + + let options = CompileOptions::new(); + let library = device.new_library_with_source(INDEXING, &options).unwrap(); + + let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; + let right = [1.0f32; 15]; + let index = [0u32, 4, 2]; + let ids_dim_size = index.len() as u32; + let dst_dim_size: u32 = 15; + let left_size: u32 = 3; + let right_size: u32 = 3; + + let function = library.get_function("ia_u32_f32", None).unwrap(); + let pipeline = device + .new_compute_pipeline_state_with_function(&function) + .unwrap(); + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.set_compute_pipeline_state(&pipeline); + + let index_buffer = new_buffer(&device, &index); + let inputs_buffer = new_buffer(&device, &left); + let outputs_buffer = new_buffer(&device, &right); + + set_params!( + encoder, + ( + &index_buffer, + &inputs_buffer, + &outputs_buffer, + ids_dim_size, + left_size, + dst_dim_size, + right_size + ) + ); + + let grid_size = MTLSize { + width: right.len() as NSUInteger, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width: pipeline.max_total_threads_per_threadgroup(), + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(grid_size, thread_group_size); + encoder.end_encoding(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let expected = vec![ + 2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0, + ]; + let result = outputs_buffer.read_to_vec::(right.len()); + assert_eq!(result, expected); +} + +#[test] +fn cos_f16() { + let v: Vec = [1.0f32, 2.0, 3.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let results = run(&v, unary::contiguous::cos::HALF); + let expected: Vec = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect(); + assert_eq!(approx_f16(results, 4), vec![0.5405, -0.4163, -0.9902]); + assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]); +} + +fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input = new_buffer(&device, v); + + let options = MTLResourceOptions::StorageModeManaged; + let mut output = device.new_buffer((out_length * core::mem::size_of::()) as u64, options); + call_reduce_contiguous( + &device, + command_buffer, + &kernels, + name, + v.len(), + out_length, + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::(out_length) +} + +fn run_softmax(v: &[T], last_dim: usize, name: &'static str) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); + call_last_softmax( + &device, + command_buffer, + &kernels, + name, + v.len(), + last_dim, + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::(v.len()) +} + +#[test] +fn reduce_sum() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let out_length = 1; + + let results = run_reduce(&v, out_length, "fast_sum_float"); + assert_eq!(approx(results, 4), vec![21.0]); +} + +#[test] +fn reduce_sum2() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let out_length = 2; + + let results = run_reduce(&v, out_length, "fast_sum_float"); + assert_eq!(approx(results, 4), vec![6.0, 15.0]); +} + +#[test] +fn softmax() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_float"); + assert_eq!( + approx(results, 4), + vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] + ); + + let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]; + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_float"); + assert_eq!( + approx(results, 4), + vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] + ); + + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let last_dim = 3; + let results = run_softmax(&v, last_dim, "softmax_float"); + assert_eq!( + approx(results, 4), + vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652] + ); +} + +fn run_where_cond( + shape: &[usize], + cond: &[I], + (cond_stride, cond_offset): (Vec, usize), + left_true: &[T], + (left_stride, left_offset): (Vec, usize), + right_false: &[T], + (_right_stride, _right_offset): (Vec, usize), + name: &'static str, +) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + + let length = cond.len(); + let cond = device.new_buffer_with_data( + cond.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(cond) as u64, + options, + ); + let left = device.new_buffer_with_data( + left_true.as_ptr() as *const core::ffi::c_void, + (length * core::mem::size_of::()) as u64, + options, + ); + let right = device.new_buffer_with_data( + right_false.as_ptr() as *const core::ffi::c_void, + (length * core::mem::size_of::()) as u64, + options, + ); + + let mut output = device.new_buffer((length * core::mem::size_of::()) as u64, options); + call_where_cond_strided( + &device, + command_buffer, + &kernels, + name, + shape, + &cond, + (&cond_stride, cond_offset), + &left, + (&left_stride, left_offset), + &right, + (&cond_stride, cond_offset), + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::(length) +} + +#[test] +fn where_cond() { + let shape = vec![6]; + let cond = vec![0u8, 1, 0, 0, 1, 1]; + let cond_l = (vec![1], 0); + let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let left_l = (vec![1], 0); + let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0]; + let right_l = (vec![1], 0); + let results = run_where_cond( + &shape, + &cond, + cond_l, + &left_true, + left_l, + &right_false, + right_l, + "where_u8_f32", + ); + assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); +} diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal new file mode 100644 index 0000000000..eb6424e81c --- /dev/null +++ b/candle-metal-kernels/src/unary.metal @@ -0,0 +1,80 @@ +#include + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +template METAL_FUNC T sqr(T in){ return in * in; } +template METAL_FUNC T neg(T in){ return -in; } +template METAL_FUNC T id(T in){ return in; } + + +using namespace metal; + +#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint thread_position_in_grid [[ thread_position_in_grid ]] \ +) { \ + if (thread_position_in_grid >= dim) { \ + return; \ + } \ + output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \ +}\ +kernel void FN_NAME_STRIDED( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint thread_position_in_grid [[ thread_position_in_grid ]] \ +) { \ + if (thread_position_in_grid >= dim) { \ + return; \ + } \ + output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \ +} + +#define UNARY_OP(NAME) \ +UNARY(NAME, float, NAME##_float, NAME##_float_strided); \ +UNARY(NAME, half, NAME##_half, NAME##_half_strided); + +#define BFLOAT_UNARY_OP(NAME) \ +UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided); + + +UNARY_OP(cos) +UNARY_OP(sin) +UNARY_OP(sqr) +UNARY_OP(sqrt) +UNARY_OP(neg) +UNARY_OP(exp) +UNARY_OP(log) +UNARY(id, float, copy_float, copy_float_strided) +UNARY(id, half, copy_half, copy_half_strided) + +#if __METAL_VERSION__ >= 310 +BFLOAT_UNARY_OP(cos) +BFLOAT_UNARY_OP(sin) +BFLOAT_UNARY_OP(sqr) +BFLOAT_UNARY_OP(sqrt) +BFLOAT_UNARY_OP(neg) +BFLOAT_UNARY_OP(exp) +BFLOAT_UNARY_OP(log) + +UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided) +#endif