from algorithm import vectorize, parallelize, unroll, vectorize_unroll, unswitch from algorithm import Static1DTileUnitFunc as Tile1DFunc from memory import memset_zero, memcpy, stack_allocation, memset import benchmark from sys import argv from random import rand from testing import * import math.limit alias BufferPtrFloat32 = DTypePointer[DType.float32] # alias nelts = simdwidthof[DType.float32]() alias workers = 12 # batch matmul with tiling V1 @always_inline fn tile_parallel_v1[tiled_fn: Tile1DFunc, tile: Int](end: Int): fn row(i: Int): let io = i * tile tiled_fn[tile](io) parallelize[row](end // tile, workers) # deal with tail elements if end % tile != 0: tiled_fn[tile](end - tile) @always_inline fn batch_matmul_tiling_v1[ n: Int ]( C: StaticTuple[n, BufferPtrFloat32], A: BufferPtrFloat32, B: StaticTuple[n, BufferPtrFloat32], rows: Int, cols: Int, ): alias nelts = simdwidthof[DType.float32]() alias tile_j = 2**n * nelts alias tile_i = 8 // n alias stack_size = tile_i * nelts @parameter fn calc_tiles_row[tile_i: Int](io: Int): var accumulator = StaticTuple[n, BufferPtrFloat32]() @parameter fn _init[k: Int](): accumulator[k] = stack_allocation[stack_size, DType.float32]() memset_zero(accumulator[k], stack_size) unroll[n, _init]() var temp_a = stack_allocation[tile_j, DType.float32]() @parameter fn calc_cols(jo: Int): @parameter fn copy_a[nelts: Int](j: Int): temp_a.simd_store[nelts](j, A.simd_load[nelts](jo + j)) vectorize_unroll[nelts, tile_j // nelts, copy_a](tile_j) @parameter fn calc_row[i: Int](): @parameter fn calc_col[nelts: Int](j: Int): @parameter fn _multiply[k: Int](): accumulator[k].simd_store[nelts]( i * nelts, accumulator[k].simd_load[nelts](i * nelts) + temp_a.simd_load[nelts](j) * B[k].simd_load[nelts]((io + i) * cols + jo + j), ) unroll[n, _multiply]() vectorize_unroll[nelts, tile_j // nelts, calc_col](tile_j) unroll[tile_i, calc_row]() for jo in range(0, cols - cols % tile_j, tile_j): calc_cols(jo) # deal with tail elements if cols % tile_j != 0: let temp = cols - cols % tile_j @unroll for i in range(tile_i): @parameter fn calc_tail_col[_nelts: Int](jo: Int): let j = temp + jo @parameter fn _multiply[k: Int](): accumulator[k].simd_store[_nelts]( i * nelts, accumulator[k].simd_load[_nelts](i * nelts) + A.simd_load[_nelts](j) * B[k].simd_load[_nelts]((io + i) * cols + j), ) unroll[n, _multiply]() vectorize[nelts, calc_tail_col](cols % tile_j) @parameter fn copy_values[i: Int](): @parameter fn _reduce[k: Int](): C[k].store( io + i, accumulator[k].simd_load[nelts](i * nelts).reduce_add() ) unroll[n, _reduce]() unroll[tile_i, copy_values]() tile_parallel_v1[calc_tiles_row, tile_i](rows) # batch matmul with tiling V2 @always_inline fn tile_parallel_v2[tiled_fn: Tile1DFunc, tile: Int](end: Int): fn row(i: Int): let io = i * tile tiled_fn[tile](io) parallelize[row](end // tile, workers) # deal with tail elements # for i in range(end - end % tile, end, 1): # tiled_fn[1](i) if end % tile != 0: tiled_fn[tile](end - tile) @always_inline fn batch_matmul_tiling_v2[ n: Int, _j: Int = 2**n, _i: Int = 8 // n ]( C: StaticTuple[n, BufferPtrFloat32], A: BufferPtrFloat32, B: StaticTuple[n, BufferPtrFloat32], rows: Int, cols: Int, ): alias nelts = simdwidthof[DType.float32]() # TODO: This is probably the reason why we get different results in speedup alias tile_j = _j * nelts alias tile_i = _i # alias tiles_i = (8, 4, 1) # alias tiles_j = (1, 4, 8) # alias tile_j = tiles_j.get[n - 1, Int]() * nelts # alias tile_i = tiles_i.get[n - 1, Int]() alias stack_size = tile_i * nelts @parameter fn calc_tiles_row[tile_i: Int](io: Int): var accumulator = StaticTuple[n, BufferPtrFloat32]() @parameter fn _init[k: Int](): accumulator[k] = stack_allocation[stack_size, DType.float32]() memset_zero(accumulator[k], stack_size) unroll[n, _init]() @parameter fn calc_cols[tile_j_unroll: Int](jo: Int, tile_j: Int): @parameter fn _batch[k: Int](): @parameter fn calc_row[i: Int](): let row_offset_c = i * nelts let row_offset_b = (io + i) * cols @parameter fn calc_col[_nelts: Int](j: Int): accumulator[k].simd_store[_nelts]( row_offset_c, accumulator[k].simd_load[_nelts](row_offset_c) + A.simd_load[_nelts](jo + j) * B[k].simd_load[_nelts](row_offset_b + jo + j), ) vectorize_unroll[nelts, tile_j_unroll, calc_col](tile_j) unroll[tile_i, calc_row]() unroll[n, _batch]() for jo in range(0, cols - cols % tile_j, tile_j): calc_cols[tile_j // nelts](jo, tile_j) calc_cols[1](cols - cols % tile_j, cols % tile_j) @parameter fn _copy_values[k: Int](): @parameter fn _reduce[i: Int](): C[k].store( io + i, accumulator[k].simd_load[nelts](i * nelts).reduce_add() ) unroll[tile_i, _reduce]() unroll[n, _copy_values]() tile_parallel_v2[calc_tiles_row, tile_i](rows) # normal batch matmul @always_inline fn batch_matmul[ n: Int ]( C: StaticTuple[n, BufferPtrFloat32], A: BufferPtrFloat32, B: StaticTuple[n, BufferPtrFloat32], rows: Int, cols: Int, ): alias nelts = 4 * simdwidthof[DType.float32]() # alias nelts = 2 * simdwidthof[DType.float32]() @parameter fn compute_row(i: Int): var tmp = StaticTuple[n, SIMD[DType.float32, nelts]]() @parameter fn init[k: Int](): tmp[k] = SIMD[DType.float32, nelts](0) unroll[n, init]() let row_offset = i * cols @parameter fn dot[_nelts: Int](j: Int): if _nelts < nelts: # take care of tail array elements with length < nelts let a = A.simd_load[_nelts](j) @parameter fn _multiply_tail[k: Int](): tmp[k][0] += ( a * B[k].simd_load[_nelts](row_offset + j) ).reduce_add() unroll[n, _multiply_tail]() else: let a = A.simd_load[nelts](j) @parameter fn _multiply[k: Int](): tmp[k] += a * B[k].simd_load[nelts](row_offset + j) unroll[n, _multiply]() vectorize[nelts, dot](cols) @parameter fn _reduce[k: Int](): C[k].store(i, tmp[k].reduce_add()) unroll[n, _reduce]() parallelize[compute_row](rows, workers) fn benchmark_matmul[ n: Int ]( c: StaticTuple[n, BufferPtrFloat32], a: BufferPtrFloat32, b: StaticTuple[n, BufferPtrFloat32], rows: Int, cols: Int, ): @parameter fn matmul_no_tiling(): batch_matmul[n](c, a, b, rows, cols) @parameter fn matmul_tiling_v1(): batch_matmul_tiling_v1[n](c, a, b, rows, cols) @parameter fn matmul_tiling_v2(): batch_matmul_tiling_v2[n](c, a, b, rows, cols) print("\n---Benchmarking Batch Matmul size", n, "\n") let report_1 = benchmark.run[matmul_no_tiling](10, 10_000) let report_2 = benchmark.run[matmul_tiling_v1](10, 10_000) let report_3 = benchmark.run[matmul_tiling_v2](10, 10_000) print("matmul_no_tiling: ", report_1.mean["ms"]()) # report_1.print() print("matmul_tiling_v1: ", report_2.mean["ms"]()) # report_2.print() print("matmul_tiling_v2: ", report_3.mean["ms"]()) # report_3.print() print("Speedup v1: ", report_1.mean["ms"]() / report_2.mean["ms"]()) print("Speedup v2: ", report_1.mean["ms"]() / report_3.mean["ms"]()) fn test_cache_misses[ n: Int ]( c: StaticTuple[n, BufferPtrFloat32], a: BufferPtrFloat32, b: StaticTuple[n, BufferPtrFloat32], rows: Int, cols: Int, version: Int = 1, ): for i in range(1_000): if version == 0: batch_matmul[n](c, a, b, rows, cols) elif version == 1: batch_matmul_tiling_v1[n](c, a, b, rows, cols) else: batch_matmul_tiling_v2[n](c, a, b, rows, cols) fn test_matmul(rows: Int, cols: Int): print("Testing Batch Matmul size", "rows:", rows, "cols:", cols) let a = BufferPtrFloat32.alloc(cols) let b = BufferPtrFloat32.alloc(cols * rows) let b_2 = BufferPtrFloat32.alloc(cols * rows) let c_1 = BufferPtrFloat32.alloc(rows) let c_1_2 = BufferPtrFloat32.alloc(rows) let c_2 = BufferPtrFloat32.alloc(rows) let c_2_2 = BufferPtrFloat32.alloc(rows) memset_zero(c_1, rows) memset_zero(c_2, rows) memset_zero(c_1_2, rows) memset_zero(c_2_2, rows) # initialize rand[DType.float32](a, cols) rand[DType.float32](b, cols * rows) rand[DType.float32](b_2, cols * rows) batch_matmul[1](c_1, a, b, rows, cols) batch_matmul_tiling_v2[1](c_2, a, b, rows, cols) var flag = True for i in range(rows): flag = flag and assert_almost_equal(c_1.load(i), c_2.load(i)) print("test_matmul size 1: ", flag) memset_zero(c_1, rows) memset_zero(c_2, rows) batch_matmul[2](StaticTuple[2](c_1, c_1_2), a, StaticTuple[2](b, b_2), rows, cols) batch_matmul_tiling_v2[2]( StaticTuple[2](c_2, c_2_2), a, StaticTuple[2](b, b_2), rows, cols ) flag = True for i in range(rows): flag = flag and assert_almost_equal(c_1.load(i), c_2.load(i)) flag = flag and assert_almost_equal(c_1_2.load(i), c_2_2.load(i)) print("test_matmul size 2: ", flag) _ = (a, b, b_2, c_1, c_1_2, c_2, c_2_2) fn benchmark_matmul_tiles[ n: Int, _j: Int, _i: Int ]( c: StaticTuple[n, BufferPtrFloat32], a: BufferPtrFloat32, b: StaticTuple[n, BufferPtrFloat32], rows: Int, cols: Int, ) -> Float64: @parameter fn matmul_tiling_v2(): batch_matmul_tiling_v2[n, _j, _i](c, a, b, rows, cols) let report_1 = benchmark.run[matmul_tiling_v2](10, 1_000) return report_1.mean["ms"]() fn test_matmul_tiles_speeds(): let sizes = ( (256, 257), (288, 288), (288, 526), (768, 288), (512, 512), (517, 517), (32000, 288), (288, 32000), (2000, 3210), (5300, 2000), ) # 10 @parameter fn test_tiles_matmul[i: Int](): let rows_cols = sizes.get[i, Tuple[Int, Int]]() let rows = rows_cols.get[0, Int]() let cols = rows_cols.get[1, Int]() let a = BufferPtrFloat32.alloc(cols) let b = BufferPtrFloat32.alloc(cols * rows) let c = BufferPtrFloat32.alloc(rows) let b_2 = BufferPtrFloat32.alloc(cols * rows) let b_3 = BufferPtrFloat32.alloc(cols * rows) let c_2 = BufferPtrFloat32.alloc(rows) let c_3 = BufferPtrFloat32.alloc(rows) print("\nBenchmarking Batch Matmul size", "rows:", rows, "cols:", cols) alias i_j_values = (1, 2, 4, 6, 8, 10, 12) var best_speed = limit.max_finite[DType.float64]() var best_j_i = (0, 0) var speeds = StaticTuple[49, Tuple[Float64, Int, Int]]() @parameter fn outer_test_values[j: Int](): @parameter fn test_values[k: Int](): # let temp = benchmark_matmul_tiles[ # 3, i_j_values.get[j, Int](), i_j_values.get[k, Int]() # ]( # StaticTuple[3](c, c_2, c_3), # a, # StaticTuple[3](b, b_2, b_3), # rows, # cols, # ) # let temp = benchmark_matmul_tiles[ # 2, i_j_values.get[j, Int](), i_j_values.get[k, Int]() # ](StaticTuple[2](c, c_2), a, StaticTuple[2](b, b_2), rows, cols) let temp = benchmark_matmul_tiles[ 1, i_j_values.get[j, Int](), i_j_values.get[k, Int]() ](c, a, b, rows, cols) speeds[j * 7 + k] = ( temp, i_j_values.get[j, Int](), i_j_values.get[k, Int](), ) if temp < best_speed: best_speed = temp best_j_i = (i_j_values.get[j, Int](), i_j_values.get[k, Int]()) unroll[7, test_values]() unroll[7, outer_test_values]() for i in range(49): print( "speed: ", speeds[i].get[0, Float64](), "j_i: ", speeds[i].get[1, Int](), speeds[i].get[2, Int](), ) print( "Best speed: ", best_speed, "j_i: ", best_j_i.get[0, Int](), best_j_i.get[1, Int](), ) _ = (a, b, b_2, b_3, c, c_2, c_3) unroll[10, test_tiles_matmul]() fn main(): let args = argv() let sizes = ( (256, 257), (288, 288), (288, 526), (768, 288), (512, 512), (517, 517), (32000, 288), (288, 32000), (2000, 3210), (5300, 2000), ) # 10 # test_matmul_tiles_speeds() if len(args) > 1: var batch_size = 1 var version = 2 var rows = 288 var cols = 288 var to_print = False try: for i in range(len(args)): if args[i] == "-s": batch_size = atol(args[i + 1]) if args[i] == "-v": version = atol(args[i + 1]) if args[i] == "-rc": rows = atol(args[i + 1]) cols = atol(args[i + 2]) if args[i] == "-p": to_print = True except e: print("Invalid arguments", e) if to_print: print( "version:", version, ",batch_size: ", batch_size, ",rows: ", rows, ",cols: ", cols, ) let a = BufferPtrFloat32.alloc(cols) let b = BufferPtrFloat32.alloc(cols * rows) let c = BufferPtrFloat32.alloc(rows) let b_2 = BufferPtrFloat32.alloc(cols * rows) let b_3 = BufferPtrFloat32.alloc(cols * rows) let c_2 = BufferPtrFloat32.alloc(rows) let c_3 = BufferPtrFloat32.alloc(rows) if batch_size == 1: test_cache_misses[1](c, a, b, rows, cols, version) elif batch_size == 2: test_cache_misses[2]( StaticTuple[2](c, c_2), a, StaticTuple[2](b, b_2), rows, cols, version ) elif batch_size >= 3: test_cache_misses[3]( StaticTuple[3](c, c_2, c_3), a, StaticTuple[3](b, b_2, b_3), rows, cols, version, ) _ = (a, b, b_2, b_3, c, c_2, c_3) else: @parameter fn test_matmul_sizes[i: Int](): let rows_cols = sizes.get[i, Tuple[Int, Int]]() test_matmul( rows_cols.get[0, Int](), rows_cols.get[1, Int](), ) unroll[10, test_matmul_sizes]() print("Running benchmark...") @parameter fn benchmark_matmuls[i: Int](): let rows_cols = sizes.get[i, Tuple[Int, Int]]() let rows = rows_cols.get[0, Int]() let cols = rows_cols.get[1, Int]() let a = BufferPtrFloat32.alloc(cols) let b = BufferPtrFloat32.alloc(cols * rows) let c = BufferPtrFloat32.alloc(rows) let b_2 = BufferPtrFloat32.alloc(cols * rows) let b_3 = BufferPtrFloat32.alloc(cols * rows) let c_2 = BufferPtrFloat32.alloc(rows) let c_3 = BufferPtrFloat32.alloc(rows) print("\nBenchmarking Batch Matmul size", "rows:", rows, "cols:", cols) benchmark_matmul[1](c, a, b, rows, cols) benchmark_matmul[2]( StaticTuple[2](c, c_2), a, StaticTuple[2](b, b_2), rows, cols ) benchmark_matmul[3]( StaticTuple[3](c, c_2, c_3), a, StaticTuple[3](b, b_2, b_3), rows, cols, ) _ = (a, b, b_2, b_3, c, c_2, c_3) unroll[10, benchmark_matmuls]()