Skip to content

Commit

Permalink
Add script to tune parameters [skip benchmarks]
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasfaingnaert committed Dec 2, 2023
1 parent 6a8e8cb commit 30e6e26
Show file tree
Hide file tree
Showing 10 changed files with 605 additions and 14 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
test/Manifest.toml
Manifest.toml
tuning/.CondaPkg/
29 changes: 16 additions & 13 deletions configs/configs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using GemmKernels
using LinearAlgebra
using ForwardDiff
using Octavian

struct Configuration
name # Human-readable name of the configuration.
Expand Down Expand Up @@ -64,7 +65,8 @@ function generate_inputs(cf::Configuration)
new_b_h = cf.transpose_b ? transpose(b_h) : b_h

(cf.calc_reference)(c_h, new_a_h, new_b_h, cf.alpha, cf.beta)
c_h, a, b, c, d
c_ref = CuArray(c_h)
c_ref, a, b, c, d
end

# Run the GEMM.
Expand All @@ -88,21 +90,21 @@ function run_baseline(cf::Configuration, a, b, c, d)
end

# Verify results.
function verify(cf::Configuration, c_h, d)
cf.verify(c_h, d)
function verify(cf::Configuration, c_ref, d)
cf.verify(c_ref, d)
end

function verify_default(c_h, d)
isapprox(c_h, Array(d))
function verify_default(c_ref, d)
isapprox(c_ref, d)
end

function verify_bias(c_h, d, bias)
c_h .+ Array(bias) Array(d)
function verify_bias(c_ref, d, bias)
c_ref .+ bias d
end

function verify_dual(c_h, d)
c_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, c_h)
d_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, Array(d))
function verify_dual(c_ref, d)
c_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, c_ref)
d_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, d)
isapprox(c_dual, d_dual)
end

Expand Down Expand Up @@ -238,10 +240,10 @@ macro get_wmma_config()
CD_type,
transpose_a,
transpose_b,
mul!,
Octavian.matmul!,
Epilogue.Default(),
verify_default,
Kernel.matmul_pipelined,
kernel,
wmma_baseline)
end end)
end
Expand Down Expand Up @@ -520,7 +522,8 @@ function get_configs()
[2, 2, 1],
[1, 1, 2],
[2, 2, 2]], [[2048, 2048, 2048]]),
zero_c in [false]
zero_c in [false],
kernel in [Kernel.matmul_pipelined]

push!(rv, @get_wmma_config)
end
Expand Down
6 changes: 6 additions & 0 deletions src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,12 @@ function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kw
prod(mem_b_warp) * warps_per_block block_shape.K * block_shape.N || throw(ConfigError("mem_b_warp is too big for the selected block shape: need at least one iteration in the memory copy loop!"))
prod(mem_cd_warp) * warps_per_block block_shape.M * block_shape.N || throw(ConfigError("mem_cd_warp is too big for the selected block shape: need at least one iteration in the memory copy loop!"))

# Check sizes of tiles
check_tile_smaller(lhs, rhs, msg) = ((lhs.M rhs.M) && (lhs.N rhs.N) && (lhs.K rhs.K)) || throw(ConfigError(msg))

check_tile_smaller(compute_warp, block_shape, "compute_warp must be smaller than block_shape!")
check_tile_smaller(block_shape, gemm_shape, "block_shape must be smaller than gemm_shape!")

return Config(
#= Params =#
gemm_shape,
Expand Down
24 changes: 24 additions & 0 deletions src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,30 @@ function matmul(conf::Config, a, b, c, d;
conf.block_shape.K 2 * conf.compute_op_shape.K || throw(ConfigError("Need at least two stages to use a pipelined kernel, i.e. BLOCK_K ≥ 2 * OPERATOR_K"))
end

# Check LocalArray size limit of 32 elements.
if kernel == Kernel.matmul_singlestage
num_fragments_m = conf.compute_warp.M ÷ conf.compute_op_shape.M
num_fragments_n = conf.compute_warp.N ÷ conf.compute_op_shape.N

num_fragments_m * num_fragments_n < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!"))
end

if kernel == Kernel.matmul_pipelined
num_fragments_m = conf.compute_warp.M ÷ conf.compute_op_shape.M
num_fragments_n = conf.compute_warp.N ÷ conf.compute_op_shape.N

a_frag_i = (conf.block_shape.M * conf.block_shape.K) ÷ (conf.mem_a_warp.M * conf.mem_a_warp.K * conf.warps_per_block)
a_frag_j = (conf.mem_a_warp.M * conf.mem_a_warp.K) ÷ (conf.mem_a_thread.M * conf.mem_a_thread.K * 32)
b_frag_i = (conf.block_shape.K * conf.block_shape.N) ÷ (conf.mem_b_warp.K * conf.mem_b_warp.N * conf.warps_per_block)
b_frag_j = (conf.mem_b_warp.K * conf.mem_b_warp.N) ÷ (conf.mem_b_thread.K * conf.mem_b_thread.N * 32)

num_fragments_m * num_fragments_n < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!"))
a_frag_i * a_frag_j < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!"))
b_frag_i * b_frag_j < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!"))
2 * num_fragments_m < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!"))
2 * num_fragments_n < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!"))
end

hostkernel = @cuda launch=false kernel(args...)
attributes(hostkernel.fun)[CUDA.FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES] = shmem

Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
XUnit = "3e3c03f2-1a94-11e9-2981-050a4ca824ab"
14 changes: 14 additions & 0 deletions tuning/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
Hyperopt = "93e5fe13-2215-51db-baaf-2e9a34fb2712"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
PythonPlot = "274fc56d-3b97-40fa-a1cd-1b4a50311bf9"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Binary file added tuning/best-configs.bin
Binary file not shown.
Binary file added tuning/configs.bin
Binary file not shown.
Loading

0 comments on commit 30e6e26

Please sign in to comment.