diff --git a/.gitignore b/.gitignore index c181d1f8..d9a944ad 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -test/Manifest.toml +Manifest.toml +tuning/.CondaPkg/ diff --git a/configs/configs.jl b/configs/configs.jl index 1097746f..1a0f2e81 100644 --- a/configs/configs.jl +++ b/configs/configs.jl @@ -3,6 +3,7 @@ using GemmKernels using LinearAlgebra using ForwardDiff +using Octavian struct Configuration name # Human-readable name of the configuration. @@ -64,7 +65,8 @@ function generate_inputs(cf::Configuration) new_b_h = cf.transpose_b ? transpose(b_h) : b_h (cf.calc_reference)(c_h, new_a_h, new_b_h, cf.alpha, cf.beta) - c_h, a, b, c, d + c_ref = CuArray(c_h) + c_ref, a, b, c, d end # Run the GEMM. @@ -88,21 +90,21 @@ function run_baseline(cf::Configuration, a, b, c, d) end # Verify results. -function verify(cf::Configuration, c_h, d) - cf.verify(c_h, d) +function verify(cf::Configuration, c_ref, d) + cf.verify(c_ref, d) end -function verify_default(c_h, d) - isapprox(c_h, Array(d)) +function verify_default(c_ref, d) + isapprox(c_ref, d) end -function verify_bias(c_h, d, bias) - c_h .+ Array(bias) ≈ Array(d) +function verify_bias(c_ref, d, bias) + c_ref .+ bias ≈ d end -function verify_dual(c_h, d) - c_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, c_h) - d_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, Array(d)) +function verify_dual(c_ref, d) + c_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, c_ref) + d_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, d) isapprox(c_dual, d_dual) end @@ -238,7 +240,7 @@ macro get_wmma_config() CD_type, transpose_a, transpose_b, - mul!, + Octavian.matmul!, Epilogue.Default(), verify_default, kernel, diff --git a/test/Project.toml b/test/Project.toml index 8828b9af..8d1be0aa 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,5 +5,6 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" XUnit = "3e3c03f2-1a94-11e9-2981-050a4ca824ab" diff --git a/tuning/Project.toml b/tuning/Project.toml new file mode 100644 index 00000000..8fab9e13 --- /dev/null +++ b/tuning/Project.toml @@ -0,0 +1,14 @@ +[deps] +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" +Hyperopt = "93e5fe13-2215-51db-baaf-2e9a34fb2712" +JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36" +Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" +PythonPlot = "274fc56d-3b97-40fa-a1cd-1b4a50311bf9" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" diff --git a/tuning/tune-wmma.jl b/tuning/tune-wmma.jl new file mode 100644 index 00000000..8ab2c2d8 --- /dev/null +++ b/tuning/tune-wmma.jl @@ -0,0 +1,620 @@ +using CUDA, GemmKernels +using DataFrames +using DataStructures +using Dates +using Distributed +using FileWatching.Pidfile +using Logging +using LoggingExtras +using ProgressMeter +using Serialization +using Statistics +using StatsBase + +if myid() == 1 + using Plots + pythonplot() +end + +####### + +const N_vals = 2 .^ (7:14) + +# Stop sampling when normalised 95p CI is smaller than this... +const BENCH_NORM_CI_THRESHOLD = 0.01 + +# ... or we have exceeded the time limit... +const BENCH_MAX_NUM_SECONDS = 5 + +# ... but have at least 10 samples. +const BENCH_MIN_NUM_SAMPLES = 10 + +##### + +# Stop gathering samples for plot if we have spent this much time... +# 60 seconds/configuration * 32 configurations = 32 minutes for plot. +const PLOT_NUM_SECONDS = 60 + +# ... but have at least 100 samples. +const PLOT_MIN_NUM_SAMPLES = 100 + +# Group samples in batches of 10 samples each. +const PLOT_BATCH_SIZE = 10 + +const AB_type = Float16 +const CD_type = Float32 + +const zero_c = true + +####### + +# Reuse inputs across iterations. +c_ref = nothing +a = nothing +b = nothing +c = nothing +d = nothing +input_transpose_a = nothing +input_transpose_b = nothing +input_N = nothing + +include("../configs/configs.jl") + +# Write logging messages to file for persistence. +timestamp_logger(logger) = TransformerLogger(logger) do log + merge(log, (; message = "$(Dates.format(now(), "yyyy-mm-dd HH:MM:SS")) $(log.message)")) +end +function log_filename() + path = joinpath(@__DIR__, "tuning.log") + if myid() != 1 + path = "$(path).$(myid())" + end + path +end +FileLogger(log_filename(); append=true) |> timestamp_logger |> (x -> MinLevelLogger(x, Logging.Info)) |> global_logger + +function kernel_string_to_function(str) + Dict( + "singlestage" => Kernel.matmul_singlestage, + "pipelined" => Kernel.matmul_pipelined + )[str] +end + +get_label(transpose_a, transpose_b) = "$(transpose_a ? "T" : "N")$(transpose_b ? "T" : "N")" + +function generate_configs() + all_configs = DataFrame( + transpose_a=Bool[], + transpose_b=Bool[], + N=Int[], + BLOCK_M=Int[], + BLOCK_N=Int[], + BLOCK_K=Int[], + WARPS_M=Int[], + WARPS_N=Int[], + OP_M=Int[], + OP_N=Int[], + OP_K=Int[], + kernel_str=String[], + category=String[], + times=Vector{Any}[] + ) + + for transpose_a in [false, true], + transpose_b in [false, true], + N in N_vals, + BLOCK_M in 2 .^ (6:9), + BLOCK_N in 2 .^ (6:9), + BLOCK_K in 2 .^ (5:7), + WARPS_M in 2 .^ (0:3), + WARPS_N in 2 .^ (0:3), + (OP_M, OP_N, OP_K) in [ + (16, 16, 16), + (8, 32, 16), + (32, 8, 16), + ], + kernel_str in ["singlestage", "pipelined"] + + push!(all_configs, Dict( + :transpose_a => transpose_a, + :transpose_b => transpose_b, + :N => N, + :BLOCK_M => BLOCK_M, + :BLOCK_N => BLOCK_N, + :BLOCK_K => BLOCK_K, + :WARPS_M => WARPS_M, + :WARPS_N => WARPS_N, + :OP_M => OP_M, + :OP_N => OP_N, + :OP_K => OP_K, + :kernel_str => kernel_str, + :category => "unknown", + :times => [], + )) + end + + all_configs +end + +function get_config(row) + transpose_a = row["transpose_a"] + transpose_b = row["transpose_b"] + M = N = K = row["N"] + BLOCK_M = row["BLOCK_M"] + BLOCK_N = row["BLOCK_N"] + BLOCK_K = row["BLOCK_K"] + WARPS_M = row["WARPS_M"] + WARPS_N = row["WARPS_N"] + OP_M = row["OP_M"] + OP_N = row["OP_N"] + OP_K = row["OP_K"] + kernel = kernel_string_to_function(row["kernel_str"]) + + @get_wmma_config +end + +function generate_inputs_if_needed(row) + global input_transpose_a, input_transpose_b, input_N, c_ref, a, b, c, d + + cf = get_config(row) + + if (input_transpose_a, input_transpose_b, input_N) != (row.transpose_a, row.transpose_b, row.N) + for x in [c_ref, a, b, c, d] + if x !== nothing + CUDA.unsafe_free!(x) + end + end + c_ref, a, b, c, d = generate_inputs(cf) + input_transpose_a, input_transpose_b, input_N = row.transpose_a, row.transpose_b, row.N + end +end + +function get_inputs_for_plot(input_dict, row) + if row.N ∉ keys(input_dict) + cf = get_config(row) + _, a, b, c, d = generate_inputs(cf) + input_dict[row.N] = (a, b, c, d) + end + + return input_dict[row.N] +end + +function measure_config(row) + cf = get_config(row) + + generate_inputs_if_needed(row) + + d .= 0 + + try + run_gemm(cf, a, b, c, d) + catch err + bt = catch_backtrace() + log = sprint(Base.showerror, err) * sprint(Base.show_backtrace, bt) + + if isa(err, GemmKernels.ConfigError) + @info "Skipping configuration $(NamedTuple(row))\n" * log + return [Inf], "unsupported_config_post_run" + end + + if isa(err, CuError) + @error "Configuration failed: $(NamedTuple(row))\n" * log + rethrow() + end + + @info "Skipping configuration: $(NamedTuple(row))\n" * log + return [Inf], "error" + end + + if !verify(cf, c_ref, d) + @warn "Configuration produced invalid result: $(NamedTuple(row))" + + return [Inf], "invalid_result" + end + + times = Float64[] + + # Use CUDA.@elapsed instead of CUDA.@profile, because the latter is slower. + device_synchronize() + GC.gc(true) + + mkpidlock(joinpath(@__DIR__, "tuning.pid")) do + start_time = Dates.now() + + while true + synchronize(stream()) + time = CUDA.@elapsed run_gemm(cf, a, b, c, d) + push!(times, time) + + if length(times) >= BENCH_MIN_NUM_SAMPLES + (Dates.now() - start_time > Second(BENCH_MAX_NUM_SECONDS)) && break + (confidence_interval_95(times) / median(times) < BENCH_NORM_CI_THRESHOLD) && break + end + end + end + + return times, "success" +end + +confidence_interval_95(times) = 1.58 * iqr(times) / sqrt(length(times)) + +function prettytime(times) + min, q1, med, q3, max = nquantile(times, 4) + ci_95 = confidence_interval_95(times) + + # timescale + scale, unit = if med < 1e3 + 1, "ns" + elseif med < 1e6 + 1e3, "μs" + elseif med < 1e9 + 1e6, "ms" + else + 1e9, "s" + end + + rnd_min, rnd_q1, rnd_med, rnd_q3, rnd_max, rnd_ci_95 = round.([min, q1, med, q3, max, ci_95] ./ scale; sigdigits=3) + rnd_rel_ci_95 = round(100 * ci_95 / med; sigdigits=3) + + return "$rnd_med $unit ± $rnd_ci_95 $unit ($rnd_rel_ci_95%) (length: $(length(times)), 5-num summary: $rnd_min, $rnd_q1, $rnd_med, $rnd_q3, $rnd_max $unit)" +end + +perf_ratio(gemmkernels, baseline) = percentile(baseline, 0) / percentile(gemmkernels, 0) +perf_ratio_lo(gemmkernels, baseline) = percentile(baseline, 0) / percentile(gemmkernels, 75) +perf_ratio_hi(gemmkernels, baseline) = percentile(baseline, 75) / percentile(gemmkernels, 0) + +function get_uncertainty(gk, bl) + lo, mid, hi = (perf_ratio_lo(gk, bl), perf_ratio(gk, bl), perf_ratio_hi(gk, bl)) + + hi_uncertainty = abs(hi - mid) / mid + lo_uncertainty = abs(lo - mid) / mid + uncertainty = max(hi_uncertainty, lo_uncertainty) + + uncertainty, lo_uncertainty, hi_uncertainty +end + +function got_enough_samples(row) + gk, bl = row["gemmkernels_times"], row["baseline_times"] + + (length(gk) < PLOT_MIN_NUM_SAMPLES) && return false + (length(bl) < PLOT_MIN_NUM_SAMPLES) && return false + + row["time_spent"] >= PLOT_NUM_SECONDS +end + +function get_nvml_data(dev) + Dict( + :clock_info => NVML.clock_info(dev), + :max_clock_info => NVML.max_clock_info(dev), + :clock_event_reasons => NVML.clock_event_reasons(dev), + :power_usage => NVML.power_usage(dev), + :energy_consumption => NVML.energy_consumption(dev), + :temperature => NVML.temperature(dev), + :memory_info => NVML.memory_info(dev), + :utilization_rates => NVML.utilization_rates(dev), + ) +end + +function wait_if_throttling(dev) + cer = NVML.clock_event_reasons(dev) + + while cer.hw_power_brake || cer.sw_power_cap || cer.hw_slow || cer.sw_thermal || cer.hw_thermal + @info "Throttling detected. Sleeping for one second..." + sleep(1) + end +end + +function benchmark_best_configs(configs) + best_configs = DataFrame( + transpose_a=Bool[], + transpose_b=Bool[], + N=Int[], + BLOCK_M=Int[], + BLOCK_N=Int[], + BLOCK_K=Int[], + WARPS_M=Int[], + WARPS_N=Int[], + OP_M=Int[], + OP_N=Int[], + OP_K=Int[], + kernel_str=String[], + category=String[], + time_spent=Float64[], + gemmkernels_times=Vector{Any}[], + baseline_times=Vector{Any}[], + gemmkernels_nvml=Vector{Any}[], + baseline_nvml=Vector{Any}[] + ) + + dev = NVML.Device(parent_uuid(device())) + + for transpose_a = [false, true], + transpose_b = [false, true], + N = N_vals + + relevant_configs = configs[(@. (configs[!, "transpose_a"] == transpose_a) & (configs[!, "transpose_b"] == transpose_b) & (configs[!, "N"] == N)), :] + _, best_config_index = findmin(minimum.(relevant_configs[!, "times"], init=Inf)) + best_config = relevant_configs[best_config_index, :] + + push!(best_configs, Dict( + :transpose_a => transpose_a, + :transpose_b => transpose_b, + :N => N, + :BLOCK_M => best_config["BLOCK_M"], + :BLOCK_N => best_config["BLOCK_N"], + :BLOCK_K => best_config["BLOCK_K"], + :WARPS_M => best_config["WARPS_M"], + :WARPS_N => best_config["WARPS_N"], + :OP_M => best_config["OP_M"], + :OP_N => best_config["OP_N"], + :OP_K => best_config["OP_K"], + :kernel_str => best_config["kernel_str"], + :category => "todo", + :time_spent => 0.0, + :gemmkernels_times => [], + :baseline_times => [], + :gemmkernels_nvml => [], + :baseline_nvml => [], + )) + end + + # We will reuse matrix inputs across iterations. This takes about 4 GB of GPU memory for e.g. all matrix sizes for NN. + # Group runs of the same transposition together, so we don't have to keep 4 * 4 GB of inputs in memory. + for transpose_a in [false, true], + transpose_b in [false, true] + + input_dict = Dict() + + p = ProgressUnknown(desc="Benchmarking", dt=1.0) + + # Spread the samples of one configuration over time, to reduce the effect + # of time-related noise. Note that this means that the progress bar may + # make big jumps. + while true + (sum(@. (best_configs[!, "category"] == "todo") & (best_configs[!, "transpose_a"] == transpose_a) & (best_configs[!, "transpose_b"] == transpose_b)) == 0) && break + + for config_row in eachrow(best_configs) + if (config_row.category, config_row.transpose_a, config_row.transpose_b) != ("todo", transpose_a, transpose_b) + continue + end + + a, b, c, d = get_inputs_for_plot(input_dict, config_row) + cf = get_config(config_row) + + @info "Profiling configuration $(NamedTuple(config_row))..." + + for run_baseline in [false, true] + for i in 1:PLOT_BATCH_SIZE + wait_if_throttling() + + start_time = Dates.now() + + push!(config_row[if run_baseline "baseline_nvml" else "gemmkernels_nvml" end], get_nvml_data(dev)) + + if run_baseline + prof = CUDA.@profile concurrent=false run_baseline(cf, a, b, c, d) + else + prof = CUDA.@profile concurrent=false run_gemm(cf, a, b, c, d) + end + + push!(config_row[if run_baseline "baseline_times" else "gemmkernels_times" end], sum(prof.device[!, "stop"] - prof.device[!, "start"])) + + config_row["time_spent"] += (Dates.now() - start_time) / Second(1) + end + end + + if got_enough_samples(config_row) + config_row["category"] = "done" + end + + # Update progress bar. + next!(p; showvalues = [ + (:transpose_a, transpose_a), + (:transpose_b, transpose_b), + (:N, config_row["N"]), + (:num_samples, length(config_row["gemmkernels_times"])), + (:time_spent_in_config, config_row["time_spent"]), + (:remaining_N, best_configs[(@. (best_configs[!, "category"] == "todo") & (best_configs[!, "transpose_a"] == transpose_a) & (best_configs[!, "transpose_b"] == transpose_b)), :].N), + (:remaining_configurations, sum(best_configs[!, "category"] .== "todo")) + ]) + end + end + end + + best_configs +end + +function plot_results(best_configs) + markershapes = Dict( + "NN" => :circle, + "NT" => :dtriangle, + "TN" => :diamond, + "TT" => :cross + ) + + p = plot() + title!("$AB_type x $AB_type = $CD_type ($(name(device())))") + xlabel!("Matrix size [-]") + ylabel!("Performance relative to cuBLAS [%]") + + for transpose_a in [false, true], + transpose_b in [false, true] + + label = get_label(transpose_a, transpose_b) + + relevant_configs = best_configs[(@. (best_configs[!, "transpose_a"] == transpose_a) & (best_configs[!, "transpose_b"] == transpose_b)), :] + + ratios = @. 100 * perf_ratio(relevant_configs.gemmkernels_times, relevant_configs.baseline_times) + ratios_lo = @. 100 * perf_ratio_lo(relevant_configs.gemmkernels_times, relevant_configs.baseline_times) + ratios_hi = @. 100 * perf_ratio_hi(relevant_configs.gemmkernels_times, relevant_configs.baseline_times) + + plot!(p, relevant_configs.N, ratios, ribbon=(ratios .- ratios_lo, ratios_hi .- ratios), label=label, markershape=markershapes[label], xscale=:log2) + end + + savefig(p, joinpath(@__DIR__, "$(name(device())).pdf")) +end + +function main() + @info "Starting WMMA tuning script for device $(name(device())) using $(nworkers()) workers..." + + # (0) Load configurations from disk, or generate them. + config_path = joinpath(@__DIR__, "configs.bin") + configs = nothing + if isfile(config_path) + @info "Loading configurations from disk..." + try + configs = open(config_path, "r") do io + deserialize(io) + end + @info "Loaded $(size(configs, 1)) configurations." + catch err + @error "Error while loading configurations from disk: $(sprint(Base.showerror, err)))" + mv(config_path, "$(config_path).broken") + end + end + if configs === nothing + # (1) Generate configurations. + @info "Generating configurations..." + configs = generate_configs() + @info "Generated $(size(configs, 1)) configurations." + + # (2) Filter configurations where we can determine upfront that they are unsupported. + @info "Filtering configurations that we know are unsupported a-priori..." + + for config_row in eachrow(configs) + try + cf = get_config(config_row) + catch err + if isa(err, GemmKernels.ConfigError) + config_row["category"] = "unsupported_config_pre_run" + else + rethrow() + end + end + end + + @info "Filtered $(counter(configs[!, "category"])["unsupported_config_pre_run"]) configurations." + + open(config_path, "w") do io + serialize(io, configs) + end + end + + # (3) Measure performance of configurations. + num_unknown = counter(configs[!, "category"])["unknown"] + p = Progress(num_unknown; desc="Parameter sweep", dt=1.0, showspeed=true) + + @info "Need to perform parameter sweep over $(num_unknown) configurations." + + channel = RemoteChannel(() -> Channel(), 1) + @sync begin + # measure each configuration in parallel + @async begin + @sync @distributed for i in 1:size(configs,1) + config_row = configs[i, :] + @info "Got configuration $i: $(NamedTuple(config_row))..." + + if config_row.category != "unknown" + continue + end + + config_row.category = "crashed" + + @info "Measuring configuration $(NamedTuple(config_row))..." + + start_time = Dates.now() + times, category = measure_config(config_row) + end_time = Dates.now() + + put!(channel, (i, start_time, end_time, category, times)) + end + + @info "Done with parameter sweep." + put!(channel, nothing) + end + + # process the results + @async begin + while true + data = take!(channel) + data === nothing && break + + try + # Update configuration + i, start_time, end_time, category, times = data + config_row = configs[i, :] + @info "Result for $(NamedTuple(config_row)): $(category) -- $(prettytime(times .* 1e9))" + + # Save results in case the process crashes. + config_row.times = times + config_row.category = category + open(config_path, "w") do io + serialize(io, configs) + end + + # Update progress bar + counter_dict_abs = Dict(counter(configs[!, "category"])) + counter_dict_rel = Dict(k => "$(round(100 * v / sum(values(counter_dict_abs)); sigdigits=3))%" for (k, v) in counter_dict_abs) + next!(p; showvalues=[ + (:N, config_row.N), + (:transpose, get_label(config_row.transpose_a, config_row.transpose_b)), + (:block_shape, (config_row.BLOCK_M, config_row.BLOCK_N, config_row.BLOCK_K)), + (:num_warps, (config_row.WARPS_M, config_row.WARPS_N)), + (:op_shape, (config_row.OP_M, config_row.OP_N, config_row.OP_K)), + (:kernel, config_row.kernel_str), + (:counters, counter_dict_abs), + (:counters_relative, counter_dict_rel), + (:last_result, "$(config_row.category) -- $(prettytime(config_row.times .* 1e9))"), + (:last_iteration_time, end_time - start_time) + ]) + catch err + bt = catch_backtrace() + log = sprint(Base.showerror, err) * sprint(Base.show_backtrace, bt) + @error "Error while updating progress bar: $log" + end + end + end + end + + # Save data for final iteration. + open(config_path, "w") do io + serialize(io, configs) + end + + # And load again, for good measure. + configs = open(config_path, "r") do io + deserialize(io) + end + + # (4) Select best configurations, and benchmark. + best_configs_path = joinpath(@__DIR__, "best-configs.bin") + best_configs = nothing + if isfile(best_configs_path) + try + @info "Loading best configurations from disk..." + best_configs = open(best_configs_path, "r") do io + deserialize(io) + end + catch err + @error "Error while loading best configurations from disk: $(sprint(Base.showerror, err)))" + mv(best_configs_path, "$(best_configs_path).broken") + end + end + if best_configs === nothing + @info "Benchmarking configurations for plot..." + best_configs = benchmark_best_configs(configs) + + open(best_configs_path, "w") do io + serialize(io, best_configs) + end + end + + + # (5) Plotting results + @info "Plotting results..." + plot_results(best_configs) +end + +if !isinteractive() && myid() == 1 + main() +end diff --git a/tuning/tune-wmma.sh b/tuning/tune-wmma.sh new file mode 100755 index 00000000..73341cc7 --- /dev/null +++ b/tuning/tune-wmma.sh @@ -0,0 +1,141 @@ +#!/usr/bin/env bash +set -Eeuo pipefail + +GPU_ID=0 +GPU_CLOCK=-1 +MEM_CLOCK=-1 + +usage() +{ + cat <&2 +Usage: $0 [OPTIONS] + +Tune WMMA Parameters. + +Options: +-h, --help Show this help. +-i id Specify which GPU to target. +-gc, --gpu-clock speed Change the frequency the GPU core clock is locked to + before benchmarking, in MHz (default the max frequency). +-mc, --memory-clock speed Change the frequency the GPU memory clock is locked to + before benchmarking, in MHz (default the max frequency). +EOF +} + +positional=() +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + usage; exit 0 + ;; + -i) + shift + GPU_ID=$1 + shift + ;; + -gc|--gpu-clock) + shift + GPU_CLOCK=$1 + shift + ;; + -mc|--memory-clock) + shift + MEM_CLOCK=$1 + shift + ;; + -*) + echo "Unknown command-line option '$1'." + echo "Try '$0 --help' for more information." + exit 1 + ;; + *) + positional+=("$1") + shift + ;; + esac +done +set -- "${positional[@]}" + +if [[ $# -ne 0 ]]; then + echo "Expected 0 positional arguments, but got $#." + echo "Try '$0 --help' for more information." + exit 1 +fi + +export CUDA_VISIBLE_DEVICES=$GPU_ID + +if [[ "$GPU_CLOCK" == "-1" ]]; then + GPU_CLOCK="$(nvidia-smi -i $GPU_ID --query-supported-clocks=graphics --format=csv | sort -rn | head -1 | cut -f1 -d' ')" +fi + +if [[ "$MEM_CLOCK" == "-1" ]]; then + MEM_CLOCK="$(nvidia-smi -i $GPU_ID --query-supported-clocks=memory --format=csv | sort -rn | head -1 | cut -f1 -d' ')" +fi + +echo "Locking GPU $GPU_ID clock speeds to $GPU_CLOCK MHz (GPU) / $MEM_CLOCK MHz (Mem)..." + +if ! nvidia-smi -i $GPU_ID --query-supported-clocks=graphics,memory --format=csv | grep -F "$GPU_CLOCK MHz, $MEM_CLOCK MHz"; then + echo "Unsupported combination of clock speeds!" + exit 1 +fi + +# Prompt for sudo +sudo -v &>/dev/null + +# Sudo keep-alive +while true; do + sleep 300 + sudo -n true + kill -0 "$$" || exit +done &> /dev/null & + +sudo nvidia-smi -i $GPU_ID -pm 1 +sudo nvidia-smi -i $GPU_ID --lock-gpu-clocks=$GPU_CLOCK +sudo nvidia-smi -i $GPU_ID --lock-memory-clocks=$MEM_CLOCK + +cd "$( dirname "${BASH_SOURCE[0]}" )" +cd .. + +echo "+++ :julia: Instantiating project" +julia --project -e 'using Pkg; Pkg.instantiate(); Pkg.precompile()' +julia --project=tuning -e 'using Pkg; Pkg.instantiate(); Pkg.precompile()' + +until julia --project=tuning -e ' + pushfirst!(LOAD_PATH, @__DIR__) + using CUDA, Distributed + + # determine how many workers to use + memory_usage = 5*2^30 + cpu_memory = Sys.free_memory() + gpu_memory = CUDA.available_memory() + workers = min( + floor(Int, cpu_memory / memory_usage), + floor(Int, gpu_memory / memory_usage), + Sys.CPU_THREADS + ) + println("+++ :julia: Tuning using $workers workers") + + # launch workers + using Distributed + env = [ + "JULIA_NUM_THREADS" => "1", + "OPENBLAS_NUM_THREADS" => "1", + "JULIA_CUDA_HARD_MEMORY_LIMIT" => string(memory_usage), + ] + exeflags = [ + "--project=$(Base.active_project())", + "--heap-size-hint=$memory_usage" + ] + addprocs(workers; exeflags, env) + + using Distributed + @everywhere pushfirst!(LOAD_PATH, @__DIR__) + @everywhere include("tuning/tune-wmma.jl")' "$@"; do + + echo "Tuning script crashed. Resuming after 1 second..." >&2 + sleep 1 +done + +echo "Unlocking GPU clock speeds..." +sudo nvidia-smi -i $GPU_ID --reset-gpu-clocks +sudo nvidia-smi -i $GPU_ID --reset-memory-clocks