Skip to content

Commit

Permalink
benchmark: for gemm and attention
Browse files Browse the repository at this point in the history
  • Loading branch information
skyleaworlder committed Jul 7, 2023
1 parent 4947342 commit 083e8bd
Showing 1 changed file with 55 additions and 0 deletions.
55 changes: 55 additions & 0 deletions benchmark/benchmark/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,58 @@ for rank in (2,), N in (128, 512, 2048,)
end
SUITE["upsample"]["nearest"]["$(rank+2)-N($N)"] = et_suite
end

########## gemm ############
SUITE["gemm"] = BenchmarkGroup()
for et in (Float32, Float64)
et_suite = BenchmarkGroup(
"gemm!" => BenchmarkGroup(),
"batched_gemm!" => BenchmarkGroup())
SUITE["gemm"][string(et)] = et_suite

# transA and transB are not of the main varaints.
# gemm! meets some memory problem, not included here.
input_items = [
(Val(false), Val(false), 'N', 'N', 80, 40, 100, et(1.0), et(0.0)),
(Val(false), Val(false), 'N', 'N', 512, 512, 128, et(0.5), et(1.0)),
(Val(false), Val(false), 'N', 'N', 1024, 1024, 1024, et(0.5), et(0.0)),
]
for (transA, transB, transA_ch, transB_ch, M, N, K, alpha, beta) in input_items
bA = ones(et, M, N, 1)
bB = ones(et, N, K, 1)
bC = zeros(et, M, K, 1)
et_suite["batched_gemm!"][
"trans($transA_ch,$transB_ch)-M($M)-N($N)-K($K)-alpha($alpha)-beta($beta)"
] = @benchmarkable NNlib.batched_gemm!(
$transA_ch, $transB_ch,
$alpha, $bA, $bB, $beta, $bC)
end
end


########## attention ############
SUITE["attention"] = BenchmarkGroup()
for et in (Float16, Float64)
et_suite = BenchmarkGroup(
"attention" => BenchmarkGroup(), "score" => BenchmarkGroup())
SUITE["attention"][string(et)] = et_suite

input_items = [
((8,6,1), (8,10,1), (4,10,1), nothing, 1),
((64,64,16), (64,64,16), (64,64,16), (64,64), 4),
((16,128,8), (16,512,8), (32,512,8), (512,128), 4),
]
for (q_sz, k_sz, v_sz, bias_sz, nheads) in input_items
q, q_score = rand(et, q_sz...), rand(et, 8, q_sz...)
k, k_score = rand(et, k_sz...), rand(et, 8, k_sz...)
v = rand(et, v_sz...)
bias = isnothing(bias_sz) ? nothing : rand(et, bias_sz...)
mask = isnothing(bias_sz) ? nothing : rand(Bool, bias_sz...)
et_suite["attention"][
"q($q_sz)-k($k_sz)-v($v_sz)-bias($bias_sz)-nheads($nheads)"
] = @benchmarkable dot_product_attention($q, $k, $v, $bias; nheads = $nheads)
et_suite["score"][
"q(8, $q_sz)-k(8, $k_sz)-bias($bias_sz)-nheads($nheads)"
] = @benchmarkable dot_product_attention_scores($q_score, $k_score, $bias; mask = $mask)
end
end

0 comments on commit 083e8bd

Please sign in to comment.