Skip to content

Commit

Permalink
Add a check for the block shape in the K dimension (#167)
Browse files Browse the repository at this point in the history
  • Loading branch information
wardvermeulen authored Nov 3, 2023
1 parent 221d911 commit 41466e4
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kw

op_shape = Operator.shape(operator)

if block_shape.M < 2 * op_shape.M || block_shape.N < 2 * op_shape.N
if block_shape.M < 2 * op_shape.M || block_shape.N < 2 * op_shape.N || block_shape.K < op_shape.K
# TODO: Find out why this is.
throw(ConfigError("There is a mismatch between the block shape and the operator shape. Their dimensions must adhere to the following constraints: BLOCK_M ≥ 2 * OPERATOR_M, BLOCK_N ≥ 2 * OPERATOR_N."))
throw(ConfigError("There is a mismatch between the block shape and the operator shape. Their dimensions must adhere to the following constraints: BLOCK_M ≥ 2 * OPERATOR_M, BLOCK_N ≥ 2 * OPERATOR_N, BLOCK_K ≥ OPERATOR_K."))
end

check_operator_config(operator)
Expand Down

0 comments on commit 41466e4

Please sign in to comment.