From 41466e4809d60524e280ec1ef816aca4fb0308c4 Mon Sep 17 00:00:00 2001 From: Ward Vermeulen <37931310+wardvermeulen@users.noreply.github.com> Date: Fri, 3 Nov 2023 13:02:30 +0100 Subject: [PATCH] Add a check for the block shape in the K dimension (#167) --- src/config.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/config.jl b/src/config.jl index 436675b5..ccccc11a 100644 --- a/src/config.jl +++ b/src/config.jl @@ -157,9 +157,9 @@ function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kw op_shape = Operator.shape(operator) - if block_shape.M < 2 * op_shape.M || block_shape.N < 2 * op_shape.N + if block_shape.M < 2 * op_shape.M || block_shape.N < 2 * op_shape.N || block_shape.K < op_shape.K # TODO: Find out why this is. - throw(ConfigError("There is a mismatch between the block shape and the operator shape. Their dimensions must adhere to the following constraints: BLOCK_M ≥ 2 * OPERATOR_M, BLOCK_N ≥ 2 * OPERATOR_N.")) + throw(ConfigError("There is a mismatch between the block shape and the operator shape. Their dimensions must adhere to the following constraints: BLOCK_M ≥ 2 * OPERATOR_M, BLOCK_N ≥ 2 * OPERATOR_N, BLOCK_K ≥ OPERATOR_K.")) end check_operator_config(operator)