From aed9b290ff75e9c7c825be9ffedab4c93fb769d1 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Thu, 26 Oct 2023 13:59:08 +0200 Subject: [PATCH] Add gemm correctness test and improve the performance test --- examples/mps/matrix-multiplication/main.rs | 227 +++++++++++++++++---- src/mps.rs | 25 ++- 2 files changed, 205 insertions(+), 47 deletions(-) diff --git a/examples/mps/matrix-multiplication/main.rs b/examples/mps/matrix-multiplication/main.rs index d833b65..c4792ec 100644 --- a/examples/mps/matrix-multiplication/main.rs +++ b/examples/mps/matrix-multiplication/main.rs @@ -1,6 +1,126 @@ use metal::mps::*; use metal::*; use rand::{thread_rng, Rng}; +use std::io::Write; +use std::ops::{AddAssign, Mul}; +use std::{array, io}; + +fn main() { + correctness(); + performance(); +} + +fn correctness() { + // First verify the correctness of the naive solution + let a = Matrix::new([1, 2, 6, 24, 120, 720], 3, 2); + let b = Matrix::new([1, 2, 3, 5, 8, 13], 2, 3); + let result = matrix_mul::(a, b); + assert_eq!( + result.entries(), + &[11, 18, 29, 126, 204, 330, 3720, 6000, 9720] + ); + + const M: u64 = 100; + const N: u64 = 100; + const K: u64 = 100; + const ITERATIONS: usize = 50; + + let device = Device::system_default().expect("No device found"); + let command_queue = device.new_command_queue(); + + println!("Correctness: "); + for i in 0..ITERATIONS { + progress_bar(i, ITERATIONS); + + let left = generate_matrix::(); + let right = generate_matrix::(); + + let command_buffer = command_queue.new_command_buffer(); + let result = encode_gemm( + &device, + command_buffer, + false, + false, + &left, + &right, + 1.0, + 0.0, + ); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let expected = matrix_mul(left, right); + approx_eq(result.contents(), expected.entries().to_vec()); + } + + println!(" ✅\n"); +} + +fn performance() { + const M: u64 = 4096; + const N: u64 = 4096; + const K: u64 = 4096; + + const ITERATIONS: usize = 50; + + println!("Performance: "); + println!("Generating input matrices: (f32 {M}x{K} and f16 {K}x{N})"); + // Generate random matrices + let left = generate_matrix::(); + let right = generate_matrix::(); + + // Setup + let device = Device::system_default().expect("No device found"); + let command_queue = device.new_command_queue(); + + let cases = [ + (false, false, 1.0, 0.0), + (true, false, 1.0, 0.0), + (false, true, 1.0, 0.0), + (false, false, 0.5, 0.0), + (false, false, 1.0, 0.5), + ]; + for (t_left, t_right, alpha, beta) in cases { + println!("Running with transpose left: {t_left}, transpose right: {t_right}, alpha: {alpha}, beta: {beta}"); + let mut flops: Vec = vec![]; + + let mut total_time = std::time::Duration::new(0, 0); + for i in 0..ITERATIONS { + progress_bar(i, ITERATIONS); + + let start = std::time::Instant::now(); + let command_buffer = command_queue.new_command_buffer(); + let _ = encode_gemm( + &device, + command_buffer, + t_left, + t_right, + &left, + &right, + alpha, + beta, + ); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let time = std::time::Instant::now() - start; + + total_time += time; + + // Calculate GFLOPS + // C <- alpha * AB + beta * C + // Operations = 2(M * N * K) + flops.push((M * N * (2 * K + 2)) as f64 / (time.as_secs_f64() * 1e+9f64)); + } + println!(" ✅"); + + let avg_gflops = flops.iter().sum::() / flops.len() as f64; + println!("Avg GFLOPS: {}", avg_gflops); + println!("Total time: {:#?}", total_time); + println!("Avg time: {:#?}", total_time / ITERATIONS as u32); + println!() + } +} fn generate_matrix() -> Matrix where @@ -9,59 +129,78 @@ where { let mut rng = thread_rng(); Matrix::new( - (0..ROWS * COLS).map(|_| T::from_f64(rng.gen())).collect(), + (0..ROWS * COLS).map(|_| T::from_f64(rng.gen())), ROWS as NSUInteger, COLS as NSUInteger, ) } -fn main() { - const M: u64 = 4096; - const N: u64 = 4096; - const K: u64 = 4096; - const RUNS: u64 = 100; +// Naive matrix multiplication for testing +fn matrix_mul(a: Matrix, b: Matrix) -> Matrix +where + T::Type: AddAssign + Mul + Copy, +{ + assert_eq!(a.columns(), b.rows()); + let sum_count = a.columns() as usize; + let rows = a.rows() as usize; + let columns = b.columns() as usize; + let size = rows * columns; - let transpose_left = false; - let transpose_right = false; - let alpha = 1.0; - let beta = 0.0; + let mut entries = Vec::with_capacity(size); - // Generate random matrices - let left = generate_matrix::(); - let right = generate_matrix::(); + for idx in 0..size { + let i = idx / rows; + let j = idx % columns; - // Setup - let device = Device::system_default().expect("No device found"); - let command_queue = device.new_command_queue(); - let mut total_time = std::time::Duration::new(0, 0); + let mut sum = T::from_f64(0.0); + for di in 0..sum_count { + sum += a.entry(i, di) * b.entry(di, j); + } + entries.push(sum); + } - for _ in 0..RUNS { - let command_buffer = command_queue.new_command_buffer(); - let start = std::time::Instant::now(); - let _ = encode_gemm( - &device, - command_buffer, - transpose_left, - transpose_right, - &left, - &right, - alpha, - beta, - ); - command_buffer.commit(); - command_buffer.wait_until_completed(); - let time = std::time::Instant::now() - start; - total_time += time; + Matrix::new(entries, a.rows(), b.columns()) +} + +fn euclidean_distance(a: Vec, b: Vec) -> f64 +where + T: Into + Clone + Copy, +{ + assert_eq!(a.len(), b.len(), "Lengths not equal"); + + let mut sum = 0.0; + + for i in 0..a.len() { + sum += (a[i].into() - b[i].into()).powi(2); } - // Calculate GFLOPS - // C <- alpha * AB + beta * C - // Operations = M * N * (K+2) + M * N * K - let ops_count = M * N * (2 * K + 2); - let ops_count = (ops_count * RUNS) as f64; - let gflops = ops_count / (total_time.as_secs_f64() * 1000e+3f64); - // TODO: Something is wrong here hehe - println!("GFLOPS: {}", gflops); - println!("Total time: {:?}", total_time); - println!("Avg time: {:?}", total_time / RUNS as u32); + sum.sqrt() +} + +fn approx_eq(a: Vec, b: Vec) +where + T: Into + Clone + Copy, +{ + assert_eq!(a.len(), b.len(), "Lengths not equal"); + + let avg_magnitude = 0.004f64; + let avg_deviation = (a.len() as f64).sqrt(); + let tolerance = avg_magnitude.max(avg_deviation * 3e-7); + + let distance = euclidean_distance(a, b); + assert!( + distance < tolerance, + "Distance not less than tolerance: {} < {} ", + distance, + tolerance + ); +} + +fn progress_bar(i: usize, len: usize) { + print!("\r"); + print!("["); + print!("{}", "=".repeat(i)); + print!("{}", " ".repeat(len - i - 1)); + print!("]"); + io::stdout().flush().unwrap(); } diff --git a/src/mps.rs b/src/mps.rs index 86fb696..b984284 100644 --- a/src/mps.rs +++ b/src/mps.rs @@ -908,7 +908,12 @@ pub struct Matrix { } impl Matrix { - pub fn new(entries: Vec, rows: NSUInteger, columns: NSUInteger) -> Self { + pub fn new>( + entries: E, + rows: NSUInteger, + columns: NSUInteger, + ) -> Matrix { + let entries: Vec = entries.into_iter().collect(); assert_eq!(entries.len(), rows as usize * columns as usize); Self { entries, @@ -916,8 +921,22 @@ impl Matrix { columns, } } - pub fn entries(&self) -> Vec { - self.entries.clone() + pub fn entries(&self) -> &[T::Type] { + &self.entries + } + + pub fn entry(&self, row: usize, column: usize) -> T::Type { + assert!(row < self.rows as usize); + assert!(column < self.columns as usize); + self.entries[row * self.columns as usize + column] + } + + pub fn rows(&self) -> NSUInteger { + self.rows + } + + pub fn columns(&self) -> NSUInteger { + self.columns } }