Skip to content

Commit

Permalink
Merge pull request #1533 from huggingface/ivarflakstad/metal-prng
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad authored Jan 22, 2024
2 parents 1cf3436 + 80b1c68 commit fd7c856
Show file tree
Hide file tree
Showing 7 changed files with 540 additions and 19 deletions.
3 changes: 2 additions & 1 deletion candle-core/benches/bench_main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ mod benchmarks;

use criterion::criterion_main;
criterion_main!(
benchmarks::matmul::benches,
benchmarks::affine::benches,
benchmarks::matmul::benches,
benchmarks::random::benches,
benchmarks::where_cond::benches
);
1 change: 1 addition & 0 deletions candle-core/benches/benchmarks/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub(crate) mod affine;
pub(crate) mod matmul;
pub(crate) mod random;
pub(crate) mod where_cond;

use candle_core::{Device, Result};
Expand Down
63 changes: 63 additions & 0 deletions candle-core/benches/benchmarks/random.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;

fn rand_uniform(a: &Tensor) {
a.rand_like(-1.0, 123.0).unwrap();
}

fn rand_normal(a: &Tensor) {
a.randn_like(100.0, 15.0).unwrap();
}

fn run_random_bench(c: &mut Criterion, device: &Device) {
let b = 1;

let rows = 2048;
let cols = 2048;

let dtype = DType::F32;
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();

let flops = b * rows * cols * dtype.size_in_bytes();

let mut group = c.benchmark_group(device.bench_name("random_uniform"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |benches| {
benches.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
rand_uniform(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();

let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();

let mut group = c.benchmark_group(device.bench_name("random_normal"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |benches| {
benches.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
rand_normal(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}

fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_random_bench(c, &device);
}
}

criterion_group!(benches, criterion_benchmark);
86 changes: 72 additions & 14 deletions candle-core/src/metal_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ use candle_metal_kernels::Kernels;
use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::ffi::c_void;
use std::path::Path;
use std::sync::{Arc, RwLock, TryLockError};
use std::sync::{Arc, Mutex, RwLock, TryLockError};

/// Simple way to catch lock error without
/// depending on T
Expand Down Expand Up @@ -101,6 +102,8 @@ pub struct MetalDevice {
/// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers
/// (strong_count = 1).
buffers: AllocatedBuffers,
/// Seed for random number generation.
seed: Arc<Mutex<Buffer>>,
}

impl std::fmt::Debug for MetalDevice {
Expand Down Expand Up @@ -225,7 +228,7 @@ impl MetalDevice {
// The slice might not live long enough for metal
// To actually fill the GPU buffer.
// Putting this wait forces the GPU buffer to be filled
// with the actual data allowing the CPU storage todo
// with the actual data allowing the CPU storage to do
// deallocate properly.
self.wait_until_completed()?;
Ok(real)
Expand Down Expand Up @@ -1554,6 +1557,11 @@ impl BackendDevice for MetalDevice {
Ok(val) => val.parse()?,
_ => 10,
};
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
[299792458].as_ptr() as *const c_void,
4,
MTLResourceOptions::StorageModeManaged,
)));
Ok(Self {
device,
command_queue,
Expand All @@ -1562,13 +1570,10 @@ impl BackendDevice for MetalDevice {
compute_per_buffer,
buffers,
kernels,
seed,
})
}

fn set_seed(&self, _seed: u64) -> Result<()> {
crate::bail!("Metal set_seed not implemented")
}

fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Metal {
gpu_id: self.registry_id() as usize,
Expand Down Expand Up @@ -1608,12 +1613,31 @@ impl BackendDevice for MetalDevice {
&self,
shape: &Shape,
dtype: DType,
mean: f64,
stddev: f64,
min: f64,
max: f64,
) -> Result<Self::Storage> {
// TODO is there a better way ?
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
self.storage_from_cpu_storage(&cpu_storage)
let name = match dtype {
DType::F32 => "rand_uniform_f32",
DType::F16 => "rand_uniform_f16",
DType::BF16 => "rand_uniform_bf16",
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_uniform")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_random_uniform(
&self.device,
&command_buffer,
&self.kernels,
name,
min as f32,
max as f32,
shape.elem_count(),
&*self.seed.lock().unwrap(),
&buffer,
)
.map_err(MetalError::from)?;

Ok(Self::Storage::new(buffer, self.clone(), dtype))
}

fn rand_normal(
Expand All @@ -1623,9 +1647,43 @@ impl BackendDevice for MetalDevice {
mean: f64,
stddev: f64,
) -> Result<Self::Storage> {
// TODO is there a better way ?
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
self.storage_from_cpu_storage(&cpu_storage)
let name = match dtype {
DType::F32 => "rand_normal_f32",
DType::F16 => "rand_normal_f16",
DType::BF16 => "rand_normal_bf16",
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_normal")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_random_normal(
&self.device,
&command_buffer,
&self.kernels,
name,
mean as f32,
stddev as f32,
shape.elem_count(),
&*self.seed.lock().unwrap(),
&buffer,
)
.map_err(MetalError::from)?;

Ok(Self::Storage::new(buffer, self.clone(), dtype))
}

fn set_seed(&self, seed: u64) -> Result<()> {
let seed: u32 = seed.try_into().map_err(|_| {
MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())
})?;

let seed_buffer = self.seed.try_lock().map_err(MetalError::from)?;
let contents = seed_buffer.contents();
unsafe {
std::ptr::copy([seed].as_ptr(), contents as *mut u32, 4);
}
seed_buffer.did_modify_range(metal::NSRange::new(0, 4));

Ok(())
}
}

Expand Down
76 changes: 74 additions & 2 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ const UNARY: &str = include_str!("unary.metal");
const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal");
const REDUCE: &str = include_str!("reduce.metal");
const CONV: &str = include_str!("conv.metal");
const REDUCE: &str = include_str!("reduce.metal");
const RANDOM: &str = include_str!("random.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const QUANTIZED: &str = include_str!("quantized.metal");

Expand Down Expand Up @@ -62,10 +63,12 @@ macro_rules! primitive {
}
};
}
primitive!(bool);
primitive!(usize);
primitive!(i64);
primitive!(i32);
primitive!(i64);
primitive!(u32);
primitive!(u64);
primitive!(f32);

impl<T> EncoderParam for &[T] {
Expand Down Expand Up @@ -120,6 +123,7 @@ pub enum Source {
Reduce,
Mfa,
Conv,
Random,
Quantized,
}

Expand Down Expand Up @@ -241,6 +245,7 @@ impl Kernels {
Source::Cast => CAST,
Source::Reduce => REDUCE,
Source::Conv => CONV,
Source::Random => RANDOM,
Source::Quantized => QUANTIZED,
Source::Mfa => panic!("Invalid lib"),
}
Expand Down Expand Up @@ -1527,6 +1532,73 @@ pub fn call_upsample_nearest_2d(
Ok(())
}

#[allow(clippy::too_many_arguments)]
pub fn call_random_uniform(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
min: f32,
max: f32,
length: usize,
seed: &Buffer,
buffer: &Buffer,
) -> Result<(), MetalKernelError> {
if min >= max {
return Err(MetalKernelError::LoadLibraryError(
"min must be less than max".to_string(),
));
}
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
let encoder = command_buffer.new_compute_command_encoder();

let odd = (length % 2 != 0) as usize;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);

encoder.set_compute_pipeline_state(&pipeline);

set_params!(encoder, (length, min, max, seed, buffer));

encoder.use_resource(seed, metal::MTLResourceUsage::Read);
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();

Ok(())
}

#[allow(clippy::too_many_arguments)]
pub fn call_random_normal(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
mean: f32,
stddev: f32,
length: usize,
seed: &Buffer,
buffer: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
let encoder = command_buffer.new_compute_command_encoder();

let odd = (length % 2 != 0) as usize;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);

encoder.set_compute_pipeline_state(&pipeline);

set_params!(encoder, (length, mean, stddev, seed, buffer));

encoder.use_resource(seed, metal::MTLResourceUsage::Read);
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();

Ok(())
}

#[derive(Debug, Clone, Copy)]
pub enum GgmlDType {
Q4_0,
Expand Down
Loading

0 comments on commit fd7c856

Please sign in to comment.