From 17aa1d586b2fdf51765b6ce59b19942c315e1573 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Wed, 15 Nov 2023 17:54:11 +0100 Subject: [PATCH 01/21] Add script to tune parameters [skip benchmarks] --- .gitignore | 3 +- configs/configs.jl | 29 +-- src/config.jl | 6 + src/matmul.jl | 24 ++ test/Project.toml | 1 + tuning/Project.toml | 14 ++ tuning/tune-wmma.jl | 522 ++++++++++++++++++++++++++++++++++++++++++++ tuning/tune-wmma.sh | 24 ++ 8 files changed, 609 insertions(+), 14 deletions(-) create mode 100644 tuning/Project.toml create mode 100644 tuning/tune-wmma.jl create mode 100755 tuning/tune-wmma.sh diff --git a/.gitignore b/.gitignore index c181d1f8..d9a944ad 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -test/Manifest.toml +Manifest.toml +tuning/.CondaPkg/ diff --git a/configs/configs.jl b/configs/configs.jl index c88b2d8d..d0f7b672 100644 --- a/configs/configs.jl +++ b/configs/configs.jl @@ -3,6 +3,7 @@ using GemmKernels using LinearAlgebra using ForwardDiff +using Octavian struct Configuration name # Human-readable name of the configuration. @@ -64,7 +65,8 @@ function generate_inputs(cf::Configuration) new_b_h = cf.transpose_b ? transpose(b_h) : b_h (cf.calc_reference)(c_h, new_a_h, new_b_h, cf.alpha, cf.beta) - c_h, a, b, c, d + c_ref = CuArray(c_h) + c_ref, a, b, c, d end # Run the GEMM. @@ -88,21 +90,21 @@ function run_baseline(cf::Configuration, a, b, c, d) end # Verify results. -function verify(cf::Configuration, c_h, d) - cf.verify(c_h, d) +function verify(cf::Configuration, c_ref, d) + cf.verify(c_ref, d) end -function verify_default(c_h, d) - isapprox(c_h, Array(d)) +function verify_default(c_ref, d) + isapprox(c_ref, d) end -function verify_bias(c_h, d, bias) - c_h .+ Array(bias) ≈ Array(d) +function verify_bias(c_ref, d, bias) + c_ref .+ bias ≈ d end -function verify_dual(c_h, d) - c_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, c_h) - d_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, Array(d)) +function verify_dual(c_ref, d) + c_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, c_ref) + d_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, d) isapprox(c_dual, d_dual) end @@ -238,10 +240,10 @@ macro get_wmma_config() CD_type, transpose_a, transpose_b, - mul!, + Octavian.matmul!, Epilogue.Default(), verify_default, - Kernel.matmul_pipelined, + kernel, wmma_baseline) end end) end @@ -520,7 +522,8 @@ function get_configs() [2, 2, 1], [1, 1, 2], [2, 2, 2]], [[2048, 2048, 2048]]), - zero_c in [false] + zero_c in [false], + kernel in [Kernel.matmul_pipelined] push!(rv, @get_wmma_config) end diff --git a/src/config.jl b/src/config.jl index 29c26f0c..846f4c62 100644 --- a/src/config.jl +++ b/src/config.jl @@ -215,6 +215,12 @@ function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kw prod(mem_b_warp) * warps_per_block ≤ block_shape.K * block_shape.N || throw(ConfigError("mem_b_warp is too big for the selected block shape: need at least one iteration in the memory copy loop!")) prod(mem_cd_warp) * warps_per_block ≤ block_shape.M * block_shape.N || throw(ConfigError("mem_cd_warp is too big for the selected block shape: need at least one iteration in the memory copy loop!")) + # Check sizes of tiles + check_tile_smaller(lhs, rhs, msg) = ((lhs.M ≤ rhs.M) && (lhs.N ≤ rhs.N) && (lhs.K ≤ rhs.K)) || throw(ConfigError(msg)) + + check_tile_smaller(compute_warp, block_shape, "compute_warp must be smaller than block_shape!") + check_tile_smaller(block_shape, gemm_shape, "block_shape must be smaller than gemm_shape!") + return Config( #= Params =# gemm_shape, diff --git a/src/matmul.jl b/src/matmul.jl index 3c4bde0d..f3dfd34f 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -34,6 +34,30 @@ function matmul(conf::Config, a, b, c, d; conf.block_shape.K ≥ 2 * conf.compute_op_shape.K || throw(ConfigError("Need at least two stages to use a pipelined kernel, i.e. BLOCK_K ≥ 2 * OPERATOR_K")) end + # Check LocalArray size limit of 32 elements. + if kernel == Kernel.matmul_singlestage + num_fragments_m = conf.compute_warp.M ÷ conf.compute_op_shape.M + num_fragments_n = conf.compute_warp.N ÷ conf.compute_op_shape.N + + num_fragments_m * num_fragments_n < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!")) + end + + if kernel == Kernel.matmul_pipelined + num_fragments_m = conf.compute_warp.M ÷ conf.compute_op_shape.M + num_fragments_n = conf.compute_warp.N ÷ conf.compute_op_shape.N + + a_frag_i = (conf.block_shape.M * conf.block_shape.K) ÷ (conf.mem_a_warp.M * conf.mem_a_warp.K * conf.warps_per_block) + a_frag_j = (conf.mem_a_warp.M * conf.mem_a_warp.K) ÷ (conf.mem_a_thread.M * conf.mem_a_thread.K * 32) + b_frag_i = (conf.block_shape.K * conf.block_shape.N) ÷ (conf.mem_b_warp.K * conf.mem_b_warp.N * conf.warps_per_block) + b_frag_j = (conf.mem_b_warp.K * conf.mem_b_warp.N) ÷ (conf.mem_b_thread.K * conf.mem_b_thread.N * 32) + + num_fragments_m * num_fragments_n < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!")) + a_frag_i * a_frag_j < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!")) + b_frag_i * b_frag_j < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!")) + 2 * num_fragments_m < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!")) + 2 * num_fragments_n < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!")) + end + hostkernel = @cuda launch=false kernel(args...) attributes(hostkernel.fun)[CUDA.FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES] = shmem diff --git a/test/Project.toml b/test/Project.toml index 8828b9af..8d1be0aa 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,5 +5,6 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" XUnit = "3e3c03f2-1a94-11e9-2981-050a4ca824ab" diff --git a/tuning/Project.toml b/tuning/Project.toml new file mode 100644 index 00000000..8fab9e13 --- /dev/null +++ b/tuning/Project.toml @@ -0,0 +1,14 @@ +[deps] +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" +Hyperopt = "93e5fe13-2215-51db-baaf-2e9a34fb2712" +JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36" +Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" +PythonPlot = "274fc56d-3b97-40fa-a1cd-1b4a50311bf9" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" diff --git a/tuning/tune-wmma.jl b/tuning/tune-wmma.jl new file mode 100644 index 00000000..c9420e87 --- /dev/null +++ b/tuning/tune-wmma.jl @@ -0,0 +1,522 @@ +using CUDA, GemmKernels +using DataFrames +using DataStructures +using Dates +using Logging +using LoggingExtras +using Plots +using ProgressMeter +using Serialization +using Statistics +using StatsBase + +pythonplot() + +####### + +const N_vals = 2 .^ (7:14) + +# Stop sampling when normalised 95p CI is smaller than this... +const BENCH_NORM_CI_THRESHOLD = 0.01 + +# ... or we have exceeded the time limit... +const BENCH_MAX_NUM_SECONDS = 5 + +# ... but have at least 10 samples. +const BENCH_MIN_NUM_SAMPLES = 10 + +##### + +# Stop gathering samples for plot if the "error bars" are smaller than this... +const PLOT_RATIO_MAX_UNCERTAINTY = 0.05 + +# ... or we have exceeded the time limit... +# In my experience, only N <= 2^9 requires more than a handful of samples. +# That's only 3*4 configurations, so a limit of say 600 seconds will take ~2 hours. +const PLOT_MAX_NUM_SECONDS = 600 + +# ... but have at least 10 samples. +const PLOT_MIN_NUM_SAMPLES = 10 + +const AB_type = Float16 +const CD_type = Float32 + +const zero_c = true + +const OP_M, OP_N, OP_K = 16, 16, 16 + +####### + +# Reuse inputs across iterations. +c_ref = nothing +a = nothing +b = nothing +c = nothing +d = nothing +input_transpose_a = nothing +input_transpose_b = nothing +input_N = nothing + +include("../configs/configs.jl") + +# Write logging messages to file for persistence. +timestamp_logger(logger) = TransformerLogger(logger) do log + merge(log, (; message = "$(Dates.format(now(), "yyyy-mm-dd HH:MM:SS")) $(log.message)")) +end +FileLogger("tuning/tuning.log"; append=true) |> timestamp_logger |> (x -> MinLevelLogger(x, Logging.Info)) |> global_logger + +function kernel_string_to_function(str) + Dict( + "singlestage" => Kernel.matmul_singlestage, + "pipelined" => Kernel.matmul_pipelined + )[str] +end + +get_label(transpose_a, transpose_b) = "$(transpose_a ? "T" : "N")$(transpose_b ? "T" : "N")" + +function generate_configs() + all_configs = DataFrame( + transpose_a=Bool[], + transpose_b=Bool[], + N=Int[], + BLOCK_M=Int[], + BLOCK_N=Int[], + BLOCK_K=Int[], + WARPS_M=Int[], + WARPS_N=Int[], + kernel_str=String[], + category=String[], + times=Vector{Any}[] + ) + + for transpose_a in [false, true], + transpose_b in [false, true], + N in N_vals, + BLOCK_M in 2 .^ (6:9), + BLOCK_N in 2 .^ (6:9), + BLOCK_K in 2 .^ (5:7), + WARPS_M in 2 .^ (0:3), + WARPS_N in 2 .^ (0:3), + kernel_str in ["singlestage", "pipelined"] + + push!(all_configs, Dict( + :transpose_a => transpose_a, + :transpose_b => transpose_b, + :N => N, + :BLOCK_M => BLOCK_M, + :BLOCK_N => BLOCK_N, + :BLOCK_K => BLOCK_K, + :WARPS_M => WARPS_M, + :WARPS_N => WARPS_N, + :kernel_str => kernel_str, + :category => "unknown", + :times => [], + )) + end + + all_configs +end + +function get_config(row) + transpose_a = row["transpose_a"] + transpose_b = row["transpose_b"] + M = N = K = row["N"] + BLOCK_M = row["BLOCK_M"] + BLOCK_N = row["BLOCK_N"] + BLOCK_K = row["BLOCK_K"] + WARPS_M = row["WARPS_M"] + WARPS_N = row["WARPS_N"] + kernel = kernel_string_to_function(row["kernel_str"]) + + @get_wmma_config +end + +function generate_inputs_if_needed(row) + global input_transpose_a, input_transpose_b, input_N, c_ref, a, b, c, d + + cf = get_config(row) + + if (input_transpose_a, input_transpose_b, input_N) != (row.transpose_a, row.transpose_b, row.N) + c_ref, a, b, c, d = generate_inputs(cf) + input_transpose_a, input_transpose_b, input_N = row.transpose_a, row.transpose_b, row.N + end +end + +function get_inputs_for_plot(input_dict, row) + if row.N ∉ keys(input_dict) + cf = get_config(row) + _, a, b, c, d = generate_inputs(cf) + input_dict[row.N] = (a, b, c, d) + end + + return input_dict[row.N] +end + +function measure_config(row) + cf = get_config(row) + + generate_inputs_if_needed(row) + + d .= 0 + + try + run_gemm(cf, a, b, c, d) + catch err + if isa(err, GemmKernels.ConfigError) + @info "Skipping configuration $(NamedTuple(row))" * sprint(Base.showerror, err) + return [Inf], "unsupported_config_post_run" + end + + if isa(err, CuError) + @error "Configuration failed: $(NamedTuple(row))" * sprint(Base.showerror, err) + rethrow() + end + + @info "Skipping configuration: $(NamedTuple(row))" * sprint(Base.showerror, err) + return [Inf], "error" + end + + if !verify(cf, c_ref, d) + @warn "Configuration produced invalid result: $(NamedTuple(row))" + + expected = c_ref + actual = d + + mad, index = findmax(abs.(expected - actual)) + + @warn "Maximum absolute deviation is $(mad) at index $(index)." + + return [Inf], "invalid_result" + end + + times = Float64[] + + # Use CUDA.@elapsed instead of CUDA.@profile, because the latter is slower. + device_synchronize() + GC.gc(true) + + start_time = Dates.now() + + while true + synchronize(stream()) + time = CUDA.@elapsed run_gemm(cf, a, b, c, d) + push!(times, time) + + if length(times) >= BENCH_MIN_NUM_SAMPLES + (Dates.now() - start_time > Second(BENCH_MAX_NUM_SECONDS)) && break + (confidence_interval_95(times) / median(times) < BENCH_NORM_CI_THRESHOLD) && break + end + end + + return times, "success" +end + +confidence_interval_95(times) = 1.58 * iqr(times) / sqrt(length(times)) + +function prettytime(times) + min, q1, med, q3, max = nquantile(times, 4) + ci_95 = confidence_interval_95(times) + + # timescale + scale, unit = if med < 1e3 + 1, "ns" + elseif med < 1e6 + 1e3, "μs" + elseif med < 1e9 + 1e6, "ms" + else + 1e9, "s" + end + + rnd_min, rnd_q1, rnd_med, rnd_q3, rnd_max, rnd_ci_95 = round.([min, q1, med, q3, max, ci_95] ./ scale; sigdigits=3) + rnd_rel_ci_95 = round(100 * ci_95 / med; sigdigits=3) + + return "$rnd_med $unit ± $rnd_ci_95 $unit ($rnd_rel_ci_95%) (length: $(length(times)), 5-num summary: $rnd_min, $rnd_q1, $rnd_med, $rnd_q3, $rnd_max $unit)" +end + +perf_ratio(gemmkernels, baseline) = percentile(baseline, 0) / percentile(gemmkernels, 0) +perf_ratio_lo(gemmkernels, baseline) = percentile(baseline, 0) / percentile(gemmkernels, 75) +perf_ratio_hi(gemmkernels, baseline) = percentile(baseline, 75) / percentile(gemmkernels, 0) + +function get_uncertainty(gk, bl) + lo, mid, hi = (perf_ratio_lo(gk, bl), perf_ratio(gk, bl), perf_ratio_hi(gk, bl)) + + hi_uncertainty = abs(hi - mid) / mid + lo_uncertainty = abs(lo - mid) / mid + uncertainty = max(hi_uncertainty, lo_uncertainty) + + uncertainty, lo_uncertainty, hi_uncertainty +end + +function got_enough_samples(row) + gk, bl = row["gemmkernels_times"], row["baseline_times"] + + (length(gk) < PLOT_MIN_NUM_SAMPLES) && return false + (length(bl) < PLOT_MIN_NUM_SAMPLES) && return false + + (row["time_spent"] >= PLOT_MAX_NUM_SECONDS) && return true + + uncertainty, _, _ = get_uncertainty(gk, bl) + + uncertainty < PLOT_RATIO_MAX_UNCERTAINTY +end + +function benchmark_best_configs(configs) + best_configs = DataFrame( + transpose_a=Bool[], + transpose_b=Bool[], + N=Int[], + BLOCK_M=Int[], + BLOCK_N=Int[], + BLOCK_K=Int[], + WARPS_M=Int[], + WARPS_N=Int[], + kernel_str=String[], + category=String[], + uncertainty=Float64[], + time_spent=Float64[], + gemmkernels_times=Vector{Any}[], + baseline_times=Vector{Any}[] + ) + + for transpose_a = [false, true], + transpose_b = [false, true], + N = N_vals + + relevant_configs = configs[(@. (configs[!, "transpose_a"] == transpose_a) & (configs[!, "transpose_b"] == transpose_b) & (configs[!, "N"] == N)), :] + _, best_config_index = findmin(minimum.(relevant_configs[!, "times"], init=Inf)) + best_config = relevant_configs[best_config_index, :] + + push!(best_configs, Dict( + :transpose_a => transpose_a, + :transpose_b => transpose_b, + :N => N, + :BLOCK_M => best_config["BLOCK_M"], + :BLOCK_N => best_config["BLOCK_N"], + :BLOCK_K => best_config["BLOCK_K"], + :WARPS_M => best_config["WARPS_M"], + :WARPS_N => best_config["WARPS_N"], + :kernel_str => best_config["kernel_str"], + :category => "todo", + :uncertainty => Inf, + :time_spent => 0.0, + :gemmkernels_times => [], + :baseline_times => [], + )) + end + + # We will reuse matrix inputs across iterations. This takes about 4 GB of GPU memory for e.g. all matrix sizes for NN. + # Group runs of the same transposition together, so we don't have to keep 4 * 4 GB of inputs in memory. + for transpose_a in [false, true], + transpose_b in [false, true] + + input_dict = Dict() + + p = ProgressUnknown(desc="Benchmarking (highest uncertainty)", dt=1.0) + + # Spread the samples of one configuration over time, to reduce the effect + # of time-related noise. Note that this means that the progress bar may + # make big jumps. + while true + (sum(@. (best_configs[!, "category"] == "todo") & (best_configs[!, "transpose_a"] == transpose_a) & (best_configs[!, "transpose_b"] == transpose_b)) == 0) && break + + for config_row in eachrow(best_configs) + if (config_row.category, config_row.transpose_a, config_row.transpose_b) != ("todo", transpose_a, transpose_b) + continue + end + + a, b, c, d = get_inputs_for_plot(input_dict, config_row) + cf = get_config(config_row) + + @info "Profiling configuration $(NamedTuple(config_row))..." + + start_time = Dates.now() + + prof = CUDA.@profile run_gemm(cf, a, b, c, d) + push!(config_row["gemmkernels_times"], sum(prof.device[!, "stop"] - prof.device[!, "start"])) + + prof = CUDA.@profile run_baseline(cf, a, b, c, d) + push!(config_row["baseline_times"], sum(prof.device[!, "stop"] - prof.device[!, "start"])) + + config_row["time_spent"] += (Dates.now() - start_time) / Second(1) + old_uncertainty = config_row["uncertainty"] + config_row["uncertainty"], _, _ = get_uncertainty(config_row["gemmkernels_times"], config_row["baseline_times"]) + + if got_enough_samples(config_row) + config_row["category"] = "done" + end + + # Update progress bar. + highest_uncertainty = best_configs[(@. (best_configs[!, "transpose_a"] == transpose_a) & (best_configs[!, "transpose_b"] == transpose_b)), :] + highest_uncertainty = maximum(highest_uncertainty[!, "uncertainty"]) + next!(p; showvalues = [ + (:transpose_a, transpose_a), + (:transpose_b, transpose_b), + (:N, config_row["N"]), + (:num_samples, length(config_row["gemmkernels_times"])), + (:uncertainty, "$(config_row["uncertainty"]) (Δ = $(config_row["uncertainty"] - old_uncertainty))"), + (:time_spent_in_config, config_row["time_spent"]), + (:highest_uncertainty, highest_uncertainty), + (:remaining_N, best_configs[(@. (best_configs[!, "category"] == "todo") & (best_configs[!, "transpose_a"] == transpose_a) & (best_configs[!, "transpose_b"] == transpose_b)), :].N), + (:remaining_configurations, sum(best_configs[!, "category"] .== "todo")) + ]) + end + end + end + + best_configs +end + +function plot_results(best_configs) + markershapes = Dict( + "NN" => :circle, + "NT" => :dtriangle, + "TN" => :diamond, + "TT" => :cross + ) + + p = plot() + title!("$AB_type x $AB_type = $CD_type ($(CUDA.name(CUDA.device())))") + xlabel!("Matrix size [-]") + ylabel!("Performance relative to cuBLAS [%]") + + for transpose_a in [false, true], + transpose_b in [false, true] + + label = get_label(transpose_a, transpose_b) + + relevant_configs = best_configs[(@. (best_configs[!, "transpose_a"] == transpose_a) & (best_configs[!, "transpose_b"] == transpose_b)), :] + + ratios = @. 100 * perf_ratio(relevant_configs.gemmkernels_times, relevant_configs.baseline_times) + ratios_lo = @. 100 * perf_ratio_lo(relevant_configs.gemmkernels_times, relevant_configs.baseline_times) + ratios_hi = @. 100 * perf_ratio_hi(relevant_configs.gemmkernels_times, relevant_configs.baseline_times) + + plot!(p, relevant_configs.N, ratios, ribbon=(ratios .- ratios_lo, ratios_hi .- ratios), label=label, markershape=markershapes[label], xscale=:log2) + end + + savefig(p, "tuning/plot.pdf") +end + +function main() + @info "Starting WMMA tuning script..." + + configs = nothing + + if !isfile("tuning/configs.bin") + # (1) Generate configurations. + @info "Generating configurations..." + configs = generate_configs() + @info "Generated $(size(configs, 1)) configurations." + + # (2) Filter configurations where we can determine upfront that they are unsupported. + @info "Filtering configurations that we know are unsupported a-priori..." + + for config_row in eachrow(configs) + try + cf = get_config(config_row) + catch err + if isa(err, GemmKernels.ConfigError) + config_row["category"] = "unsupported_config_pre_run" + else + rethrow() + end + end + end + + @info "Filtered $(counter(configs[!, "category"])["unsupported_config_pre_run"]) configurations." + + open("tuning/configs.bin", "w") do io + serialize(io, configs) + end + end + + @info "Loading configurations from disk..." + configs = open("tuning/configs.bin", "r") do io + deserialize(io) + end + @info "Loaded $(size(configs, 1)) configurations." + + # (3) Measure performance of configurations. + num_unknown = counter(configs[!, "category"])["unknown"] + p = Progress(num_unknown; desc="Parameter sweep", dt=1.0, showspeed=true) + + @info "Need to perform parameter sweep over $(num_unknown) configurations." + + # Generate inputs for the first configuration. This is not strictly + # speaking necessary, but doing this outside of the loop means that the + # first iteration will not be excessively slow, which improves the "ETA" + # estimate. + first_unknown_config = findfirst(configs[!, "category"] .== "unknown") + !isnothing(first_unknown_config) && generate_inputs_if_needed(configs[first_unknown_config, :]) + + for config_row in eachrow(configs) + start_time = Dates.now() + + if config_row.category != "unknown" + continue + end + + config_row.category = "crashed" + + # Save results in case the process crashes. + open("tuning/configs.bin", "w") do io + serialize(io, configs) + end + + @info "Measuring configuration $(NamedTuple(config_row))..." + + times, category = measure_config(config_row) + + @info "Result for $(NamedTuple(config_row)): $(category) -- $(prettytime(times .* 1e9))" + + config_row.category = category + config_row.times = times + + counter_dict_abs = Dict(counter(configs[!, "category"])) + counter_dict_rel = Dict(k => "$(round(100 * v / sum(values(counter_dict_abs)); sigdigits=3))%" for (k, v) in counter_dict_abs) + + next!(p; showvalues=[ + (:N, config_row.N), + (:transpose, get_label(config_row.transpose_a, config_row.transpose_b)), + (:block_shape, (config_row.BLOCK_M, config_row.BLOCK_N, config_row.BLOCK_K)), + (:num_warps, (config_row.WARPS_M, config_row.WARPS_N)), + (:kernel, config_row.kernel_str), + (:counters, counter_dict_abs), + (:counters_relative, counter_dict_rel), + (:last_result, "$(category) -- $(prettytime(times .* 1e9))"), + (:last_iteration_time, Dates.now() - start_time) + ]) + end + + # Save data for final iteration. + open("tuning/configs.bin", "w") do io + serialize(io, configs) + end + + # And load again, for good measure. + configs = open("tuning/configs.bin", "r") do io + deserialize(io) + end + + # (4) Select best configurations, and benchmark. + if !isfile("tuning/best-configs.bin") + @info "Benchmarking configurations for plot..." + best_configs = benchmark_best_configs(configs) + + open("tuning/best-configs.bin", "w") do io + serialize(io, best_configs) + end + end + + @info "Loading best configurations from disk..." + best_configs = open("tuning/best-configs.bin", "r") do io + deserialize(io) + end + + # (5) Plotting results + @info "Plotting results..." + plot_results(best_configs) +end + + +isinteractive() || main() diff --git a/tuning/tune-wmma.sh b/tuning/tune-wmma.sh new file mode 100755 index 00000000..a97b779c --- /dev/null +++ b/tuning/tune-wmma.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +set -Eeuo pipefail + +cd "$( dirname "${BASH_SOURCE[0]}" )" + +rm -f configs.bson +rm -f tuning.log + +cd .. + +until julia --project -e ' + println("--- :julia: Instantiating project") + using Pkg + Pkg.instantiate() + Pkg.activate("tuning") + Pkg.instantiate() + push!(LOAD_PATH, @__DIR__) + + println("+++ :julia: Tuning") + include("tuning/tune-wmma.jl")'; do + + echo "Tuning script crashed. Resuming after 1 second..." >&2 + sleep 1 +done From 51aba69c996b3f3c3b89665748673e903ad1fb32 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Mon, 4 Dec 2023 10:12:27 +0100 Subject: [PATCH 02/21] Small tweaks --- tuning/tune-wmma.jl | 2 +- tuning/tune-wmma.sh | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/tuning/tune-wmma.jl b/tuning/tune-wmma.jl index c9420e87..8db53efc 100644 --- a/tuning/tune-wmma.jl +++ b/tuning/tune-wmma.jl @@ -376,7 +376,7 @@ function plot_results(best_configs) ) p = plot() - title!("$AB_type x $AB_type = $CD_type ($(CUDA.name(CUDA.device())))") + title!("$AB_type x $AB_type = $CD_type ($(name(device())))") xlabel!("Matrix size [-]") ylabel!("Performance relative to cuBLAS [%]") diff --git a/tuning/tune-wmma.sh b/tuning/tune-wmma.sh index a97b779c..414f8544 100755 --- a/tuning/tune-wmma.sh +++ b/tuning/tune-wmma.sh @@ -3,9 +3,6 @@ set -Eeuo pipefail cd "$( dirname "${BASH_SOURCE[0]}" )" -rm -f configs.bson -rm -f tuning.log - cd .. until julia --project -e ' From 7835c1c85cb8bbb90f3478a5271454d58a7ccb56 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Mon, 4 Dec 2023 10:13:15 +0100 Subject: [PATCH 03/21] Apply suggestions from code review Co-authored-by: Tim Besard --- tuning/tune-wmma.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tuning/tune-wmma.jl b/tuning/tune-wmma.jl index 8db53efc..b5a17686 100644 --- a/tuning/tune-wmma.jl +++ b/tuning/tune-wmma.jl @@ -394,11 +394,11 @@ function plot_results(best_configs) plot!(p, relevant_configs.N, ratios, ribbon=(ratios .- ratios_lo, ratios_hi .- ratios), label=label, markershape=markershapes[label], xscale=:log2) end - savefig(p, "tuning/plot.pdf") + savefig(p, "tuning/$(name(device())).pdf") end function main() - @info "Starting WMMA tuning script..." + @info "Starting WMMA tuning script for device $(name(device()))..." configs = nothing From f92717a3c4159664be2673f5149502db4727a351 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Mon, 4 Dec 2023 13:53:12 +0100 Subject: [PATCH 04/21] Don't re-instantiate. --- tuning/tune-wmma.sh | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tuning/tune-wmma.sh b/tuning/tune-wmma.sh index 414f8544..5c6f76c6 100755 --- a/tuning/tune-wmma.sh +++ b/tuning/tune-wmma.sh @@ -5,15 +5,13 @@ cd "$( dirname "${BASH_SOURCE[0]}" )" cd .. -until julia --project -e ' - println("--- :julia: Instantiating project") - using Pkg - Pkg.instantiate() - Pkg.activate("tuning") - Pkg.instantiate() - push!(LOAD_PATH, @__DIR__) +echo "+++ :julia: Instantiating project" +julia --project -e 'using Pkg; Pkg.instantiate()' +julia --project=tuning -e 'using Pkg; Pkg.instantiate()' - println("+++ :julia: Tuning") +echo "+++ :julia: Tuning" +until julia --project=tuning -e ' + push!(LOAD_PATH, @__DIR__) include("tuning/tune-wmma.jl")'; do echo "Tuning script crashed. Resuming after 1 second..." >&2 From 6882b8b52d607fbd158eb55c6ba5384a91d5e650 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Mon, 4 Dec 2023 13:53:21 +0100 Subject: [PATCH 05/21] Simplify log output path. --- tuning/tune-wmma.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tuning/tune-wmma.jl b/tuning/tune-wmma.jl index b5a17686..1d9fc20f 100644 --- a/tuning/tune-wmma.jl +++ b/tuning/tune-wmma.jl @@ -63,7 +63,7 @@ include("../configs/configs.jl") timestamp_logger(logger) = TransformerLogger(logger) do log merge(log, (; message = "$(Dates.format(now(), "yyyy-mm-dd HH:MM:SS")) $(log.message)")) end -FileLogger("tuning/tuning.log"; append=true) |> timestamp_logger |> (x -> MinLevelLogger(x, Logging.Info)) |> global_logger +FileLogger(joinpath(@__DIR__, "tuning.log"); append=true) |> timestamp_logger |> (x -> MinLevelLogger(x, Logging.Info)) |> global_logger function kernel_string_to_function(str) Dict( From 8d7843d3faa2d2f270801e19bfde7ee5780ac765 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Mon, 4 Dec 2023 14:33:37 +0100 Subject: [PATCH 06/21] Distribute tuning across multiple processes. --- tuning/tune-wmma.jl | 116 ++++++++++++++++++++++++++++---------------- tuning/tune-wmma.sh | 9 ++-- 2 files changed, 79 insertions(+), 46 deletions(-) diff --git a/tuning/tune-wmma.jl b/tuning/tune-wmma.jl index 1d9fc20f..cf0a5739 100644 --- a/tuning/tune-wmma.jl +++ b/tuning/tune-wmma.jl @@ -2,15 +2,19 @@ using CUDA, GemmKernels using DataFrames using DataStructures using Dates +using Distributed +using FileWatching.Pidfile using Logging using LoggingExtras -using Plots using ProgressMeter using Serialization using Statistics using StatsBase -pythonplot() +if myid() == 1 + using Plots + pythonplot() +end ####### @@ -63,7 +67,14 @@ include("../configs/configs.jl") timestamp_logger(logger) = TransformerLogger(logger) do log merge(log, (; message = "$(Dates.format(now(), "yyyy-mm-dd HH:MM:SS")) $(log.message)")) end -FileLogger(joinpath(@__DIR__, "tuning.log"); append=true) |> timestamp_logger |> (x -> MinLevelLogger(x, Logging.Info)) |> global_logger +function log_filename() + path = joinpath(@__DIR__, "tuning.log") + if myid() != 1 + path = "$(path).$(myid())" + end + path +end +FileLogger(log_filename(); append=true) |> timestamp_logger |> (x -> MinLevelLogger(x, Logging.Info)) |> global_logger function kernel_string_to_function(str) Dict( @@ -195,16 +206,18 @@ function measure_config(row) device_synchronize() GC.gc(true) - start_time = Dates.now() + mkpidlock(joinpath(@__DIR__, "tuning.pid")) do + start_time = Dates.now() - while true - synchronize(stream()) - time = CUDA.@elapsed run_gemm(cf, a, b, c, d) - push!(times, time) + while true + synchronize(stream()) + time = CUDA.@elapsed run_gemm(cf, a, b, c, d) + push!(times, time) - if length(times) >= BENCH_MIN_NUM_SAMPLES - (Dates.now() - start_time > Second(BENCH_MAX_NUM_SECONDS)) && break - (confidence_interval_95(times) / median(times) < BENCH_NORM_CI_THRESHOLD) && break + if length(times) >= BENCH_MIN_NUM_SAMPLES + (Dates.now() - start_time > Second(BENCH_MAX_NUM_SECONDS)) && break + (confidence_interval_95(times) / median(times) < BENCH_NORM_CI_THRESHOLD) && break + end end end @@ -398,7 +411,7 @@ function plot_results(best_configs) end function main() - @info "Starting WMMA tuning script for device $(name(device()))..." + @info "Starting WMMA tuning script for device $(name(device())) using $(nworkers()) workers..." configs = nothing @@ -449,43 +462,61 @@ function main() first_unknown_config = findfirst(configs[!, "category"] .== "unknown") !isnothing(first_unknown_config) && generate_inputs_if_needed(configs[first_unknown_config, :]) - for config_row in eachrow(configs) - start_time = Dates.now() - - if config_row.category != "unknown" - continue + channel = RemoteChannel(() -> Channel(), 1) + @sync begin + @async begin + while true + values = take!(channel) + values === nothing && break + next!(p; showvalues=values) + end end - config_row.category = "crashed" + @async begin + @info "Starting parameter sweep..." + pmap(eachrow(configs)) do config_row + @info "Got configuration $(NamedTuple(config_row))..." + start_time = Dates.now() - # Save results in case the process crashes. - open("tuning/configs.bin", "w") do io - serialize(io, configs) - end + if config_row.category != "unknown" + return config_row + end + + config_row.category = "crashed" - @info "Measuring configuration $(NamedTuple(config_row))..." + # Save results in case the process crashes. + open("tuning/configs.bin", "w") do io + serialize(io, configs) + end + + @info "Measuring configuration $(NamedTuple(config_row))..." - times, category = measure_config(config_row) + times, category = measure_config(config_row) - @info "Result for $(NamedTuple(config_row)): $(category) -- $(prettytime(times .* 1e9))" + @info "Result for $(NamedTuple(config_row)): $(category) -- $(prettytime(times .* 1e9))" - config_row.category = category - config_row.times = times + config_row.category = category + config_row.times = times - counter_dict_abs = Dict(counter(configs[!, "category"])) - counter_dict_rel = Dict(k => "$(round(100 * v / sum(values(counter_dict_abs)); sigdigits=3))%" for (k, v) in counter_dict_abs) + counter_dict_abs = Dict(counter(configs[!, "category"])) + counter_dict_rel = Dict(k => "$(round(100 * v / sum(values(counter_dict_abs)); sigdigits=3))%" for (k, v) in counter_dict_abs) - next!(p; showvalues=[ - (:N, config_row.N), - (:transpose, get_label(config_row.transpose_a, config_row.transpose_b)), - (:block_shape, (config_row.BLOCK_M, config_row.BLOCK_N, config_row.BLOCK_K)), - (:num_warps, (config_row.WARPS_M, config_row.WARPS_N)), - (:kernel, config_row.kernel_str), - (:counters, counter_dict_abs), - (:counters_relative, counter_dict_rel), - (:last_result, "$(category) -- $(prettytime(times .* 1e9))"), - (:last_iteration_time, Dates.now() - start_time) - ]) + put!(channel, [ + (:N, config_row.N), + (:transpose, get_label(config_row.transpose_a, config_row.transpose_b)), + (:block_shape, (config_row.BLOCK_M, config_row.BLOCK_N, config_row.BLOCK_K)), + (:num_warps, (config_row.WARPS_M, config_row.WARPS_N)), + (:kernel, config_row.kernel_str), + (:counters, counter_dict_abs), + (:counters_relative, counter_dict_rel), + (:last_result, "$(category) -- $(prettytime(times .* 1e9))"), + (:last_iteration_time, Dates.now() - start_time) + ]) + + return config_row + end + put!(channel, nothing) + end end # Save data for final iteration. @@ -518,5 +549,6 @@ function main() plot_results(best_configs) end - -isinteractive() || main() +if !isinteractive() && myid() == 1 + main() +end diff --git a/tuning/tune-wmma.sh b/tuning/tune-wmma.sh index 5c6f76c6..281837f8 100755 --- a/tuning/tune-wmma.sh +++ b/tuning/tune-wmma.sh @@ -6,13 +6,14 @@ cd "$( dirname "${BASH_SOURCE[0]}" )" cd .. echo "+++ :julia: Instantiating project" -julia --project -e 'using Pkg; Pkg.instantiate()' -julia --project=tuning -e 'using Pkg; Pkg.instantiate()' +julia --project -e 'using Pkg; Pkg.instantiate(); Pkg.precompile()' +julia --project=tuning -e 'using Pkg; Pkg.instantiate(); Pkg.precompile()' echo "+++ :julia: Tuning" until julia --project=tuning -e ' - push!(LOAD_PATH, @__DIR__) - include("tuning/tune-wmma.jl")'; do + using Distributed + @everywhere push!(LOAD_PATH, @__DIR__) + @everywhere include("tuning/tune-wmma.jl")' "$@"; do echo "Tuning script crashed. Resuming after 1 second..." >&2 sleep 1 From bad0c5692f29d444f60faf837c1b964904fd5f79 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Mon, 4 Dec 2023 15:38:43 +0100 Subject: [PATCH 07/21] Fix/improve distributed execution. --- tuning/tune-wmma.jl | 120 +++++++++++++++++++++++++------------------- 1 file changed, 69 insertions(+), 51 deletions(-) diff --git a/tuning/tune-wmma.jl b/tuning/tune-wmma.jl index cf0a5739..4f7a0cb9 100644 --- a/tuning/tune-wmma.jl +++ b/tuning/tune-wmma.jl @@ -173,17 +173,20 @@ function measure_config(row) try run_gemm(cf, a, b, c, d) catch err + bt = catch_backtrace() + log = sprint(Base.showerror, err) * sprint(Base.show_backtrace, bt) + if isa(err, GemmKernels.ConfigError) - @info "Skipping configuration $(NamedTuple(row))" * sprint(Base.showerror, err) + @info "Skipping configuration $(NamedTuple(row))" * log return [Inf], "unsupported_config_post_run" end if isa(err, CuError) - @error "Configuration failed: $(NamedTuple(row))" * sprint(Base.showerror, err) + @error "Configuration failed: $(NamedTuple(row))" * log rethrow() end - @info "Skipping configuration: $(NamedTuple(row))" * sprint(Base.showerror, err) + @info "Skipping configuration: $(NamedTuple(row))" * log return [Inf], "error" end @@ -412,10 +415,24 @@ end function main() @info "Starting WMMA tuning script for device $(name(device())) using $(nworkers()) workers..." + config_path = joinpath(@__DIR__, "configs.bin") configs = nothing - if !isfile("tuning/configs.bin") + if isfile(config_path) + @info "Loading configurations from disk..." + try + configs = open(config_path, "r") do io + deserialize(io) + end + @info "Loaded $(size(configs, 1)) configurations." + catch err + @error "Error while loading configurations from disk: $(sprint(Base.showerror, err)))" + mv(config_path, "$(config_path).broken") + end + end + + if configs === nothing # (1) Generate configurations. @info "Generating configurations..." configs = generate_configs() @@ -443,80 +460,81 @@ function main() end end - @info "Loading configurations from disk..." - configs = open("tuning/configs.bin", "r") do io - deserialize(io) - end - @info "Loaded $(size(configs, 1)) configurations." - # (3) Measure performance of configurations. num_unknown = counter(configs[!, "category"])["unknown"] p = Progress(num_unknown; desc="Parameter sweep", dt=1.0, showspeed=true) @info "Need to perform parameter sweep over $(num_unknown) configurations." - # Generate inputs for the first configuration. This is not strictly - # speaking necessary, but doing this outside of the loop means that the - # first iteration will not be excessively slow, which improves the "ETA" - # estimate. - first_unknown_config = findfirst(configs[!, "category"] .== "unknown") - !isnothing(first_unknown_config) && generate_inputs_if_needed(configs[first_unknown_config, :]) - channel = RemoteChannel(() -> Channel(), 1) @sync begin + # measure each configuration in parallel @async begin - while true - values = take!(channel) - values === nothing && break - next!(p; showvalues=values) - end - end - - @async begin - @info "Starting parameter sweep..." - pmap(eachrow(configs)) do config_row - @info "Got configuration $(NamedTuple(config_row))..." - start_time = Dates.now() + @sync @distributed for i in 1:size(configs,1) + config_row = configs[i, :] + @info "Got configuration $i: $(NamedTuple(config_row))..." if config_row.category != "unknown" - return config_row + continue end config_row.category = "crashed" - # Save results in case the process crashes. - open("tuning/configs.bin", "w") do io - serialize(io, configs) - end - @info "Measuring configuration $(NamedTuple(config_row))..." + start_time = Dates.now() times, category = measure_config(config_row) + end_time = Dates.now() @info "Result for $(NamedTuple(config_row)): $(category) -- $(prettytime(times .* 1e9))" config_row.category = category config_row.times = times - counter_dict_abs = Dict(counter(configs[!, "category"])) - counter_dict_rel = Dict(k => "$(round(100 * v / sum(values(counter_dict_abs)); sigdigits=3))%" for (k, v) in counter_dict_abs) - - put!(channel, [ - (:N, config_row.N), - (:transpose, get_label(config_row.transpose_a, config_row.transpose_b)), - (:block_shape, (config_row.BLOCK_M, config_row.BLOCK_N, config_row.BLOCK_K)), - (:num_warps, (config_row.WARPS_M, config_row.WARPS_N)), - (:kernel, config_row.kernel_str), - (:counters, counter_dict_abs), - (:counters_relative, counter_dict_rel), - (:last_result, "$(category) -- $(prettytime(times .* 1e9))"), - (:last_iteration_time, Dates.now() - start_time) - ]) - - return config_row + put!(channel, (i, start_time, end_time, config_row)) end + + @info "Done with parameter sweep." put!(channel, nothing) end + + # process the results + @async begin + while true + data = take!(channel) + data === nothing && break + + try + # Update configuration + i, start_time, end_time, config_row = data + configs[i, :] = config_row + + # Save results in case the process crashes. + open("tuning/configs.bin", "w") do io + serialize(io, configs) + end + + # Update progress bar + counter_dict_abs = Dict(counter(configs[!, "category"])) + counter_dict_rel = Dict(k => "$(round(100 * v / sum(values(counter_dict_abs)); sigdigits=3))%" for (k, v) in counter_dict_abs) + next!(p; showvalues=[ + (:N, config_row.N), + (:transpose, get_label(config_row.transpose_a, config_row.transpose_b)), + (:block_shape, (config_row.BLOCK_M, config_row.BLOCK_N, config_row.BLOCK_K)), + (:num_warps, (config_row.WARPS_M, config_row.WARPS_N)), + (:kernel, config_row.kernel_str), + (:counters, counter_dict_abs), + (:counters_relative, counter_dict_rel), + (:last_result, "$(config_row.category) -- $(prettytime(config_row.times .* 1e9))"), + (:last_iteration_time, end_time - start_time) + ]) + catch err + bt = catch_backtrace() + log = sprint(Base.showerror, err) * sprint(Base.show_backtrace, bt) + @error "Error while updating progress bar: $log" + end + end + end end # Save data for final iteration. From 425ab493d058e915584997b32c3fa5e6ca807f18 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Mon, 4 Dec 2023 16:19:51 +0100 Subject: [PATCH 08/21] Tweaks. --- tuning/tune-wmma.jl | 20 +++++++++----------- tuning/tune-wmma.sh | 1 + 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/tuning/tune-wmma.jl b/tuning/tune-wmma.jl index 4f7a0cb9..53348fd2 100644 --- a/tuning/tune-wmma.jl +++ b/tuning/tune-wmma.jl @@ -177,16 +177,16 @@ function measure_config(row) log = sprint(Base.showerror, err) * sprint(Base.show_backtrace, bt) if isa(err, GemmKernels.ConfigError) - @info "Skipping configuration $(NamedTuple(row))" * log + @info "Skipping configuration $(NamedTuple(row))\n" * log return [Inf], "unsupported_config_post_run" end if isa(err, CuError) - @error "Configuration failed: $(NamedTuple(row))" * log + @error "Configuration failed: $(NamedTuple(row))\n" * log rethrow() end - @info "Skipping configuration: $(NamedTuple(row))" * log + @info "Skipping configuration: $(NamedTuple(row))\n" * log return [Inf], "error" end @@ -486,12 +486,7 @@ function main() times, category = measure_config(config_row) end_time = Dates.now() - @info "Result for $(NamedTuple(config_row)): $(category) -- $(prettytime(times .* 1e9))" - - config_row.category = category - config_row.times = times - - put!(channel, (i, start_time, end_time, config_row)) + put!(channel, (i, start_time, end_time, category, times)) end @info "Done with parameter sweep." @@ -506,10 +501,13 @@ function main() try # Update configuration - i, start_time, end_time, config_row = data - configs[i, :] = config_row + i, start_time, end_time, category, times = data + config_row = configs[i, :] + @info "Result for $(NamedTuple(config_row)): $(category) -- $(prettytime(times .* 1e9))" # Save results in case the process crashes. + config_row.times = times + config_row.category = category open("tuning/configs.bin", "w") do io serialize(io, configs) end diff --git a/tuning/tune-wmma.sh b/tuning/tune-wmma.sh index 281837f8..dbf1cf12 100755 --- a/tuning/tune-wmma.sh +++ b/tuning/tune-wmma.sh @@ -11,6 +11,7 @@ julia --project=tuning -e 'using Pkg; Pkg.instantiate(); Pkg.precompile()' echo "+++ :julia: Tuning" until julia --project=tuning -e ' + ENV["JULIA_CUDA_HARD_MEMORY_LIMIT"] = "4GiB" using Distributed @everywhere push!(LOAD_PATH, @__DIR__) @everywhere include("tuning/tune-wmma.jl")' "$@"; do From 9a782d71db3aef5a7ba020bce34c8283ef369932 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Mon, 4 Dec 2023 20:46:45 +0100 Subject: [PATCH 09/21] More fine tuning. --- tuning/tune-wmma.jl | 48 +++++++++++++++++++++++++-------------------- tuning/tune-wmma.sh | 28 ++++++++++++++++++++++++-- 2 files changed, 53 insertions(+), 23 deletions(-) diff --git a/tuning/tune-wmma.jl b/tuning/tune-wmma.jl index 53348fd2..e710708c 100644 --- a/tuning/tune-wmma.jl +++ b/tuning/tune-wmma.jl @@ -148,6 +148,11 @@ function generate_inputs_if_needed(row) cf = get_config(row) if (input_transpose_a, input_transpose_b, input_N) != (row.transpose_a, row.transpose_b, row.N) + for x in [c_ref, a, b, c, d] + if x !== nothing + CUDA.unsafe_free!(x) + end + end c_ref, a, b, c, d = generate_inputs(cf) input_transpose_a, input_transpose_b, input_N = row.transpose_a, row.transpose_b, row.N end @@ -193,13 +198,6 @@ function measure_config(row) if !verify(cf, c_ref, d) @warn "Configuration produced invalid result: $(NamedTuple(row))" - expected = c_ref - actual = d - - mad, index = findmax(abs.(expected - actual)) - - @warn "Maximum absolute deviation is $(mad) at index $(index)." - return [Inf], "invalid_result" end @@ -410,15 +408,15 @@ function plot_results(best_configs) plot!(p, relevant_configs.N, ratios, ribbon=(ratios .- ratios_lo, ratios_hi .- ratios), label=label, markershape=markershapes[label], xscale=:log2) end - savefig(p, "tuning/$(name(device())).pdf") + savefig(p, joinpath(@__DIR__, "$(name(device())).pdf")) end function main() @info "Starting WMMA tuning script for device $(name(device())) using $(nworkers()) workers..." - config_path = joinpath(@__DIR__, "configs.bin") + # (0) Load configurations from disk, or generate them. + config_path = joinpath(@__DIR__, "configs.bin") configs = nothing - if isfile(config_path) @info "Loading configurations from disk..." try @@ -431,7 +429,6 @@ function main() mv(config_path, "$(config_path).broken") end end - if configs === nothing # (1) Generate configurations. @info "Generating configurations..." @@ -455,7 +452,7 @@ function main() @info "Filtered $(counter(configs[!, "category"])["unsupported_config_pre_run"]) configurations." - open("tuning/configs.bin", "w") do io + open(config_path, "w") do io serialize(io, configs) end end @@ -508,7 +505,7 @@ function main() # Save results in case the process crashes. config_row.times = times config_row.category = category - open("tuning/configs.bin", "w") do io + open(config_path, "w") do io serialize(io, configs) end @@ -536,29 +533,38 @@ function main() end # Save data for final iteration. - open("tuning/configs.bin", "w") do io + open(config_path, "w") do io serialize(io, configs) end # And load again, for good measure. - configs = open("tuning/configs.bin", "r") do io + configs = open(config_path, "r") do io deserialize(io) end # (4) Select best configurations, and benchmark. - if !isfile("tuning/best-configs.bin") + best_configs_path = joinpath(@__DIR__, "best-configs.bin") + best_configs = nothing + if isfile(best_configs_path) + try + @info "Loading best configurations from disk..." + best_configs = open(best_configs_path, "r") do io + deserialize(io) + end + catch err + @error "Error while loading best configurations from disk: $(sprint(Base.showerror, err)))" + mv(best_configs_path, "$(best_configs_path).broken") + end + end + if best_configs === nothing @info "Benchmarking configurations for plot..." best_configs = benchmark_best_configs(configs) - open("tuning/best-configs.bin", "w") do io + open(best_configs_path, "w") do io serialize(io, best_configs) end end - @info "Loading best configurations from disk..." - best_configs = open("tuning/best-configs.bin", "r") do io - deserialize(io) - end # (5) Plotting results @info "Plotting results..." diff --git a/tuning/tune-wmma.sh b/tuning/tune-wmma.sh index dbf1cf12..dac97e2a 100755 --- a/tuning/tune-wmma.sh +++ b/tuning/tune-wmma.sh @@ -9,9 +9,33 @@ echo "+++ :julia: Instantiating project" julia --project -e 'using Pkg; Pkg.instantiate(); Pkg.precompile()' julia --project=tuning -e 'using Pkg; Pkg.instantiate(); Pkg.precompile()' -echo "+++ :julia: Tuning" until julia --project=tuning -e ' - ENV["JULIA_CUDA_HARD_MEMORY_LIMIT"] = "4GiB" + using CUDA, Distributed + + # determine how many workers to use + memory_usage = 5*2^30 + cpu_memory = Sys.free_memory() + gpu_memory = CUDA.available_memory() + workers = min( + floor(Int, cpu_memory / memory_usage), + floor(Int, gpu_memory / memory_usage), + Sys.CPU_THREADS + ) + println("+++ :julia: Tuning using $workers workers") + + # launch workers + using Distributed + env = [ + "JULIA_NUM_THREADS" => "1", + "OPENBLAS_NUM_THREADS" => "1", + "JULIA_CUDA_HARD_MEMORY_LIMIT" => string(memory_usage), + ] + exeflags = [ + "--project=$(Base.active_project())", + "--heap-size-hint=$memory_usage" + ] + addprocs(workers; exeflags, env) + using Distributed @everywhere push!(LOAD_PATH, @__DIR__) @everywhere include("tuning/tune-wmma.jl")' "$@"; do From 54bcadb1934f96b500d2cb7d777e1a7b32ba42ad Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Wed, 6 Dec 2023 11:54:34 +0100 Subject: [PATCH 10/21] Lock GPU clock speeds during tuning --- tuning/tune-wmma.sh | 80 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 79 insertions(+), 1 deletion(-) diff --git a/tuning/tune-wmma.sh b/tuning/tune-wmma.sh index dac97e2a..e384b9e2 100755 --- a/tuning/tune-wmma.sh +++ b/tuning/tune-wmma.sh @@ -1,8 +1,82 @@ #!/usr/bin/env bash set -Eeuo pipefail -cd "$( dirname "${BASH_SOURCE[0]}" )" +GPU_CLOCK=915 +MEM_CLOCK=5001 + +usage() +{ + cat <&2 +Usage: $0 [OPTIONS] + +Tune WMMA Parameters. + +Options: +-h, --help Show this help. +-gc, --gpu-clock speed Change the frequency the GPU core clock is locked to + before benchmarking, in MHz (default 915 MHz). +-mc, --memory-clock speed Change the frequency the GPU memory clock is locked to + before benchmarking, in MHz (default 5001 MHz). +EOF +} + +positional=() +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + usage; exit 0 + ;; + -gc|--gpu-clock) + shift + GPU_CLOCK=$1 + shift + ;; + -mc|--memory-clock) + shift + MEM_CLOCK=$1 + shift + ;; + -*) + echo "Unknown command-line option '$1'." + echo "Try '$0 --help' for more information." + exit 1 + ;; + *) + positional+=("$1") + shift + ;; + esac +done +set -- "${positional[@]}" + +if [[ $# -ne 0 ]]; then + echo "Expected 0 positional arguments, but got $#." + echo "Try '$0 --help' for more information." + exit 1 +fi +echo "Locking GPU clock speeds to $GPU_CLOCK MHz (GPU) / $MEM_CLOCK MHz (Mem)..." + +if ! nvidia-smi --query-supported-clocks=graphics,memory --format=csv | grep -F "$GPU_CLOCK MHz, $MEM_CLOCK MHz"; then + echo "Unsupported combination of clock speeds!" + exit 1 +fi + +# Prompt for sudo +sudo -v &>/dev/null + +# Sudo keep-alive +while true; do + sleep 300 + sudo -n true + kill -0 "$$" || exit +done &> /dev/null & + +sudo nvidia-smi -pm 1 +sudo nvidia-smi --lock-gpu-clocks=$GPU_CLOCK +sudo nvidia-smi --lock-memory-clocks=$MEM_CLOCK + +cd "$( dirname "${BASH_SOURCE[0]}" )" cd .. echo "+++ :julia: Instantiating project" @@ -43,3 +117,7 @@ until julia --project=tuning -e ' echo "Tuning script crashed. Resuming after 1 second..." >&2 sleep 1 done + +echo "Unlocking GPU clock speeds..." +sudo nvidia-smi --reset-gpu-clocks +sudo nvidia-smi --reset-memory-clocks From e184391b18131682557677005540ececd0d27142 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Thu, 7 Dec 2023 14:12:45 +0100 Subject: [PATCH 11/21] Extend set of WMMA operator shapes --- configs/configs.jl | 34 ++++++++++++++---- src/config.jl | 6 +++- src/operator.jl | 90 +++++++++++++++++++++++++++------------------- test/matmul.jl | 15 ++++++-- 4 files changed, 97 insertions(+), 48 deletions(-) diff --git a/configs/configs.jl b/configs/configs.jl index c88b2d8d..1aa5226c 100644 --- a/configs/configs.jl +++ b/configs/configs.jl @@ -337,7 +337,7 @@ macro get_wmma_complex_config() conf = GemmKernels.get_config( gemm_shape = (M = M, N = N, K = K), - operator = Operator.WMMAComplexOp{OP_M, OP_N, OP_K}, + operator = Operator.WMMAComplexOp{OP_M, OP_N, OP_K, AB_type, CD_type}, global_a_layout = transpose_a ? Layout.InterleavedRowMajor{Float16} : Layout.InterleavedColMajor{Float16}, global_b_layout = transpose_b ? Layout.InterleavedRowMajor{Float16} : Layout.InterleavedColMajor{Float16}, @@ -391,7 +391,7 @@ macro get_wmma_dual_config() conf = GemmKernels.get_config( gemm_shape = (M = M, N = N, K = K), - operator = Operator.WMMADualOp{OP_M, OP_N, OP_K}, + operator = Operator.WMMADualOp{OP_M, OP_N, OP_K, AB_type, CD_type}, global_a_layout = Layout.InterleavedColMajor{Float16}, global_b_layout = Layout.InterleavedColMajor{Float16}, @@ -514,7 +514,11 @@ function get_configs() transpose_b = [false, true], (BLOCK_M, BLOCK_N, BLOCK_K) in [(128, 128, 64)], (WARPS_M, WARPS_N) in [(4, 2)], - (OP_M, OP_N, OP_K) in [(16, 16, 16)], + (OP_M, OP_N, OP_K) in [ + (16, 16, 16), + (8, 32, 16), + (32, 8, 16), + ], (M, N, K) in vcat(min_dimension .* [ [1, 1, 1], [2, 2, 1], @@ -530,7 +534,11 @@ function get_configs() (Float16, Float32, 128)], transpose_a = [false, true], transpose_b = [false, true], - (OP_M, OP_N, OP_K) in [(16, 16, 16)], + (OP_M, OP_N, OP_K) in [ + (16, 16, 16), + (8, 32, 16), + (32, 8, 16), + ], (M, N, K) in vcat(min_dimension .* [ [1, 1, 1], [2, 2, 2]], [[4096, 4096, 4096]]) @@ -541,7 +549,11 @@ function get_configs() for (AB_type, CD_type, min_dimension) in [ (Float16, Float32, 128)], transpose_b = [false, true], - (OP_M, OP_N, OP_K) in [(16, 16, 16)], + (OP_M, OP_N, OP_K) in [ + (16, 16, 16), + (8, 32, 16), + (32, 8, 16), + ], (M, N, K) in vcat(min_dimension .* [ [1, 1, 1], [2, 2, 2]], [[4096, 4096, 4096]]) @@ -553,7 +565,11 @@ function get_configs() for (AB_type, CD_type) in [(Float16, Float32)], transpose_a = [false, true], transpose_b = [false, true], - (OP_M, OP_N, OP_K) in [(16, 16, 16)], + (OP_M, OP_N, OP_K) in [ + (16, 16, 16), + (8, 32, 16), + (32, 8, 16), + ], (M, N, K) in [ (128, 128, 128), (256, 256, 256), @@ -566,7 +582,11 @@ function get_configs() for (AB_type, CD_type) in [(Float16, Float32)], transpose_a = [false], transpose_b = [false], - (OP_M, OP_N, OP_K) in [(16, 16, 16)], + (OP_M, OP_N, OP_K) in [ + (16, 16, 16), + (8, 32, 16), + (32, 8, 16), + ], (M, N, K) in [ (128, 128, 128), (256, 256, 256), diff --git a/src/config.jl b/src/config.jl index bb6c2bee..59cd3f75 100644 --- a/src/config.jl +++ b/src/config.jl @@ -119,7 +119,11 @@ end function check_wmma_shape(operator::Type) op_shape = Operator.shape(operator) - if op_shape ∉ [(M=16, N=16, K=16)] + if op_shape ∉ [ + (M=16, N=16, K=16), + (M=8, N=32, K=16), + (M=32, N=8, K=16), + ] throw(ConfigError("Unsupported WMMA Operator shape $(op_shape)!")) end end diff --git a/src/operator.jl b/src/operator.jl index 861999fd..295dbe3d 100644 --- a/src/operator.jl +++ b/src/operator.jl @@ -156,6 +156,26 @@ struct WMMAOp{M, N, K, CT, AT} end @inline shape(::Type{WMMAOp{M, N, K, CT, AT}}) where {M, N, K, CT, AT} = (M = M, N = N, K = K) +for (M, N, K) in [ + (16, 16, 16), + (8, 32, 16), + (32, 8, 16) + ], + (layout_type, wmma_layout_type) in [ + (Layout.ColMajor, WMMA.ColMajor), + (Layout.UnsafeAlignedColMajor, WMMA.ColMajor), + (Layout.RowMajor, WMMA.RowMajor), + (Layout.UnsafeAlignedRowMajor, WMMA.RowMajor), + ] + @eval begin + # TODO: Have accessors in CUDA.jl to get the fragment sizes? + # FP16 (16, 16, 16), (8, 32, 16), and (32, 8, 16) + @inline fragtype_a(::Type{WMMAOp{$M, $N, $K, CT, AT}}, ::Type{$layout_type{CT}}) where {CT, AT} = WMMA.Fragment{$M, $N, $K, 16, CT, $wmma_layout_type, WMMA.MatrixA} + @inline fragtype_b(::Type{WMMAOp{$M, $N, $K, CT, AT}}, ::Type{$layout_type{CT}}) where {CT, AT} = WMMA.Fragment{$M, $N, $K, 16, CT, $wmma_layout_type, WMMA.MatrixB} + @inline fragtype_accum(::Type{WMMAOp{$M, $N, $K, CT, AT}}, ::Type{$layout_type{AT}}) where {CT, AT} = WMMA.Fragment{$M, $N, $K, 8, AT, WMMA.Unspecified, WMMA.Accumulator} + end +end + # convert_index_func: function used to transpose the index in case of a row-major layout for (layout_type, wmma_layout_type, convert_index_func) in [ (Layout.ColMajor, WMMA.ColMajor, identity), @@ -164,10 +184,6 @@ for (layout_type, wmma_layout_type, convert_index_func) in [ (Layout.UnsafeAlignedRowMajor, WMMA.RowMajor, x -> reverse(Tuple(x))), ] @eval begin - @inline fragtype_a(::Type{WMMAOp{16, 16, 16, CT, AT}}, ::Type{$layout_type{CT}}) where {CT, AT} = WMMA.Fragment{16, 16, 16, 16, CT, $wmma_layout_type, WMMA.MatrixA} - @inline fragtype_b(::Type{WMMAOp{16, 16, 16, CT, AT}}, ::Type{$layout_type{CT}}) where {CT, AT} = WMMA.Fragment{16, 16, 16, 16, CT, $wmma_layout_type, WMMA.MatrixB} - @inline fragtype_accum(::Type{WMMAOp{16, 16, 16, CT, AT}}, ::Type{$layout_type{AT}}) where {CT, AT} = WMMA.Fragment{16, 16, 16, 8, AT, WMMA.Unspecified, WMMA.Accumulator} - @inline function load_a(::Type{WMMAOp{M, N, K, CT, AT}}, ::Type{$layout_type{CT}}, workspace, tile::Tile) where {M, N, K, CT, AT} conf = WMMA.Config{M, N, K, AT} @@ -219,46 +235,46 @@ end # WMMAComplex # ----------- -struct WMMAComplexOp{M, N, K} end +struct WMMAComplexOp{M, N, K, CT, AT} end -@inline shape(::Type{WMMAComplexOp{M, N, K}}) where {M, N, K} = (M = M, N = N, K = K) +@inline shape(::Type{WMMAComplexOp{M, N, K, CT, AT}}) where {M, N, K, CT, AT} = (M = M, N = N, K = K) # convert_index_func: function used to transpose the index in case of a row-major layout -for (layout_type, wmma_layout_type, convert_index_func) in [ - (Layout.SplitColMajor, WMMA.ColMajor, identity), - (Layout.SplitRowMajor, WMMA.RowMajor, x -> reverse(Tuple(x))), +for (layout_type, base_layout, wmma_layout_type, convert_index_func) in [ + (Layout.SplitColMajor, Layout.UnsafeAlignedColMajor, WMMA.ColMajor, identity), + (Layout.SplitRowMajor, Layout.UnsafeAlignedRowMajor, WMMA.RowMajor, x -> reverse(Tuple(x))), ] @eval begin - @inline fragtype_a(::Type{WMMAComplexOp{16, 16, 16}}, ::Type{$layout_type{Float16}}) = NTuple{2, WMMA.Fragment{16, 16, 16, 16, Float16, $wmma_layout_type, WMMA.MatrixA}} - @inline fragtype_b(::Type{WMMAComplexOp{16, 16, 16}}, ::Type{$layout_type{Float16}}) = NTuple{2, WMMA.Fragment{16, 16, 16, 16, Float16, $wmma_layout_type, WMMA.MatrixB}} - @inline fragtype_accum(::Type{WMMAComplexOp{16, 16, 16}}, ::Type{$layout_type{Float32}}) = NTuple{2, WMMA.Fragment{16, 16, 16, 8, Float32, WMMA.Unspecified, WMMA.Accumulator}} + @inline fragtype_a(::Type{WMMAComplexOp{M, N, K, CT, AT}}, ::Type{$layout_type{CT}}) where {M, N, K, CT, AT} = NTuple{2, fragtype_a(WMMAOp{M, N, K, CT, AT}, $base_layout{CT})} + @inline fragtype_b(::Type{WMMAComplexOp{M, N, K, CT, AT}}, ::Type{$layout_type{CT}}) where {M, N, K, CT, AT} = NTuple{2, fragtype_b(WMMAOp{M, N, K, CT, AT}, $base_layout{CT})} + @inline fragtype_accum(::Type{WMMAComplexOp{M, N, K, CT, AT}}, ::Type{$layout_type{AT}}) where {M, N, K, CT, AT} = NTuple{2, fragtype_accum(WMMAOp{M, N, K, CT, AT}, $base_layout{AT})} - @inline function load_a(::Type{WMMAComplexOp{M, N, K}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K} - conf = WMMA.Config{16, 16, 16, Float32} + @inline function load_a(::Type{WMMAComplexOp{M, N, K, CT, AT}}, ::Type{$layout_type{CT}}, workspace, tile::Tile) where {M, N, K, CT, AT} + conf = WMMA.Config{M, N, K, AT} ind = linearise($convert_index_func(tile.index), (size(workspace)[1], size(workspace)[2])) return (WMMA.load_a(pointer(workspace, ind), size(workspace)[1], $wmma_layout_type, conf), WMMA.load_a(pointer(workspace, ind + size(workspace)[1] * size(workspace)[2]), size(workspace)[1], $wmma_layout_type, conf)) end - @inline function load_b(::Type{WMMAComplexOp{M, N, K}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K} - conf = WMMA.Config{16, 16, 16, Float32} + @inline function load_b(::Type{WMMAComplexOp{M, N, K, CT, AT}}, ::Type{$layout_type{CT}}, workspace, tile::Tile) where {M, N, K, CT, AT} + conf = WMMA.Config{M, N, K, AT} ind = linearise($convert_index_func(tile.index), (size(workspace)[1], size(workspace)[2])) return (WMMA.load_b(pointer(workspace, ind), size(workspace)[1], $wmma_layout_type, conf), WMMA.load_b(pointer(workspace, ind + size(workspace)[1] * size(workspace)[2]), size(workspace)[1], $wmma_layout_type, conf)) end - @inline function load_c(::Type{WMMAComplexOp{M, N, K}}, ::Type{$layout_type{Float32}}, workspace, tile::Tile) where {M, N, K} - conf = WMMA.Config{M, N, K, Float32} + @inline function load_c(::Type{WMMAComplexOp{M, N, K, CT, AT}}, ::Type{$layout_type{AT}}, workspace, tile::Tile) where {M, N, K, CT, AT} + conf = WMMA.Config{M, N, K, AT} ind = linearise($convert_index_func(tile.index), (size(workspace)[1], size(workspace)[2])) return (WMMA.load_c(pointer(workspace, ind), size(workspace)[1], $wmma_layout_type, conf), WMMA.load_c(pointer(workspace, ind + size(workspace)[1] * size(workspace)[2]), size(workspace)[1], $wmma_layout_type, conf)) end - @inline function store_d(::Type{WMMAComplexOp{M, N, K}}, ::Type{$layout_type{Float32}}, workspace, frag, tile::Tile) where {M, N, K} - conf = WMMA.Config{M, N, K, Float32} + @inline function store_d(::Type{WMMAComplexOp{M, N, K, CT, AT}}, ::Type{$layout_type{AT}}, workspace, frag, tile::Tile) where {M, N, K, CT, AT} + conf = WMMA.Config{M, N, K, AT} ind = linearise($convert_index_func(tile.index), (size(workspace)[1], size(workspace)[2])) WMMA.store_d(pointer(workspace, ind), frag[1], size(workspace)[1], $wmma_layout_type, conf) @@ -269,8 +285,8 @@ end using LLVM -@inline function mma(::Type{WMMAComplexOp{M, N, K}}, a_frag, b_frag, c_frag) where {M, N, K} - conf = WMMA.Config{16, 16, 16, Float32} +@inline function mma(::Type{WMMAComplexOp{M, N, K, CT, AT}}, a_frag, b_frag, c_frag) where {M, N, K, CT, AT} + conf = WMMA.Config{M, N, K, AT} c_re = c_frag[1] c_im = c_frag[2] @@ -288,48 +304,48 @@ end # WMMADual # -------- -struct WMMADualOp{M, N, K} end +struct WMMADualOp{M, N, K, CT, AT} end -@inline shape(::Type{WMMADualOp{M, N, K}}) where {M, N, K} = (M = M, N = N, K = K) +@inline shape(::Type{WMMADualOp{M, N, K, CT, AT}}) where {M, N, K, CT, AT} = (M = M, N = N, K = K) -@inline fragtype_a(::Type{WMMADualOp{16, 16, 16}}, ::Type{Layout.SplitColMajor{Float16}}) = NTuple{2, WMMA.Fragment{16, 16, 16, 16, Float16, WMMA.ColMajor, WMMA.MatrixA}} -@inline fragtype_b(::Type{WMMADualOp{16, 16, 16}}, ::Type{Layout.SplitColMajor{Float16}}) = NTuple{2, WMMA.Fragment{16, 16, 16, 16, Float16, WMMA.ColMajor, WMMA.MatrixB}} -@inline fragtype_accum(::Type{WMMADualOp{16, 16, 16}}, ::Type{Layout.SplitColMajor{Float32}}) = NTuple{2, WMMA.Fragment{16, 16, 16, 8, Float32, WMMA.Unspecified, WMMA.Accumulator}} +@inline fragtype_a(::Type{WMMADualOp{M, N, K, CT, AT}}, ::Type{Layout.SplitColMajor{CT}}) where {M, N, K, CT, AT} = NTuple{2, fragtype_a(WMMAOp{M, N, K, CT, AT}, Layout.UnsafeAlignedColMajor{CT})} +@inline fragtype_b(::Type{WMMADualOp{M, N, K, CT, AT}}, ::Type{Layout.SplitColMajor{CT}}) where {M, N, K, CT, AT} = NTuple{2, fragtype_b(WMMAOp{M, N, K, CT, AT}, Layout.UnsafeAlignedColMajor{CT})} +@inline fragtype_accum(::Type{WMMADualOp{M, N, K, CT, AT}}, ::Type{Layout.SplitColMajor{AT}}) where {M, N, K, CT, AT} = NTuple{2, fragtype_accum(WMMAOp{M, N, K, CT, AT}, Layout.UnsafeAlignedColMajor{AT})} -@inline function load_a(::Type{WMMADualOp{M, N, K}}, ::Type{Layout.SplitColMajor{Float16}}, workspace, tile::Tile) where {M, N, K} - conf = WMMA.Config{16, 16, 16, Float32} +@inline function load_a(::Type{WMMADualOp{M, N, K, CT, AT}}, ::Type{Layout.SplitColMajor{CT}}, workspace, tile::Tile) where {M, N, K, CT, AT} + conf = WMMA.Config{M, N, K, AT} ind = linearise(tile.index, (size(workspace)[1], size(workspace)[2])) return (WMMA.load_a(pointer(workspace, ind), size(workspace)[1], WMMA.ColMajor, conf), WMMA.load_a(pointer(workspace, ind + size(workspace)[1] * size(workspace)[2]), size(workspace)[1], WMMA.ColMajor, conf)) end -@inline function load_b(::Type{WMMADualOp{M, N, K}}, ::Type{Layout.SplitColMajor{Float16}}, workspace, tile::Tile) where {M, N, K} - conf = WMMA.Config{16, 16, 16, Float32} +@inline function load_b(::Type{WMMADualOp{M, N, K, CT, AT}}, ::Type{Layout.SplitColMajor{CT}}, workspace, tile::Tile) where {M, N, K, CT, AT} + conf = WMMA.Config{M, N, K, AT} ind = linearise(tile.index, (size(workspace)[1], size(workspace)[2])) return (WMMA.load_b(pointer(workspace, ind), size(workspace)[1], WMMA.ColMajor, conf), WMMA.load_b(pointer(workspace, ind + size(workspace)[1] * size(workspace)[2]), size(workspace)[1], WMMA.ColMajor, conf)) end -@inline function load_c(::Type{WMMADualOp{M, N, K}}, ::Type{Layout.SplitColMajor{Float32}}, workspace, tile::Tile) where {M, N, K} - conf = WMMA.Config{M, N, K, Float32} +@inline function load_c(::Type{WMMADualOp{M, N, K, CT, AT}}, ::Type{Layout.SplitColMajor{AT}}, workspace, tile::Tile) where {M, N, K, CT, AT} + conf = WMMA.Config{M, N, K, AT} ind = linearise(tile.index, (size(workspace)[1], size(workspace)[2])) return (WMMA.load_c(pointer(workspace, ind), size(workspace)[1], WMMA.ColMajor, conf), WMMA.load_c(pointer(workspace, ind + size(workspace)[1] * size(workspace)[2]), size(workspace)[1], WMMA.ColMajor, conf)) end -@inline function store_d(::Type{WMMADualOp{M, N, K}}, ::Type{Layout.SplitColMajor{Float32}}, workspace, frag, tile::Tile) where {M, N, K} - conf = WMMA.Config{M, N, K, Float32} +@inline function store_d(::Type{WMMADualOp{M, N, K, CT, AT}}, ::Type{Layout.SplitColMajor{AT}}, workspace, frag, tile::Tile) where {M, N, K, CT, AT} + conf = WMMA.Config{M, N, K, AT} ind = linearise(tile.index, (size(workspace)[1], size(workspace)[2])) WMMA.store_d(pointer(workspace, ind), frag[1], size(workspace)[1], WMMA.ColMajor, conf) WMMA.store_d(pointer(workspace, ind + size(workspace)[1] * size(workspace)[2]), frag[2], size(workspace)[1], WMMA.ColMajor, conf) end -@inline function mma(::Type{WMMADualOp{M, N, K}}, a_frag, b_frag, c_frag) where {M, N, K} - conf = WMMA.Config{16, 16, 16, Float32} +@inline function mma(::Type{WMMADualOp{M, N, K, CT, AT}}, a_frag, b_frag, c_frag) where {M, N, K, CT, AT} + conf = WMMA.Config{M, N, K, AT} c_re = c_frag[1] c_du = c_frag[2] diff --git a/test/matmul.jl b/test/matmul.jl index fa625021..19f08de6 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -7,8 +7,17 @@ include("../configs/configs.jl") @testset "Matrix multiplication" begin @testcase "$( cf.name )" for cf in get_configs() - c_h, a, b, c, d = generate_inputs(cf) - run_gemm(cf, a, b, c, d) - @test verify(cf, c_h, d) + try + c_h, a, b, c, d = generate_inputs(cf) + run_gemm(cf, a, b, c, d) + @test verify(cf, c_h, d) + catch err + # Count tests with config errors as "broken". + if isa(err, GemmKernels.ConfigError) + @test true skip=true + else + rethrow() + end + end end end From 900f2a4207c4de754330ed63bdb8b1d8c4feb1fd Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Tue, 12 Dec 2023 12:56:23 +0100 Subject: [PATCH 12/21] Perform sweep over different WMMA shapes --- tuning/tune-wmma.jl | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/tuning/tune-wmma.jl b/tuning/tune-wmma.jl index e710708c..b829821e 100644 --- a/tuning/tune-wmma.jl +++ b/tuning/tune-wmma.jl @@ -47,8 +47,6 @@ const CD_type = Float32 const zero_c = true -const OP_M, OP_N, OP_K = 16, 16, 16 - ####### # Reuse inputs across iterations. @@ -95,6 +93,9 @@ function generate_configs() BLOCK_K=Int[], WARPS_M=Int[], WARPS_N=Int[], + OP_M=Int[], + OP_N=Int[], + OP_K=Int[], kernel_str=String[], category=String[], times=Vector{Any}[] @@ -108,6 +109,11 @@ function generate_configs() BLOCK_K in 2 .^ (5:7), WARPS_M in 2 .^ (0:3), WARPS_N in 2 .^ (0:3), + (OP_M, OP_N, OP_K) in [ + (16, 16, 16), + (8, 32, 16), + (32, 8, 16), + ], kernel_str in ["singlestage", "pipelined"] push!(all_configs, Dict( @@ -119,6 +125,9 @@ function generate_configs() :BLOCK_K => BLOCK_K, :WARPS_M => WARPS_M, :WARPS_N => WARPS_N, + :OP_M => OP_M, + :OP_N => OP_N, + :OP_K => OP_K, :kernel_str => kernel_str, :category => "unknown", :times => [], @@ -137,6 +146,9 @@ function get_config(row) BLOCK_K = row["BLOCK_K"] WARPS_M = row["WARPS_M"] WARPS_N = row["WARPS_N"] + OP_M = row["OP_M"] + OP_N = row["OP_N"] + OP_K = row["OP_K"] kernel = kernel_string_to_function(row["kernel_str"]) @get_wmma_config @@ -285,6 +297,9 @@ function benchmark_best_configs(configs) BLOCK_K=Int[], WARPS_M=Int[], WARPS_N=Int[], + OP_M=Int[], + OP_N=Int[], + OP_K=Int[], kernel_str=String[], category=String[], uncertainty=Float64[], @@ -310,6 +325,9 @@ function benchmark_best_configs(configs) :BLOCK_K => best_config["BLOCK_K"], :WARPS_M => best_config["WARPS_M"], :WARPS_N => best_config["WARPS_N"], + :OP_M => best_config["OP_M"], + :OP_N => best_config["OP_N"], + :OP_K => best_config["OP_K"], :kernel_str => best_config["kernel_str"], :category => "todo", :uncertainty => Inf, @@ -517,6 +535,7 @@ function main() (:transpose, get_label(config_row.transpose_a, config_row.transpose_b)), (:block_shape, (config_row.BLOCK_M, config_row.BLOCK_N, config_row.BLOCK_K)), (:num_warps, (config_row.WARPS_M, config_row.WARPS_N)), + (:op_shape, (config_row.OP_M, config_row.OP_N, config_row.OP_K)), (:kernel, config_row.kernel_str), (:counters, counter_dict_abs), (:counters_relative, counter_dict_rel), From e490dcc882fb73372d6544acf3565d82d4567975 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Tue, 12 Dec 2023 13:51:01 +0100 Subject: [PATCH 13/21] Add NVML data to benchmark dataframe --- tuning/tune-wmma.jl | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/tuning/tune-wmma.jl b/tuning/tune-wmma.jl index b829821e..4f8158af 100644 --- a/tuning/tune-wmma.jl +++ b/tuning/tune-wmma.jl @@ -287,6 +287,19 @@ function got_enough_samples(row) uncertainty < PLOT_RATIO_MAX_UNCERTAINTY end +function get_nvml_data(dev) + Dict( + :clock_info => NVML.clock_info(dev), + :max_clock_info => NVML.max_clock_info(dev), + :clock_event_reasons => NVML.clock_event_reasons(dev), + :power_usage => NVML.power_usage(dev), + :energy_consumption => NVML.energy_consumption(dev), + :temperature => NVML.temperature(dev), + :memory_info => NVML.memory_info(dev), + :utilization_rates => NVML.utilization_rates(dev), + ) +end + function benchmark_best_configs(configs) best_configs = DataFrame( transpose_a=Bool[], @@ -305,9 +318,13 @@ function benchmark_best_configs(configs) uncertainty=Float64[], time_spent=Float64[], gemmkernels_times=Vector{Any}[], - baseline_times=Vector{Any}[] + baseline_times=Vector{Any}[], + gemmkernels_nvml=Vector{Any}[], + baseline_nvml=Vector{Any}[] ) + dev = NVML.Device(parent_uuid(device())) + for transpose_a = [false, true], transpose_b = [false, true], N = N_vals @@ -334,6 +351,8 @@ function benchmark_best_configs(configs) :time_spent => 0.0, :gemmkernels_times => [], :baseline_times => [], + :gemmkernels_nvml => [], + :baseline_nvml => [], )) end @@ -364,9 +383,11 @@ function benchmark_best_configs(configs) start_time = Dates.now() + push!(config_row["gemmkernels_nvml"], get_nvml_data(dev)) prof = CUDA.@profile run_gemm(cf, a, b, c, d) push!(config_row["gemmkernels_times"], sum(prof.device[!, "stop"] - prof.device[!, "start"])) + push!(config_row["baseline_nvml"], get_nvml_data(dev)) prof = CUDA.@profile run_baseline(cf, a, b, c, d) push!(config_row["baseline_times"], sum(prof.device[!, "stop"] - prof.device[!, "start"])) From 1bacdfe91c91ca5d49b04af37afbd4d80e9a7f43 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Tue, 12 Dec 2023 14:17:14 +0100 Subject: [PATCH 14/21] Use serial profiling --- tuning/tune-wmma.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tuning/tune-wmma.jl b/tuning/tune-wmma.jl index 4f8158af..f383121e 100644 --- a/tuning/tune-wmma.jl +++ b/tuning/tune-wmma.jl @@ -384,11 +384,11 @@ function benchmark_best_configs(configs) start_time = Dates.now() push!(config_row["gemmkernels_nvml"], get_nvml_data(dev)) - prof = CUDA.@profile run_gemm(cf, a, b, c, d) + prof = CUDA.@profile concurrent=false run_gemm(cf, a, b, c, d) push!(config_row["gemmkernels_times"], sum(prof.device[!, "stop"] - prof.device[!, "start"])) push!(config_row["baseline_nvml"], get_nvml_data(dev)) - prof = CUDA.@profile run_baseline(cf, a, b, c, d) + prof = CUDA.@profile concurrent=false run_baseline(cf, a, b, c, d) push!(config_row["baseline_times"], sum(prof.device[!, "stop"] - prof.device[!, "start"])) config_row["time_spent"] += (Dates.now() - start_time) / Second(1) From 97e2e7fd0e4ebd87cb1cb76229b59c7c9f24b820 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 12 Dec 2023 13:28:48 +0000 Subject: [PATCH 15/21] Allow selecting a GPU. --- tuning/tune-wmma.sh | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/tuning/tune-wmma.sh b/tuning/tune-wmma.sh index e384b9e2..68da4a7e 100755 --- a/tuning/tune-wmma.sh +++ b/tuning/tune-wmma.sh @@ -1,6 +1,7 @@ #!/usr/bin/env bash set -Eeuo pipefail +GPU_ID=0 GPU_CLOCK=915 MEM_CLOCK=5001 @@ -13,6 +14,7 @@ Tune WMMA Parameters. Options: -h, --help Show this help. +-i id Specify which GPU to target. -gc, --gpu-clock speed Change the frequency the GPU core clock is locked to before benchmarking, in MHz (default 915 MHz). -mc, --memory-clock speed Change the frequency the GPU memory clock is locked to @@ -26,6 +28,11 @@ while [[ $# -gt 0 ]]; do -h|--help) usage; exit 0 ;; + -i) + shift + GPU_ID=$1 + shift + ;; -gc|--gpu-clock) shift GPU_CLOCK=$1 @@ -55,9 +62,11 @@ if [[ $# -ne 0 ]]; then exit 1 fi -echo "Locking GPU clock speeds to $GPU_CLOCK MHz (GPU) / $MEM_CLOCK MHz (Mem)..." +export CUDA_VISIBLE_DEVICES=$GPU_ID + +echo "Locking GPU $GPU_ID clock speeds to $GPU_CLOCK MHz (GPU) / $MEM_CLOCK MHz (Mem)..." -if ! nvidia-smi --query-supported-clocks=graphics,memory --format=csv | grep -F "$GPU_CLOCK MHz, $MEM_CLOCK MHz"; then +if ! nvidia-smi -i $GPU_ID --query-supported-clocks=graphics,memory --format=csv | grep -F "$GPU_CLOCK MHz, $MEM_CLOCK MHz"; then echo "Unsupported combination of clock speeds!" exit 1 fi @@ -72,9 +81,9 @@ while true; do kill -0 "$$" || exit done &> /dev/null & -sudo nvidia-smi -pm 1 -sudo nvidia-smi --lock-gpu-clocks=$GPU_CLOCK -sudo nvidia-smi --lock-memory-clocks=$MEM_CLOCK +sudo nvidia-smi -i $GPU_ID -pm 1 +sudo nvidia-smi -i $GPU_ID --lock-gpu-clocks=$GPU_CLOCK +sudo nvidia-smi -i $GPU_ID --lock-memory-clocks=$MEM_CLOCK cd "$( dirname "${BASH_SOURCE[0]}" )" cd .. @@ -119,5 +128,5 @@ until julia --project=tuning -e ' done echo "Unlocking GPU clock speeds..." -sudo nvidia-smi --reset-gpu-clocks -sudo nvidia-smi --reset-memory-clocks +sudo nvidia-smi -i $GPU_ID --reset-gpu-clocks +sudo nvidia-smi -i $GPU_ID --reset-memory-clocks From 88da82c1b2d3fe037f1db20fcecb5fc0f94f6505 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 13 Dec 2023 07:43:19 +0000 Subject: [PATCH 16/21] Fix manifest loading. --- tuning/tune-wmma.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tuning/tune-wmma.sh b/tuning/tune-wmma.sh index 68da4a7e..310ba9d5 100755 --- a/tuning/tune-wmma.sh +++ b/tuning/tune-wmma.sh @@ -93,6 +93,7 @@ julia --project -e 'using Pkg; Pkg.instantiate(); Pkg.precompile()' julia --project=tuning -e 'using Pkg; Pkg.instantiate(); Pkg.precompile()' until julia --project=tuning -e ' + pushfirst!(LOAD_PATH, @__DIR__) using CUDA, Distributed # determine how many workers to use @@ -120,7 +121,7 @@ until julia --project=tuning -e ' addprocs(workers; exeflags, env) using Distributed - @everywhere push!(LOAD_PATH, @__DIR__) + @everywhere pushfirst!(LOAD_PATH, @__DIR__) @everywhere include("tuning/tune-wmma.jl")' "$@"; do echo "Tuning script crashed. Resuming after 1 second..." >&2 From 83dccd7a17212f81b0b9dd027e422475a0dc4377 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Tue, 2 Jan 2024 12:36:32 +0100 Subject: [PATCH 17/21] Remove LocalArray size check --- src/matmul.jl | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/src/matmul.jl b/src/matmul.jl index f3dfd34f..3c4bde0d 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -34,30 +34,6 @@ function matmul(conf::Config, a, b, c, d; conf.block_shape.K ≥ 2 * conf.compute_op_shape.K || throw(ConfigError("Need at least two stages to use a pipelined kernel, i.e. BLOCK_K ≥ 2 * OPERATOR_K")) end - # Check LocalArray size limit of 32 elements. - if kernel == Kernel.matmul_singlestage - num_fragments_m = conf.compute_warp.M ÷ conf.compute_op_shape.M - num_fragments_n = conf.compute_warp.N ÷ conf.compute_op_shape.N - - num_fragments_m * num_fragments_n < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!")) - end - - if kernel == Kernel.matmul_pipelined - num_fragments_m = conf.compute_warp.M ÷ conf.compute_op_shape.M - num_fragments_n = conf.compute_warp.N ÷ conf.compute_op_shape.N - - a_frag_i = (conf.block_shape.M * conf.block_shape.K) ÷ (conf.mem_a_warp.M * conf.mem_a_warp.K * conf.warps_per_block) - a_frag_j = (conf.mem_a_warp.M * conf.mem_a_warp.K) ÷ (conf.mem_a_thread.M * conf.mem_a_thread.K * 32) - b_frag_i = (conf.block_shape.K * conf.block_shape.N) ÷ (conf.mem_b_warp.K * conf.mem_b_warp.N * conf.warps_per_block) - b_frag_j = (conf.mem_b_warp.K * conf.mem_b_warp.N) ÷ (conf.mem_b_thread.K * conf.mem_b_thread.N * 32) - - num_fragments_m * num_fragments_n < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!")) - a_frag_i * a_frag_j < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!")) - b_frag_i * b_frag_j < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!")) - 2 * num_fragments_m < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!")) - 2 * num_fragments_n < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!")) - end - hostkernel = @cuda launch=false kernel(args...) attributes(hostkernel.fun)[CUDA.FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES] = shmem From 90e307ae52cca9af8f1581fce45e37561e502c50 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Tue, 2 Jan 2024 13:02:59 +0100 Subject: [PATCH 18/21] Lock to max frequency by default --- tuning/tune-wmma.sh | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tuning/tune-wmma.sh b/tuning/tune-wmma.sh index 310ba9d5..73341cc7 100755 --- a/tuning/tune-wmma.sh +++ b/tuning/tune-wmma.sh @@ -2,8 +2,8 @@ set -Eeuo pipefail GPU_ID=0 -GPU_CLOCK=915 -MEM_CLOCK=5001 +GPU_CLOCK=-1 +MEM_CLOCK=-1 usage() { @@ -16,9 +16,9 @@ Options: -h, --help Show this help. -i id Specify which GPU to target. -gc, --gpu-clock speed Change the frequency the GPU core clock is locked to - before benchmarking, in MHz (default 915 MHz). + before benchmarking, in MHz (default the max frequency). -mc, --memory-clock speed Change the frequency the GPU memory clock is locked to - before benchmarking, in MHz (default 5001 MHz). + before benchmarking, in MHz (default the max frequency). EOF } @@ -64,6 +64,14 @@ fi export CUDA_VISIBLE_DEVICES=$GPU_ID +if [[ "$GPU_CLOCK" == "-1" ]]; then + GPU_CLOCK="$(nvidia-smi -i $GPU_ID --query-supported-clocks=graphics --format=csv | sort -rn | head -1 | cut -f1 -d' ')" +fi + +if [[ "$MEM_CLOCK" == "-1" ]]; then + MEM_CLOCK="$(nvidia-smi -i $GPU_ID --query-supported-clocks=memory --format=csv | sort -rn | head -1 | cut -f1 -d' ')" +fi + echo "Locking GPU $GPU_ID clock speeds to $GPU_CLOCK MHz (GPU) / $MEM_CLOCK MHz (Mem)..." if ! nvidia-smi -i $GPU_ID --query-supported-clocks=graphics,memory --format=csv | grep -F "$GPU_CLOCK MHz, $MEM_CLOCK MHz"; then From 48467c542eb260fb204f91dcae6d297d78b1adb4 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Tue, 2 Jan 2024 13:16:09 +0100 Subject: [PATCH 19/21] Use time as stop criterion for plot Instead of the uncertainty. --- tuning/tune-wmma.jl | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/tuning/tune-wmma.jl b/tuning/tune-wmma.jl index f383121e..b317f1c4 100644 --- a/tuning/tune-wmma.jl +++ b/tuning/tune-wmma.jl @@ -31,13 +31,9 @@ const BENCH_MIN_NUM_SAMPLES = 10 ##### -# Stop gathering samples for plot if the "error bars" are smaller than this... -const PLOT_RATIO_MAX_UNCERTAINTY = 0.05 - -# ... or we have exceeded the time limit... -# In my experience, only N <= 2^9 requires more than a handful of samples. -# That's only 3*4 configurations, so a limit of say 600 seconds will take ~2 hours. -const PLOT_MAX_NUM_SECONDS = 600 +# Stop gathering samples for plot if we have spent this much time... +# 60 seconds/configuration * 32 configurations = 32 minutes for plot. +const PLOT_NUM_SECONDS = 60 # ... but have at least 10 samples. const PLOT_MIN_NUM_SAMPLES = 10 @@ -280,11 +276,7 @@ function got_enough_samples(row) (length(gk) < PLOT_MIN_NUM_SAMPLES) && return false (length(bl) < PLOT_MIN_NUM_SAMPLES) && return false - (row["time_spent"] >= PLOT_MAX_NUM_SECONDS) && return true - - uncertainty, _, _ = get_uncertainty(gk, bl) - - uncertainty < PLOT_RATIO_MAX_UNCERTAINTY + row["time_spent"] >= PLOT_NUM_SECONDS end function get_nvml_data(dev) @@ -315,7 +307,6 @@ function benchmark_best_configs(configs) OP_K=Int[], kernel_str=String[], category=String[], - uncertainty=Float64[], time_spent=Float64[], gemmkernels_times=Vector{Any}[], baseline_times=Vector{Any}[], @@ -347,7 +338,6 @@ function benchmark_best_configs(configs) :OP_K => best_config["OP_K"], :kernel_str => best_config["kernel_str"], :category => "todo", - :uncertainty => Inf, :time_spent => 0.0, :gemmkernels_times => [], :baseline_times => [], @@ -363,7 +353,7 @@ function benchmark_best_configs(configs) input_dict = Dict() - p = ProgressUnknown(desc="Benchmarking (highest uncertainty)", dt=1.0) + p = ProgressUnknown(desc="Benchmarking", dt=1.0) # Spread the samples of one configuration over time, to reduce the effect # of time-related noise. Note that this means that the progress bar may @@ -392,24 +382,18 @@ function benchmark_best_configs(configs) push!(config_row["baseline_times"], sum(prof.device[!, "stop"] - prof.device[!, "start"])) config_row["time_spent"] += (Dates.now() - start_time) / Second(1) - old_uncertainty = config_row["uncertainty"] - config_row["uncertainty"], _, _ = get_uncertainty(config_row["gemmkernels_times"], config_row["baseline_times"]) if got_enough_samples(config_row) config_row["category"] = "done" end # Update progress bar. - highest_uncertainty = best_configs[(@. (best_configs[!, "transpose_a"] == transpose_a) & (best_configs[!, "transpose_b"] == transpose_b)), :] - highest_uncertainty = maximum(highest_uncertainty[!, "uncertainty"]) next!(p; showvalues = [ (:transpose_a, transpose_a), (:transpose_b, transpose_b), (:N, config_row["N"]), (:num_samples, length(config_row["gemmkernels_times"])), - (:uncertainty, "$(config_row["uncertainty"]) (Δ = $(config_row["uncertainty"] - old_uncertainty))"), (:time_spent_in_config, config_row["time_spent"]), - (:highest_uncertainty, highest_uncertainty), (:remaining_N, best_configs[(@. (best_configs[!, "category"] == "todo") & (best_configs[!, "transpose_a"] == transpose_a) & (best_configs[!, "transpose_b"] == transpose_b)), :].N), (:remaining_configurations, sum(best_configs[!, "category"] .== "todo")) ]) From abba6db17f47d3073f22f336871b38ebad4d9666 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Tue, 2 Jan 2024 14:12:52 +0100 Subject: [PATCH 20/21] Sleep if throttling is detected --- tuning/tune-wmma.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tuning/tune-wmma.jl b/tuning/tune-wmma.jl index b317f1c4..ac58c57e 100644 --- a/tuning/tune-wmma.jl +++ b/tuning/tune-wmma.jl @@ -292,6 +292,15 @@ function get_nvml_data(dev) ) end +function wait_if_throttling(dev) + cer = NVML.clock_event_reasons(dev) + + while cer.hw_power_brake || cer.sw_power_cap || cer.hw_slow || cer.sw_thermal || cer.hw_thermal + @info "Throttling detected. Sleeping for one second..." + sleep(1) + end +end + function benchmark_best_configs(configs) best_configs = DataFrame( transpose_a=Bool[], @@ -371,6 +380,8 @@ function benchmark_best_configs(configs) @info "Profiling configuration $(NamedTuple(config_row))..." + wait_if_throttling() + start_time = Dates.now() push!(config_row["gemmkernels_nvml"], get_nvml_data(dev)) From cd6a190a3379c42b6b74e9a33b53c3b7f6ce230c Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Tue, 2 Jan 2024 14:23:00 +0100 Subject: [PATCH 21/21] Run kernels in batches --- tuning/tune-wmma.jl | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/tuning/tune-wmma.jl b/tuning/tune-wmma.jl index ac58c57e..8ab2c2d8 100644 --- a/tuning/tune-wmma.jl +++ b/tuning/tune-wmma.jl @@ -35,8 +35,11 @@ const BENCH_MIN_NUM_SAMPLES = 10 # 60 seconds/configuration * 32 configurations = 32 minutes for plot. const PLOT_NUM_SECONDS = 60 -# ... but have at least 10 samples. -const PLOT_MIN_NUM_SAMPLES = 10 +# ... but have at least 100 samples. +const PLOT_MIN_NUM_SAMPLES = 100 + +# Group samples in batches of 10 samples each. +const PLOT_BATCH_SIZE = 10 const AB_type = Float16 const CD_type = Float32 @@ -380,19 +383,25 @@ function benchmark_best_configs(configs) @info "Profiling configuration $(NamedTuple(config_row))..." - wait_if_throttling() + for run_baseline in [false, true] + for i in 1:PLOT_BATCH_SIZE + wait_if_throttling() - start_time = Dates.now() + start_time = Dates.now() - push!(config_row["gemmkernels_nvml"], get_nvml_data(dev)) - prof = CUDA.@profile concurrent=false run_gemm(cf, a, b, c, d) - push!(config_row["gemmkernels_times"], sum(prof.device[!, "stop"] - prof.device[!, "start"])) + push!(config_row[if run_baseline "baseline_nvml" else "gemmkernels_nvml" end], get_nvml_data(dev)) - push!(config_row["baseline_nvml"], get_nvml_data(dev)) - prof = CUDA.@profile concurrent=false run_baseline(cf, a, b, c, d) - push!(config_row["baseline_times"], sum(prof.device[!, "stop"] - prof.device[!, "start"])) + if run_baseline + prof = CUDA.@profile concurrent=false run_baseline(cf, a, b, c, d) + else + prof = CUDA.@profile concurrent=false run_gemm(cf, a, b, c, d) + end - config_row["time_spent"] += (Dates.now() - start_time) / Second(1) + push!(config_row[if run_baseline "baseline_times" else "gemmkernels_times" end], sum(prof.device[!, "stop"] - prof.device[!, "start"])) + + config_row["time_spent"] += (Dates.now() - start_time) / Second(1) + end + end if got_enough_samples(config_row) config_row["category"] = "done"