From cfd241b7d5e4513f83aa9269a2d18bd58513c766 Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Mon, 25 Mar 2024 16:13:58 -0400 Subject: [PATCH 1/9] Rewrite syrk algorithms --- src/level3/syrk.py | 1254 ++------------------------------------------ 1 file changed, 35 insertions(+), 1219 deletions(-) diff --git a/src/level3/syrk.py b/src/level3/syrk.py index 98c538b..63ee374 100644 --- a/src/level3/syrk.py +++ b/src/level3/syrk.py @@ -1,1236 +1,52 @@ from __future__ import annotations -from os import abort -import sys -import getopt -from exo import proc -from exo.platforms.x86 import * -from exo.platforms.neon import * from exo import * -from exo.syntax import * - from exo.stdlib.scheduling import * - -from kernels.gemm_kernels import GEBP_kernel, Microkernel -from format_options import * -from stdlib import * +from exo.API_cursors import * +from exo.libs.memories import DRAM_STATIC import exo_blas_config as C +from stdlib import * +from codegen_helpers import * +from blaslib import * -class SYRK: - """ - TODO: Add Beta and Alpha - """ - - def __init__( - self, - machine: "MachineParameters", - precision: str, - K_blk: int, - M_blk: int, - M_blk_small: int, - M_r: int, - N_r: int, - e_reg: int, - ): - - # Precision - self.precision = precision - self.prefix = "s" if precision == "f32" else "d" - # print(M_r, N_r) - - # Generate kernels - self.microkernel = Microkernel(machine, M_r, N_r, K_blk, self.precision) - self.gebp_kernel = GEBP_kernel(self.microkernel, M_blk, M_blk, self.precision) - - # Blocking dimensions - self.K_blk = K_blk - self.M_blk = M_blk - self.M_blk_small = M_blk_small - self.e_reg = e_reg - - # Machine - self.machine = machine - - ### SYRK procedures - @proc - def syrk_lower_notranspose_noalpha( - N: size, - K: size, - A1: f32[N, K] @ DRAM, - A2: f32[K, N] @ DRAM, - C: f32[N, N] @ DRAM, - ): - # C = A*A**T + C - assert N >= 1 - assert K >= 1 - assert stride(A1, 1) == 1 - assert stride(A2, 1) == 1 - assert stride(C, 1) == 1 - - for i in seq(0, N): - for j in seq(0, i + 1): - for k in seq(0, K): - C[i, j] += A1[i, k] * A2[k, j] - - syrk_lower_notranspose_noalpha = self.specialize_syrk( - syrk_lower_notranspose_noalpha, self.precision, ["A1", "A2", "C"] - ) - syrk_lower_notranspose_noalpha = rename( - syrk_lower_notranspose_noalpha, - f"{self.prefix}{syrk_lower_notranspose_noalpha.name()}", - ) - - @proc - def syrk_lower_transpose_noalpha( - N: size, - K: size, - A1: f32[K, N] @ DRAM, - A2: f32[K, N] @ DRAM, - C: f32[N, N] @ DRAM, - ): - # C = A**T*A + C - assert N >= 1 - assert K >= 1 - assert stride(A1, 1) == 1 - assert stride(A2, 1) == 1 - assert stride(C, 1) == 1 - assert N == K - for i in seq(0, N): - for j in seq(0, i + 1): - for k in seq(0, K): - C[i, j] += A1[k, i] * A2[k, j] - - syrk_lower_transpose_noalpha = self.specialize_syrk( - syrk_lower_transpose_noalpha, self.precision, ["A1", "A2", "C"] - ) - syrk_lower_transpose_noalpha = rename( - syrk_lower_transpose_noalpha, - f"{self.prefix}{syrk_lower_transpose_noalpha.name()}", - ) - - @proc - def syrk_lower_notranspose_alpha( - N: size, - K: size, - A1: f32[N, K] @ DRAM, - alpha: f32[1], - A2: f32[N, K] @ DRAM, - C: f32[N, N] @ DRAM, - ): - - for i in seq(0, N): - for j in seq(0, i + 1): - temp: f32[1] - temp[0] = 0.0 - for k in seq(0, K): - temp[0] += A1[i, k] * A2[j, k] - C[i, j] += alpha[0] * temp[0] - - syrk_lower_notranspose_alpha = self.specialize_syrk( - syrk_lower_notranspose_alpha, - self.precision, - ["A1", "A2", "C", "alpha", "temp"], - ) - syrk_lower_notranspose_alpha = rename( - syrk_lower_notranspose_alpha, - f"{self.prefix}{syrk_lower_notranspose_alpha.name()}", - ) - - @proc - def syrk_lower_transpose_alpha( - N: size, - K: size, - A1: f32[K, N] @ DRAM, - alpha: f32[1], - A2: f32[K, N] @ DRAM, - C: f32[N, N] @ DRAM, - ): - assert N == K - temp: f32[K, N] - for j in seq(0, N): - for k in seq(0, K): - temp[j, k] = A1[j, k] * alpha[0] - for i in seq(0, N): - for j in seq(0, i + 1): - for k in seq(0, K): - C[i, j] += temp[k, i] * A2[k, j] - - syrk_lower_transpose_alpha = self.specialize_syrk( - syrk_lower_transpose_alpha, - self.precision, - ["A1", "A2", "C", "alpha", "temp"], - ) - syrk_lower_transpose_alpha = rename( - syrk_lower_transpose_alpha, - f"{self.prefix}{syrk_lower_transpose_alpha.name()}", - ) - - @proc - def syrk_upper_notranspose_noalpha( - N: size, - K: size, - A1: f32[N, K] @ DRAM, - A2: f32[N, K] @ DRAM, - C: f32[N, N] @ DRAM, - ): - for j in seq(0, N): - for k in seq(0, K): - for i in seq(0, j + 1): - C[i, j] += A1[i, k] * A2[j, k] - - syrk_upper_notranspose_noalpha = self.specialize_syrk( - syrk_upper_notranspose_noalpha, self.precision, ["A1", "A2", "C"] - ) - syrk_upper_notranspose_noalpha = rename( - syrk_upper_notranspose_noalpha, - f"{self.prefix}{syrk_upper_notranspose_noalpha.name()}", - ) - - @proc - def syrk_upper_transpose_noalpha( - N: size, - K: size, - A1: f32[K, N] @ DRAM, - A2: f32[K, N] @ DRAM, - C: f32[N, N] @ DRAM, - ): - assert K == N - for j in seq(0, N): - for k in seq(0, K): - for i in seq(0, j + 1): - C[i, j] += A1[k, i] * A2[k, j] - - syrk_upper_transpose_noalpha = self.specialize_syrk( - syrk_upper_transpose_noalpha, self.precision, ["A1", "A2", "C"] - ) - syrk_upper_transpose_noalpha = rename( - syrk_upper_transpose_noalpha, - f"{self.prefix}{syrk_upper_transpose_noalpha.name()}", - ) - - @proc - def syrk_upper_notranspose_alpha( - N: size, - K: size, - A1: f32[N, K] @ DRAM, - alpha: f32[1] @ DRAM, - A2: f32[N, K] @ DRAM, - C: f32[N, N] @ DRAM, - ): - for j in seq(0, N): - for k in seq(0, K): - for i in seq(0, j + 1): - C[i, j] += alpha[0] * A1[i, k] * A2[j, k] - - syrk_upper_notranspose_alpha = self.specialize_syrk( - syrk_upper_notranspose_alpha, self.precision, ["A1", "A2", "C", "alpha"] - ) - syrk_upper_notranspose_alpha = rename( - syrk_upper_notranspose_alpha, - f"{self.prefix}{syrk_upper_notranspose_alpha.name()}", - ) - - @proc - def syrk_upper_transpose_alpha( - N: size, - K: size, - A1: f32[K, N] @ DRAM, - alpha: f32[1] @ DRAM, - A2: f32[K, N] @ DRAM, - C: f32[N, N] @ DRAM, - ): - assert K == N - for j in seq(0, N): - for k in seq(0, K): - for i in seq(0, j + 1): - C[i, j] += A1[k, i] * A2[k, j] * alpha[0] - - syrk_upper_transpose_alpha = self.specialize_syrk( - syrk_upper_transpose_alpha, self.precision, ["A1", "A2", "C", "alpha"] - ) - syrk_upper_transpose_alpha = rename( - syrk_upper_transpose_alpha, - f"{self.prefix}{syrk_upper_transpose_alpha.name()}", - ) - - ### Diagonal handlers - @proc - def diag_handler_lower_notranspose( - N: size, - K: size, - A1: [f32][N, K] @ DRAM, - A2: [f32][K, N] @ DRAM, - C: [f32][N, N] @ DRAM, - ): - # C = A*A**T + C - assert N >= 1 - assert K >= 1 - assert stride(A1, 1) == 1 - assert stride(A2, 1) == 1 - assert stride(C, 1) == 1 - - for i in seq(0, N): - for j in seq(0, i): - for k in seq(0, K): - C[i, j] += A1[i, k] * A2[k, j] - - diag_handler_lower_notranspose = set_precision( - diag_handler_lower_notranspose, "A1", self.precision - ) - diag_handler_lower_notranspose = set_precision( - diag_handler_lower_notranspose, "A2", self.precision - ) - diag_handler_lower_notranspose = set_precision( - diag_handler_lower_notranspose, "C", self.precision - ) - diag_handler_lower_notranspose = rename( - diag_handler_lower_notranspose, - f"{self.prefix}_{diag_handler_lower_notranspose.name()}", - ) - - ### Scaling procedures - @proc - def syrk_apply_scalar_lower( - M: size, N: size, scalar: f32[1], P: f32[M, M] @ DRAM - ): - for i in seq(0, M): - for j in seq(0, i + 1): - P[i, j] = P[i, j] * scalar[0] - - syrk_apply_scalar_lower = set_precision( - syrk_apply_scalar_lower, "scalar", self.precision - ) - syrk_apply_scalar_lower = set_precision( - syrk_apply_scalar_lower, "P", self.precision - ) - - @proc - def syrk_apply_scalar_upper( - M: size, N: size, scalar: f32[1], P: f32[M, M] @ DRAM - ): - for i in seq(0, M): - for j in seq(0, M - i): - P[i, j] = P[i, j] * scalar[0] - - syrk_apply_scalar_upper = set_precision( - syrk_apply_scalar_upper, "scalar", self.precision - ) - syrk_apply_scalar_upper = set_precision( - syrk_apply_scalar_upper, "P", self.precision - ) - - @proc - def syrk_apply_scalar_lower_no_overwrite( - M: size, N: size, scalar: f32[1], P: f32[M, N] @ DRAM, Q: f32[M, N] @ DRAM - ): - for i in seq(0, M): - for j in seq(0, N): - Q[i, j] = P[i, j] * scalar[0] - - syrk_apply_scalar_lower_no_overwrite = set_precision( - syrk_apply_scalar_lower_no_overwrite, "scalar", self.precision - ) - syrk_apply_scalar_lower_no_overwrite = set_precision( - syrk_apply_scalar_lower_no_overwrite, "P", self.precision - ) - syrk_apply_scalar_lower_no_overwrite = set_precision( - syrk_apply_scalar_lower_no_overwrite, "Q", self.precision - ) - - ### Alpha and Beta scaling procedures - # TODO: fix a lot here - apply_alpha = self.schedule_apply_scalar( - syrk_apply_scalar_lower_no_overwrite, - machine, - ["Q", "P"], - f"{self.prefix}_apply_alpha_{M_blk}_{K_blk}", - False, - ) - apply_beta_lower = self.schedule_apply_scalar( - syrk_apply_scalar_lower, - machine, - ["P"], - f"{self.prefix}_apply_beta_{M_blk}_{K_blk}", - True, - ) - apply_beta_upper = rename( - syrk_apply_scalar_upper, self.prefix + syrk_apply_scalar_upper.name() - ) - - ### Create scheduled procedures - ( - self.gepp_syrk_scheduled_lower_notranspose, - self.gepp_syrk_base_lower_notranspose, - ) = self.generate_syrk_gepp_lower_notranspose_noalpha( - syrk_lower_notranspose_noalpha, diag_handler_lower_notranspose - ) - syrk_scheduled_lower_notranspose_noalpha = ( - self.schedule_syrk_lower_notranspose_noalpha(syrk_lower_notranspose_noalpha) - ) - - ### Entry points - @proc - def exo_syrk_lower_notranspose_noalpha_nobeta( - N: size, - K: size, - alpha: f32[1] @ DRAM, - A1: f32[N, K] @ DRAM, - A2: f32[K, N] @ DRAM, - beta: f32[1] @ DRAM, - C: f32[N, N] @ DRAM, - ): - syrk_lower_notranspose_noalpha(N, K, A1, A2, C) - - exo_syrk_lower_notranspose_noalpha_nobeta = call_eqv( - exo_syrk_lower_notranspose_noalpha_nobeta, - f"{syrk_lower_notranspose_noalpha.name()}", - syrk_scheduled_lower_notranspose_noalpha, - ) - - exo_syrk_lower_notranspose_noalpha_nobeta = self.specialize_syrk( - exo_syrk_lower_notranspose_noalpha_nobeta, - self.precision, - ["A1", "A2", "C", "alpha", "beta"], - ) - - @proc - def exo_syrk_lower_notranspose_alpha_nobeta( - N: size, - K: size, - alpha: f32[1] @ DRAM, - A1: f32[N, K] @ DRAM, - A2: f32[N, K] @ DRAM, - beta: f32[1] @ DRAM, - C: f32[N, N] @ DRAM, - ): - syrk_lower_notranspose_alpha(N, K, A1, alpha, A2, C) - - exo_syrk_lower_notranspose_alpha_nobeta = self.specialize_syrk( - exo_syrk_lower_notranspose_alpha_nobeta, - self.precision, - ["A1", "A2", "C", "alpha", "beta"], - ) - - @proc - def exo_syrk_lower_notranspose_alpha_beta( - N: size, - K: size, - alpha: f32[1] @ DRAM, - A1: f32[N, K] @ DRAM, - A2: f32[N, K] @ DRAM, - beta: f32[1] @ DRAM, - C: f32[N, N] @ DRAM, - ): - apply_beta_lower(N, N, beta, C) - syrk_lower_notranspose_alpha(N, K, A1, alpha, A2, C) - - exo_syrk_lower_notranspose_alpha_beta = self.specialize_syrk( - exo_syrk_lower_notranspose_alpha_beta, - self.precision, - ["A1", "A2", "C", "alpha", "beta"], - ) - - @proc - def exo_syrk_lower_transpose_noalpha_nobeta( - N: size, - K: size, - alpha: f32[1] @ DRAM, - A1: f32[K, N] @ DRAM, - A2: f32[K, N] @ DRAM, - beta: f32[1] @ DRAM, - C: f32[N, N] @ DRAM, - ): - assert K == N - syrk_lower_transpose_noalpha(N, K, A1, A2, C) - - exo_syrk_lower_transpose_noalpha_nobeta = self.specialize_syrk( - exo_syrk_lower_transpose_noalpha_nobeta, - self.precision, - ["A1", "A2", "C", "alpha", "beta"], - ) - - @proc - def exo_syrk_lower_transpose_alpha_nobeta( - N: size, - K: size, - alpha: f32[1] @ DRAM, - A1: f32[K, N] @ DRAM, - A2: f32[K, N] @ DRAM, - beta: f32[1] @ DRAM, - C: f32[N, N] @ DRAM, - ): - assert K == N - syrk_lower_transpose_alpha(N, K, A1, alpha, A2, C) - - exo_syrk_lower_transpose_alpha_nobeta = self.specialize_syrk( - exo_syrk_lower_transpose_alpha_nobeta, - self.precision, - ["A1", "A2", "C", "alpha", "beta"], - ) - - @proc - def exo_syrk_lower_transpose_alpha_beta( - N: size, - K: size, - alpha: f32[1] @ DRAM, - A1: f32[K, N] @ DRAM, - A2: f32[K, N] @ DRAM, - beta: f32[1] @ DRAM, - C: f32[N, N] @ DRAM, - ): - assert K == N - apply_beta_lower(N, N, beta, C) - syrk_lower_transpose_alpha(N, K, A1, alpha, A2, C) - - exo_syrk_lower_transpose_alpha_beta = self.specialize_syrk( - exo_syrk_lower_transpose_alpha_beta, - self.precision, - ["A1", "A2", "C", "alpha", "beta"], - ) - - @proc - def exo_syrk_lower_alphazero_beta( - N: size, - K: size, - alpha: f32[1] @ DRAM, - A1: f32[N, K] @ DRAM, - A2: f32[N, K] @ DRAM, - beta: f32[1] @ DRAM, - C: f32[N, N] @ DRAM, - ): - apply_beta_lower(N, N, beta, C) - - exo_syrk_lower_alphazero_beta = self.specialize_syrk( - exo_syrk_lower_alphazero_beta, - self.precision, - ["A1", "A2", "C", "alpha", "beta"], - ) - - @proc - def exo_syrk_upper_notranspose_noalpha_nobeta( - N: size, - K: size, - alpha: f32[1] @ DRAM, - A1: f32[N, K] @ DRAM, - A2: f32[N, K] @ DRAM, - beta: f32[1] @ DRAM, - C: f32[N, N] @ DRAM, - ): - syrk_upper_notranspose_noalpha(N, K, A1, A2, C) - - exo_syrk_upper_notranspose_noalpha_nobeta = self.specialize_syrk( - exo_syrk_upper_notranspose_noalpha_nobeta, - self.precision, - ["A1", "A2", "C", "alpha", "beta"], - ) - - @proc - def exo_syrk_upper_notranspose_alpha_nobeta( - N: size, - K: size, - alpha: f32[1] @ DRAM, - A1: f32[N, K] @ DRAM, - A2: f32[N, K] @ DRAM, - beta: f32[1] @ DRAM, - C: f32[N, N] @ DRAM, - ): - syrk_upper_notranspose_alpha(N, K, A1, alpha, A2, C) - - exo_syrk_upper_notranspose_alpha_nobeta = self.specialize_syrk( - exo_syrk_upper_notranspose_alpha_nobeta, - self.precision, - ["A1", "A2", "C", "alpha", "beta"], - ) - - @proc - def exo_syrk_upper_notranspose_alpha_beta( - N: size, - K: size, - alpha: f32[1] @ DRAM, - A1: f32[N, K] @ DRAM, - A2: f32[N, K] @ DRAM, - beta: f32[1] @ DRAM, - C: f32[N, N] @ DRAM, - ): - apply_beta_upper(N, N, beta, C) - syrk_upper_notranspose_alpha(N, K, A1, alpha, A2, C) - - exo_syrk_upper_notranspose_alpha_beta = self.specialize_syrk( - exo_syrk_upper_notranspose_alpha_beta, - self.precision, - ["A1", "A2", "C", "alpha", "beta"], - ) - - @proc - def exo_syrk_upper_transpose_noalpha_nobeta( - N: size, - K: size, - alpha: f32[1] @ DRAM, - A1: f32[K, N] @ DRAM, - A2: f32[K, N] @ DRAM, - beta: f32[1] @ DRAM, - C: f32[N, N] @ DRAM, - ): - assert N == K - syrk_upper_transpose_noalpha(N, K, A1, A2, C) - - exo_syrk_upper_transpose_noalpha_nobeta = self.specialize_syrk( - exo_syrk_upper_transpose_noalpha_nobeta, - self.precision, - ["A1", "A2", "C", "alpha", "beta"], - ) - - @proc - def exo_syrk_upper_transpose_alpha_nobeta( - N: size, - K: size, - alpha: f32[1] @ DRAM, - A1: f32[K, N] @ DRAM, - A2: f32[K, N] @ DRAM, - beta: f32[1] @ DRAM, - C: f32[N, N] @ DRAM, - ): - assert N == K - syrk_upper_transpose_alpha(N, K, A1, alpha, A2, C) - - exo_syrk_upper_transpose_alpha_nobeta = self.specialize_syrk( - exo_syrk_upper_transpose_alpha_nobeta, - self.precision, - ["A1", "A2", "C", "alpha", "beta"], - ) - - @proc - def exo_syrk_upper_transpose_alpha_beta( - N: size, - K: size, - alpha: f32[1] @ DRAM, - A1: f32[K, N] @ DRAM, - A2: f32[K, N] @ DRAM, - beta: f32[1] @ DRAM, - C: f32[N, N] @ DRAM, - ): - assert N == K - apply_beta_upper(N, N, beta, C) - syrk_upper_transpose_alpha(N, K, A1, alpha, A2, C) - - exo_syrk_upper_transpose_alpha_beta = self.specialize_syrk( - exo_syrk_upper_transpose_alpha_beta, - self.precision, - ["A1", "A2", "C", "alpha", "beta"], - ) - - @proc - def exo_syrk_upper_alphazero_beta( - N: size, - K: size, - alpha: f32[1] @ DRAM, - A1: f32[N, K] @ DRAM, - A2: f32[N, K] @ DRAM, - beta: f32[1] @ DRAM, - C: f32[N, N] @ DRAM, - ): - apply_beta_upper(N, N, beta, C) - - exo_syrk_upper_alphazero_beta = self.specialize_syrk( - exo_syrk_upper_alphazero_beta, - self.precision, - ["A1", "A2", "C", "alpha", "beta"], - ) - - self.entry_points = [ - exo_syrk_lower_notranspose_noalpha_nobeta, - exo_syrk_lower_notranspose_alpha_nobeta, - exo_syrk_lower_notranspose_alpha_beta, - exo_syrk_lower_transpose_noalpha_nobeta, - exo_syrk_lower_transpose_alpha_nobeta, - exo_syrk_lower_transpose_alpha_beta, - exo_syrk_upper_notranspose_noalpha_nobeta, - exo_syrk_upper_notranspose_alpha_nobeta, - exo_syrk_upper_notranspose_alpha_beta, - exo_syrk_upper_transpose_noalpha_nobeta, - exo_syrk_upper_transpose_alpha_nobeta, - exo_syrk_upper_transpose_alpha_beta, - exo_syrk_lower_alphazero_beta, - exo_syrk_upper_alphazero_beta, - ] - - def generate_syrk_gepp_base(self, syrk_win: Procedure): - gepp_syrk_base = rename(syrk_win, "gepp_syrk_base") - gepp_syrk_base = gepp_syrk_base.partial_eval(K=self.microkernel.K_blk) - return simplify(gepp_syrk_base) - - def generate_syrk_gepp_lower_notranspose_noalpha( - self, syrk: Procedure, diag_handler: Procedure - ): - - # assert(self.M_blk >= 128) # Temporary - - syrk = rename(syrk, "syrk_win") - syrk = set_window(syrk, "A1", True) - syrk = set_window(syrk, "A2", True) - syrk = set_window(syrk, "C", True) - - gepp_syrk_base = self.generate_syrk_gepp_base(syrk) - - gepp_syrk_scheduled = rename( - gepp_syrk_base, f"gepp_{self.prefix}syrk_scheduled" - ) - gepp_syrk_scheduled = divide_loop( - gepp_syrk_scheduled, "i", self.M_blk, ["io", "ii"], tail="cut" - ) - gepp_syrk_scheduled = cut_loop(gepp_syrk_scheduled, "for j in _:_", 1) - gepp_syrk_scheduled = shift_loop(gepp_syrk_scheduled, "for j in _:_ #1", 0) - gepp_syrk_scheduled = divide_loop( - gepp_syrk_scheduled, "j #1", self.M_blk, ["jo", "ji"], tail="cut" - ) - - gepp_syrk_scheduled = reorder_stmts( - gepp_syrk_scheduled, - gepp_syrk_scheduled.find("for j in _:_ #0").expand(0, 1), - ) - gepp_syrk_scheduled = autofission( - gepp_syrk_scheduled, - gepp_syrk_scheduled.find("for j in _:_").after(), - n_lifts=1, - ) - gepp_syrk_scheduled = autofission( - gepp_syrk_scheduled, - gepp_syrk_scheduled.find("for j in _:_").before(), - n_lifts=1, - ) - gepp_syrk_scheduled = simplify(gepp_syrk_scheduled) - - gepp_syrk_scheduled = reorder_loops(gepp_syrk_scheduled, "ii jo") - gepp_syrk_scheduled = replace( - gepp_syrk_scheduled, "for ii in _:_ #0", self.gebp_kernel.base_gebp - ) - gepp_syrk_scheduled = call_eqv( - gepp_syrk_scheduled, - f"gebp_base_{self.gebp_kernel.this_id}(_)", - self.gebp_kernel.scheduled_gebp, - ) - gepp_syrk_scheduled = simplify(gepp_syrk_scheduled) - gepp_syrk_scheduled = autofission( - gepp_syrk_scheduled, - gepp_syrk_scheduled.find("for ii in _:_ #0").before(), - n_lifts=1, - ) - - diag_syrk_base = rename(diag_handler, f"diag_handler") - diag_syrk_base = diag_syrk_base.partial_eval(K=self.K_blk, N=self.M_blk) - gepp_syrk_scheduled = replace( - gepp_syrk_scheduled, "for ii in _:_ #1", diag_syrk_base - ) - - gebp_diag_handler = GEBP_kernel( - self.microkernel, self.M_blk_small, self.M_blk_small, self.precision - ) - diag_syrk_scheduled = rename( - diag_syrk_base, f"{self.prefix}_diag_handler_scheduled" - ) - diag_syrk_scheduled = divide_loop( - diag_syrk_scheduled, "i", gebp_diag_handler.M_blk, ["io", "ii"], tail="cut" - ) - diag_syrk_scheduled = divide_loop( - diag_syrk_scheduled, "j", gebp_diag_handler.M_blk, ["jo", "ji"], tail="cut" - ) - diag_syrk_scheduled = autofission( - diag_syrk_scheduled, - diag_syrk_scheduled.find("for ji in _:_ #1").before(), - n_lifts=1, - ) - diag_syrk_scheduled = simplify(diag_syrk_scheduled) - diag_syrk_scheduled = reorder_loops(diag_syrk_scheduled, "ii jo") - diag_syrk_scheduled = replace( - diag_syrk_scheduled, "for ii in _:_ #0", gebp_diag_handler.base_gebp - ) - diag_syrk_scheduled = call_eqv( - diag_syrk_scheduled, - f"gebp_base_{gebp_diag_handler.this_id}(_)", - gebp_diag_handler.scheduled_gebp, - ) - - microkernel_diag_handler = Microkernel( - self.machine, - self.microkernel.M_r, - self.microkernel.N_r, - self.K_blk, - self.precision, - ) - diag_syrk_scheduled = divide_loop( - diag_syrk_scheduled, - "for ii in _:_", - microkernel_diag_handler.M_r, - ["iio", "iii"], - tail="cut", - ) - diag_syrk_scheduled = divide_loop( - diag_syrk_scheduled, - "for ji in _:_", - microkernel_diag_handler.N_r, - ["jio", "jii"], - tail="cut", - ) - diag_syrk_scheduled = autofission( - diag_syrk_scheduled, - diag_syrk_scheduled.find("for jii in _:_ #1").before(), - n_lifts=1, - ) - diag_syrk_scheduled = simplify(diag_syrk_scheduled) - diag_syrk_scheduled = reorder_loops(diag_syrk_scheduled, "iii jio") - diag_syrk_scheduled = replace( - diag_syrk_scheduled, - "for iii in _:_ #0", - microkernel_diag_handler.base_microkernel, - ) - - diag_syrk_scheduled = call_eqv( - diag_syrk_scheduled, - f"microkernel_{microkernel_diag_handler.this_id}(_)", - microkernel_diag_handler.scheduled_microkernel, - ) - # print(simplify(diag_syrk_scheduled)) - - # Unsafe microkernel - """ - for iii in seq(0, 4): - for jii in seq(0, (iii + 4 * iio) % 8): - for k in seq(0, 256): - C[iii + 4 * iio + 64 * io, jii + iio / 2 * 8 + 64 * - io] += A1[iii + 4 * iio + 64 * io, - k] * A2[k, jii + iio / 2 * 8 + 64 * io] - """ - - if self.precision == "f32": ##UNDER CONSTRUCTION - diag_syrk_scheduled = autofission( - diag_syrk_scheduled, - diag_syrk_scheduled.find("for iii in _:_").before(), - n_lifts=2, - ) - # diag_syrk_scheduled = divide_loop(diag_syrk_scheduled, 'for jii in _:_', microkernel_diag_handler.N_r, ['jiio', 'jiii'], tail='cut') - # diag_syrk_scheduled = autofission(diag_syrk_scheduled, diag_syrk_scheduled.find('for jiio in _:_').after(), n_lifts=2) - diag_syrk_scheduled = simplify(diag_syrk_scheduled) - # print(diag_syrk_scheduled) - - diag_syrk_scheduled, unsafe_microkernel_base = extract_subproc( - diag_syrk_scheduled, - diag_syrk_scheduled.find("for io in _:_ #1"), - "unsafe_microkernel_base", - ) - microkernel_diag_base = microkernel_diag_handler.base_microkernel - microkernel_diag_scheduled = microkernel_diag_handler.scheduled_microkernel - # print(microkernel_diag_scheduled) - - @proc - def unsafe_microkernel_scheduled( - A: [f32][128, 256], - B: [f32][256, 128], - C: [f32][128, 128], - ): - assert stride(C, 1) == 1 - assert stride(B, 1) == 1 - assert stride(A, 1) == 1 - # assert stride(C, 0) == 32 - # assert stride(B, 0) == 32 - # assert stride(A, 0) == 32 - # C[0, 0] = 0.0 - # A_vec: f32[4, 8] @ AVX2 - # B_vec: f32[2, 8] @ AVX2 - C_reg: f32[128, 128] @ DRAM - for i in seq(0, 128): - for j in seq(0, 128): - C_reg[i, j] = 0.0 - for i in seq(0, 128): - for j in seq(0, 128): - for k in seq(0, 256): - C_reg[i, j] += A[i, k] * B[k, j] - for i in seq(0, 128): - for j in seq(0, i % 16): - C[i, j + ((i / 16) * 16)] += C_reg[i, j + ((i / 16) * 16)] - - gebp_unsafe = GEBP_kernel( - self.microkernel, self.M_blk, self.M_blk, self.precision - ) - # gebp_unsafe.scheduled_gebp = inline( - # gebp_unsafe.scheduled_gebp, - # f"avx2_microkernel_4x16_{self.microkernel.this_id}(_)", - # ) - - # print(gebp_unsafe.scheduled_gebp) - - unsafe_microkernel_scheduled = replace( - unsafe_microkernel_scheduled, "for i in _:_ #1", gebp_unsafe.base_gebp - ) - unsafe_microkernel_scheduled = call_eqv( - unsafe_microkernel_scheduled, - f"gebp_base_{gebp_unsafe.this_id}(_)", - gebp_unsafe.scheduled_gebp, - ) - - # unsafe_microkernel_scheduled = divide_loop(unsafe_microkernel_scheduled, 'j #0', self.machine.f32_vec_width, ['jo', 'ji'], perfect=True) - - # unsafe_microkernel_scheduled = replace( - # unsafe_microkernel_scheduled, - # "for ji in _:_ #0", - # self.machine.set_zero_instr_f32, - # ) - - # unsafe_microkernel_scheduled = divide_loop(unsafe_microkernel_scheduled, 'j #0', self.machine.f32_vec_width, ['jo', 'ji'], tail='cut') - # unsafe_microkernel_scheduled = replace( - # unsafe_microkernel_scheduled, - # "for k in _:_ #1", - # avx2_mask_storeu_ps, - # ) - # unsafe_microkernel_scheduled = reorder_loops(unsafe_microkernel_scheduled, "io iio") - - # unsafe_microkernel_scheduled = unsafe_microkernel_scheduled.partial_eval(M=microkernel_diag_handler.M_r, N=) - - unsafe_microkernel_scheduled = set_precision( - unsafe_microkernel_scheduled, "A", self.precision - ) - unsafe_microkernel_scheduled = set_precision( - unsafe_microkernel_scheduled, "B", self.precision - ) - unsafe_microkernel_scheduled = set_precision( - unsafe_microkernel_scheduled, "C", self.precision - ) - unsafe_microkernel_scheduled = rename( - unsafe_microkernel_scheduled, - self.prefix + "_" + unsafe_microkernel_scheduled.name(), - ) - - unsafe_microkernel_base.unsafe_assert_eq(unsafe_microkernel_scheduled) - unsafe_microkernel_scheduled = simplify(unsafe_microkernel_scheduled) - - diag_syrk_scheduled = call_eqv( - diag_syrk_scheduled, - "unsafe_microkernel_base", - unsafe_microkernel_scheduled, - ) - # diag_syrk_scheduled = inline(diag_syrk_scheduled, "s_unsafe_microkernel_scheduled(_)") - - # diag_syrk_scheduled = diag_syrk_scheduled.add_assertion("stride(A1, 0)==32") - # diag_syrk_scheduled = diag_syrk_scheduled.add_assertion("stride(A2, 0)==32") - # diag_syrk_scheduled = diag_syrk_scheduled.add_assertion("stride(C, 0)==32") - # diag_syrk_scheduled.unsafe_assert_eq(diag_syrk_base) - - diag_syrk_scheduled = simplify(diag_syrk_scheduled) - gepp_syrk_scheduled = call_eqv( - gepp_syrk_scheduled, "diag_handler(_)", diag_syrk_scheduled - ) - - # gepp_syrk_scheduled = gepp_syrk_scheduled.add_assertion( - # f"stride(A1, 0) == {self.M_blk}" - # ) - # gepp_syrk_scheduled = gepp_syrk_scheduled.add_assertion(f"stride(A2, 0) == N") - # gepp_syrk_scheduled = gepp_syrk_scheduled.add_assertion(f"stride(C, 0) == N") - # gepp_syrk_base.unsafe_assert_eq(gepp_syrk_scheduled) - - # print(gepp_syrk_scheduled) - - ### Vectorize K loop - if self.precision == "f32" and False: - # k_gebp = rename(self.microkernel.sgemm_window, "gebp_k_dim") - # k_gebp = k_gebp.partial_eval(M=self.M_blk, N=1, K=self.K_blk) - # k_gebp = reorder_loops(k_gebp, 'i j') - k_microkernel_dim = self.e_reg - - gepp_syrk_scheduled = autofission( - gepp_syrk_scheduled, - gepp_syrk_scheduled.find("for ii in _:_ #0").after(), - n_lifts=1, - ) - gepp_syrk_scheduled = reorder_loops(gepp_syrk_scheduled, "ii j") - gepp_syrk_scheduled = divide_loop( - gepp_syrk_scheduled, - "ii", - k_microkernel_dim, - ["iio", "iii"], - perfect=True, - ) - gepp_syrk_scheduled = reorder_loops(gepp_syrk_scheduled, "j iio") - # print(gepp_syrk_scheduled) - - k_microkernel = rename(self.microkernel.sgemm_window, "k_microkernel") - k_microkernel = k_microkernel.partial_eval( - M=k_microkernel_dim, N=1, K=self.K_blk - ) - k_microkernel = reorder_loops(k_microkernel, "i j") - gepp_syrk_scheduled = replace( - gepp_syrk_scheduled, "for j in _:_ #0", k_microkernel - ) - - k_microkernel_scheduled = rename(k_microkernel, "k_microkernel_scheduled") - k_microkernel_scheduled = divide_loop( - k_microkernel_scheduled, - "i", - self.machine.f32_vec_width, - ["io", "ii"], - perfect=True, - ) - # print(k_microkernel_scheduled) - - c_reg_str = f"C[{self.machine.f32_vec_width}*io+ii, j]" - k_microkernel_scheduled = stage_mem( - k_microkernel_scheduled, "C[_] += _", c_reg_str, "C_reg" - ) - k_microkernel_scheduled = set_memory( - k_microkernel_scheduled, "C_reg", self.machine.mem_type - ) - k_microkernel_scheduled = expand_dim( - k_microkernel_scheduled, "C_reg", self.machine.f32_vec_width, "ii" - ) - k_microkernel_scheduled = lift_alloc( - k_microkernel_scheduled, "C_reg", n_lifts=4 - ) - k_microkernel_scheduled = autofission( - k_microkernel_scheduled, - k_microkernel_scheduled.find("C_reg[_] = _").after(), - n_lifts=4, - ) - k_microkernel_scheduled = autofission( - k_microkernel_scheduled, - k_microkernel_scheduled.find("C[_] = _").before(), - n_lifts=4, - ) - - k_microkernel_scheduled = reorder_loops(k_microkernel_scheduled, "ii k") - k_microkernel_scheduled = reorder_loops(k_microkernel_scheduled, "io k") - k_microkernel_scheduled = reorder_loops(k_microkernel_scheduled, "j k") - # print(k_microkernel_scheduled) - - # Setup A buffer in vector mem - k_microkernel_scheduled = bind_expr( - k_microkernel_scheduled, "A[_]", "A_vec" - ) - k_microkernel_scheduled = set_memory( - k_microkernel_scheduled, "A_vec", self.machine.mem_type - ) - k_microkernel_scheduled = expand_dim( - k_microkernel_scheduled, "A_vec", self.machine.f32_vec_width, "ii" - ) - k_microkernel_scheduled = expand_dim( - k_microkernel_scheduled, "A_vec", self.K_blk, "k" - ) - k_microkernel_scheduled = set_precision( - k_microkernel_scheduled, "A_vec", self.precision - ) - # print(k_microkernel_scheduled) - - # Setup B buffer in vector mem - k_microkernel_scheduled = bind_expr( - k_microkernel_scheduled, "B[_]", "B_vec" - ) - - k_microkernel_scheduled = set_memory( - k_microkernel_scheduled, "B_vec", self.machine.mem_type - ) - - k_microkernel_scheduled = expand_dim( - k_microkernel_scheduled, "B_vec", self.machine.f32_vec_width, f"ii" - ) - k_microkernel_scheduled = expand_dim( - k_microkernel_scheduled, "B_vec", 1, f"j" - ) - k_microkernel_scheduled = set_precision( - k_microkernel_scheduled, "B_vec", self.precision - ) - # print(k_microkernel_scheduled) - - # Move A_vec and B_vec into proper sites - k_microkernel_scheduled = lift_alloc( - k_microkernel_scheduled, "A_vec", n_lifts=4 - ) - k_microkernel_scheduled = autofission( - k_microkernel_scheduled, - k_microkernel_scheduled.find("A_vec[_] = _").after(), - n_lifts=4, - ) - k_microkernel_scheduled = lift_alloc( - k_microkernel_scheduled, "B_vec", n_lifts=4 - ) - k_microkernel_scheduled = autofission( - k_microkernel_scheduled, - k_microkernel_scheduled.find("B_vec[_] = _").after(), - n_lifts=4, - ) - # print(k_microkernel_scheduled) - - k_microkernel_scheduled = replace_all_stmts( - k_microkernel_scheduled, self.machine.load_instr_f32 - ) - k_microkernel_scheduled = replace_all_stmts( - k_microkernel_scheduled, self.machine.broadcast_instr_f32 - ) - k_microkernel_scheduled = replace_all_stmts( - k_microkernel_scheduled, self.machine.store_instr_f32 - ) - k_microkernel_scheduled = replace_all_stmts( - k_microkernel_scheduled, self.machine.fmadd_reduce_instr_f32 - ) - k_microkernel_scheduled = simplify(k_microkernel_scheduled) - - gepp_syrk_scheduled = call_eqv( - gepp_syrk_scheduled, "k_microkernel(_)", k_microkernel_scheduled - ) - - return simplify(gepp_syrk_scheduled), simplify(gepp_syrk_base) - - def schedule_syrk_lower_notranspose_noalpha(self, ssyrk_base: Procedure): - syrk = divide_loop( - ssyrk_base, "k", self.K_blk, ["ko", "ki"], tail="cut_and_guard" - ) - syrk = autofission(syrk, syrk.find("for ko in _:_ #0").after(), n_lifts=2) - syrk = reorder_loops(syrk, "j ko") - syrk = reorder_loops(syrk, "i ko") - syrk = replace(syrk, "for i in _:_ #0", self.gepp_syrk_base_lower_notranspose) - syrk = call_eqv( - syrk, "gepp_syrk_base(_)", self.gepp_syrk_scheduled_lower_notranspose - ) - # print(syrk) - return simplify(syrk) - - def bind(self, proc, buffer, reg, machine): - proc = bind_expr(proc, buffer, reg) - proc = expand_dim(proc, reg, machine.f32_vec_width, "ji") - proc = lift_alloc(proc, f"{reg} : _", n_lifts=2) - proc = fission(proc, proc.find(f"{reg} = _").after()) - return simplify(proc) - - def stage(self, proc, buffer, reg, machine): - proc = stage_mem( - proc, - f"{buffer}[_] = _", - f"{buffer}[i, ji + {machine.f32_vec_width}*jo]", - reg, - ) - proc = expand_dim(proc, reg, machine.f32_vec_width, f"ji") - proc = lift_alloc(proc, f"{reg} : _", n_lifts=2) - proc = fission(proc, proc.find(f"{reg}[_] = _").after()) - return simplify(proc) - - def schedule_apply_scalar( - self, - proc: Procedure, - machine: "MachineParameters", - buffer_names: list, - name: str, - apply_hack: bool, - ): - - proc = rename(proc, name) - for buffer in buffer_names + ["scalar"]: - proc = set_precision(proc, buffer, self.precision) - - proc = divide_loop( - proc, "j", machine.f32_vec_width, ["jo", "ji"], tail="cut_and_guard" - ) - proc = self.bind(proc, "scalar[_]", "scalar_vec", machine) - if len(buffer_names) > 1: - proc = self.bind( - proc, f"{buffer_names[1]}[_]", f"{buffer_names[1]}_vec", machine - ) - proc = set_precision(proc, f"{buffer_names[1]}_vec", self.precision) - proc = self.stage(proc, f"{buffer_names[0]}", f"{buffer_names[0]}_vec", machine) - - if apply_hack: - proc = fission(proc, proc.find(f"{buffer_names[0]}_vec[_] = _ #1").after()) - - for buffer_name in buffer_names: - proc = set_memory(proc, f"{buffer_name}_vec", machine.mem_type) - proc = set_memory(proc, "scalar_vec", machine.mem_type) - proc = set_precision(proc, "scalar_vec", self.precision) - - if self.precision == "f32": - instr_lst = [ - machine.load_instr_f32, - machine.broadcast_instr_f32, - machine.reg_copy_instr_f32, - machine.mul_instr_f32, - machine.store_instr_f32, - ] - else: - instr_lst = [ - machine.load_instr_f64, - machine.broadcast_instr_f64, - machine.reg_copy_instr_f64, - machine.mul_instr_f64, - machine.store_instr_f64, - ] - - if apply_hack: - proc = self.bind(proc, "P_vec[_]", "P_vec2", machine) - proc = set_memory(proc, "P_vec2", machine.mem_type) - proc = set_precision(proc, "P_vec2", self.precision) - - for instr in instr_lst: - proc = replace_all_stmts(proc, instr) - # if self.main: - # proc = rename(proc, proc.name() + "_main") - - return simplify(proc) - - def specialize_syrk(self, syrk: Procedure, precision: str, args=list[str]): - prefix = "s" if precision == "f32" else "d" - name = syrk.name().replace("exo_", "") - syrk = rename(syrk, "exo_" + prefix + name) - for arg in args: - syrk = set_precision(syrk, arg, precision) - return simplify(syrk) - +@proc +def syrk_rm_l(N: size, K: size, alpha: R, A: [R][N, K], A_alias: [R][N, K], C: [R][N, N]): + assert stride(A, 1) == 1 + assert stride(A_alias, 1) == 1 + assert stride(C, 1) == 1 -k_blk = 256 -m_blk = 128 -m_blk_small = 32 -m_reg = 4 -n_reg = 16 -e_reg = 16 + for i in seq(0, N): + for j in seq(0, i + 1): + for k in seq(0, K): + C[i, j] += alpha * (A[i, k] * A_alias[j, k]) -ssyrk = SYRK(C.Machine, "f32", k_blk, m_blk, m_blk_small, m_reg, n_reg, e_reg) +@proc +def syrk_rm_u(N: size, K: size, alpha: R, A: [R][N, K], A_alias: [R][N, K], C: [R][N, N]): + assert stride(A, 1) == 1 + assert stride(A_alias, 1) == 1 + assert stride(C, 1) == 1 -for i in range(13): - ssyrk.entry_points[i] = simplify(ssyrk.entry_points[i]) + for i in seq(0, N): + for j in seq(i, N): + for k in seq(0, K): + C[i, j] += alpha * (A[i, k] * A_alias[j, k]) -exo_ssyrk_lower_notranspose_noalpha_nobeta = ssyrk.entry_points[0] -exo_ssyrk_lower_notranspose_alpha_nobeta = ssyrk.entry_points[1] -exo_ssyrk_lower_notranspose_alpha_beta = ssyrk.entry_points[2] -exo_ssyrk_lower_transpose_noalpha_nobeta = ssyrk.entry_points[3] -exo_ssyrk_lower_transpose_alpha_nobeta = ssyrk.entry_points[4] -exo_ssyrk_lower_transpose_alpha_beta = ssyrk.entry_points[5] -exo_ssyrk_upper_notranspose_noalpha_nobeta = ssyrk.entry_points[6] -exo_ssyrk_upper_notranspose_alpha_nobeta = ssyrk.entry_points[7] -exo_ssyrk_upper_notranspose_alpha_beta = ssyrk.entry_points[8] -exo_ssyrk_upper_transpose_noalpha_nobeta = ssyrk.entry_points[9] -exo_ssyrk_upper_transpose_alpha_nobeta = ssyrk.entry_points[10] -exo_ssyrk_upper_transpose_alpha_beta = ssyrk.entry_points[11] -exo_ssyrk_lower_alphazero_beta = ssyrk.entry_points[12] -exo_ssyrk_upper_alphazero_beta = ssyrk.entry_points[13] -C.Machine.f32_vec_width //= 2 -dsyrk = SYRK(C.Machine, "f64", k_blk, m_blk, m_blk_small, m_reg, n_reg // 2, e_reg) -C.Machine.f32_vec_width *= 2 +syrk_rm_u = shift_loop(syrk_rm_u, "j", 0) -for i in range(13): - dsyrk.entry_points[i] = simplify(dsyrk.entry_points[i]) +PARAMS = {AVX2: (4, 3, 66, 3, 512), AVX512: (6, 4, 44, 1, 512), Neon: (1, 1, 1, 1, 1)} -exo_dsyrk_lower_notranspose_noalpha_nobeta = dsyrk.entry_points[0] -exo_dsyrk_lower_notranspose_alpha_nobeta = dsyrk.entry_points[1] -exo_dsyrk_lower_notranspose_alpha_beta = dsyrk.entry_points[2] -exo_dsyrk_lower_transpose_noalpha_nobeta = dsyrk.entry_points[3] -exo_dsyrk_lower_transpose_alpha_nobeta = dsyrk.entry_points[4] -exo_dsyrk_lower_transpose_alpha_beta = dsyrk.entry_points[5] -exo_dsyrk_upper_notranspose_noalpha_nobeta = dsyrk.entry_points[6] -exo_dsyrk_upper_notranspose_alpha_nobeta = dsyrk.entry_points[7] -exo_dsyrk_upper_notranspose_alpha_beta = dsyrk.entry_points[8] -exo_dsyrk_upper_transpose_noalpha_nobeta = dsyrk.entry_points[9] -exo_dsyrk_upper_transpose_alpha_nobeta = dsyrk.entry_points[10] -exo_dsyrk_upper_transpose_alpha_beta = dsyrk.entry_points[11] -exo_dsyrk_lower_alphazero_beta = dsyrk.entry_points[12] -exo_dsyrk_upper_alphazero_beta = dsyrk.entry_points[13] +m_r, n_r_fac, M_tile_fac, N_tile_fac, K_tile = PARAMS[C.Machine.mem_type] +n_r = n_r_fac * C.Machine.vec_width("f32") +M_tile = M_tile_fac * m_r +N_tile = N_tile_fac * n_r -__all__ = [p.name() for p in ssyrk.entry_points] + [ - p.name() for p in dsyrk.entry_points -] +variants_generator(identity_schedule, ("f32",), (AVX2, AVX512))( + syrk_rm_l, "i", m_r, n_r_fac, M_tile, N_tile, K_tile, globals=globals() +) +variants_generator(identity_schedule, ("f32",), (AVX2, AVX512))( + syrk_rm_u, "i", m_r, n_r_fac, M_tile, N_tile, K_tile, globals=globals() +) From 501c2d8112b50fc9db9db22180da91e45674d542 Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Mon, 25 Mar 2024 16:14:21 -0400 Subject: [PATCH 2/9] Rewrite syrk correctness tests and wrapper --- test/level3/CMakeLists.txt | 3 +- test/level3/dsyrk/bench.cpp | 77 ---------------------- test/level3/dsyrk/correctness.cpp | 83 ------------------------ test/level3/dsyrk/exo_dsyrk.h | 99 ----------------------------- test/level3/ssyrk/bench.cpp | 85 ------------------------- test/level3/ssyrk/correctness.cpp | 83 ------------------------ test/level3/ssyrk/exo_ssyrk.h | 99 ----------------------------- test/level3/syrk/bench.cpp | 87 +++++++++++++++++++++++++ test/level3/syrk/correctness.cpp | 63 ++++++++++++++++++ test/level3/syrk/exo_syrk_wrapper.h | 39 ++++++++++++ 10 files changed, 190 insertions(+), 528 deletions(-) delete mode 100644 test/level3/dsyrk/bench.cpp delete mode 100644 test/level3/dsyrk/correctness.cpp delete mode 100644 test/level3/dsyrk/exo_dsyrk.h delete mode 100644 test/level3/ssyrk/bench.cpp delete mode 100644 test/level3/ssyrk/correctness.cpp delete mode 100644 test/level3/ssyrk/exo_ssyrk.h create mode 100644 test/level3/syrk/bench.cpp create mode 100644 test/level3/syrk/correctness.cpp create mode 100644 test/level3/syrk/exo_syrk_wrapper.h diff --git a/test/level3/CMakeLists.txt b/test/level3/CMakeLists.txt index 5551522..726343c 100644 --- a/test/level3/CMakeLists.txt +++ b/test/level3/CMakeLists.txt @@ -1,6 +1,5 @@ add_exo_blas_test(level3 gemm "") -add_exo_blas_test(level3 syrk s) -add_exo_blas_test(level3 syrk d) +add_exo_blas_test(level3 syrk "") # add_exo_blas_test(syr2k s) add_exo_blas_test(level3 trmm s) add_exo_blas_test(level3 symm s) diff --git a/test/level3/dsyrk/bench.cpp b/test/level3/dsyrk/bench.cpp deleted file mode 100644 index fd9362d..0000000 --- a/test/level3/dsyrk/bench.cpp +++ /dev/null @@ -1,77 +0,0 @@ -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "exo_dsyrk.h" -#include "generate_buffer.h" - -static void print_matrix(std::vector M, int n, int k) { - for (int i = 0; i < k; i++) { - for (int j = 0; j < n; j++) { - std::cout << M[j * k + i] << ", "; - } - std::cout << std::endl; - } - std::cout << std::endl; -} - -static std::vector transpose(std::vector V, const int m, - const int k) { - std::vector V_t(k * m); - for (int i = 0; i < m; i++) { - for (int j = 0; j < k; j++) { - V_t[j * m + i] = V[i * k + j]; - } - } - - return V_t; -} - -static void BM_DSYRK_CBLAS(benchmark::State &state) { - int n = state.range(0); - auto a = AlignedBuffer2D(n, n, 2.0, 64); - auto c = AlignedBuffer2D(n, n, 2.0, 64); - - for (auto _ : state) { - cblas_dsyrk(CblasRowMajor, CblasLower, CblasNoTrans, n, n, // M N - 1.0, // alpha - a.data(), - n, // M - 1.0, c.data(), - n // M - ); - } - - state.counters["flops"] = benchmark::Counter( - static_cast(state.iterations()) * n * n * n * 2, - benchmark::Counter::kIsRate, benchmark::Counter::kIs1000); -} - -static void BM_DSYRK_EXO(benchmark::State &state) { - int n = state.range(0); - auto a = AlignedBuffer2D(n, n, 2.0, 64); - auto c = AlignedBuffer2D(n, n, 2.0, 64); - - double alpha = 1.0f; - double beta = 1.0f; - - for (auto _ : state) { - exo_dsyrk(CblasRowMajor, CblasLower, CblasNoTrans, n, n, &alpha, a.data(), - a.data(), &beta, c.data()); - } - - state.counters["flops"] = benchmark::Counter( - static_cast(state.iterations()) * n * n * n * 2, - benchmark::Counter::kIsRate, benchmark::Counter::kIs1000); -} - -BENCHMARK(BM_DSYRK_CBLAS)->ArgNames({"n"})->RangeMultiplier(2)->Range(16, 8192); -BENCHMARK(BM_DSYRK_EXO)->ArgNames({"n"})->RangeMultiplier(2)->Range(16, 8192); diff --git a/test/level3/dsyrk/correctness.cpp b/test/level3/dsyrk/correctness.cpp deleted file mode 100644 index e219fea..0000000 --- a/test/level3/dsyrk/correctness.cpp +++ /dev/null @@ -1,83 +0,0 @@ -#include - -#include -#include - -#include "correctness_helpers.h" -#include "exo_dsyrk.h" -#include "generate_buffer.h" - -static std::vector _transpose(double *V, const int m, const int k) { - std::vector V_t(k * m); - for (int i = 0; i < m; i++) { - for (int j = 0; j < k; j++) { - V_t[j * m + i] = V[i * k + j]; - } - } - - return V_t; -} - -void test_dsyrk(const enum CBLAS_UPLO uplo, - const enum CBLAS_TRANSPOSE transpose, const int n, const int k, - const double alpha, const double beta) { - std::cout << "Running syrk test: N = " << n << ", alpha = " << alpha - << ", beta = " << beta << " uplo = " << uplo - << ", transpose = " << transpose << std::endl; - auto a = AlignedBuffer2D(n, k); - auto a2 = a; - auto c = AlignedBuffer2D(n, k, 2.0f, 64); - auto c2 = c; - - cblas_dsyrk(CblasRowMajor, uplo, transpose, n, n, // M N - alpha, // alpha - a.data(), - n, // M - beta, c.data(), - n // M - ); - if (uplo == CblasLower && transpose == CblasNoTrans && alpha == 1.0 && - beta == 1.0) { - auto at = _transpose(a.data(), n, n); - exo_dsyrk(CblasRowMajor, uplo, transpose, n, n, &alpha, a.data(), at.data(), - &beta, c2.data()); - } else { - exo_dsyrk(CblasRowMajor, uplo, transpose, n, n, &alpha, a.data(), a2.data(), - &beta, c2.data()); - } - - double epsilon = 0.01; - for (int i = 0; i < n * k; i++) { - double correct = c[i]; - double exo_out = c2[i]; - if (!check_relative_error_okay(correct, exo_out, epsilon)) { - std::cout << "Error at " << i / n << ", " << i % n - << ". Expected: " << correct << ", got: " << exo_out - << std::endl; - exit(1); - } - } - - std::cout << "Passed!" << std::endl; -} - -int main() { - std::vector dims{32, 64, 256, 513}; - std::vector uplos{CblasLower, CblasUpper}; - std::vector transposes{CblasNoTrans, CblasTrans}; - std::vector alphas{0.0, 1.0, 2.0}; - std::vector betas{0.0, 1.0, 2.0}; - - for (auto const n : dims) { - for (auto const uplo : uplos) { - for (auto const transpose : transposes) { - for (auto const alpha : alphas) { - for (auto const beta : betas) { - test_dsyrk(uplo, transpose, n, n, alpha, beta); - } - } - } - } - // test_dsyrk('L', 'N', n, n, 1.0, 1.0); - } -} diff --git a/test/level3/dsyrk/exo_dsyrk.h b/test/level3/dsyrk/exo_dsyrk.h deleted file mode 100644 index 304b832..0000000 --- a/test/level3/dsyrk/exo_dsyrk.h +++ /dev/null @@ -1,99 +0,0 @@ -#pragma once - -#include - -#include "exo_syrk.h" - -void exo_dsyrk_lower_notranspose(const int n, const int k, const double *alpha, - const double *A1, const double *A2, - const double *beta, double *C) { - if (*alpha == 1.0 && *beta == 1.0) { - exo_dsyrk_lower_notranspose_noalpha_nobeta(nullptr, n, k, alpha, A1, A2, - beta, C); - } else if (*alpha == 0.0 && *beta == 1.0) { - return; - } else if (*alpha == 0.0 && *beta != 1.0) { - exo_dsyrk_lower_alphazero_beta(nullptr, n, k, alpha, A1, A2, beta, C); - } else if (*alpha != 1.0 && *beta == 1.0) { - exo_dsyrk_lower_notranspose_alpha_nobeta(nullptr, n, k, alpha, A1, A2, beta, - C); - } else { - exo_dsyrk_lower_notranspose_alpha_beta(nullptr, n, k, alpha, A1, A2, beta, - C); - } -} - -void exo_dsyrk_lower_transpose(const int n, const int k, const double *alpha, - const double *A1, const double *A2, - const double *beta, double *C) { - if (*alpha == 1.0 && *beta == 1.0) { - exo_dsyrk_lower_transpose_noalpha_nobeta(nullptr, n, k, alpha, A1, A2, beta, - C); - } else if (*alpha == 0.0 && *beta == 1.0) { - return; - } else if (*alpha == 0.0 && *beta != 1.0) { - exo_dsyrk_lower_alphazero_beta(nullptr, n, k, alpha, A1, A2, beta, C); - } else if (*alpha != 1.0 && *beta == 1.0) { - exo_dsyrk_lower_transpose_alpha_nobeta(nullptr, n, k, alpha, A1, A2, beta, - C); - } else { - exo_dsyrk_lower_transpose_alpha_beta(nullptr, n, k, alpha, A1, A2, beta, C); - } -} - -void exo_dsyrk_upper_notranspose(const int n, const int k, const double *alpha, - const double *A1, const double *A2, - const double *beta, double *C) { - if (*alpha == 1.0 && *beta == 1.0) { - exo_dsyrk_upper_notranspose_noalpha_nobeta(nullptr, n, k, alpha, A1, A2, - beta, C); - } else if (*alpha == 0.0 && *beta == 1.0) { - return; - } else if (*alpha == 0.0 && *beta != 1.0) { - exo_dsyrk_upper_alphazero_beta(nullptr, n, k, alpha, A1, A2, beta, C); - } else if (*alpha != 1.0 && *beta == 1.0) { - exo_dsyrk_upper_notranspose_alpha_nobeta(nullptr, n, k, alpha, A1, A2, beta, - C); - } else { - exo_dsyrk_upper_notranspose_alpha_beta(nullptr, n, k, alpha, A1, A2, beta, - C); - } -} - -void exo_dsyrk_upper_transpose(const int n, const int k, const double *alpha, - const double *A1, const double *A2, - const double *beta, double *C) { - if (*alpha == 1.0 && *beta == 1.0) { - exo_dsyrk_upper_transpose_noalpha_nobeta(nullptr, n, k, alpha, A1, A2, beta, - C); - } else if (*alpha == 0.0 && *beta == 1.0) { - return; - } else if (*alpha == 0.0 && *beta != 1.0) { - exo_dsyrk_upper_alphazero_beta(nullptr, n, k, alpha, A1, A2, beta, C); - } else if (*alpha != 1.0 && *beta == 1.0) { - exo_dsyrk_upper_transpose_alpha_nobeta(nullptr, n, k, alpha, A1, A2, beta, - C); - } else { - exo_dsyrk_upper_transpose_alpha_beta(nullptr, n, k, alpha, A1, A2, beta, C); - } -} - -void exo_dsyrk(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, - const enum CBLAS_TRANSPOSE transpose, const int n, const int k, - const double *alpha, const double *A1, const double *A2, - const double *beta, double *C) { - // TODO: other cases - if (uplo == CblasLower) { - if (transpose == CblasNoTrans) { - exo_dsyrk_lower_notranspose(n, k, alpha, A1, A2, beta, C); - } else { - exo_dsyrk_lower_transpose(n, k, alpha, A1, A2, beta, C); - } - } else { - if (transpose == CblasNoTrans) { - exo_dsyrk_upper_notranspose(n, k, alpha, A1, A2, beta, C); - } else { - exo_dsyrk_upper_transpose(n, k, alpha, A1, A2, beta, C); - } - } -} diff --git a/test/level3/ssyrk/bench.cpp b/test/level3/ssyrk/bench.cpp deleted file mode 100644 index f999616..0000000 --- a/test/level3/ssyrk/bench.cpp +++ /dev/null @@ -1,85 +0,0 @@ -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "exo_ssyrk.h" -#include "generate_buffer.h" - -static void print_matrix(std::vector M, int n, int k) { - for (int i = 0; i < k; i++) { - for (int j = 0; j < n; j++) { - std::cout << M[j * k + i] << ", "; - } - std::cout << std::endl; - } - std::cout << std::endl; -} - -static std::vector transpose(std::vector V, const int m, - const int k) { - std::vector V_t(k * m); - for (int i = 0; i < m; i++) { - for (int j = 0; j < k; j++) { - V_t[j * m + i] = V[i * k + j]; - } - } - - return V_t; -} - -static void BM_SSYRK_CBLAS(benchmark::State &state) { - int n = state.range(0); - auto a = AlignedBuffer2D(n, n, 2.0, 64); - auto c = AlignedBuffer2D(n, n, 2.0, 64); - - for (auto _ : state) { - cblas_ssyrk(CblasRowMajor, CblasLower, CblasNoTrans, n, n, // M N - 1.0, // alpha - a.data(), - n, // M - 1.0, c.data(), - n // M - ); - } - - state.counters["flops"] = benchmark::Counter( - static_cast(state.iterations()) * n * n * n, - benchmark::Counter::kIsRate, benchmark::Counter::kIs1000); -} - -static void BM_SSYRK_EXO(benchmark::State &state) { - int n = state.range(0); - auto a = AlignedBuffer2D(n, n, 2.0, 64); - auto c = AlignedBuffer2D(n, n, 2.0, 64); - - float alpha = 1.0f; - float beta = 1.0f; - - for (auto _ : state) { - exo_ssyrk(CblasRowMajor, CblasLower, CblasNoTrans, n, n, &alpha, a.data(), - a.data(), &beta, c.data()); - } - - state.counters["flops"] = benchmark::Counter( - static_cast(state.iterations()) * n * n * n, - benchmark::Counter::kIsRate, benchmark::Counter::kIs1000); -} - -// BENCHMARK(BM_SSYRK_CBLAS) -// ->ArgNames({"n", "m", "k"}) -// ->Args({8192, 8192, 8192}) -// ->Args({512, 512, 512}); -// ->ArgsProduct({benchmark::CreateRange(48, 48*100, 48)}); - -BENCHMARK(BM_SSYRK_EXO) - ->ArgNames({"n", "m", "k"}) - ->Args({8192, 8192, 8192}) - ->Args({512, 512, 512}); diff --git a/test/level3/ssyrk/correctness.cpp b/test/level3/ssyrk/correctness.cpp deleted file mode 100644 index 12cd43b..0000000 --- a/test/level3/ssyrk/correctness.cpp +++ /dev/null @@ -1,83 +0,0 @@ -#include - -#include -#include - -#include "correctness_helpers.h" -#include "exo_ssyrk.h" -#include "generate_buffer.h" - -static std::vector _transpose(float *V, const int m, const int k) { - std::vector V_t(k * m); - for (int i = 0; i < m; i++) { - for (int j = 0; j < k; j++) { - V_t[j * m + i] = V[i * k + j]; - } - } - - return V_t; -} - -void test_ssyrk(const enum CBLAS_UPLO uplo, - const enum CBLAS_TRANSPOSE transpose, const int n, const int k, - const float alpha, const float beta) { - std::cout << "Running syrk test: N = " << n << ", alpha = " << alpha - << ", beta = " << beta << " uplo = " << uplo - << ", transpose = " << transpose << std::endl; - auto a = AlignedBuffer2D(n, k); - auto a2 = a; - auto c = AlignedBuffer2D(n, k, 2.0f, 64); - auto c2 = c; - - cblas_ssyrk(CblasRowMajor, uplo, transpose, n, n, // M N - alpha, // alpha - a.data(), - n, // M - beta, c.data(), - n // M - ); - if (uplo == CblasLower && transpose == CblasNoTrans && alpha == 1.0 && - beta == 1.0) { - auto at = _transpose(a.data(), n, n); - exo_ssyrk(CblasRowMajor, uplo, transpose, n, n, &alpha, a.data(), at.data(), - &beta, c2.data()); - } else { - exo_ssyrk(CblasRowMajor, uplo, transpose, n, n, &alpha, a.data(), a2.data(), - &beta, c2.data()); - } - - double epsilon = 0.01; - for (int i = 0; i < n * k; i++) { - double correct = c[i]; - double exo_out = c2[i]; - if (!check_relative_error_okay(correct, exo_out, epsilon)) { - std::cout << "Error at " << i / n << ", " << i % n - << ". Expected: " << correct << ", got: " << exo_out - << std::endl; - exit(1); - } - } - - std::cout << "Passed!" << std::endl; -} - -int main() { - std::vector dims{32, 64, 513}; - std::vector uplos{CblasLower, CblasUpper}; - std::vector transposes{CblasNoTrans, CblasTrans}; - std::vector alphas{0.0, 1.0, 2.0}; - std::vector betas{0.0, 1.0, 2.0}; - - for (auto const n : dims) { - for (auto const uplo : uplos) { - for (auto const transpose : transposes) { - for (auto const alpha : alphas) { - for (auto const beta : betas) { - test_ssyrk(uplo, transpose, n, n, alpha, beta); - } - } - } - } - // test_ssyrk('L', 'N', n, n, 1.0, 1.0); - } -} diff --git a/test/level3/ssyrk/exo_ssyrk.h b/test/level3/ssyrk/exo_ssyrk.h deleted file mode 100644 index e36731c..0000000 --- a/test/level3/ssyrk/exo_ssyrk.h +++ /dev/null @@ -1,99 +0,0 @@ -#pragma once - -#include - -#include "exo_syrk.h" - -void exo_ssyrk_lower_notranspose(const int n, const int k, const float *alpha, - const float *A1, const float *A2, - const float *beta, float *C) { - if (*alpha == 1.0 && *beta == 1.0) { - exo_ssyrk_lower_notranspose_noalpha_nobeta(nullptr, n, k, alpha, A1, A2, - beta, C); - } else if (*alpha == 0.0 && *beta == 1.0) { - return; - } else if (*alpha == 0.0 && *beta != 1.0) { - exo_ssyrk_lower_alphazero_beta(nullptr, n, k, alpha, A1, A2, beta, C); - } else if (*alpha != 1.0 && *beta == 1.0) { - exo_ssyrk_lower_notranspose_alpha_nobeta(nullptr, n, k, alpha, A1, A2, beta, - C); - } else { - exo_ssyrk_lower_notranspose_alpha_beta(nullptr, n, k, alpha, A1, A2, beta, - C); - } -} - -void exo_ssyrk_lower_transpose(const int n, const int k, const float *alpha, - const float *A1, const float *A2, - const float *beta, float *C) { - if (*alpha == 1.0 && *beta == 1.0) { - exo_ssyrk_lower_transpose_noalpha_nobeta(nullptr, n, k, alpha, A1, A2, beta, - C); - } else if (*alpha == 0.0 && *beta == 1.0) { - return; - } else if (*alpha == 0.0 && *beta != 1.0) { - exo_ssyrk_lower_alphazero_beta(nullptr, n, k, alpha, A1, A2, beta, C); - } else if (*alpha != 1.0 && *beta == 1.0) { - exo_ssyrk_lower_transpose_alpha_nobeta(nullptr, n, k, alpha, A1, A2, beta, - C); - } else { - exo_ssyrk_lower_transpose_alpha_beta(nullptr, n, k, alpha, A1, A2, beta, C); - } -} - -void exo_ssyrk_upper_notranspose(const int n, const int k, const float *alpha, - const float *A1, const float *A2, - const float *beta, float *C) { - if (*alpha == 1.0 && *beta == 1.0) { - exo_ssyrk_upper_notranspose_noalpha_nobeta(nullptr, n, k, alpha, A1, A2, - beta, C); - } else if (*alpha == 0.0 && *beta == 1.0) { - return; - } else if (*alpha == 0.0 && *beta != 1.0) { - exo_ssyrk_upper_alphazero_beta(nullptr, n, k, alpha, A1, A2, beta, C); - } else if (*alpha != 1.0 && *beta == 1.0) { - exo_ssyrk_upper_notranspose_alpha_nobeta(nullptr, n, k, alpha, A1, A2, beta, - C); - } else { - exo_ssyrk_upper_notranspose_alpha_beta(nullptr, n, k, alpha, A1, A2, beta, - C); - } -} - -void exo_ssyrk_upper_transpose(const int n, const int k, const float *alpha, - const float *A1, const float *A2, - const float *beta, float *C) { - if (*alpha == 1.0 && *beta == 1.0) { - exo_ssyrk_upper_transpose_noalpha_nobeta(nullptr, n, k, alpha, A1, A2, beta, - C); - } else if (*alpha == 0.0 && *beta == 1.0) { - return; - } else if (*alpha == 0.0 && *beta != 1.0) { - exo_ssyrk_upper_alphazero_beta(nullptr, n, k, alpha, A1, A2, beta, C); - } else if (*alpha != 1.0 && *beta == 1.0) { - exo_ssyrk_upper_transpose_alpha_nobeta(nullptr, n, k, alpha, A1, A2, beta, - C); - } else { - exo_ssyrk_upper_transpose_alpha_beta(nullptr, n, k, alpha, A1, A2, beta, C); - } -} - -void exo_ssyrk(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, - const enum CBLAS_TRANSPOSE transpose, const int n, const int k, - const float *alpha, const float *A1, const float *A2, - const float *beta, float *C) { - // TODO: other cases - if (uplo == CblasLower) { - if (transpose == CblasNoTrans) { - exo_ssyrk_lower_notranspose(n, k, alpha, A1, A2, beta, C); - } else { - exo_ssyrk_lower_transpose(n, k, alpha, A1, A2, beta, C); - } - } else { - if (transpose == CblasNoTrans) { - exo_ssyrk_upper_notranspose(n, k, alpha, A1, A2, beta, C); - } else { - exo_ssyrk_upper_transpose(n, k, alpha, A1, A2, beta, C); - } - } -} diff --git a/test/level3/syrk/bench.cpp b/test/level3/syrk/bench.cpp new file mode 100644 index 0000000..855e515 --- /dev/null +++ b/test/level3/syrk/bench.cpp @@ -0,0 +1,87 @@ +#include +#include + +#include "bench_ranges.h" +#include "exo_gemm_wrapper.h" +#include "generate_buffer.h" +#include "misc.h" + +generate_wrapper(gemm); + +template +static void bench(benchmark::State &state) { + int M = state.range(0); + int N = state.range(1); + int K = state.range(2); + const enum CBLAS_ORDER order = (const enum CBLAS_ORDER)state.range(3); + const enum CBLAS_TRANSPOSE TransA = + (const enum CBLAS_TRANSPOSE)state.range(4); + const enum CBLAS_TRANSPOSE TransB = + (const enum CBLAS_TRANSPOSE)state.range(5); + const T alpha = state.range(6); + const int lda_diff = state.range(7); + const int ldb_diff = state.range(8); + const T beta = state.range(9); + const int ldc = N + state.range(10); + const int alignmentA = state.range(11); + const int alignmentB = state.range(12); + const int alignmentC = state.range(13); + + auto A_dims = get_dims(TransA, M, K, lda_diff); + const int lda = A_dims.second; + auto A = AlignedBuffer2D(A_dims.first, A_dims.second); + auto B_dims = get_dims(TransB, K, N, ldb_diff); + const int ldb = B_dims.second; + auto B = AlignedBuffer2D(B_dims.first, B_dims.second); + auto C = AlignedBuffer2D(M, N); + + for (auto _ : state) { + gemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, alpha, + A.data(), lda, B.data(), ldb, beta, C.data(), ldc); + } +} + +template +static void args(benchmark::internal::Benchmark *b) { + auto add_arg = [&b](int M, int N, int K) { + return b->Args({M, + N, + K, + order, + TransA, + TransB, + 17, + 0, + 0, + 1, + 0, + 64, + 64, + 64, + {BENCH_TYPES::level_3_eq}, + {type_bits()}}); + }; + b->ArgNames({"M", "N", "K", "order", "TransA", "TransB", "alpha", "lda_diff", + "ldb_diff", "beta", "ldc_diff", "alignmentA", "alignmentB", + "alignmentC", "bench_type", "precision"}); + for (int i = 1; i <= level_3_max_N; i *= 2) { + add_arg(i, i, i); + } + for (int i = 7; i <= level_3_max_N; i *= 7) { + add_arg(i, i, i); + } +} + +#define call_gemm_bench(lib, T, order, TransA, TransB) \ + BENCHMARK(bench) \ + ->Name(level_3_kernel_name("gemm")) \ + ->Apply(args); + +#define call_gemm_bench_all(order, TransA, TransB) \ + call_gemm_bench(Exo, float, order, TransA, TransB); \ + call_gemm_bench(Cblas, float, order, TransA, TransB); \ + call_gemm_bench(Exo, double, order, TransA, TransB); \ + call_gemm_bench(Cblas, double, order, TransA, TransB); + +call_gemm_bench_all(CBLAS_ORDER::CblasRowMajor, CBLAS_TRANSPOSE::CblasNoTrans, + CBLAS_TRANSPOSE::CblasNoTrans); diff --git a/test/level3/syrk/correctness.cpp b/test/level3/syrk/correctness.cpp new file mode 100644 index 0000000..af96ce3 --- /dev/null +++ b/test/level3/syrk/correctness.cpp @@ -0,0 +1,63 @@ +#include + +#include + +#include "correctness_helpers.h" +#include "exo_syrk_wrapper.h" +#include "generate_buffer.h" +#include "misc.h" + +generate_wrapper(syrk); + +template +void test_syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const float alpha, const int lda, const float beta, + const int ldc) { + auto A = AlignedBuffer2D(N, lda); + auto C = AlignedBuffer2D(N, ldc); + + auto A_expected = A; + auto C_expected = C; + + syrk(Order, Uplo, Trans, N, K, alpha, A.data(), lda, beta, C.data(), + ldc); + syrk(Order, Uplo, Trans, N, K, alpha, A_expected.data(), lda, beta, + C_expected.data(), ldc); + + if (!C.check_buffer_equal(C_expected)) { + failed("syrk", "Order", Order, "Uplo", Uplo, "Trans", Trans, "N", N, "K", + K, "alpha", alpha, "lda", lda, "beta", beta, "ldc", ldc); + } +} + +template +void run() { + std::vector dims{1, 7, 32, 64, 257}; + std::vector trans{CblasNoTrans}; + std::vector uplo{CblasLower, CblasUpper}; + std::vector ld_diffs{0, 5}; + std::vector alphas{13.0}; + std::vector betas{1.0}; + for (const auto Uplo : uplo) + for (const auto N : dims) + for (const auto K : dims) + for (const auto Trans : trans) + for (const auto lda_diff : ld_diffs) + for (const auto ldc_diff : ld_diffs) + for (const auto alpha : alphas) + for (const auto beta : betas) { + auto lda = K + lda_diff; + auto ldc = N + ldc_diff; + if (Trans == CBLAS_TRANSPOSE::CblasTrans) { + lda = N + lda_diff; + } + test_syrk(CBLAS_ORDER::CblasRowMajor, Uplo, Trans, N, K, + alpha, lda, beta, ldc); + } +} + +int main() { + run(); + run(); +} diff --git a/test/level3/syrk/exo_syrk_wrapper.h b/test/level3/syrk/exo_syrk_wrapper.h new file mode 100644 index 0000000..618864c --- /dev/null +++ b/test/level3/syrk/exo_syrk_wrapper.h @@ -0,0 +1,39 @@ +#pragma once + +#include + +#include "error.h" +#include "exo_syrk.h" + +#define exo_syrk(type, prefix, exo_type) \ + void exo_##prefix##syrk( \ + const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, \ + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, \ + const type alpha, const type *A, const int lda, const type beta, \ + type *C, const int ldc) { \ + if (Order != CBLAS_ORDER::CblasRowMajor) { \ + throw UnsupportedParameterException("syrk::Order must be Row Major"); \ + } \ + if (Trans != CBLAS_TRANSPOSE::CblasNoTrans) { \ + throw UnsupportedParameterException("syrk::Trans must be nonTrans"); \ + } \ + if (beta != 1.0) { \ + throw UnsupportedParameterException("syrk::beta must be 1.0"); \ + } \ + if (Uplo == CBLAS_UPLO::CblasLower) { \ + exo_##prefix##syrk_rm_l_stride_1( \ + nullptr, N, K, &alpha, \ + exo_win_2##exo_type##c{.data = A, .strides = {lda, 1}}, \ + exo_win_2##exo_type##c{.data = A, .strides = {lda, 1}}, \ + exo_win_2##exo_type{.data = C, .strides{ldc, 1}}); \ + } else { \ + exo_##prefix##syrk_rm_u_stride_1( \ + nullptr, N, K, &alpha, \ + exo_win_2##exo_type##c{.data = A, .strides = {lda, 1}}, \ + exo_win_2##exo_type##c{.data = A, .strides = {lda, 1}}, \ + exo_win_2##exo_type{.data = C, .strides{ldc, 1}}); \ + } \ + } + +exo_syrk(float, s, f32); +exo_syrk(double, d, f64); From 94df6a401b2b2f14d02979c171680816b29abfac Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Mon, 25 Mar 2024 20:07:12 -0400 Subject: [PATCH 3/9] Rewrite syrk bench tests --- test/common/misc.h | 98 +++++++++++++++++++++----------------- test/level3/gemm/bench.cpp | 6 +-- test/level3/syrk/bench.cpp | 85 +++++++++++++++------------------ 3 files changed, 94 insertions(+), 95 deletions(-) diff --git a/test/common/misc.h b/test/common/misc.h index 3c26abd..ec1ba86 100644 --- a/test/common/misc.h +++ b/test/common/misc.h @@ -15,18 +15,15 @@ std::pair get_dims(const enum CBLAS_TRANSPOSE trans, int M, int N, class BLAS_lib {}; -class Exo : public BLAS_lib {}; +class Exo : public BLAS_lib { + public: + static std::string lib_name() { return "exo"; } +}; -class Cblas : public BLAS_lib {}; - -template -std::string lib_name() { - if constexpr (std::is_same::value) { - return "exo"; - } else { - return "cblas"; - } -} +class Cblas : public BLAS_lib { + public: + static std::string lib_name() { return "cblas"; } +}; template std::string type_prefix() { @@ -46,49 +43,60 @@ int type_bits() { } } +std::string order_symbol(int Order) { + if (Order == CBLAS_ORDER::CblasRowMajor) { + return "_rm"; + } else if (Order == CBLAS_ORDER::CblasColMajor) { + return "_col"; + } else { + return ""; + } +} + +std::string trans_symbol(int Trans) { + if (Trans == CBLAS_TRANSPOSE::CblasNoTrans) { + return "n"; + } else if (Trans == CBLAS_TRANSPOSE::CblasTrans) { + return "t"; + } else { + return ""; + } +} + +std::string uplo_symbol(int Uplo) { + if (Uplo == CBLAS_UPLO::CblasLower) { + return "l"; + } else if (Uplo == CBLAS_UPLO::CblasUpper) { + return "u"; + } else { + return ""; + } +} + template std::string kernel_name(std::string kernel) { - return lib_name() + "_" + type_prefix() + kernel; + return lib::lib_name() + "_" + type_prefix() + kernel; } template std::string level_2_kernel_name(std::string kernel) { - auto name = lib_name() + "_" + type_prefix() + kernel; - name += order == CBLAS_ORDER::CblasRowMajor ? "_rm" : "_col"; - if constexpr (TransA + Uplo) { - name += "_"; - } - if constexpr (TransA == CBLAS_TRANSPOSE::CblasNoTrans) { - name += "n"; - } else if constexpr (TransA == CBLAS_TRANSPOSE::CblasTrans) { - name += "t"; - } - - if constexpr (Uplo == CBLAS_UPLO::CblasUpper) { - name += "u"; - } else if constexpr (Uplo == CBLAS_UPLO::CblasLower) { - name += "l"; - } + auto name = lib::lib_name() + "_" + type_prefix() + kernel; + name += order_symbol(order); + name += Uplo + TransA ? "_" : ""; + name += uplo_symbol(Uplo); + name += trans_symbol(TransA); return name; } -template -std::string level_3_kernel_name(std::string kernel) { - auto name = lib_name() + "_" + type_prefix() + kernel; - name += order == CBLAS_ORDER::CblasRowMajor ? "_rm" : "_col"; - if constexpr (TransA + TransB) { - name += "_"; - } - if constexpr (TransA == CBLAS_TRANSPOSE::CblasNoTrans) { - name += "n"; - } else if constexpr (TransA == CBLAS_TRANSPOSE::CblasTrans) { - name += "t"; - } - if constexpr (TransB == CBLAS_TRANSPOSE::CblasNoTrans) { - name += "n"; - } else if constexpr (TransB == CBLAS_TRANSPOSE::CblasTrans) { - name += "t"; - } +template +std::string level_3_kernel_name(std::string kernel, int Order, int Uplo, + int TransA, int TransB) { + auto name = lib::lib_name() + "_" + type_prefix() + kernel; + name += order_symbol(Order); + name += Uplo + TransA + TransB ? "_" : ""; + name += uplo_symbol(Uplo); + name += trans_symbol(TransA); + name += trans_symbol(TransB); return name; } diff --git a/test/level3/gemm/bench.cpp b/test/level3/gemm/bench.cpp index 855e515..5843f97 100644 --- a/test/level3/gemm/bench.cpp +++ b/test/level3/gemm/bench.cpp @@ -72,9 +72,9 @@ static void args(benchmark::internal::Benchmark *b) { } } -#define call_gemm_bench(lib, T, order, TransA, TransB) \ - BENCHMARK(bench) \ - ->Name(level_3_kernel_name("gemm")) \ +#define call_gemm_bench(lib, T, order, TransA, TransB) \ + BENCHMARK(bench) \ + ->Name(level_3_kernel_name("gemm", order, 0, TransA, TransB)) \ ->Apply(args); #define call_gemm_bench_all(order, TransA, TransB) \ diff --git a/test/level3/syrk/bench.cpp b/test/level3/syrk/bench.cpp index 855e515..46c5085 100644 --- a/test/level3/syrk/bench.cpp +++ b/test/level3/syrk/bench.cpp @@ -2,57 +2,48 @@ #include #include "bench_ranges.h" -#include "exo_gemm_wrapper.h" +#include "exo_syrk_wrapper.h" #include "generate_buffer.h" #include "misc.h" -generate_wrapper(gemm); +generate_wrapper(syrk); template static void bench(benchmark::State &state) { - int M = state.range(0); - int N = state.range(1); - int K = state.range(2); - const enum CBLAS_ORDER order = (const enum CBLAS_ORDER)state.range(3); - const enum CBLAS_TRANSPOSE TransA = - (const enum CBLAS_TRANSPOSE)state.range(4); - const enum CBLAS_TRANSPOSE TransB = - (const enum CBLAS_TRANSPOSE)state.range(5); - const T alpha = state.range(6); - const int lda_diff = state.range(7); - const int ldb_diff = state.range(8); - const T beta = state.range(9); - const int ldc = N + state.range(10); - const int alignmentA = state.range(11); - const int alignmentB = state.range(12); - const int alignmentC = state.range(13); + int N = state.range(0); + int K = state.range(1); + const enum CBLAS_ORDER Order = (const enum CBLAS_ORDER)state.range(2); + const enum CBLAS_UPLO Uplo = (const enum CBLAS_UPLO)state.range(3); + const enum CBLAS_TRANSPOSE Trans = (const enum CBLAS_TRANSPOSE)state.range(4); + const T alpha = state.range(5); + const int lda_diff = state.range(6); + const T beta = state.range(7); + const int ldc = N + state.range(8); + const int alignmentA = state.range(9); + const int alignmentB = state.range(10); + const int alignmentC = state.range(11); - auto A_dims = get_dims(TransA, M, K, lda_diff); + auto A_dims = get_dims(Trans, N, K, lda_diff); const int lda = A_dims.second; - auto A = AlignedBuffer2D(A_dims.first, A_dims.second); - auto B_dims = get_dims(TransB, K, N, ldb_diff); - const int ldb = B_dims.second; - auto B = AlignedBuffer2D(B_dims.first, B_dims.second); - auto C = AlignedBuffer2D(M, N); + auto A = AlignedBuffer2D(A_dims.first, A_dims.second, alignmentA); + auto C = AlignedBuffer2D(N, ldc, alignmentC); for (auto _ : state) { - gemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, alpha, - A.data(), lda, B.data(), ldb, beta, C.data(), ldc); + syrk(Order, Uplo, Trans, N, K, alpha, A.data(), lda, beta, C.data(), + ldc); } } -template +template static void args(benchmark::internal::Benchmark *b) { - auto add_arg = [&b](int M, int N, int K) { - return b->Args({M, - N, + auto add_arg = [&b](int N, int K) { + return b->Args({N, K, order, - TransA, - TransB, + Uplo, + Trans, 17, 0, - 0, 1, 0, 64, @@ -61,27 +52,27 @@ static void args(benchmark::internal::Benchmark *b) { {BENCH_TYPES::level_3_eq}, {type_bits()}}); }; - b->ArgNames({"M", "N", "K", "order", "TransA", "TransB", "alpha", "lda_diff", - "ldb_diff", "beta", "ldc_diff", "alignmentA", "alignmentB", - "alignmentC", "bench_type", "precision"}); + b->ArgNames({"N", "K", "order", "Uplo", "Trans", "alpha", "lda_diff", "beta", + "ldc_diff", "alignmentA", "alignmentC", "bench_type", + "precision"}); for (int i = 1; i <= level_3_max_N; i *= 2) { - add_arg(i, i, i); + add_arg(i, i); } for (int i = 7; i <= level_3_max_N; i *= 7) { - add_arg(i, i, i); + add_arg(i, i); } } -#define call_gemm_bench(lib, T, order, TransA, TransB) \ +#define call_syrk_bench(lib, T, order, Uplo, Trans) \ BENCHMARK(bench) \ - ->Name(level_3_kernel_name("gemm")) \ - ->Apply(args); + ->Name(level_3_kernel_name("syrk", order, Uplo, Trans, 0)) \ + ->Apply(args); -#define call_gemm_bench_all(order, TransA, TransB) \ - call_gemm_bench(Exo, float, order, TransA, TransB); \ - call_gemm_bench(Cblas, float, order, TransA, TransB); \ - call_gemm_bench(Exo, double, order, TransA, TransB); \ - call_gemm_bench(Cblas, double, order, TransA, TransB); +#define call_syrk_bench_all(order, Uplo, Trans) \ + call_syrk_bench(Exo, float, order, Uplo, Trans); \ + call_syrk_bench(Cblas, float, order, Uplo, Trans); \ + call_syrk_bench(Exo, double, order, Uplo, Trans); \ + call_syrk_bench(Cblas, double, order, Uplo, Trans); -call_gemm_bench_all(CBLAS_ORDER::CblasRowMajor, CBLAS_TRANSPOSE::CblasNoTrans, +call_syrk_bench_all(CBLAS_ORDER::CblasRowMajor, CBLAS_UPLO::CblasLower, CBLAS_TRANSPOSE::CblasNoTrans); From 19f7e512b063f801f36327c1838173b28b91834f Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Mon, 25 Mar 2024 20:59:36 -0400 Subject: [PATCH 4/9] syrk macro schedule --- src/common/codegen_helpers.py | 14 ++-- src/common/stdlib.py | 5 +- src/level3/syrk.py | 119 ++++++++++++++++++++++++++++++++-- 3 files changed, 123 insertions(+), 15 deletions(-) diff --git a/src/common/codegen_helpers.py b/src/common/codegen_helpers.py index 2461ddb..83a8325 100644 --- a/src/common/codegen_helpers.py +++ b/src/common/codegen_helpers.py @@ -52,9 +52,7 @@ def generate_stride_1_proc(proc): proc = rename(proc, proc.name() + "_stride_1") for arg in proc.args(): if arg.is_tensor(): - proc = proc.add_assertion( - f"stride({arg.name()}, {len(arg.shape()) - 1}) == 1" - ) + proc = proc.add_assertion(f"stride({arg.name()}, {len(arg.shape()) - 1}) == 1") return proc @@ -110,15 +108,14 @@ def export_perf_features(kernel_name, perf_features): json.dump(perf_features, f, sort_keys=True, indent=4, separators=(",", ": ")) -def variants_generator(blas_op, opt_precisions=("f32", "f64"), targets=(AVX2, Neon)): +def variants_generator(blas_op, opt_precisions=("f32", "f64"), targets=(AVX2, Neon), stage_scalars=True): def generate(proc, loop_name, *args, globals=None, **kwargs): perf_features = {} for precision in ("f32", "f64"): proc_variant = specialize_precision(proc, precision) - proc_variant = stage_scalar_args(proc_variant) - stride_any = generate_stride_any_proc(proc_variant) + stride_any = stage_scalar_args(stride_any) stride_any = bind_builtins_args(stride_any, stride_any.body(), precision) export_exo_proc(globals, stride_any) @@ -127,9 +124,8 @@ def generate(proc, loop_name, *args, globals=None, **kwargs): algorithm = get_perf_features(stride_1) if precision in opt_precisions and C.Machine.mem_type in targets: - stride_1 = blas_op( - stride_1, loop, precision, C.Machine, *args, **kwargs - ) + stride_1 = blas_op(stride_1, loop, precision, C.Machine, *args, **kwargs) + stride_1 = stage_scalar_args(stride_1) stride_1 = bind_builtins_args(stride_1, stride_1.body(), precision) scheduled = get_perf_features(stride_1) diff --git a/src/common/stdlib.py b/src/common/stdlib.py index 6d574e9..8e55e7d 100644 --- a/src/common/stdlib.py +++ b/src/common/stdlib.py @@ -670,8 +670,9 @@ def get_window(cursor): for loop in my_loops: lo_rng = eval_rng(loop.lo(), env) hi_rng = eval_rng(loop.hi(), env) - env[loop.name()] = (lo_rng[0], hi_rng[1]) - window = tuple(eval_rng(idx, env) for idx in cursor.idx()) + env[loop.name()] = (lo_rng[0], f"{hi_rng[1]} - 1") + extend_hi = lambda rng: (rng[0], f"{rng[1]} + 1") if rng[0] != rng[1] else rng + window = tuple(extend_hi(eval_rng(idx, env)) for idx in cursor.idx()) return window def window_to_str(window): diff --git a/src/level3/syrk.py b/src/level3/syrk.py index 63ee374..2beb6e5 100644 --- a/src/level3/syrk.py +++ b/src/level3/syrk.py @@ -37,16 +37,127 @@ def syrk_rm_u(N: size, K: size, alpha: R, A: [R][N, K], A_alias: [R][N, K], C: [ syrk_rm_u = shift_loop(syrk_rm_u, "j", 0) -PARAMS = {AVX2: (4, 3, 66, 3, 512), AVX512: (6, 4, 44, 1, 512), Neon: (1, 1, 1, 1, 1)} + +def schedule_compute(compute, precision, machine, m_r, n_r_fac): + vw = machine.vec_width(precision) + n_r = vw * n_r_fac + i_loop = compute.body()[0] + j_loop = get_inner_loop(compute, i_loop) + k_loop = get_inner_loop(compute, j_loop) + compute, cs = auto_stage_mem(compute, k_loop, "C", "C_tile", accum=True, rc=1) + compute = lift_reduce_constant(compute, cs.load.expand(0, 1)) + assign = compute.forward(cs.store).prev() + compute = inline_assign(compute, assign) + compute = set_memory(compute, cs.alloc, machine.mem_type) + + compute = tile_loops_bottom_up(compute, i_loop, (m_r, n_r, None), tail="guard") + + compute = repeate_n(lift_scope)(compute, k_loop, n=4) + compute = divide_dim(compute, cs.alloc, 1, vw) + init_i, cmp_i, axpy_i = compute.find_loop("ii", many=True) + init_j, cmp_j, axpy_j = compute.find_loop("ji", many=True) + compute = vectorize(compute, init_j, vw, precision, machine.mem_type, tail="perfect") + compute = unroll_loop(compute, init_j) + compute, (o_cmp_j, i_cmp_j, _) = divide_loop_(compute, cmp_j, vw, tail="perfect", rc=True) + compute = simplify(compute) + compute, cursors = auto_stage_mem(compute, cmp_i, "packed_A_alias", rc=True) + compute = set_memory(compute, cursors.alloc, machine.mem_type) + compute = simplify(compute) + compute = divide_dim(compute, cursors.alloc, 0, vw) + compute = vectorize(compute, cursors.load, vw, precision, machine.mem_type, rules=[fma_rule], tail="perfect") + compute = unroll_loop(compute, cursors.load) + compute = vectorize(compute, i_cmp_j, vw, precision, machine.mem_type, rules=[fma_rule], tail="perfect") + compute = unroll_loop(compute, i_cmp_j) + compute = unroll_loop(compute, o_cmp_j) + compute, alpah_cursors = auto_stage_mem(compute, axpy_j.body(), "alpha", rc=True) + compute = vectorize(compute, axpy_j, vw, precision, machine.mem_type, rules=[fma_rule], tail="perfect") + compute = unroll_loop(compute, axpy_j) + compute = simplify(compute) + print(compute) + + def cut(proc, loop, cond, rng): + loop = proc.forward(loop) + cut_val = FormattedExprStr(f"_ - 1", loop.hi()) + proc, (loop1, loop2) = cut_loop_(proc, loop, cut_val, rc=True) + proc = specialize(proc, loop2.body(), [f"{cond(loop2, i)} == {i}" for i in rng]) + return proc + + right_cond = lambda l, i: f"(N - {l.name()} * {n_r} + {vw - 1}) / {vw}" + compute = cut(compute, j_loop, right_cond, range(1, n_r_fac)) + compute = dce(compute) + compute = replace_all_stmts(compute, machine.get_instructions(precision)) + + compute = simplify(unroll_loops(compute)) + bottom_cond = lambda l, i: f"N - {l.name()} * {m_r}" + compute = cut(compute, i_loop, bottom_cond, range(m_r, 1, -1)) + + def rewrite(p): + try: + p = delete_pass(p) + except: + pass + p = dce(p) + return simplify(p) + + blocks = compute.find_loop("C_tile:_", many=True) + for i, tile in enumerate(blocks): + name = compute.name() + str(i) + compute = extract_and_schedule(rewrite)(compute, tile.expand(), name) + return simplify(compute) + + +def schedule_macro(mk, precision, machine, max_M, max_N, max_K, m_r, n_r_fac): + vw = machine.vec_width(precision) + n_r = vw * n_r_fac + + for var, max_var in zip(("N", "K"), (max_M, max_N, max_K)): + mk = mk.add_assertion(f"{var} <= {max_var}") + + mk_starter = mk + mk = rename(mk, mk.name() + "_mk") + i_loop = mk.body()[0] + + packed_A_shape = ((0, max_M // m_r), (1, max_K), (0, m_r)) + mk, cursors = pack_mem(mk, i_loop, "A", packed_A_shape, "packed_A", rc=1) + mk = set_memory(mk, cursors.alloc, DRAM_STATIC) + mk, _ = extract_subproc(mk, cursors.load, mk.name() + "_A_pack") + + packed_B_shape = ((1, max_N // n_r), (0, max_K), (1, n_r)) + mk, cursors = pack_mem(mk, i_loop, "A_alias", packed_B_shape, "packed_A_alias", rc=1) + mk = set_memory(mk, cursors.alloc, DRAM_STATIC) + mk, _ = extract_subproc(mk, cursors.load, mk.name() + "_A_alias_pack") + + mk = extract_and_schedule(schedule_compute)(mk, i_loop, mk.name() + "_compute", precision, machine, m_r, n_r_fac) + return mk_starter, simplify(mk) + + +def schedule(proc, i_loop, precision, machine, m_r, n_r_fac, M_tile, N_tile, K_tile): + macro = schedule_macro(proc, precision, machine, M_tile, N_tile, K_tile, m_r, n_r_fac) + tiled = proc + k_loop = get_inner_loop(tiled, get_inner_loop(tiled, i_loop)) + tiled = repeate_n(lift_scope)(tiled, k_loop, n=2) + tiled = tile_loops_bottom_up(tiled, k_loop, [K_tile, M_tile, N_tile]) + tiled = apply(repeate_n(reorder_loops))(tiled, tiled.find_loop("ki", many=True), n=2) + tiled = replace_all_stmts(tiled, [macro]) + + macro_calls = filter_cursors(is_call)(tiled, nlr_stmts(tiled)) + tiled = simplify(apply(inline_proc_and_wins)(tiled, macro_calls)) + + tiled = apply(hoist_from_loop)(tiled, tiled.find_loop("jo", many=True)) + tiled = squash_buffers(tiled, tiled.find("packed_A : _", many=True)) + tiled = squash_buffers(tiled, tiled.find("packed_A_alias : _", many=True)) + print(tiled) + return simplify(tiled) + + +PARAMS = {AVX2: (2, 2, 66, 3, 512), AVX512: (6, 4, 44, 1, 512), Neon: (1, 1, 1, 1, 1)} m_r, n_r_fac, M_tile_fac, N_tile_fac, K_tile = PARAMS[C.Machine.mem_type] n_r = n_r_fac * C.Machine.vec_width("f32") M_tile = M_tile_fac * m_r N_tile = N_tile_fac * n_r -variants_generator(identity_schedule, ("f32",), (AVX2, AVX512))( - syrk_rm_l, "i", m_r, n_r_fac, M_tile, N_tile, K_tile, globals=globals() -) +variants_generator(schedule, ("f32",), (AVX2, AVX512))(syrk_rm_l, "i", m_r, n_r_fac, M_tile, N_tile, K_tile, globals=globals()) variants_generator(identity_schedule, ("f32",), (AVX2, AVX512))( syrk_rm_u, "i", m_r, n_r_fac, M_tile, N_tile, K_tile, globals=globals() ) From b068756a75cd4af7ee393e539346118c71ad1357 Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Mon, 25 Mar 2024 23:50:25 -0400 Subject: [PATCH 5/9] Cache tiling for syrk --- src/common/perf_features.py | 28 +++++------- src/common/stdlib.py | 20 +++++++-- src/level3/gemm.py | 3 +- src/level3/syrk.py | 57 ++++++++++++++++--------- test/codegen/reference/sha256/avx2.json | 2 +- 5 files changed, 66 insertions(+), 44 deletions(-) diff --git a/src/common/perf_features.py b/src/common/perf_features.py index 7477083..ebf7e10 100644 --- a/src/common/perf_features.py +++ b/src/common/perf_features.py @@ -63,9 +63,7 @@ def get_stmt_ops(cursor, syms): iter_sym = sp.Symbol(cursor.name(), integer=integers, positive=integers) new_syms[cursor.name()] = iter_sym stmts_ops = [get_stmt_ops(s, new_syms) for s in cursor.body()] - sums = [ - sp.summation(s_ops, (iter_sym, lo, hi - one)) for s_ops in stmts_ops - ] + sums = [sp.summation(s_ops, (iter_sym, lo, hi - one)) for s_ops in stmts_ops] return sum(sums) else: return zero @@ -74,14 +72,14 @@ def get_stmt_ops(cursor, syms): for arg in proc.args(): if arg.type().is_indexable(): is_size = arg.type() == ExoType.Size - syms[arg.name()] = sp.Symbol( - arg.name(), integer=integers, positive=is_size and integers - ) - - ops = get_stmt_ops(proc.body(), syms) - ops = sp.simplify(ops) - ops = sp.factor(ops) - return ops + syms[arg.name()] = sp.Symbol(arg.name(), integer=integers, positive=is_size and integers) + try: + ops = get_stmt_ops(proc.body(), syms) + ops = sp.simplify(ops) + ops = sp.factor(ops) + return ops + except: + return zero def count_flops(proc, upper=False): @@ -91,9 +89,7 @@ def count_flops(proc, upper=False): def get_expr_ops(proc, expr): expr = proc.forward(expr) if isinstance(expr, (BuiltInFunctionCursor, BinaryOpCursor, UnaryMinusCursor)): - children_ops = sum( - get_expr_ops(proc, c) for c in get_numeric_children(proc, expr) - ) + children_ops = sum(get_expr_ops(proc, c) for c in get_numeric_children(proc, expr)) return one + children_ops return zero @@ -140,9 +136,7 @@ def get_bytes_traffic(p, c): def get_expr_loads(proc, expr): expr = proc.forward(expr) if isinstance(expr, (BuiltInFunctionCursor, BinaryOpCursor, UnaryMinusCursor)): - return sum( - get_expr_loads(proc, c) for c in get_numeric_children(proc, expr) - ) + return sum(get_expr_loads(proc, c) for c in get_numeric_children(proc, expr)) elif isinstance(expr, ReadCursor): return get_bytes_traffic(proc, expr) return zero diff --git a/src/common/stdlib.py b/src/common/stdlib.py index 8e55e7d..1501069 100644 --- a/src/common/stdlib.py +++ b/src/common/stdlib.py @@ -566,8 +566,7 @@ def get_depth(proc, scope): def push_scope_in(proc, scope, depth, size=None): scope = proc.forward(scope) allocs = list(filter_cursors(is_alloc)(proc, scope.body())) - proc = apply(attempt(parallelize_and_lift_alloc))(proc, allocs) - + proc = apply(parallelize_and_lift_alloc)(proc, allocs) if const_allocs and size: proc = apply(lambda p, a: resize_dim(p, a, 0, size, 0))(proc, allocs) @@ -1015,11 +1014,18 @@ def bound_loop_by_if(proc, loop): if not isinstance(if_c.orelse(), InvalidCursor): raise BLAS_SchedulingError(err) - if not isinstance(if_c.cond().lhs(), ReadCursor) or if_c.cond().lhs().name() != loop.name() or if_c.cond().op() != "<": + cond = if_c.cond() + cond_lhs = cond.lhs() + + if is_read(proc, cond_lhs, loop.name()) and cond.op() == "<": + cut_point = FormattedExprStr("_ + _", loop.lo(), cond.rhs()) + elif is_add(proc, cond_lhs) and is_read(proc, cond_lhs.lhs(), loop.name()) and cond.op() == "<": + cut_point = FormattedExprStr("_ + (_ - _)", loop.lo(), cond.rhs(), cond_lhs.rhs()) + else: raise BLAS_SchedulingError(err) if_c = loop.body()[0] - proc = cut_loop(proc, loop, FormattedExprStr("_ + _", loop.lo(), if_c.cond().rhs())) + proc = cut_loop(proc, loop, cut_point) loop1 = proc.forward(loop) loop2 = loop1.next() proc = eliminate_dead_code(proc, loop1.body()[0]) @@ -1359,3 +1365,9 @@ def add_loop_nest(loop): load = load.as_block().expand(0, diff) return proc, pack_mem_cursors(alloc, load, block, store) + + +def inline_calls(proc, block=InvalidCursor(), subproc=None): + calls = filter_cursors(is_call)(proc, nlr_stmts(proc, block), subproc) + proc = simplify(apply(inline_proc_and_wins)(proc, calls)) + return proc diff --git a/src/level3/gemm.py b/src/level3/gemm.py index 05e4a5c..c133fd7 100644 --- a/src/level3/gemm.py +++ b/src/level3/gemm.py @@ -127,8 +127,7 @@ def schedule(main_gemm, i_loop, precision, machine, m_r, n_r_fac, M_tile, N_tile gemm_tiled = apply(repeate_n(reorder_loops))(gemm_tiled, gemm_tiled.find_loop("ki", many=True), n=2) gemm_tiled = replace_all_stmts(gemm_tiled, [gemm_macro]) - macro_calls = filter_cursors(is_call)(gemm_tiled, nlr_stmts(gemm_tiled)) - gemm_tiled = simplify(apply(inline_proc_and_wins)(gemm_tiled, macro_calls)) + gemm_tiled = inline_calls(gemm_tiled, subproc=gemm_macro[1]) gemm_tiled = apply(hoist_from_loop)(gemm_tiled, gemm_tiled.find_loop("jo", many=True)) gemm_tiled = squash_buffers(gemm_tiled, gemm_tiled.find("packed_A : _", many=True)) diff --git a/src/level3/syrk.py b/src/level3/syrk.py index 2beb6e5..10c86fe 100644 --- a/src/level3/syrk.py +++ b/src/level3/syrk.py @@ -1,4 +1,5 @@ from __future__ import annotations +import math from exo import * from exo.stdlib.scheduling import * @@ -73,7 +74,6 @@ def schedule_compute(compute, precision, machine, m_r, n_r_fac): compute = vectorize(compute, axpy_j, vw, precision, machine.mem_type, rules=[fma_rule], tail="perfect") compute = unroll_loop(compute, axpy_j) compute = simplify(compute) - print(compute) def cut(proc, loop, cond, rng): loop = proc.forward(loop) @@ -106,22 +106,23 @@ def rewrite(p): return simplify(compute) -def schedule_macro(mk, precision, machine, max_M, max_N, max_K, m_r, n_r_fac): +def schedule_macro(mk, precision, machine, max_N, max_K, m_r, n_r_fac): vw = machine.vec_width(precision) n_r = vw * n_r_fac - for var, max_var in zip(("N", "K"), (max_M, max_N, max_K)): + for var, max_var in zip(("N", "K"), (max_N, max_N, max_K)): mk = mk.add_assertion(f"{var} <= {max_var}") mk_starter = mk mk = rename(mk, mk.name() + "_mk") i_loop = mk.body()[0] - packed_A_shape = ((0, max_M // m_r), (1, max_K), (0, m_r)) + packed_A_shape = ((0, max_N // m_r), (1, max_K), (0, m_r)) mk, cursors = pack_mem(mk, i_loop, "A", packed_A_shape, "packed_A", rc=1) mk = set_memory(mk, cursors.alloc, DRAM_STATIC) mk, _ = extract_subproc(mk, cursors.load, mk.name() + "_A_pack") + # TODO: This packing step is doing more work the necessary (packing the whole matrix, not jus triangle) packed_B_shape = ((1, max_N // n_r), (0, max_K), (1, n_r)) mk, cursors = pack_mem(mk, i_loop, "A_alias", packed_B_shape, "packed_A_alias", rc=1) mk = set_memory(mk, cursors.alloc, DRAM_STATIC) @@ -131,33 +132,49 @@ def schedule_macro(mk, precision, machine, max_M, max_N, max_K, m_r, n_r_fac): return mk_starter, simplify(mk) -def schedule(proc, i_loop, precision, machine, m_r, n_r_fac, M_tile, N_tile, K_tile): - macro = schedule_macro(proc, precision, machine, M_tile, N_tile, K_tile, m_r, n_r_fac) +def schedule(proc, i_loop, precision, machine, m_r, n_r_fac, N_tile, K_tile): + macro = schedule_macro(proc, precision, machine, N_tile, K_tile, m_r, n_r_fac) tiled = proc - k_loop = get_inner_loop(tiled, get_inner_loop(tiled, i_loop)) + j_loop = get_inner_loop(tiled, i_loop) + k_loop = get_inner_loop(tiled, j_loop) tiled = repeate_n(lift_scope)(tiled, k_loop, n=2) - tiled = tile_loops_bottom_up(tiled, k_loop, [K_tile, M_tile, N_tile]) + tiled = tile_loops_bottom_up(tiled, k_loop, [K_tile, N_tile, N_tile], tail="guard") + + # TODO: This code should be a part of tiling + def rewrite(proc, loop): + loop = proc.forward(loop) + cut_point = FormattedExprStr("_ - 1", loop.hi()) + proc, (loop, loop2) = cut_loop_(proc, loop, cut_point, rc=True) + proc = simplify(shift_loop(proc, loop2, 0)) + proc = attempt(unroll_loop)(proc, loop2) + return proc + + tiled = apply(rewrite)(tiled, (j_loop, i_loop, k_loop)) + tiled = simplify(dce(tiled)) + tiled = apply(attempt(bound_loop_by_if))(tiled, tiled.find_loop("ki", many=True)) + tiled = apply(attempt(bound_loop_by_if))(tiled, tiled.find_loop("ii", many=True)) + tiled = apply(attempt(bound_loop_by_if))(tiled, tiled.find_loop("ji", many=True)) + tiled = simplify(delete_pass(tiled)) + tiled = apply(repeate_n(reorder_loops))(tiled, tiled.find_loop("ki", many=True), n=2) tiled = replace_all_stmts(tiled, [macro]) - - macro_calls = filter_cursors(is_call)(tiled, nlr_stmts(tiled)) - tiled = simplify(apply(inline_proc_and_wins)(tiled, macro_calls)) + tiled = inline_calls(tiled, subproc=macro[1]) + # TODO: Replace gemm calls here... tiled = apply(hoist_from_loop)(tiled, tiled.find_loop("jo", many=True)) tiled = squash_buffers(tiled, tiled.find("packed_A : _", many=True)) tiled = squash_buffers(tiled, tiled.find("packed_A_alias : _", many=True)) - print(tiled) return simplify(tiled) -PARAMS = {AVX2: (2, 2, 66, 3, 512), AVX512: (6, 4, 44, 1, 512), Neon: (1, 1, 1, 1, 1)} +# TODO: Figure out proper parameters +PARAMS = {AVX2: (4, 3, 32, 512), AVX512: (6, 4, 44, 512), Neon: (1, 1, 1, 1)} -m_r, n_r_fac, M_tile_fac, N_tile_fac, K_tile = PARAMS[C.Machine.mem_type] +m_r, n_r_fac, N_tile_fac, K_tile = PARAMS[C.Machine.mem_type] n_r = n_r_fac * C.Machine.vec_width("f32") -M_tile = M_tile_fac * m_r -N_tile = N_tile_fac * n_r -variants_generator(schedule, ("f32",), (AVX2, AVX512))(syrk_rm_l, "i", m_r, n_r_fac, M_tile, N_tile, K_tile, globals=globals()) -variants_generator(identity_schedule, ("f32",), (AVX2, AVX512))( - syrk_rm_u, "i", m_r, n_r_fac, M_tile, N_tile, K_tile, globals=globals() -) +lcm = (m_r * n_r) // math.gcd(m_r, n_r) +N_tile = lcm + +variants_generator(schedule, ("f32",), (AVX2, AVX512))(syrk_rm_l, "i", m_r, n_r_fac, N_tile, K_tile, globals=globals()) +variants_generator(identity_schedule, ("f32",), (AVX2, AVX512))(syrk_rm_u, "i", m_r, n_r_fac, N_tile, K_tile, globals=globals()) diff --git a/test/codegen/reference/sha256/avx2.json b/test/codegen/reference/sha256/avx2.json index 99af8b1..a215e37 100644 --- a/test/codegen/reference/sha256/avx2.json +++ b/test/codegen/reference/sha256/avx2.json @@ -19,7 +19,7 @@ "exo_symv": "f96e0661b0221c69b43d9e476c3f5ba9a3ee41168149232b6a414145dc0d77f5", "exo_syr": "3e6894a8a9003ede58e06c57b261a1638fd6249b8a80ed356e147862f06f39aa", "exo_syr2": "9285dc796c9c573cbd974f419563a0c2e3d7507aba3521ee1c89e9e707913332", - "exo_syrk": "9894ba92a502df8968c0c4e1e09cb5510a9ad10d77f4e1595570a7d1a2167b4b", + "exo_syrk": "7a788ac95fe18b1198b7ff67f64b4436e0862051a733058912a1917cdc67aa95", "exo_tbmv": "e517f633eeaf1429c2204966a2970e7013054afd1f0bb22795075cfa5e4678db", "exo_tbsv": "faeb1392d2af7dc9cdac9fb707bd7a3273e82c92bcd940e869fb7bc5be14f020", "exo_trmm": "70f7aa84d76fe3be02cbbc5db13fdc9d55b9dd481ba02616f1457a15a653a074", From ab8c4026aa8641e26a15e7abac08d3a04ba2a4ee Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Tue, 26 Mar 2024 11:28:49 -0400 Subject: [PATCH 6/9] Bug fix --- src/common/stdlib.py | 4 +--- src/level3/syrk.py | 9 ++++----- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/common/stdlib.py b/src/common/stdlib.py index 1501069..0ad65de 100644 --- a/src/common/stdlib.py +++ b/src/common/stdlib.py @@ -1340,9 +1340,7 @@ def add_loop_nest(loop): for src_dim, _ in reversed(list(enumerate(alloc.shape()))): for dst_dim, size in reversed(divisions[src_dim][1:]): proc = divide_dim(proc, alloc, src_dim, size) - proc = apply(divide_loop_)( - proc, loop_nest[src_dim], size, tail="cut" - ) # TODO: This should be enabled but it slows down compilation + proc = apply(divide_loop_)(proc, loop_nest[src_dim], size, tail="cut") perm.append(dst_dim) perm.append(divisions[src_dim][0][0]) diff --git a/src/level3/syrk.py b/src/level3/syrk.py index 10c86fe..7b9e893 100644 --- a/src/level3/syrk.py +++ b/src/level3/syrk.py @@ -110,24 +110,22 @@ def schedule_macro(mk, precision, machine, max_N, max_K, m_r, n_r_fac): vw = machine.vec_width(precision) n_r = vw * n_r_fac - for var, max_var in zip(("N", "K"), (max_N, max_N, max_K)): + for var, max_var in zip(("N", "K"), (max_N, max_K)): mk = mk.add_assertion(f"{var} <= {max_var}") mk_starter = mk mk = rename(mk, mk.name() + "_mk") i_loop = mk.body()[0] - packed_A_shape = ((0, max_N // m_r), (1, max_K), (0, m_r)) mk, cursors = pack_mem(mk, i_loop, "A", packed_A_shape, "packed_A", rc=1) mk = set_memory(mk, cursors.alloc, DRAM_STATIC) mk, _ = extract_subproc(mk, cursors.load, mk.name() + "_A_pack") # TODO: This packing step is doing more work the necessary (packing the whole matrix, not jus triangle) - packed_B_shape = ((1, max_N // n_r), (0, max_K), (1, n_r)) - mk, cursors = pack_mem(mk, i_loop, "A_alias", packed_B_shape, "packed_A_alias", rc=1) + packed_A_alias_shape = ((0, max_N // n_r), (1, max_K), (0, n_r)) + mk, cursors = pack_mem(mk, i_loop, "A_alias", packed_A_alias_shape, "packed_A_alias", rc=1) mk = set_memory(mk, cursors.alloc, DRAM_STATIC) mk, _ = extract_subproc(mk, cursors.load, mk.name() + "_A_alias_pack") - mk = extract_and_schedule(schedule_compute)(mk, i_loop, mk.name() + "_compute", precision, machine, m_r, n_r_fac) return mk_starter, simplify(mk) @@ -164,6 +162,7 @@ def rewrite(proc, loop): tiled = apply(hoist_from_loop)(tiled, tiled.find_loop("jo", many=True)) tiled = squash_buffers(tiled, tiled.find("packed_A : _", many=True)) tiled = squash_buffers(tiled, tiled.find("packed_A_alias : _", many=True)) + return simplify(tiled) From c949d672f65f24e99375ed5794eff5c749f143bd Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Tue, 26 Mar 2024 12:00:28 -0400 Subject: [PATCH 7/9] Update syrk codegen hash --- test/codegen/reference/sha256/avx2.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/codegen/reference/sha256/avx2.json b/test/codegen/reference/sha256/avx2.json index a215e37..0edcc79 100644 --- a/test/codegen/reference/sha256/avx2.json +++ b/test/codegen/reference/sha256/avx2.json @@ -19,7 +19,7 @@ "exo_symv": "f96e0661b0221c69b43d9e476c3f5ba9a3ee41168149232b6a414145dc0d77f5", "exo_syr": "3e6894a8a9003ede58e06c57b261a1638fd6249b8a80ed356e147862f06f39aa", "exo_syr2": "9285dc796c9c573cbd974f419563a0c2e3d7507aba3521ee1c89e9e707913332", - "exo_syrk": "7a788ac95fe18b1198b7ff67f64b4436e0862051a733058912a1917cdc67aa95", + "exo_syrk": "13da1a63a67c91ece4a53d9948ca70422b846d5cea22ff4f3e31fa7b4e08e816", "exo_tbmv": "e517f633eeaf1429c2204966a2970e7013054afd1f0bb22795075cfa5e4678db", "exo_tbsv": "faeb1392d2af7dc9cdac9fb707bd7a3273e82c92bcd940e869fb7bc5be14f020", "exo_trmm": "70f7aa84d76fe3be02cbbc5db13fdc9d55b9dd481ba02616f1457a15a653a074", From 86671e4fa93c3a7b1e92ea829ad3639bce0db74e Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Tue, 26 Mar 2024 15:20:56 -0400 Subject: [PATCH 8/9] Update gemm codegen --- test/codegen/reference/sha256/avx2.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/codegen/reference/sha256/avx2.json b/test/codegen/reference/sha256/avx2.json index 0edcc79..703589d 100644 --- a/test/codegen/reference/sha256/avx2.json +++ b/test/codegen/reference/sha256/avx2.json @@ -5,7 +5,7 @@ "exo_dot": "8e7a71353e80273839bc522cfc8c889cbe25d8701a8c82a3c2559b12d9e90f5f", "exo_dsdot": "c901be9d30928e042c35aeb9dab421a34db6d593534f31b4967f9685ecf9628b", "exo_gbmv": "cb92744337cfdf3aa97d250a18e540c2e8787380ba88bfef790a3e14aeb19f37", - "exo_gemm": "8d28b97725589e64ca79c8b93ee45d2a3b3b09d3dc165e6fcd3bcb24affdf8f4", + "exo_gemm": "7f0fafc43017ecfe5a1a6c2fa4c96fa77c5f4f737aaf3204b37ed65c691e5ec6", "exo_gemv": "80d79a5752c20874fbe3d2c94989f5f5663687d405966065a18224800c931559", "exo_ger": "38ba46c410ea9bd1d616add516e9ab82ed811e5785848bce17191d89f81752ce", "exo_iamax": "49c60714c479234683166e5651fbe95ed5a43ecd370a391732769588948cc842", From e61fd74704321538a834775146401dc8c31a215f Mon Sep 17 00:00:00 2001 From: Samir Droubi Date: Tue, 26 Mar 2024 16:04:11 -0400 Subject: [PATCH 9/9] Ignore syrk from codegen testing * Due to non-deterministic unification --- test/codegen/hash.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/test/codegen/hash.py b/test/codegen/hash.py index 5ad711d..784f25d 100644 --- a/test/codegen/hash.py +++ b/test/codegen/hash.py @@ -17,7 +17,7 @@ # Be very catious adding tests here. Ideally, we don't want any. -NON_DETERMINISTIC_TESTS = {"exo_trmv", "exo_trsv", "exo_syr"} +NON_DETERMINISTIC_TESTS = {"exo_trmv", "exo_trsv", "exo_syr", "exo_syrk"} def get_diff(file1, file2): @@ -29,9 +29,7 @@ def get_diff(file1, file2): f2_text = f2.readlines() if f1_text != f2_text: - diff = difflib.unified_diff( - f1_text, f2_text, fromfile=str(file1), tofile=str(file2), lineterm="" - ) + diff = difflib.unified_diff(f1_text, f2_text, fromfile=str(file1), tofile=str(file2), lineterm="") diff = "\n".join(diff) return diff @@ -122,10 +120,7 @@ def check_sha256(target_arch, level, kernel): reference_source = get_reference_source_filename(target_arch, kernel) if not os.path.exists(reference_source): - exit( - err - + f"Reference source was not found at {reference_source} to show the diff.\n{update_instructions}." - ) + exit(err + f"Reference source was not found at {reference_source} to show the diff.\n{update_instructions}.") reference_source_hash = get_reference_source_hash(target_arch, kernel) if reference_source_hash != reference_hash: